egrecho.core.module#

class egrecho.core.module.TopVirtualModel(config, *args, **kwargs)[source]#

Bases: LightningModule, GenericFileMixin

A lightning module which is related to training, val, test.

In fit (train + validate) stage, you need to set self.teacher, where configures step logics, dataloaders, criterion, etc.

setup(stage)[source]#

Hook of lightning.pytorch.core.hooks.DataHooks.setup().

Called this at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str) -- either 'fit', 'validate', 'test', or 'predict'

Return type:

None

Example

>>> class LitModel(TopVirtualModule):
...    def __init__(self):
...         super().__init__()
...         self.l1 = None
...
...     def setup(self, stage):
...         if stage == 'fit':
...             self.l1 = nn.Linear(28, 1000)
...             self.teacher.setup_model()
...
...
>>> class LitTeacher(Teacher):
...     def __init__(self):
...         super().__init__()
...
...     def setup_model(self):
...         self.model.classifier = nn.Linear(1000, 42)
...
...     def training_step(self, batch, batch_idx):
...         pass
>>> model = LitModel()
>>> teacher = LitTeacher()
>>> model.teacher = teacher
>>> model.setup("fit")
>>> assert model.l1 is not None
>>> assert model.classifier is not None
training_step(*args, **kwargs)[source]#

Redirection to training_step() in teacher.

validation_step(*args, **kwargs)[source]#

Redirection to training_step() in teacher.

classmethod pipeline_out(model_out)[source]#

Transform output (forward()) to dict for pipeline. write it for your specify model.

Return type:

Dict

on_train_start()[source]#

Called at the very beginning of train.

If on DDP it is called on every process

Return type:

None

on_train_end()[source]#

Called at the very end of train.

If on DDP it is called on every process

Return type:

None

configure_optimizers()[source]#

Redirection to configure_optimizers() in teacher.

configure_teacher_callbacks()[source]#

Configure teacher-specific callbacks.

Manually call this funtion to get callbacks and add them to your trainer.

Example::
>>> class LitTeacher(Teacher):
...     def __init__(self):
...         super().__init__()
...
...     def setup_model(self):
...         self.model.classifier = nn.Linear(1000, 42)
...
...     def configure_teacher_callbacks(self):
...         return [PrintCallback()]
...
...     def training_step(self, batch, batch_idx):
...         pass
...
...
>>> class PrintCallback(Callback):
...     def on_train_start(self, trainer, pl_module):
...         print("Training is started!")
...
...     def on_train_end(self, trainer, pl_module):
...         print("Training is done.")
...
...
>>> class LitModel(TopVirtualModule):
...    def __init__(self, teacher):
...         super().__init__()
...         self.l1 = None
...         self.teacher = teacher
...         self.teacher.setup_model()
...
>>> teacher = LitTeacher()
>>> model = LitModel(teacher)
>>> t_callbacks = model.configure_teacher_callbacks()
>>> trainer = Trainer(accelerator="gpu", devices=2, callbacks=t_callbacks)
classmethod from_pretrained(checkpoint_path=None, map_location='cpu', hparams_file=None, ignore_mismatched_sizes=False, strict=True, **kwargs)[source]#

Load pretrained model from checkpoint.

This is raw now, to be implemented.

Parameters:
  • checkpoint_path (Optional[str]) -- Path to checkpoint. This can also be a URL, or file-like object.

  • map_location (Optional[Any]) --

    MAP_LOCATION_TYPE as in torch.load(). Defaults to ‘cpu’.

    If you preferring to load a checkpoint saved a GPU model to GPU, set it to None (not move to another GPU) or set a specified device.

  • hparams_file (Union[str, Path, None]) --

    Path or str, optional Path to a .yaml file with hierarchical structure as in this example:

    num_classes: 5994
    config:
        channels: 1024
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights do not have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you would like to use. These will be converted into a dict and passed into your Model for use.

  • ignore_mismatched_sizes (bool) -- bool Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels). Defaults to False.

  • strict (bool) -- bool, optional Whether to strictly enforce that the keys in checkpoint match the keys returned by this module’s state dict. Defaults to True.

  • kwargs -- optional Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Return type:

