Training Loop

class mlip.training.training_loop.TrainingLoop(train_dataset: GraphDataset | PrefetchIterator, validation_dataset: GraphDataset | PrefetchIterator, force_field: ForceField, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, should_parallelize: bool = False)

Training loop class.

It implements only the loop based on its inputs but does not construct any auxiliary objects within it. For example, the model, dataset, and optimizer must be passed to this function from the outside.

training_state

The training state.

__init__(train_dataset: GraphDataset | PrefetchIterator, validation_dataset: GraphDataset | PrefetchIterator, force_field: ForceField, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, should_parallelize: bool = False) None

Constructor.

Parameters:
  • train_dataset – The training dataset as either a GraphDataset or a PrefetchIterator.

  • validation_dataset – The validation dataset as either a GraphDataset or a PrefetchIterator.

  • force_field – The force field model holding at least the initial parameters and a dataset info object.

  • loss – The loss, which it is derived from the Loss base class.

  • optimizer – The optimizer (based on optax).

  • config – The training loop pydantic config.

  • io_handler – The IO handler which handles checkpointing and (specialized) logging. This is an optional argument. The default is None, which means that a default IO handler will be set up which does not include checkpointing but some very basic metrics logging.

  • should_parallelize – Whether to parallelize (using data parallelization) across multiple devices. The default is False.

run() None

Runs the training loop.

The final training state can be accessed via its member variable.

test(test_dataset: GraphDataset | PrefetchIterator) None

Run the evaluation on the test dataset with the best parameters seen so far.

Parameters:

test_dataset – The test dataset as either a GraphDataset or a PrefetchIterator.