Model evaluation¶
- mlip.training.evaluation.make_evaluation_step(predictor: ForceFieldPredictor, eval_loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], should_parallelize: bool = True) Callable[[dict[str, dict[str, Array | dict]], GraphsTuple, int], dict[str, ndarray]] ¶
Creates the evaluation step function.
- Parameters:
predictor – The predictor to use.
eval_loss_fun – The loss function for the evaluation.
should_parallelize – Whether to apply data parallelization across multiple devices.
- Returns:
The evaluation step function.
- mlip.training.evaluation.run_evaluation(evaluation_step: Callable[[dict[str, dict[str, Array | dict]], GraphsTuple, int], dict[str, ndarray]], eval_dataset: GraphDataset | PrefetchIterator, params: dict[str, dict[str, Array | dict]], epoch_number: int, io_handler: TrainingIOHandler, devices: list[Device] | None = None, is_test_set: bool = False) float ¶
Runs a model evaluation on a given dataset.
- Parameters:
evaluation_step – The evaluation step function.
eval_dataset – The dataset on which to evaluate the model.
params – The parameters to use for the evaluation.
epoch_number – The current epoch number.
io_handler – The IO handler class that handles the logging of the result.
devices – The jax devices. It can be None if not run in parallel (default).
is_test_set – Whether the evaluation is done on the test set, i.e., not during a training run. By default, this is false.
- Returns:
The mean loss.