IO handling during training

class mlip.training.training_io_handler.TrainingIOHandler(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None)

An IO handler class for the training loop.

This handles checkpointing as well as specialized logging, e.g., to some external logger that a user can provide. If the config contains None for the local checkpointing directory, then this class will only do logging, but no checkpointing.

__init__(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None) None

Constructor.

Parameters:
  • config – The training IO handler pydantic config. Can be None in which case the default config will be used. Default is None.

  • data_upload_fun – A data upload function to a remote storage. This is optional, and set to None as default. This function should just take in a source path, and then the upload location can be user-defined within that function. The function can be asynchronous in which case it should return a Future.

attach_logger(logger: Callable[[LogCategory, dict[str, Any], int], None]) None

Attaches one training loop logging function to the IO handler.

The logging function must take in three parameter and should not return anything. The three parameters are a logging category which describes what type of data is logged (it is an enum), the data dictionary to log, and the current epoch number.

Parameters:

logger – The logging function to add.

log(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None

Logs data via the logging functions stored in this class.

Parameters:
  • category – A logging category which describes what type of data is logged (it is an enum)

  • to_log – A data dictionary to log (typically, metrics).

  • epoch_number – The current epoch number.

save_checkpoint(training_state: TrainingState, epoch_number: int) None

Saves a model checkpoint to disk.

Uses the data upload function as well if it exists.

Parameters:
  • training_state – The training state to save.

  • epoch_number – The current epoch number.

save_dataset_info(dataset_info: DatasetInfo) None

Save the dataset information class to disk in JSON format.

Will also upload with data upload function if it exists.

Parameters:

dataset_info – The dataset information class to save.

restore_training_state(training_state: TrainingState) TrainingState

Restores a training state from disk locally.

Note that if one wants to restore from a remote location, first download the state outside of this function.

Parameters:

training_state – An instance of training state, which will serve as a template for the restoration.

Returns:

The restored training state.

class mlip.training.training_io_handler.TrainingIOHandlerConfig(*, local_model_output_dir: str | PathLike | None = None, max_checkpoints_to_keep: Annotated[int, Gt(gt=0)] = 5, save_debiased_ema: bool = True, ema_decay: Annotated[float, Gt(gt=0.0), Le(le=1.0)] = 0.99, restore_checkpoint_if_exists: bool = False, epoch_to_restore: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, restore_optimizer_state: bool = False, clear_previous_checkpoints: bool = False)

Pydantic config holding all settings relevant for the training IO handler.

local_model_output_dir

Path to the output directory (local filesystem) where the model/dataset information and checkpoints are stored. If None, then local checkpointing will be disabled. Defaults to None.

Type:

str | os.PathLike | None

max_checkpoints_to_keep

Maximum number of old checkpoints to keep. The default is 5.

Type:

int

save_debiased_ema

Whether to also save the EMA parameters. The default is True.

Type:

bool

ema_decay

The EMA decay rate. The default is 0.99.

Type:

float

restore_checkpoint_if_exists

Whether to restore a previous checkpoint if it exists. By default, this is False.

Type:

bool

epoch_to_restore

The epoch number to restore. The default is None, which means the latest epoch will be restored.

Type:

int | None

restore_optimizer_state

Whether to also restore the optimizer state. Default is False.

Type:

bool

clear_previous_checkpoints

Whether to clear the previous checkpoints if any exist. Note that this setting can not be set to True if one selects to restore a checkpoint. The default is False.

Type:

bool

__init__(**data: Any) None

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

class mlip.training.training_io_handler.LogCategory(value)

Enum class for logging categories.

These values provide a signal to a logging function what type of data is being logged.

BEST_MODEL

Information about the current best model is logged.

TRAIN_METRICS

Metrics for the training set are logged.

EVAL_METRICS

Metrics for the validation set are logged.

TEST_METRICS

Metrics for the test set are logged.

CLEANUP_AFTER_CKPT_RESTORATION

Allows the logger to clean itself up after a checkpoint has been restored.

mlip.training.training_loggers.log_metrics_to_table(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None

Logging function for the training loop which logs the metrics to a nice table.

The table will be printed to the command line.

Parameters:
  • category – The logging category describing what type of data is currently logged.

  • to_log – The data to log (typically, the metrics).

  • epoch_number – The current epoch number.

mlip.training.training_loggers.log_metrics_to_line(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None

Logging function for the training loop which logs the metrics to a single line.

Parameters:
  • category – The logging category describing what type of data is currently logged.

  • to_log – The data to log (typically, the metrics).

  • epoch_number – The current epoch number.