Batched inference

mlip.inference.batched_inference.run_batched_inference(structures: list[Atoms] | GraphDataset | PrefetchIterator | CombinedGraphDataset, force_field: ForceField, batch_size: int = 16, max_n_node: int | None = None, max_n_edge: int | None = None, set_none_charges_to_zero: bool = True) list[Prediction]

Runs a batched inference on given structures.

Computes energies, forces, and if available with the given force field, stress tensors. Result will be returned as a list of Prediction objects, one for each input structure.

Note: When using batch_size=1, we recommend to set max_n_node and max_n_edge explicitly to avoid edge cases in the automated computation of these parameters that may cause errors.

Parameters:
  • structures – The list of ase.Atoms to iterate over and then compute predictions for. Optionally, an already processed GraphDataset or PrefetchIterator object may be passed.

  • force_field – The force field object to compute the predictions with.

  • batch_size – The batch size. Default is 16. Ignored if structures are passed as a GraphDataset or PrefetchIterator.

  • max_n_node – This value will be multiplied with the batch size to determine the maximum number of nodes we allow in a batch. Note that a batch will always contain max_n_node * batch_size nodes, as the remaining ones are filled up with dummy nodes. The default is None which means an optimal number is automatically computed for the dataset. Ignored if structures are passed as a GraphDataset or PrefetchIterator.

  • max_n_edge – This value will be multiplied with the batch size to determine the maximum number of edges we allow in a batch. Note that a batch will always contain max_n_edge * batch_size edges, as the remaining ones are filled up with dummy edges. The default is None which means an optimal number is automatically computed for the dataset. Ignored if structures are passed as a GraphDataset or PrefetchIterator.

  • set_none_charges_to_zero – Whether to set None total charges to zero during preprocessing. Default is True. Ignored if structures are passed as a GraphDataset or PrefetchIterator.

Returns:

A list of predictions for each structure. These dataclasses will hold a float for energy, a numpy array for forces of shape (num_atoms, 3). Optionally, will also contain a stress array of shape (3, 3) and a partial charge array of shape (num_atoms,).

Raises:

ValueError – if any of the input systems has only one atom.