Batching Helpers¶
- mlip.graph.batching_helpers.batch_graphs(graphs: list[Graph]) Graph¶
Returns batched graph given a list of graphs.
Adapted from
jraph.utils.batch_np.- Parameters:
graphs – The list of graphs to batch.
- Returns:
The batched graph.
- mlip.graph.batching_helpers.pad_with_graphs(graph: Graph, n_node: int, n_edge: int, n_graph: int, n_edge_long_range: int | None = None) Graph¶
Pads a graph to size by adding computation preserving graphs.
Adapted from
jraph.utils.pad_with_graphs.The graph is padded by first adding a padding graph which contains the padding nodes and edges, and then empty graphs without nodes or edges.
The empty graphs and the padding graph do not interfere with the MLIP calculations on the original graph, and so are computation preserving.
The padding graph requires at least one node and one graph.
This function does not support
jax.jit, because the shape of the output is data-dependent.- Parameters:
graph –
Graphobject to be padded with padding graph and empty graphs.n_node – The number of nodes in the padded graph.
n_edge – The number of edges in the padded graph.
n_graph – The number of graphs in the padded graph. Two is the lowest possible value, because we always have at least one graph in the original graph, and we need one padding graph for the padding.
n_edge_long_range – The number of long range edges in the padded graph. If None, long range interactions are not padded.
- Raises:
ValueError – If the passed
n_graphis smaller than 2.RuntimeError – If the given graph is too large for the given padding.
- Returns:
The padded graph.
- mlip.graph.batching_helpers.homogenize_graph_fields(graphs: list[Graph]) list[Graph]¶
Fill missing
Prediction-targeted fields with NaN so graphs from heterogeneous datasets share the same pytree structure.NaN is used as a sentinel so loss functions can detect and mask out samples whose dataset did not provide the field. Only fields that live in
Predictionare homogenized here.- Parameters:
graphs – List of graphs that may have heterogeneous optional fields.
- Returns:
A new list of graphs with uniform pytree structure.