.. _training: Model training ============== To start a model training, there are the following prerequisites: * Loading and preprocessing a training and validation dataset, as described :ref:`here `. * Initializing a force field, as described :ref:`here `. * Setting up a loss (see :ref:`below ` for details). * Setting up an optimizer (see :ref:`below ` for details). * Creating an instance of :py:class:`TrainingLoopConfig ` (can be accessed via `TrainingLoop.Config`, too). * Optionally (has a default): Creating an instance of :py:class:`TrainingIOHandler ` (see :ref:`below ` for details). Once these objects are set up, we can create an instance of :py:class:`TrainingLoop ` and start the training run: .. code-block:: python from mlip.training import TrainingLoop # Prerequisites train_set, validation_set, dataset_info = _get_dataset() # placeholder force_field = _get_force_field() # placeholder loss = _get_loss() # placeholder optimizer = _get_optimizer () # placeholder io_handler = _get_training_loop_io_handler() # placeholder config = TrainingLoop.Config(**config_kwargs) # Create TrainingLoop class training_loop = TrainingLoop( train_dataset=train_set, validation_dataset=validation_set, force_field=force_field, loss=loss, optimizer=optimizer, config=config, io_handler=io_handler, # also has a default, does not need to be set should_parallelize=len(jax.devices()) > 1, # has a default of False ) # Start the model training training_loop.run() The final :py:class:`TrainingState ` can be accessed after the run like this: .. code-block:: python final_training_state = training_loop.training_state final_params = final_training_state.params However, the final parameters are not always the ones with the best performance on the validation set, and hence, you can also access these with ``training_loop.best_params`` or directly use `training_loop.best_model` to get the :py:class:`ForceField ` instance that holds the best parameters. If you want to save a trained force field not only via the checkpointing API described further below, you can also use the function :py:func:`save_model_to_zip() ` to save it as a lightweight zip archive in case you only want to use it for inference tasks later, as this archive does not include any training state. Note that it is also possible to run an evaluation on a test dataset after training by using the :py:func:`test() ` method of the :py:class:`TrainingLoop ` instance. In the following, we describe the prerequisites listed above in more detail. .. _training_loss: Loss ---- All losses must be implemented as derived classes of :py:class:`Loss `. We currently implement two losses, the Mean-Squared-Error loss (:py:class:`Loss `), and the Huber loss (:py:class:`Loss `), which are both losses that are derived from a loss that computes errors for energies, forces, and stress, and weights them according to some weighting schedule that can depend on the epoch number (base class: :py:class:`Loss `). If one wants to use the MSE loss for training, simply run this code to initialize it: .. code-block:: python import optax from mlip.models.loss import MSELoss # uses default weight schedules loss = MSELoss() # uses a weight flip schedule energy_weight_schedule = optax.piecewise_constant_schedule(1.0, {100: 25.0}) forces_weight_schedule = optax.piecewise_constant_schedule(25.0, {100: 0.04}) loss = MSELoss(energy_weight_schedule, forces_weight_schedule) For our two implemented losses, we also allow for computation of more extended metrics by setting the `extended_metrics` argument to `True` in the loss constructor. By default, it is `False`. .. _training_optimizer: Optimizer --------- The optimizer provided to the :py:class:`TrainingLoop ` can be any `Optax optimizer `_, however, this library also has a specialized pipeline that has been inspired by `this `_ PyTorch MACE implementation. It is configurable via a :py:class:`OptimizerConfig ` object that has sensible defaults set for training MLIP models. This default MLIP optimizer can be set up like this: .. code-block:: python from mlip.training import get_default_mlip_optimizer, OptimizerConfig # with default config optimizer = get_default_mlip_optimizer() # with modified config optimizer = get_default_mlip_optimizer(OptimizerConfig(**config_kwargs)) See the API reference for :py:func:`get_default_mlip_optimizer ` and :py:class:`OptimizerConfig ` for further details on how this MLIP optimizer works internally. .. _training_io_handler: IO handling and logging ----------------------- During training, we want to allow for checkpointing of the training state and logging of metrics. The :py:class:`TrainingIOHandler ` class manages these tasks. It comes with its own config, the :py:class:`TrainingIOHandlerConfig `, which like most other configs in the library can be accessed via `TrainingIOHandler.Config`. The IO handler uses `Orbax Checkpointing `_ to save and restore model checkpoints. Also, for loading a trained model for simulations or other inference tasks, this library relies on loading these model checkpoints (see :py:func:`load_parameters_from_checkpoint() `). The local checkpointing location can be set in the config, however, uploading these checkpoints to remote storage locations can be achieved via a provided data upload function: .. code-block:: python import os from mlip.training import TrainingIOHandler io_config = TrainingIOHandler.Config(**config_kwargs) def remote_storage_sync_fun(source: str | os.PathLike) -> None: """Makes sure local data in source is uploaded to remote storage""" pass # placeholder io_handler = TrainingIOHandler(io_config, remote_storage_sync_fun) Locally, after the training run has started, the checkpointing location will contain a ``dataset_info.json`` file with the saved :py:class:`DatasetInfo ` object, and a ``model`` subdirectory with all the model checkpoints, one for each epoch that had the best model up to that point judging by validation set loss. In this location, it is recommended to also save other metadata manually, such as the applied model config. For advanced logging, e.g., to an experiment tracking platform (such as `Neptune `_), one can also attach custom logging functions to the IO handler: .. code-block:: python mlip.training.training_io_handler import LogCategory def train_logging_fun( category: LogCategory, to_log: dict[str, Any], epoch_number: int ) -> None: """Advanced logging function""" pass # placeholder io_handler.attach_logger(train_logging_fun) See the documentation of :py:class:`LogCategory ` for more details on what type of data can be logged with such a logger during training. Furthermore, this library provides built-in logging functions that can be attached to the IO handler, :py:func:`log_metrics_to_table() `, which prints the training metrics to the console in a nice table format (using `Rich tables `_), or :py:func:`log_metrics_to_line() `, which logs the metrics in a single line. Note that it is possible to omit the `io_handler` argument in the :py:class:`TrainingLoop ` class. In that case, a default IO handler is set up internally and used. This IO handler does not include checkpointing, but it does have the :py:func:`log_metrics_to_line() ` logging function attached by default.