egrecho.core.teacher#

class egrecho.core.teacher.Teacher[source]#

Bases: ABC

Base class for a teacher that aims to provide access to criterion, metrics for the model, and configure training details.

Usually, the step details in the fit (train + validation) stage are related to the data and objectives (criterion), and further combined with the model’s forward(). A teacher class is designed to extract the training logics. This class adds a point to the model to do stepping. Key methods:

  • training_step() (must be implemented) and validation_step():

    detail step logics.

  • setup_model() (must be implemented):

    an inference for building linked model, e.g., loss func, metrics can be configured here.

  • configure_optimizers() (must be implemented):

    an inference for building linked model.

  • setup():

    • called in self.model.setup.

    • build models dynamically or adjust something about them at the beginning of the fit stage.

    see setup().

See egrecho.models.architecture.speaker.asv_task.SVTeacher as an example use.

Example:

class LitTeacher(Teacher):
    def __init__(
        self,
        num_classes,
        optimizer='adam',
        lr_scheduler='warm_cosine',
        lr=0.01,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.lr = lr

    def setup_model(self):
        self.setup_loss_fn_dict({"loss": F.cross_entropy})

    def configure_optimizers(self):
        return self.configure_single_optimizer(
            self.optimizer, self.lr_scheduler, self.lr
        )

    def training_step(self, batch, batch_idx):
        some codes ...


    class LitModel(TopVirtualModule):
        def __init__(self):
            super().__init__()
            self.l1 = nn.Linear(28, 1000)
            self.classifier = nn.Linear(1000, 42)


    model = LitModel()
    teacher = LitTeacher(42)
    # link teacher & pre-configure training details.
    model.setup_teacher(teacher)
    dataloader = ...
    trainer = pl.Trainer(...)
    trainer.fit(model, dataloader)

Note

In this class, setup_loss_fn_dict(), setup_train_metrics() are modified from flash. Feel free to modify them in derived implementations.

Class attributes (overridden by derived classes):
  • task_name (str) -- name of task, e.g., “automatic-speaker-verification”.

attach_model(model)[source]#

Attach model, do this once before step.

setup()[source]#

Called by linked model’s hook: setup() in fit stage, gives a chance to setup in accelerate environment.

setup_model()[source]#

An inferce for building linked model.

setup_loss_fn_dict(loss_fn)[source]#

Use one attr loss_fn_dict in model as loss container.

setup_train_metrics(metrics)[source]#

Use one attr train_metrics in model as metrics(recommand torchmetrics) container.

setup_val_metrics(metrics)[source]#

Use one attr val_metrics in model as metrics(recommand torchmetrics) container.

training_step(batch, batch_idx)[source]#

Core method to be implemented.

validation_step(batch, batch_idx)[source]#

Core method to be implemented.

on_train_start()[source]#

Called at the very beginning of train.

If on DDP it is called on every process

Return type:

None

on_train_end()[source]#

Called at the very end of train.

If on DDP it is called on every process

Return type:

None

classmethod available_optimizers()[source]#

Returns a list containing the keys of the available Optimizers.

Return type:

List[str]

classmethod available_lr_schedulers()[source]#

Returns a list containing the keys of the available LR schedulers.

Return type:

List[str]

configure_optimizers()[source]#

Implement this method in subclasses. see lightning.pytorch.LightningModule.configure_optimizers().

configure_single_optimizer(optimizer, lr_scheduler=None, learning_rate=None)[source]#

Implement how optimizer and optionally learning rate schedulers should be configured.

Return type:

Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]

instantiate_lr_scheduler(optimizer, lr_scheduler)[source]#

Initiates lr_scheduler to lighting’s lr_scheduler config.

Return type:

Dict[str, Any]

classmethod get_lr_scheduler_total_steps_name(lr_scheduler_key)[source]#

Try to get the num of training steps key name for lr_scheduler if needed. :rtype: Optional[str]

  • Use the metadata total_steps_key=... registed in registry of that scheduler key.

  • Find the signature of registry fn, return the signature param key which is in lr_total_steps_key_registry of this teacher class.

  • Return None if all faield.

you can dynamicly register key in lr_total_steps_key_registry.

configure_teacher_callbacks()[source]#

Configure teacher-specific callbacks.

Manually call this funtion to get callbacks and add them to your trainer.

Example:

class LitTeacher(Teacher):
    def __init__(self):
        super().__init__()

    def setup_model(self):
        self.model.classifier = nn.Linear(1000, 42)

    def configure_teacher_callbacks(self):
        return [PrintCallback()]

     def training_step(self, batch, batch_idx):
         pass

class PrintCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is started!")

    def on_train_end(self, trainer, pl_module):
        print("Training is done.")


class LitModel(TopVirtualModule):
   def __init__(self, teacher):
        super().__init__()
        self.l1 = None
        self.teacher = teacher
        self.teacher.setup_model()


teacher = LitTeacher()
model = LitModel(teacher)
t_callbacks = model.configure_teacher_callbacks()
trainer = Trainer(accelerator="gpu", devices=2, callbacks=t_callbacks)
property estimated_num_steps_per_epoch: int | float#

The estimated number of steps that will optimizer.step() during training in one epoch.

This accounts for gradient accumulation and the current trainer configuration. This might sets up your training dataloader if hadn’t been set up already.

def configure_optimizers(self):
    optimizer = ...
    stepping_batches = self.trainer.estimated_num_steps_per_epoch
    num_epoch = 10
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches * num_epoch)
    return [optimizer], [scheduler]
Raises:

MisconfigurationException -- If estimated stepping batches cannot be computed due to different accumulate_grad_batches at different epochs.

get_num_training_steps()[source]#

Total training steps inferred from datamodule and devices.

Return type:

int

egrecho.core.teacher.normalize_callable_dict(fn)[source]#

Normalize class/func into dict.

Return type:

Union[Dict, Mapping]