TopVirtualModel

save_to(savedir, **kwargs)[source]#

Saves to a directory.

Parameters:

Example:

from egrecho.models.ecapa.model import EcapaModel
from egrecho.data.features.feature_extractor_audio import KaldiFeatureExtractor
extractor = KaldiFeatureExtractor()
model = EcapaModel()
dirpath = 'testdir/ecapa'
model.save_to(dirpath, components=extractor)
$ tree testdir/ecapa
testdir/ecapa/
├── config
│   ├── feature_config.yaml
│   ├── model_config.yaml
│   └── types.yaml
└── model_weight.ckpt
model = EcapaModel.fetch_from(dirpath)
assert isinstance(model,EcapaModel)
# fetch extractor
hresults = model.save_load_helper.fetch_from(dirpath, skip_keys='model')
assert isinstance(hresults.feature_extractor, KaldiFeatureExtractor)
# base model instantiate.
model = TopVirtualModel.fetch_from(dirpath)
assert isinstance(model,EcapaModel)
# now remove types.yaml
# rm -f testdir/ecapa/config/types.yaml
model = TopVirtualModel.fetch_from(dirpath)
# Error instantiate TopVirtualModel
# Let's complete the model type
model_cls = 'egrecho.models.ecapa.model.EcapaModel'
model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls)
assert isinstance(model, EcapaModel)
# Type is ok
model_cls = EcapaModel
model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls)
assert isinstance(model, EcapaModel)
# classname string is ok as EcapaModel is already imported
model_cls = 'EcapaModel'
model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls)
assert isinstance(model, EcapaModel)
model_cls = 'Valle'
# Error as 'Valle' is not registed.
model = TopVirtualModel.fetch_from(
    dirpath,
    _cls_=model_cls,
    init_weight="random",
    config_fname='anyinvalid.yaml',
    config=None,
)  # only load model without weight and eliminate the model_config.yaml of Ecapa model directory
from egrecho.models.valle.model import Valle
# Try again.
model = TopVirtualModel.fetch_from(
    dirpath,
    _cls_=model_cls,
    init_weight="random",
    config_fname='???.yaml',
    config=None,
)
assert isinstance(model, Valle)
classmethod fetch_from(dirpath, config=None, init_weight='pretrained', map_location='cpu', strict=True, save_load_helper=None, **kwargs)[source]#

Fetch pretrained from a directory.

Parameters:
  • dirpath -- srcdir

  • config (Union[str, Path, Dict[str, Any], None]) -- config path/dict which could override the underlying model init cfg (model_config.yaml).

  • init_weight (Union[Literal['pretrained', 'random'], None, str]) -- Init weight from (‘pretrained’|’random’), or string ckpt name (model_weight.ckpt) or full path to ckpt /path/to/model_weight.ckpt. Default: 'pretrained'.

  • map_location (Optional[device]) -- MAP_LOCATION_TYPE as in torch.load(). Defaults to ‘cpu’.

  • strict (bool) -- Whether to strictly enforce that the keys in checkpoint match the keys returned by this module’s state dict. Defaults: True

  • save_load_helper (Optional[SaveLoadHelper]) -- obj of save_load_helper Default: None, which will initiate a default SaveLoadHelper.

  • **kwargs (Dict[str,Any]) -- additional parameters of model cfg.

Example:

