Model evaluation¶
- mlip.training.evaluation.make_evaluation_step(predictor: ForceFieldPredictor, eval_loss_fun: Callable[[Graph, Graph, int, bool], tuple[Array, dict[str, Array]]], avg_n_graphs_per_batch: float, should_parallelize: bool = False, in_shardings: NamedSharding | tuple[NamedSharding, ...] | None = None, out_shardings: NamedSharding | tuple[NamedSharding, ...] | None = None) Callable[[dict[str, dict[str, Array | dict]], Graph, int], dict[str, ndarray]]¶
Creates the evaluation step function.
- Parameters:
predictor – The predictor to use.
eval_loss_fun – The loss function for the evaluation.
avg_n_graphs_per_batch – Average number of graphs per batch used for reweighting of metrics.
in_shardings – Optional in_shardings for
jax.jit.out_shardings – Optional out_shardings for
jax.jit.
- Returns:
The evaluation step function.
- mlip.training.evaluation.run_evaluation(evaluation_step: Callable[[dict[str, dict[str, Array | dict]], Graph, int], dict[str, ndarray]], eval_dataset: GraphDataset | PrefetchIterator | CombinedGraphDataset, params: dict[str, dict[str, Array | dict]], epoch_number: int, io_handler: TrainingIOHandler, is_test_set: bool = False, subset_name: str | None = None) float¶
Runs a model evaluation on a given dataset.
- Parameters:
evaluation_step – The evaluation step function.
eval_dataset – The dataset on which to evaluate the model.
params – The parameters to use for the evaluation.
epoch_number – The current epoch number.
io_handler – The IO handler class that handles the logging of the result.
is_test_set – Whether the evaluation is done on the test set, i.e., not during a training run. By default, this is false.
subset_name – Subset name for that dataset. If given, then the metric names will be prefixed by it.
- Returns:
The mean loss.