IO handling during training¶
- class mlip.training.training_io_handler.TrainingIOHandler(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None)¶
An IO handler class for the training loop.
This handles checkpointing as well as specialized logging, e.g., to some external logger that a user can provide. Checkpointing is delegated to
OrbaxCheckpointerinstances for post-epoch and intra-epoch saves.- __init__(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None) None¶
Constructor.
- Parameters:
config – The training IO handler pydantic config. Can be
Nonein which case the default config will be used. Default isNone.data_upload_fun – A data upload function to a remote storage. This is optional, and set to None as default. This function should just take in a source path, and then the upload location can be user-defined within that function. The function can be asynchronous in which case it should return a Future.
- attach_logger(logger: Callable[[LogCategory, dict[str, Any], int], None]) None¶
Attaches one training loop logging function to the IO handler.
The logging function must take in three parameter and should not return anything. The three parameters are a logging category which describes what type of data is logged (it is an enum), the data dictionary to log, and the current epoch number.
- Parameters:
logger – The logging function to add.
- log(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None¶
Logs data via the logging functions stored in this class.
- Parameters:
category – A logging category which describes what type of data is logged (it is an enum)
to_log – A data dictionary to log (typically, metrics).
epoch_number – The current epoch number.
- save_checkpoint(training_state: TrainingState, epoch_number: int, dataset_state: GraphDatasetState | None = None, metrics: dict[str, float] | None = None) None¶
Saves a model checkpoint using the post-epoch checkpointer.
- Parameters:
training_state – The training state to save.
epoch_number – The current epoch number.
dataset_state – Iterator state of the training dataset, persisted as a sibling item so it can be restored separately. In multi-host mode the caller must pass a globally-addressable pytree here.
metrics – Scalar metrics used for best-checkpoint ranking.
- save_dataset_info(dataset_info: DatasetInfo) None¶
Save the dataset information class to disk in JSON format.
Will also upload with data upload function if it exists.
- Parameters:
dataset_info – The dataset information class to save.
- restore_checkpoint(training_state: TrainingState, dataset_state: GraphDatasetState | None = None, mesh: Mesh | None = None) tuple[TrainingState, GraphDatasetState | None, int, list]¶
Restores a training state from disk locally.
When intra-epoch checkpointing is enabled, both the post-epoch and intra-epoch checkpoint managers are inspected. The checkpoint with the higher
num_steps(read from Orbax custom metadata) wins. After restoring, any intra-epoch checkpoints are cleared so the next run starts fresh.- Parameters:
training_state – An instance of training state, which will serve as a template for the restoration.
dataset_state – Optional dataset state template. When provided, the dataset state is restored from the checkpoint and returned alongside the training state.
mesh – Optional JAX
Meshfor hardware-agnostic restoration.
- Returns:
A tuple of
(restored_training_state, restored_dataset_state, last_completed_epoch, metrics).restored_dataset_stateisNonewhen no template is provided or no checkpoint exists. The epoch number is always read from the checkpoint metadata. For post-epoch checkpoints it is the epoch that was completed. For intra-epoch checkpoints it isepoch_number - 1(the last fully completed epoch).metricsis the list of per-step training metrics accumulated before the intra-epoch checkpoint was taken (empty for post-epoch or fresh starts).
- class mlip.training.training_io_handler.TrainingIOHandlerConfig(*, checkpoint_dir: str | PathLike | None = None, restore_dir: str | PathLike | None = None, max_to_keep: Annotated[int, Gt(gt=0)] = 5, save_debiased_ema: bool = True, ema_decay: Annotated[float, Gt(gt=0.0), Le(le=1.0)] = 0.99, use_single_host_patch: bool = False, enable_async_checkpointing: bool = False, use_intra_epoch_checkpointing: bool = False, intra_epoch_save_every_n_steps: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = 100, intra_epoch_max_to_keep: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = 5, restore_checkpoint_if_exists: bool = False, epoch_to_restore: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, restore_optimizer_state: bool = False, clear_previous_checkpoints: bool = False)¶
Pydantic config holding all settings relevant for the training IO handler.
- checkpoint_dir¶
Root checkpoint directory. Supports both local paths and can also support Paths supported by Orbax. When
None, dataset-info saving and theclear_previous_checkpointsguard are skipped. Defaults toNone.- Type:
str | os.PathLike | None
- restore_dir¶
Directory to restore checkpoints from. If
None, will default tocheckpoint_dir.- Type:
str | os.PathLike | None
- max_to_keep¶
Maximum number of post-epoch checkpoints to retain. The default is 5.
- Type:
int
- save_debiased_ema¶
Whether to also save the EMA parameters. The default is
True.- Type:
bool
- ema_decay¶
The EMA decay rate. The default is 0.99.
- Type:
float
- use_intra_epoch_checkpointing¶
Whether to also use intra-epoch checkpointing. The default is
False.- Type:
bool
- intra_epoch_save_every_n_steps¶
Save an intra-epoch checkpoint every N training steps.
Nonedisables intra-epoch checkpointing.- Type:
int | None
- intra_epoch_max_to_keep¶
Maximum intra-epoch checkpoints to retain.
- Type:
int | None
- restore_checkpoint_if_exists¶
Whether to restore a previous checkpoint if it exists. By default, this is
False.- Type:
bool
- epoch_to_restore¶
The epoch number to restore. The default is
None, which means the latest epoch will be restored.- Type:
int | None
- restore_optimizer_state¶
Whether to also restore the optimizer state. Default is
False.- Type:
bool
- clear_previous_checkpoints¶
Whether to clear the previous checkpoints if any exist. Note that this setting can not be set to
Trueif one selects to restore a checkpoint. The default isFalse.- Type:
bool
- use_single_host_patch¶
Whether to use single-host (process-0-only) checkpointing. When
True, Orbax writes without multi-host coordination and async checkpointing is disabled. Defaults toFalse.- Type:
bool
- enable_async_checkpointing¶
Whether Orbax should write checkpoints asynchronously. Defaults to
False. Incompatible withuse_single_host_patch; seeCheckpointerConfig.- Type:
bool
- __init__(**data: Any) None¶
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.
- class mlip.training.training_io_handler.LogCategory(*values)¶
Enum class for logging categories.
These values provide a signal to a logging function what type of data is being logged.
- BEST_MODEL¶
Information about the current best model is logged.
- TRAIN_METRICS¶
Metrics for the training set are logged.
- EVAL_METRICS¶
Metrics for the validation set are logged.
- TEST_METRICS¶
Metrics for the test set are logged.
- SYSTEM_METRICS¶
Per-process system metrics (runtime, throughput) are logged.
- CLEANUP_AFTER_CKPT_RESTORATION¶
Allows the logger to clean itself up after a checkpoint has been restored.
- mlip.training.training_loggers.log_metrics_to_table(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None¶
Logging function for the training loop which logs the metrics to a nice table.
The table will be printed to the command line.
This function also converts MSE metrics to RMSE before logging them.
- Parameters:
category – The logging category describing what type of data is currently logged.
to_log – The data to log (typically, the metrics).
epoch_number – The current epoch number.
- mlip.training.training_loggers.log_metrics_to_line(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None¶
Logging function for the training loop which logs the metrics to a single line.
This function also converts MSE metrics to RMSE before logging them.
- Parameters:
category – The logging category describing what type of data is currently logged.
to_log – The data to log (typically, the metrics).
epoch_number – The current epoch number.
- mlip.training.training_loggers.convert_mse_to_rmse_in_logs(to_log: dict[str, Any]) dict[str, Any]¶
Simple helper function to convert all MSE values to RMSE values in a given metrics object to log. To compute a correct RMSE, we need to take the square root at the very end and not before any averaging happens, hence, our logged metrics objects only contain MSE as a metric, which needs to be converted with this function.
- Parameters:
to_log – The metrics dictionary.
- Returns:
The metrics dictionary with any MSE entries converted to RMSE.