IO handling during training¶
- class mlip.training.training_io_handler.TrainingIOHandler(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None)¶
An IO handler class for the training loop.
This handles checkpointing as well as specialized logging, e.g., to some external logger that a user can provide. If the config contains
None
for the local checkpointing directory, then this class will only do logging, but no checkpointing.- __init__(config: TrainingIOHandlerConfig | None = None, data_upload_fun: Callable[[str | PathLike], Future | None] | None = None) None ¶
Constructor.
- Parameters:
config – The training IO handler pydantic config. Can be
None
in which case the default config will be used. Default isNone
.data_upload_fun – A data upload function to a remote storage. This is optional, and set to None as default. This function should just take in a source path, and then the upload location can be user-defined within that function. The function can be asynchronous in which case it should return a Future.
- attach_logger(logger: Callable[[LogCategory, dict[str, Any], int], None]) None ¶
Attaches one training loop logging function to the IO handler.
The logging function must take in three parameter and should not return anything. The three parameters are a logging category which describes what type of data is logged (it is an enum), the data dictionary to log, and the current epoch number.
- Parameters:
logger – The logging function to add.
- log(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None ¶
Logs data via the logging functions stored in this class.
- Parameters:
category – A logging category which describes what type of data is logged (it is an enum)
to_log – A data dictionary to log (typically, metrics).
epoch_number – The current epoch number.
- save_checkpoint(training_state: TrainingState, epoch_number: int) None ¶
Saves a model checkpoint to disk.
Uses the data upload function as well if it exists.
- Parameters:
training_state – The training state to save.
epoch_number – The current epoch number.
- save_dataset_info(dataset_info: DatasetInfo) None ¶
Save the dataset information class to disk in JSON format.
Will also upload with data upload function if it exists.
- Parameters:
dataset_info – The dataset information class to save.
- restore_training_state(training_state: TrainingState) TrainingState ¶
Restores a training state from disk locally.
Note that if one wants to restore from a remote location, first download the state outside of this function.
- Parameters:
training_state – An instance of training state, which will serve as a template for the restoration.
- Returns:
The restored training state.
- class mlip.training.training_io_handler.TrainingIOHandlerConfig(*, local_model_output_dir: str | PathLike | None = None, max_checkpoints_to_keep: Annotated[int, Gt(gt=0)] = 5, save_debiased_ema: bool = True, ema_decay: Annotated[float, Gt(gt=0.0), Le(le=1.0)] = 0.99, restore_checkpoint_if_exists: bool = False, epoch_to_restore: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None, restore_optimizer_state: bool = False, clear_previous_checkpoints: bool = False)¶
Pydantic config holding all settings relevant for the training IO handler.
- local_model_output_dir¶
Path to the output directory (local filesystem) where the model/dataset information and checkpoints are stored. If
None
, then local checkpointing will be disabled. Defaults toNone
.- Type:
str | os.PathLike | None
- max_checkpoints_to_keep¶
Maximum number of old checkpoints to keep. The default is 5.
- Type:
int
- save_debiased_ema¶
Whether to also save the EMA parameters. The default is
True
.- Type:
bool
- ema_decay¶
The EMA decay rate. The default is 0.99.
- Type:
float
- restore_checkpoint_if_exists¶
Whether to restore a previous checkpoint if it exists. By default, this is
False
.- Type:
bool
- epoch_to_restore¶
The epoch number to restore. The default is
None
, which means the latest epoch will be restored.- Type:
int | None
- restore_optimizer_state¶
Whether to also restore the optimizer state. Default is
False
.- Type:
bool
- clear_previous_checkpoints¶
Whether to clear the previous checkpoints if any exist. Note that this setting can not be set to
True
if one selects to restore a checkpoint. The default isFalse
.- Type:
bool
- __init__(**data: Any) None ¶
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError
][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.self
is explicitly positional-only to allowself
as a field name.
- class mlip.training.training_io_handler.LogCategory(value)¶
Enum class for logging categories.
These values provide a signal to a logging function what type of data is being logged.
- BEST_MODEL¶
Information about the current best model is logged.
- TRAIN_METRICS¶
Metrics for the training set are logged.
- EVAL_METRICS¶
Metrics for the validation set are logged.
- TEST_METRICS¶
Metrics for the test set are logged.
- CLEANUP_AFTER_CKPT_RESTORATION¶
Allows the logger to clean itself up after a checkpoint has been restored.
- mlip.training.training_loggers.log_metrics_to_table(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None ¶
Logging function for the training loop which logs the metrics to a nice table.
The table will be printed to the command line.
- Parameters:
category – The logging category describing what type of data is currently logged.
to_log – The data to log (typically, the metrics).
epoch_number – The current epoch number.
- mlip.training.training_loggers.log_metrics_to_line(category: LogCategory, to_log: dict[str, Any], epoch_number: int) None ¶
Logging function for the training loop which logs the metrics to a single line.
- Parameters:
category – The logging category describing what type of data is currently logged.
to_log – The data to log (typically, the metrics).
epoch_number – The current epoch number.