Skip to content

lit_mlflow.callback

source module lit_mlflow.callback

source class MlFlowAutoCallback(patch_device_monitor: bool = True)

Bases : Callback

Attributes

  • state_key : str Identifier for the state of the callback.

Methods

source property MlFlowAutoCallback.client: MlflowClient | None

source method MlFlowAutoCallback.setup(trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)None

Called when fit, validate, test, predict, or tune begins.

source method MlFlowAutoCallback.teardown(trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)None

Called when fit, validate, test, predict, or tune ends.

source method MlFlowAutoCallback.on_fit_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when fit begins.

source method MlFlowAutoCallback.on_fit_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when fit ends.

source method MlFlowAutoCallback.on_sanity_check_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation sanity check starts.

source method MlFlowAutoCallback.on_sanity_check_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation sanity check ends.

source method MlFlowAutoCallback.on_train_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int)None

Called when the train batch begins.

source method MlFlowAutoCallback.on_train_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int)None

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

source method MlFlowAutoCallback.on_train_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train epoch begins.

source method MlFlowAutoCallback.on_train_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the :class:lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()

source method MlFlowAutoCallback.on_validation_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the val epoch begins.

source method MlFlowAutoCallback.on_validation_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the val epoch ends.

source method MlFlowAutoCallback.on_test_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test epoch begins.

source method MlFlowAutoCallback.on_test_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test epoch ends.

source method MlFlowAutoCallback.on_predict_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the predict epoch begins.

source method MlFlowAutoCallback.on_predict_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the predict epoch ends.

source method MlFlowAutoCallback.on_validation_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the validation batch begins.

source method MlFlowAutoCallback.on_validation_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the validation batch ends.

source method MlFlowAutoCallback.on_test_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the test batch begins.

source method MlFlowAutoCallback.on_test_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the test batch ends.

source method MlFlowAutoCallback.on_predict_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the predict batch begins.

source method MlFlowAutoCallback.on_predict_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the predict batch ends.

source method MlFlowAutoCallback.on_train_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train begins.

source method MlFlowAutoCallback.on_train_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train ends.

source method MlFlowAutoCallback.on_validation_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation loop begins.

source method MlFlowAutoCallback.on_validation_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation loop ends.

source method MlFlowAutoCallback.on_test_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test begins.

source method MlFlowAutoCallback.on_test_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test ends.

source method MlFlowAutoCallback.on_predict_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the predict begins.

source method MlFlowAutoCallback.on_predict_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when predict ends.

source method MlFlowAutoCallback.on_exception(trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException)None

Called when any trainer execution is interrupted by an exception.