Batching Helpers

mlip.graph.batching_helpers.batch_graphs(graphs: list[Graph]) Graph

Returns batched graph given a list of graphs.

Adapted from jraph.utils.batch_np.

Parameters:

graphs – The list of graphs to batch.

Returns:

The batched graph.

mlip.graph.batching_helpers.pad_with_graphs(graph: Graph, n_node: int, n_edge: int, n_graph: int, n_edge_long_range: int | None = None) Graph

Pads a graph to size by adding computation preserving graphs.

Adapted from jraph.utils.pad_with_graphs.

The graph is padded by first adding a padding graph which contains the padding nodes and edges, and then empty graphs without nodes or edges.

The empty graphs and the padding graph do not interfere with the MLIP calculations on the original graph, and so are computation preserving.

The padding graph requires at least one node and one graph.

This function does not support jax.jit, because the shape of the output is data-dependent.

Parameters:
  • graphGraph object to be padded with padding graph and empty graphs.

  • n_node – The number of nodes in the padded graph.

  • n_edge – The number of edges in the padded graph.

  • n_graph – The number of graphs in the padded graph. Two is the lowest possible value, because we always have at least one graph in the original graph, and we need one padding graph for the padding.

  • n_edge_long_range – The number of long range edges in the padded graph. If None, long range interactions are not padded.

Raises:
  • ValueError – If the passed n_graph is smaller than 2.

  • RuntimeError – If the given graph is too large for the given padding.

Returns:

The padded graph.

mlip.graph.batching_helpers.homogenize_graph_fields(graphs: list[Graph]) list[Graph]

Fill missing Prediction-targeted fields with NaN so graphs from heterogeneous datasets share the same pytree structure.

NaN is used as a sentinel so loss functions can detect and mask out samples whose dataset did not provide the field. Only fields that live in Prediction are homogenized here.

Parameters:

graphs – List of graphs that may have heterogeneous optional fields.

Returns:

A new list of graphs with uniform pytree structure.