Training step

mlip.training.training_step.make_train_step(predictor: ForceFieldPredictor, loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], optimizer: GradientTransformation, ema_fun: EMAParameterTransformation, num_gradient_accumulation_steps: int | None = 1, should_parallelize: bool = True) Callable

Create a training step function to optimize model params using gradients.

Parameters:
  • predictor – The force field predictor, instance of nn.Module.

  • loss_fun – A function that computes the loss from predictions, a reference labelled graph, and the epoch number.

  • optimizer – An optimizer for updating model params based on computed gradients.

  • ema_fun – A function for updating the exponential moving average (EMA) of the model params.

  • num_gradient_accumulation_steps – The number of gradient accumulation steps before a parameter update is performed. Defaults to 1, implying immediate updates.

  • should_parallelize – Whether to apply pmap.

Returns:

A function that takes the current training state and a batch of data as input, and returns the updated training state along with training metrics.