IO handling during training

class mlip.training.training_io_handler.TrainingIOHandler(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None)

An IO handler class for the training loop.

This handles checkpointing as well as specialized logging, e.g., to some external logger that a user can provide. Checkpointing is delegated to OrbaxCheckpointer instances for post-epoch and intra-epoch saves.

__init__(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None) None

Constructor.

Parameters:
  • config – The training IO handler pydantic config. Can be None in which case the default config will be used. Default is None.

  • data_upload_fun – A data upload function to a remote storage. This is optional, and set to None as default. This function should just take in a source path, and then the upload location can be user-defined within that function. The function can be asynchronous in which case it should return a Future.

attach_logger(logger: Callable[[LogCategory, dict[str, Any], int], None]) None

Attaches one training loop logging function to the IO handler.

The logging function must take in three parameter and should not return anything. The three parameters are a logging category which describes what type of data is logged (it is an enum), the data dictionary to log, and the current epoch number.

Parameters:

logger – The logging function to add.

log(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None

Logs data via the logging functions stored in this class.

Parameters:
  • category – A logging category which describes what type of data is logged (it is an enum)

  • to_log – A data dictionary to log (typically, metrics).

  • epoch_number – The current epoch number.

save_checkpoint(training_state: TrainingState, epoch_number: int, dataset_state: GraphDatasetState | None = None, metrics: dict[str, float] | None = None) None

Saves a model checkpoint using the post-epoch checkpointer.

Parameters:
  • training_state – The training state to save.

  • epoch_number – The current epoch number.

  • dataset_state – Iterator state of the training dataset, persisted as a sibling item so it can be restored separately. In multi-host mode the caller must pass a globally-addressable pytree here.

  • metrics – Scalar metrics used for best-checkpoint ranking.

save_dataset_info(dataset_info: DatasetInfo) None

Save the dataset information class to disk in JSON format.

Will also upload with data upload function if it exists.

Parameters:

dataset_info – The dataset information class to save.

restore_checkpoint(training_state: TrainingState, dataset_state: GraphDatasetState | None = None, mesh: Mesh | None = None) tuple[TrainingState, GraphDatasetState | None, int, list]

Restores a training state from disk locally.

When intra-epoch checkpointing is enabled, both the post-epoch and intra-epoch checkpoint managers are inspected. The checkpoint with the higher num_steps (read from Orbax custom metadata) wins. After restoring, any intra-epoch checkpoints are cleared so the next run starts fresh.

Parameters:
  • training_state – An instance of training state, which will serve as a template for the restoration.

  • dataset_state – Optional dataset state template. When provided, the dataset state is restored from the checkpoint and returned alongside the training state.

  • mesh – Optional JAX Mesh for hardware-agnostic restoration.

Returns:

A tuple of (restored_training_state, restored_dataset_state, last_completed_epoch, metrics). restored_dataset_state is None when no template is provided or no checkpoint exists. The epoch number is always read from the checkpoint metadata. For post-epoch checkpoints it is the epoch that was completed. For intra-epoch checkpoints it is epoch_number - 1 (the last fully completed epoch). metrics is the list of per-step training metrics accumulated before the intra-epoch checkpoint was taken (empty for post-epoch or fresh starts).

class mlip.training.training_io_handler.TrainingIOHandlerConfig(*, checkpoint_dir: str | PathLike | None = None, restore_dir: str | PathLike | None = None, max_to_keep: Annotated[int, Gt(gt=0)] = 5, save_debiased_ema: bool = True, ema_decay: Annotated[float, Gt(gt=0.0), Le(le=1.0)] = 0.99, use_single_host_patch: bool = False, enable_async_checkpointing: bool = False, use_intra_epoch_checkpointing: bool = False, intra_epoch_save_every_n_steps: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = 100, intra_epoch_max_to_keep: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = 5, restore_checkpoint_if_exists: bool = False, epoch_to_restore: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, restore_optimizer_state: bool = False, clear_previous_checkpoints: bool = False)

Pydantic config holding all settings relevant for the training IO handler.

checkpoint_dir

Root checkpoint directory. Supports both local paths and can also support Paths supported by Orbax. When None, dataset-info saving and the clear_previous_checkpoints guard are skipped. Defaults to None.

Type:

str | os.PathLike | None

restore_dir

Directory to restore checkpoints from. If None, will default to checkpoint_dir.

Type:

str | os.PathLike | None

max_to_keep

Maximum number of post-epoch checkpoints to retain. The default is 5.

Type:

int

save_debiased_ema

Whether to also save the EMA parameters. The default is True.

Type:

bool

ema_decay

The EMA decay rate. The default is 0.99.

Type:

float

use_intra_epoch_checkpointing

Whether to also use intra-epoch checkpointing. The default is False.

Type:

bool

intra_epoch_save_every_n_steps

Save an intra-epoch checkpoint every N training steps. None disables intra-epoch checkpointing.

Type:

int | None

intra_epoch_max_to_keep

Maximum intra-epoch checkpoints to retain.

Type:

int | None

restore_checkpoint_if_exists

Whether to restore a previous checkpoint if it exists. By default, this is False.

Type:

bool

epoch_to_restore

The epoch number to restore. The default is None, which means the latest epoch will be restored.

Type:

int | None

restore_optimizer_state

Whether to also restore the optimizer state. Default is False.

Type:

bool

clear_previous_checkpoints

Whether to clear the previous checkpoints if any exist. Note that this setting can not be set to True if one selects to restore a checkpoint. The default is False.

Type:

bool

use_single_host_patch

Whether to use single-host (process-0-only) checkpointing. When True, Orbax writes without multi-host coordination and async checkpointing is disabled. Defaults to False.

Type:

bool

enable_async_checkpointing

Whether Orbax should write checkpoints asynchronously. Defaults to False. Incompatible with use_single_host_patch; see CheckpointerConfig.

Type:

bool

__init__(**data: Any) None

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

class mlip.training.training_io_handler.LogCategory(*values)

Enum class for logging categories.

These values provide a signal to a logging function what type of data is being logged.

BEST_MODEL

Information about the current best model is logged.

TRAIN_METRICS

Metrics for the training set are logged.

EVAL_METRICS

Metrics for the validation set are logged.

TEST_METRICS

Metrics for the test set are logged.

SYSTEM_METRICS

Per-process system metrics (runtime, throughput) are logged.

CLEANUP_AFTER_CKPT_RESTORATION

Allows the logger to clean itself up after a checkpoint has been restored.

mlip.training.training_loggers.log_metrics_to_table(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None

Logging function for the training loop which logs the metrics to a nice table.

The table will be printed to the command line.

This function also converts MSE metrics to RMSE before logging them.

Parameters:
  • category – The logging category describing what type of data is currently logged.

  • to_log – The data to log (typically, the metrics).

  • epoch_number – The current epoch number.

mlip.training.training_loggers.log_metrics_to_line(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None

Logging function for the training loop which logs the metrics to a single line.

This function also converts MSE metrics to RMSE before logging them.

Parameters:
  • category – The logging category describing what type of data is currently logged.

  • to_log – The data to log (typically, the metrics).

  • epoch_number – The current epoch number.

mlip.training.training_loggers.convert_mse_to_rmse_in_logs(to_log: dict[str, Any]) dict[str, Any]

Simple helper function to convert all MSE values to RMSE values in a given metrics object to log. To compute a correct RMSE, we need to take the square root at the very end and not before any averaging happens, hence, our logged metrics objects only contain MSE as a metric, which needs to be converted with this function.

Parameters:

to_log – The metrics dictionary.

Returns:

The metrics dictionary with any MSE entries converted to RMSE.