Batched inference¶
- mlip.inference.batched_inference.run_batched_inference(structures: list[Atoms] | GraphDataset | PrefetchIterator | CombinedGraphDataset, force_field: ForceField, batch_size: int = 16, max_n_node: int | None = None, max_n_edge: int | None = None, set_none_charges_to_zero: bool = True) list[Prediction]¶
Runs a batched inference on given structures.
Computes energies, forces, and if available with the given force field, stress tensors. Result will be returned as a list of
Predictionobjects, one for each input structure.Note: When using
batch_size=1, we recommend to setmax_n_nodeandmax_n_edgeexplicitly to avoid edge cases in the automated computation of these parameters that may cause errors.- Parameters:
structures – The list of
ase.Atomsto iterate over and then compute predictions for. Optionally, an already processedGraphDatasetorPrefetchIteratorobject may be passed.force_field – The force field object to compute the predictions with.
batch_size – The batch size. Default is 16. Ignored if structures are passed as a
GraphDatasetorPrefetchIterator.max_n_node – This value will be multiplied with the batch size to determine the maximum number of nodes we allow in a batch. Note that a batch will always contain max_n_node * batch_size nodes, as the remaining ones are filled up with dummy nodes. The default is
Nonewhich means an optimal number is automatically computed for the dataset. Ignored if structures are passed as aGraphDatasetorPrefetchIterator.max_n_edge – This value will be multiplied with the batch size to determine the maximum number of edges we allow in a batch. Note that a batch will always contain max_n_edge * batch_size edges, as the remaining ones are filled up with dummy edges. The default is
Nonewhich means an optimal number is automatically computed for the dataset. Ignored if structures are passed as aGraphDatasetorPrefetchIterator.set_none_charges_to_zero – Whether to set None total charges to zero during preprocessing. Default is
True. Ignored if structures are passed as aGraphDatasetorPrefetchIterator.
- Returns:
A list of predictions for each structure. These dataclasses will hold a float for energy, a numpy array for forces of shape
(num_atoms, 3). Optionally, will also contain a stress array of shape(3, 3)and a partial charge array of shape(num_atoms,).- Raises:
ValueError – if any of the input systems has only one atom.