Training step¶
- mlip.training.training_step.make_train_step(predictor: ForceFieldPredictor, loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], optimizer: GradientTransformation, ema_fun: EMAParameterTransformation, num_gradient_accumulation_steps: int | None = 1, should_parallelize: bool = True) Callable ¶
Create a training step function to optimize model params using gradients.
- Parameters:
predictor – The force field predictor, instance of
nn.Module
.loss_fun – A function that computes the loss from predictions, a reference labelled graph, and the epoch number.
optimizer – An optimizer for updating model params based on computed gradients.
ema_fun – A function for updating the exponential moving average (EMA) of the model params.
num_gradient_accumulation_steps – The number of gradient accumulation steps before a parameter update is performed. Defaults to 1, implying immediate updates.
should_parallelize – Whether to apply pmap.
- Returns:
A function that takes the current training state and a batch of data as input, and returns the updated training state along with training metrics.