Model evaluation

mlip.training.evaluation.make_evaluation_step(predictor: ForceFieldPredictor, eval_loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], should_parallelize: bool = True) Callable[[dict[str, dict[str, Array | dict]], GraphsTuple, int], dict[str, ndarray]]

Creates the evaluation step function.

Parameters:
  • predictor – The predictor to use.

  • eval_loss_fun – The loss function for the evaluation.

  • should_parallelize – Whether to apply data parallelization across multiple devices.

Returns:

The evaluation step function.

mlip.training.evaluation.run_evaluation(evaluation_step: Callable[[dict[str, dict[str, Array | dict]], GraphsTuple, int], dict[str, ndarray]], eval_dataset: GraphDataset | PrefetchIterator, params: dict[str, dict[str, Array | dict]], epoch_number: int, io_handler: TrainingIOHandler, devices: list[Device] | None = None, is_test_set: bool = False) float

Runs a model evaluation on a given dataset.

Parameters:
  • evaluation_step – The evaluation step function.

  • eval_dataset – The dataset on which to evaluate the model.

  • params – The parameters to use for the evaluation.

  • epoch_number – The current epoch number.

  • io_handler – The IO handler class that handles the logging of the result.

  • devices – The jax devices. It can be None if not run in parallel (default).

  • is_test_set – Whether the evaluation is done on the test set, i.e., not during a training run. By default, this is false.

Returns:

The mean loss.