from egrecho.models.ecapa.model import EcapaModel
from egrecho.data.features.feature_extractor_audio import KaldiFeatureExtractor
extractor = KaldiFeatureExtractor()
model = EcapaModel()
dirpath = 'testdir/ecapa'
model.save_to(dirpath, components=extractor)
$ tree testdir/ecapa
testdir/ecapa/
├── config
│   ├── feature_config.yaml
│   ├── model_config.yaml
│   └── types.yaml
└── model_weight.ckpt
model = EcapaModel.fetch_from(dirpath)
assert isinstance(model,EcapaModel)
# fetch extractor
hresults = model.save_load_helper.fetch_from(dirpath, skip_keys='model')
assert isinstance(hresults.feature_extractor, KaldiFeatureExtractor)
# base model instantiate.
model = TopVirtualModel.fetch_from(dirpath)
assert isinstance(model,EcapaModel)
# now remove types.yaml
# rm -f testdir/ecapa/config/types.yaml
model = TopVirtualModel.fetch_from(dirpath)
# Error instantiate TopVirtualModel
# Let's complete the model type
model_cls = 'egrecho.models.ecapa.model.EcapaModel'
model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls)
assert isinstance(model, EcapaModel)
# Type is ok
model_cls = EcapaModel
model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls)
assert isinstance(model, EcapaModel)
# classname string is ok as EcapaModel is already imported
model_cls = 'EcapaModel'
model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls)
assert isinstance(model, EcapaModel)
model_cls = 'Valle'
# Error as 'Valle' is not registed.
model = TopVirtualModel.fetch_from(
    dirpath,
    _cls_=model_cls,
    init_weight="random",
    config_fname='anyinvalid.yaml',
    config=None,
)  # only load model without weight and eliminate the model_config.yaml of Ecapa model directory
from egrecho.models.valle.model import Valle
# Try again.
model = TopVirtualModel.fetch_from(
    dirpath,
    _cls_=model_cls,
    init_weight="random",
    config_fname='???.yaml',
    config=None,
)
assert isinstance(model, Valle)
auto(layer, x)[source]#

It is convenient for forward-computing when layer could be None or not

export_onnx(file_path, input_sample=None, **kwargs)[source]#

Exports the model in ONNX format in tracing mode.

Parameters:
  • file_path (Union[str, Path]) -- The path of the file the onnx model should be saved to.

  • input_sample (Optional[Any]) -- An input for tracing. Default: None (Use self.example_input_array)

  • **kwargs (Any) -- Will be passed to torch.onnx.export().

Return type:

None

Note

This general method may not appropriate for every model, you can override it for your specify model. If you want a Scripting onnx model, you should

Example:

