egrecho.core.module#
- class egrecho.core.module.TopVirtualModel(config, *args, **kwargs)[source]#
Bases:
LightningModule
,GenericFileMixin
A lightning module which is related to training, val, test.
In fit (train + validate) stage, you need to set
self.teacher
, where configures step logics, dataloaders, criterion, etc.- setup(stage)[source]#
Hook of
lightning.pytorch.core.hooks.DataHooks.setup()
.Called this at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (
str
) -- either'fit'
,'validate'
,'test'
, or'predict'
- Return type:
None
Example
>>> class LitModel(TopVirtualModule): ... def __init__(self): ... super().__init__() ... self.l1 = None ... ... def setup(self, stage): ... if stage == 'fit': ... self.l1 = nn.Linear(28, 1000) ... self.teacher.setup_model() ... ... >>> class LitTeacher(Teacher): ... def __init__(self): ... super().__init__() ... ... def setup_model(self): ... self.model.classifier = nn.Linear(1000, 42) ... ... def training_step(self, batch, batch_idx): ... pass
>>> model = LitModel() >>> teacher = LitTeacher() >>> model.teacher = teacher >>> model.setup("fit") >>> assert model.l1 is not None >>> assert model.classifier is not None
- training_step(*args, **kwargs)[source]#
Redirection to
training_step()
in teacher.
- validation_step(*args, **kwargs)[source]#
Redirection to
training_step()
in teacher.
- classmethod pipeline_out(model_out)[source]#
Transform output (
forward()
) to dict for pipeline. write it for your specify model.- Return type:
Dict
- on_train_start()[source]#
Called at the very beginning of train.
If on DDP it is called on every process
- Return type:
None
- on_train_end()[source]#
Called at the very end of train.
If on DDP it is called on every process
- Return type:
None
- configure_optimizers()[source]#
Redirection to
configure_optimizers()
in teacher.
- configure_teacher_callbacks()[source]#
Configure teacher-specific callbacks.
Manually call this funtion to get callbacks and add them to your trainer.
- Example::
>>> class LitTeacher(Teacher): ... def __init__(self): ... super().__init__() ... ... def setup_model(self): ... self.model.classifier = nn.Linear(1000, 42) ... ... def configure_teacher_callbacks(self): ... return [PrintCallback()] ... ... def training_step(self, batch, batch_idx): ... pass ... ... >>> class PrintCallback(Callback): ... def on_train_start(self, trainer, pl_module): ... print("Training is started!") ... ... def on_train_end(self, trainer, pl_module): ... print("Training is done.") ... ... >>> class LitModel(TopVirtualModule): ... def __init__(self, teacher): ... super().__init__() ... self.l1 = None ... self.teacher = teacher ... self.teacher.setup_model() ... >>> teacher = LitTeacher() >>> model = LitModel(teacher) >>> t_callbacks = model.configure_teacher_callbacks() >>> trainer = Trainer(accelerator="gpu", devices=2, callbacks=t_callbacks)
- classmethod from_pretrained(checkpoint_path=None, map_location='cpu', hparams_file=None, ignore_mismatched_sizes=False, strict=True, **kwargs)[source]#
Load pretrained model from checkpoint.
This is raw now, to be implemented.
- Parameters:
checkpoint_path (
Optional
[str
]) -- Path to checkpoint. This can also be a URL, or file-like object.map_location (
Optional
[Any
]) --MAP_LOCATION_TYPE as in torch.load(). Defaults to ‘cpu’.
If you preferring to load a checkpoint saved a GPU model to GPU, set it to None (not move to another GPU) or set a specified device.
hparams_file (
Union
[str
,Path
,None
]) --Path or str, optional Path to a .yaml file with hierarchical structure as in this example:
num_classes: 5994 config: channels: 1024
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights do not have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you would like to use. These will be converted into a dict and passed into your Model for use.
ignore_mismatched_sizes (
bool
) -- bool Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels). Defaults to False.strict (
bool
) -- bool, optional Whether to strictly enforce that the keys in checkpoint match the keys returned by this module’s state dict. Defaults to True.kwargs -- optional Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Return type:
- save_to(savedir, **kwargs)[source]#
Saves to a directory.
- Parameters:
savedir -- path
**kwargs -- args passing to
save_to()
of object ofSaveLoadHelper
.
Example:
from egrecho.models.ecapa.model import EcapaModel from egrecho.data.features.feature_extractor_audio import KaldiFeatureExtractor extractor = KaldiFeatureExtractor() model = EcapaModel() dirpath = 'testdir/ecapa' model.save_to(dirpath, components=extractor)
$ tree testdir/ecapa testdir/ecapa/ ├── config │ ├── feature_config.yaml │ ├── model_config.yaml │ └── types.yaml └── model_weight.ckpt
model = EcapaModel.fetch_from(dirpath) assert isinstance(model,EcapaModel) # fetch extractor hresults = model.save_load_helper.fetch_from(dirpath, skip_keys='model') assert isinstance(hresults.feature_extractor, KaldiFeatureExtractor) # base model instantiate. model = TopVirtualModel.fetch_from(dirpath) assert isinstance(model,EcapaModel) # now remove types.yaml # rm -f testdir/ecapa/config/types.yaml model = TopVirtualModel.fetch_from(dirpath) # Error instantiate TopVirtualModel # Let's complete the model type model_cls = 'egrecho.models.ecapa.model.EcapaModel' model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls) assert isinstance(model, EcapaModel) # Type is ok model_cls = EcapaModel model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls) assert isinstance(model, EcapaModel) # classname string is ok as EcapaModel is already imported model_cls = 'EcapaModel' model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls) assert isinstance(model, EcapaModel) model_cls = 'Valle' # Error as 'Valle' is not registed. model = TopVirtualModel.fetch_from( dirpath, _cls_=model_cls, init_weight="random", config_fname='anyinvalid.yaml', config=None, ) # only load model without weight and eliminate the model_config.yaml of Ecapa model directory from egrecho.models.valle.model import Valle # Try again. model = TopVirtualModel.fetch_from( dirpath, _cls_=model_cls, init_weight="random", config_fname='???.yaml', config=None, ) assert isinstance(model, Valle)
- classmethod fetch_from(dirpath, config=None, init_weight='pretrained', map_location='cpu', strict=True, save_load_helper=None, **kwargs)[source]#
Fetch pretrained from a directory.
- Parameters:
dirpath -- srcdir
config (
Union
[str
,Path
,Dict
[str
,Any
],None
]) -- config path/dict which could override the underlying model init cfg (model_config.yaml).init_weight (
Union
[Literal
['pretrained'
,'random'
],None
,str
]) -- Init weight from (‘pretrained’|’random’), or string ckpt name (model_weight.ckpt) or full path to ckpt /path/to/model_weight.ckpt. Default:'pretrained'
.map_location (
Optional
[device
]) -- MAP_LOCATION_TYPE as in torch.load(). Defaults to ‘cpu’.strict (
bool
) -- Whether to strictly enforce that the keys in checkpoint match the keys returned by this module’s state dict. Defaults:True
save_load_helper (
Optional
[SaveLoadHelper
]) -- obj of save_load_helper Default:None
, which will initiate a defaultSaveLoadHelper
.**kwargs (Dict[str,Any]) -- additional parameters of model cfg.
Example:
from egrecho.models.ecapa.model import EcapaModel from egrecho.data.features.feature_extractor_audio import KaldiFeatureExtractor extractor = KaldiFeatureExtractor() model = EcapaModel() dirpath = 'testdir/ecapa' model.save_to(dirpath, components=extractor)
$ tree testdir/ecapa testdir/ecapa/ ├── config │ ├── feature_config.yaml │ ├── model_config.yaml │ └── types.yaml └── model_weight.ckpt
model = EcapaModel.fetch_from(dirpath) assert isinstance(model,EcapaModel) # fetch extractor hresults = model.save_load_helper.fetch_from(dirpath, skip_keys='model') assert isinstance(hresults.feature_extractor, KaldiFeatureExtractor) # base model instantiate. model = TopVirtualModel.fetch_from(dirpath) assert isinstance(model,EcapaModel) # now remove types.yaml # rm -f testdir/ecapa/config/types.yaml model = TopVirtualModel.fetch_from(dirpath) # Error instantiate TopVirtualModel # Let's complete the model type model_cls = 'egrecho.models.ecapa.model.EcapaModel' model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls) assert isinstance(model, EcapaModel) # Type is ok model_cls = EcapaModel model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls) assert isinstance(model, EcapaModel) # classname string is ok as EcapaModel is already imported model_cls = 'EcapaModel' model = TopVirtualModel.fetch_from(dirpath, _cls_=model_cls) assert isinstance(model, EcapaModel) model_cls = 'Valle' # Error as 'Valle' is not registed. model = TopVirtualModel.fetch_from( dirpath, _cls_=model_cls, init_weight="random", config_fname='anyinvalid.yaml', config=None, ) # only load model without weight and eliminate the model_config.yaml of Ecapa model directory from egrecho.models.valle.model import Valle # Try again. model = TopVirtualModel.fetch_from( dirpath, _cls_=model_cls, init_weight="random", config_fname='???.yaml', config=None, ) assert isinstance(model, Valle)
- export_onnx(file_path, input_sample=None, **kwargs)[source]#
Exports the model in ONNX format in tracing mode.
- Parameters:
file_path (
Union
[str
,Path
]) -- The path of the file the onnx model should be saved to.input_sample (
Optional
[Any
]) -- An input for tracing. Default: None (Use self.example_input_array)**kwargs (
Any
) -- Will be passed totorch.onnx.export()
.
- Return type:
None
Note
This general method may not appropriate for every model, you can override it for your specify model. If you want a Scripting onnx model, you should
Example:
class SimpleModel(TopVirtualModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) model = SimpleModel() input_sample = torch.randn(1, 64) model.export_onnx("export.onnx", input_sample, export_params=True)
- export_jit(file_path=None, method='script', input_sample=None, **kwargs)[source]#
Exports the model to a TorchScript representation for inference or saving.
By default, compiles the entire model to a
ScriptModule
. If you prefer to use tracing, provide the argumentmethod='trace'
and ensure that either theinput_sample
argument is provided or the model hasexample_input_array
set for tracing. To customize which modules are scripted, you can override this method. To return multiple modules, use a dictionary.- Parameters:
file_path (Optional[Union[str, Path]]) -- Path to save the TorchScript representation. Default: None (no file saved).
method (Optional[str]) -- Choose between ‘script’ (default) and ‘trace’ for TorchScript compilation methods.
input_sample (Optional[Any]) -- An input to be used for tracing when method is set to ‘trace’.
Default -- None (uses
example_input_array
) if available.**kwargs (Any) -- Additional arguments passed to
torch.jit.script()
ortorch.jit.trace()
.
Note
The exported script will be set to evaluation mode.
It is recommended to install the latest supported version of PyTorch for using this feature without limitations.
Refer to the
torch.jit
documentation for supported features.Example:
class SimpleModel(TopVirtualModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) model = SimpleModel() model.export_jit("exported_model.pt")
- Returns:
The converted TorchScript representation.
- Return type:
Union[ScriptModule, Dict[str, ScriptModule]]
- class egrecho.core.module.DataModule(builder, batch_size=None, num_workers=0, prefetch_factor=None, val_num_workers=0, pin_memory=True, fullsync=True, **extra_dl_kwargs)[source]#
Bases:
LightningDataModule
A simple lightning datamodule wrapper for dataloader.
The iterable dataset in
IterabelDatasetWrapper
auto sharding samples in different ranks, we should load the dataset in hook:setup()
as this hook is called on every process when using DDP.- Parameters:
builder (DataBuilder) -- The data builder instance of
DataBuilder
responsible for creating the dataset.batch_size (Optional[int]) -- The batch size for DataLoader. Default is None for iterable dataset.
num_workers (int) -- The number of workers for DataLoader. Default is 0.
prefetch_factor (Optional[int]) -- The prefetch factor for DataLoader. Default is None.
val_num_workers (Optional[int]) -- The number of workers for validation DataLoader. Defaults to 0. If set None it will use the same number of workers as num_workers.
pin_memory (bool) -- Whether to pin memory in DataLoader. Default is True.
fullsync (bool) -- Whether to use
SyncDataLoader
. Default is True.**extra_dl_kwargs -- Additional keyword arguments to pass to DataLoader.
- setup_data()[source]#
Builds datasets and assigns dataloader func to lightning datamodule.
- Return type:
None
- property train_dataset: Dataset | None#
This property returns the train dataset.
- property val_dataset: Dataset | None#
This property returns the validation dataset.
- property test_dataset: Dataset | None#
This property returns the test dataset.
- class egrecho.core.model_base.ModuleConfig[source]#
Bases:
DataclassConfig
Base class for model configuration.
- class egrecho.core.model_base.ModelBase(*args, **kwargs)[source]#
Bases:
ModuleUtilMixin
,Module
A virtual backbone that provides common utilities.
Its implementation is used to aggregate components of submodules to a model.
- post_init()[source]#
Gives a chance to perform additional operations at the end of the model’s initialization process.
- property dummy_inputs: Tensor | Tuple | Dict | None#
Dummy inputs to do a forward pass in the network.
The return type is interpreted as follows:
Single tensor: It is assumed the model takes a single argument, i.e.,
model.forward(model.dummy_inputs)
.Tuple: The inputs is interpreted as a sequence of positional arguments, i.e.,
model.forward(*model.dummy_inputs)
.Dict: The input array represents named keyword arguments, i.e.,
model.forward(**model.dummy_inputs)
.
- export_onnx(file_path, input_sample=None, **kwargs)[source]#
Exports the model in ONNX format in tracing mode.
- Parameters:
file_path (
Union
[str
,Path
]) -- The path of the file the onnx model should be saved to.input_sample (
Optional
[Any
]) -- An input for tracing. Default: None (Usedummy_inputs
)**kwargs (
Any
) -- Will be passed totorch.onnx.export()
.
- Return type:
None
Note
This general method may not appropriate for every model, you can override it for your specify model.
Example:
class SimpleModel(TopVirtualModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) model = SimpleModel() input_sample = torch.randn(1, 64) model.export_onnx("export.onnx", input_sample, export_params=True)
- export_jit(file_path=None, method='script', input_sample=None, **kwargs)[source]#
Exports the model to a TorchScript representation for inference or saving.
By default, compiles the entire model to a
ScriptModule
. If you prefer to use tracing, provide the argumentmethod='trace'
and ensure that either theinput_sample
argument is provided or the model hasdummy_inputs
set for tracing. To customize which modules are scripted, you can override this method. To return multiple modules, use a dictionary.- Parameters:
file_path (Optional[Union[str, Path]]) -- Path to save the TorchScript representation. Default: None (no file saved).
method (Optional[str]) -- Choose between ‘script’ (default) and ‘trace’ for TorchScript compilation methods.
input_sample (Optional[Any]) -- An input to be used for tracing when method is set to ‘trace’. Default: None (uses
dummy_inputs
) if available.**kwargs (Any) -- Additional arguments passed to
torch.jit.script()
ortorch.jit.trace()
.
Note
The exported script will be set to evaluation mode.
It is recommended to install the latest supported version of PyTorch for using this feature without limitations.
Refer to the pytorch
torch.jit
documentation for supported features.Example:
class SimpleModel(TopVirtualModule): def __init__(self): super().__init__() self.l1 = torch.nn.Linear(in_features=64, out_features=4) def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) model = SimpleModel() model.export_jit("exported_model.pt")
- Returns:
The converted TorchScript representation.
- Return type:
Union[ScriptModule, Dict[str, ScriptModule]]