Skip to content

lit_mlflow

source package lit_mlflow

source class MlFlowAutoCallback(patch_device_monitor: bool = True)

Bases : Callback

Attributes

  • state_key : str Identifier for the state of the callback.

Methods

source property MlFlowAutoCallback.client: MlflowClient | None

source method MlFlowAutoCallback.setup(trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)None

Called when fit, validate, test, predict, or tune begins.

source method MlFlowAutoCallback.teardown(trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str)None

Called when fit, validate, test, predict, or tune ends.

source method MlFlowAutoCallback.on_fit_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when fit begins.

source method MlFlowAutoCallback.on_fit_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when fit ends.

source method MlFlowAutoCallback.on_sanity_check_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation sanity check starts.

source method MlFlowAutoCallback.on_sanity_check_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation sanity check ends.

source method MlFlowAutoCallback.on_train_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int)None

Called when the train batch begins.

source method MlFlowAutoCallback.on_train_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int)None

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

source method MlFlowAutoCallback.on_train_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train epoch begins.

source method MlFlowAutoCallback.on_train_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the :class:lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()

source method MlFlowAutoCallback.on_validation_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the val epoch begins.

source method MlFlowAutoCallback.on_validation_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the val epoch ends.

source method MlFlowAutoCallback.on_test_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test epoch begins.

source method MlFlowAutoCallback.on_test_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test epoch ends.

source method MlFlowAutoCallback.on_predict_epoch_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the predict epoch begins.

source method MlFlowAutoCallback.on_predict_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the predict epoch ends.

source method MlFlowAutoCallback.on_validation_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the validation batch begins.

source method MlFlowAutoCallback.on_validation_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the validation batch ends.

source method MlFlowAutoCallback.on_test_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the test batch begins.

source method MlFlowAutoCallback.on_test_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the test batch ends.

source method MlFlowAutoCallback.on_predict_batch_start(trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the predict batch begins.

source method MlFlowAutoCallback.on_predict_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int = 0)None

Called when the predict batch ends.

source method MlFlowAutoCallback.on_train_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train begins.

source method MlFlowAutoCallback.on_train_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the train ends.

source method MlFlowAutoCallback.on_validation_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation loop begins.

source method MlFlowAutoCallback.on_validation_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the validation loop ends.

source method MlFlowAutoCallback.on_test_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test begins.

source method MlFlowAutoCallback.on_test_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the test ends.

source method MlFlowAutoCallback.on_predict_start(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when the predict begins.

source method MlFlowAutoCallback.on_predict_end(trainer: pl.Trainer, pl_module: pl.LightningModule)None

Called when predict ends.

source method MlFlowAutoCallback.on_exception(trainer: pl.Trainer, pl_module: pl.LightningModule, exception: BaseException)None

Called when any trainer execution is interrupted by an exception.

source class DbxMLFlowLogger(run_name: str | None = None, tracking_uri: str | None = mlflow.get_tracking_uri(), tags: dict[str, Any] | None = None, save_dir: str | None = './mlruns', log_model: Literal[True, False, all] = False, prefix: str = '', artifact_location: str | None = None, run_id: str | None = None)

Bases : MLFlowLogger

Attributes

  • name : Optional[str] Get the experiment id.

  • version : Optional[str] Get the run id.

  • root_dir : Optional[str] Return the root directory where all versions of an experiment get saved, or None if the logger does not save data locally.

  • log_dir : Optional[str] Return directory the current version of the experiment gets saved, or None if the logger does not save data locally.

  • group_separator : str Return the default separator used by the logger to group the data into subfolders.

  • save_dir : Optional[str] The root file directory in which MLflow experiments are saved.

  • experiment_id : Optional[str] Create the experiment if it does not exist to get the experiment id.

source property DbxMLFlowLogger.experiment: MlflowClient

source property DbxMLFlowLogger.run_id: str | None