class SimpleModel(TopVirtualModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

model = SimpleModel()
input_sample = torch.randn(1, 64)
model.export_onnx("export.onnx", input_sample, export_params=True)
export_jit(file_path=None, method='script', input_sample=None, **kwargs)[source]#

Exports the model to a TorchScript representation for inference or saving.

By default, compiles the entire model to a ScriptModule. If you prefer to use tracing, provide the argument method='trace' and ensure that either the input_sample argument is provided or the model has example_input_array set for tracing. To customize which modules are scripted, you can override this method. To return multiple modules, use a dictionary.

Parameters:
  • file_path (Optional[Union[str, Path]]) -- Path to save the TorchScript representation. Default: None (no file saved).

  • method (Optional[str]) -- Choose between ‘script’ (default) and ‘trace’ for TorchScript compilation methods.

  • input_sample (Optional[Any]) -- An input to be used for tracing when method is set to ‘trace’.

  • Default -- None (uses example_input_array) if available.

  • **kwargs (Any) -- Additional arguments passed to torch.jit.script() or torch.jit.trace().

Note

  • The exported script will be set to evaluation mode.

  • It is recommended to install the latest supported version of PyTorch for using this feature without limitations.

Refer to the torch.jit documentation for supported features.

Example:

class SimpleModel(TopVirtualModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

model = SimpleModel()
model.export_jit("exported_model.pt")
Returns:

The converted TorchScript representation.

Return type:

Union[ScriptModule, Dict[str, ScriptModule]]

class egrecho.core.module.DataModule(builder, batch_size=None, num_workers=0, prefetch_factor=None, val_num_workers=0, pin_memory=True, fullsync=True, **extra_dl_kwargs)[source]#

Bases: LightningDataModule

A simple lightning datamodule wrapper for dataloader.

The iterable dataset in IterabelDatasetWrapper auto sharding samples in different ranks, we should load the dataset in hook: setup() as this hook is called on every process when using DDP.

Parameters:
  • builder (DataBuilder) -- The data builder instance of DataBuilder responsible for creating the dataset.

  • batch_size (Optional[int]) -- The batch size for DataLoader. Default is None for iterable dataset.

  • num_workers (int) -- The number of workers for DataLoader. Default is 0.

  • prefetch_factor (Optional[int]) -- The prefetch factor for DataLoader. Default is None.

  • val_num_workers (Optional[int]) -- The number of workers for validation DataLoader. Defaults to 0. If set None it will use the same number of workers as num_workers.

  • pin_memory (bool) -- Whether to pin memory in DataLoader. Default is True.

  • fullsync (bool) -- Whether to use SyncDataLoader. Default is True.

  • **extra_dl_kwargs -- Additional keyword arguments to pass to DataLoader.

setup_data()[source]#

Builds datasets and assigns dataloader func to lightning datamodule.

Return type:

None

property train_dataset: Dataset | None#

This property returns the train dataset.

property val_dataset: Dataset | None#

This property returns the validation dataset.

property test_dataset: Dataset | None#

This property returns the test dataset.

class egrecho.core.model_base.ModuleConfig[source]#

Bases: DataclassConfig

Base class for model configuration.

class egrecho.core.model_base.ModelBase(*args, **kwargs)[source]#

Bases: ModuleUtilMixin, Module

A virtual backbone that provides common utilities.

Its implementation is used to aggregate components of submodules to a model.

post_init()[source]#

Gives a chance to perform additional operations at the end of the model’s initialization process.

property dummy_inputs: Tensor | Tuple | Dict | None#

Dummy inputs to do a forward pass in the network.

The return type is interpreted as follows:

  • Single tensor: It is assumed the model takes a single argument, i.e., model.forward(model.dummy_inputs).

  • Tuple: The inputs is interpreted as a sequence of positional arguments, i.e., model.forward(*model.dummy_inputs).

  • Dict: The input array represents named keyword arguments, i.e., model.forward(**model.dummy_inputs).

export_onnx(file_path, input_sample=None, **kwargs)[source]#

Exports the model in ONNX format in tracing mode.

Parameters:
  • file_path (Union[str, Path]) -- The path of the file the onnx model should be saved to.

  • input_sample (Optional[Any]) -- An input for tracing. Default: None (Use dummy_inputs)

  • **kwargs (Any) -- Will be passed to torch.onnx.export().

Return type:

None

Note

This general method may not appropriate for every model, you can override it for your specify model.

Example:

class SimpleModel(TopVirtualModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

model = SimpleModel()
input_sample = torch.randn(1, 64)
model.export_onnx("export.onnx", input_sample, export_params=True)
export_jit(file_path=None, method='script', input_sample=None, **kwargs)[source]#

Exports the model to a TorchScript representation for inference or saving.

By default, compiles the entire model to a ScriptModule. If you prefer to use tracing, provide the argument method='trace' and ensure that either the input_sample argument is provided or the model has dummy_inputs set for tracing. To customize which modules are scripted, you can override this method. To return multiple modules, use a dictionary.

Parameters:
  • file_path (Optional[Union[str, Path]]) -- Path to save the TorchScript representation. Default: None (no file saved).

  • method (Optional[str]) -- Choose between ‘script’ (default) and ‘trace’ for TorchScript compilation methods.

  • input_sample (Optional[Any]) -- An input to be used for tracing when method is set to ‘trace’. Default: None (uses dummy_inputs) if available.

  • **kwargs (Any) -- Additional arguments passed to torch.jit.script() or torch.jit.trace().

Note

  • The exported script will be set to evaluation mode.

  • It is recommended to install the latest supported version of PyTorch for using this feature without limitations.

Refer to the pytorch torch.jit documentation for supported features.

Example:

class SimpleModel(TopVirtualModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

model = SimpleModel()
model.export_jit("exported_model.pt")
Returns:

The converted TorchScript representation.

Return type:

Union[ScriptModule, Dict[str, ScriptModule]]