egrecho.core.teacher#
- class egrecho.core.teacher.Teacher[source]#
Bases:
ABC
Base class for a teacher that aims to provide access to criterion, metrics for the model, and configure training details.
Usually, the step details in the fit (train + validation) stage are related to the data and objectives (criterion), and further combined with the model’s
forward()
. A teacher class is designed to extract the training logics. This class adds a point to the model to do stepping. Key methods:training_step()
(must be implemented) andvalidation_step()
:detail step logics.
setup_model()
(must be implemented):an inference for building linked model, e.g., loss func, metrics can be configured here.
configure_optimizers()
(must be implemented):an inference for building linked model.
-
called in
self.model.setup
.build models dynamically or adjust something about them at the beginning of the fit stage.
see
setup()
.
See
egrecho.models.architecture.speaker.asv_task.SVTeacher
as an example use.Example:
class LitTeacher(Teacher): def __init__( self, num_classes, optimizer='adam', lr_scheduler='warm_cosine', lr=0.01, ): super().__init__() self.num_classes = num_classes self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.lr = lr def setup_model(self): self.setup_loss_fn_dict({"loss": F.cross_entropy}) def configure_optimizers(self): return self.configure_single_optimizer( self.optimizer, self.lr_scheduler, self.lr ) def training_step(self, batch, batch_idx): some codes ... class LitModel(TopVirtualModule): def __init__(self): super().__init__() self.l1 = nn.Linear(28, 1000) self.classifier = nn.Linear(1000, 42) model = LitModel() teacher = LitTeacher(42) # link teacher & pre-configure training details. model.setup_teacher(teacher) dataloader = ... trainer = pl.Trainer(...) trainer.fit(model, dataloader)
Note
In this class,
setup_loss_fn_dict()
,setup_train_metrics()
are modified from flash. Feel free to modify them in derived implementations.- Class attributes (overridden by derived classes):
task_name (str) -- name of task, e.g., “automatic-speaker-verification”.
- setup()[source]#
Called by linked model’s hook:
setup()
in fit stage, gives a chance to setup in accelerate environment.
- setup_train_metrics(metrics)[source]#
Use one attr train_metrics in model as metrics(recommand torchmetrics) container.
- setup_val_metrics(metrics)[source]#
Use one attr val_metrics in model as metrics(recommand torchmetrics) container.
- on_train_start()[source]#
Called at the very beginning of train.
If on DDP it is called on every process
- Return type:
None
- on_train_end()[source]#
Called at the very end of train.
If on DDP it is called on every process
- Return type:
None
- classmethod available_optimizers()[source]#
Returns a list containing the keys of the available Optimizers.
- Return type:
List
[str
]
- classmethod available_lr_schedulers()[source]#
Returns a list containing the keys of the available LR schedulers.
- Return type:
List
[str
]
- configure_optimizers()[source]#
Implement this method in subclasses. see
lightning.pytorch.LightningModule.configure_optimizers()
.
- configure_single_optimizer(optimizer, lr_scheduler=None, learning_rate=None)[source]#
Implement how optimizer and optionally learning rate schedulers should be configured.
- Return type:
Union
[Optimizer
,Tuple
[List
[Optimizer
],List
[_LRScheduler
]]]
- instantiate_lr_scheduler(optimizer, lr_scheduler)[source]#
Initiates lr_scheduler to lighting’s lr_scheduler config.
- Return type:
Dict
[str
,Any
]
- classmethod get_lr_scheduler_total_steps_name(lr_scheduler_key)[source]#
Try to get the num of training steps key name for lr_scheduler if needed. :rtype:
Optional
[str
]Use the metadata
total_steps_key=...
registed in registry of that scheduler key.Find the signature of registry
fn
, return the signature param key which is inlr_total_steps_key_registry
of this teacher class.Return None if all faield.
you can dynamicly register key in lr_total_steps_key_registry.
- configure_teacher_callbacks()[source]#
Configure teacher-specific callbacks.
Manually call this funtion to get callbacks and add them to your trainer.
Example:
class LitTeacher(Teacher): def __init__(self): super().__init__() def setup_model(self): self.model.classifier = nn.Linear(1000, 42) def configure_teacher_callbacks(self): return [PrintCallback()] def training_step(self, batch, batch_idx): pass class PrintCallback(Callback): def on_train_start(self, trainer, pl_module): print("Training is started!") def on_train_end(self, trainer, pl_module): print("Training is done.") class LitModel(TopVirtualModule): def __init__(self, teacher): super().__init__() self.l1 = None self.teacher = teacher self.teacher.setup_model() teacher = LitTeacher() model = LitModel(teacher) t_callbacks = model.configure_teacher_callbacks() trainer = Trainer(accelerator="gpu", devices=2, callbacks=t_callbacks)
- property estimated_num_steps_per_epoch: int | float#
The estimated number of steps that will
optimizer.step()
during training in one epoch.This accounts for gradient accumulation and the current trainer configuration. This might sets up your training dataloader if hadn’t been set up already.
def configure_optimizers(self): optimizer = ... stepping_batches = self.trainer.estimated_num_steps_per_epoch num_epoch = 10 scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches * num_epoch) return [optimizer], [scheduler]
- Raises:
MisconfigurationException -- If estimated stepping batches cannot be computed due to different accumulate_grad_batches at different epochs.