Training Loop¶
- class mlip.training.training_loop.TrainingLoop(train_dataset: GraphDataset | PrefetchIterator | CombinedGraphDataset, validation_dataset: GraphDataset | PrefetchIterator | CombinedGraphDataset | dict[str, GraphDataset | PrefetchIterator | CombinedGraphDataset], force_field: ForceField, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, mesh: Mesh | None = None)¶
Training loop class.
It implements only the loop based on its inputs but does not construct any auxiliary objects within it. For example, the model, dataset, and optimizer must be passed to this function from the outside.
- training_state¶
The training state.
- __init__(train_dataset: GraphDataset | PrefetchIterator | CombinedGraphDataset, validation_dataset: GraphDataset | PrefetchIterator | CombinedGraphDataset | dict[str, GraphDataset | PrefetchIterator | CombinedGraphDataset], force_field: ForceField, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, mesh: Mesh | None = None) None¶
Constructor.
Note: This constructor updates the
add_atomic_energiesconfig field of theMLIPNetworkclass of the force field, if requested via theatomic_energies_removedfield in the dataset info. Hence, accessingself.force_fieldwill possibly yield an updated force field. However, the methodbest_model()will return the original unmodified force field (but with the best parameters).- Parameters:
train_dataset – The training dataset (GraphDataset or PrefetchIterator).
validation_dataset – The validation dataset (GraphDataset or PrefetchIterator). This can also be given as a dictionary of validation datasets instead. In that case, the metrics names during evaluation will be prefixed with the keys of that dictionary.
force_field – The force field model holding at least the initial parameters and a dataset info object.
loss – The loss, which is derived from the
Lossbase class.optimizer – The optimizer (based on optax).
config – The training loop pydantic config.
io_handler – The IO handler which handles checkpointing and (specialized) logging. This is an optional argument. The default is
None, which means that a default IO handler will be set up which does not include checkpointing but some very basic metrics logging.mesh – The device mesh to use for training and evaluation. If not provided, a device mesh will be created automatically based on the available devices.
- run() None¶
Runs the training loop.
The final training state can be accessed via its member variable.
- test(test_dataset: GraphDataset | PrefetchIterator | CombinedGraphDataset) None¶
Run the evaluation on the test dataset with the best parameters seen so far.
- Parameters:
test_dataset – The test dataset (GraphDataset or PrefetchIterator).