Graph Dataset¶
- class mlip.data.graph_dataset.GraphDatasetState(rng: PRNGKey = <factory>, num_graphs_processed: Array = <factory>)¶
State of a graph dataset, checkpointable through the wrapper chain.
- rng¶
Random key for shuffling. Stays at the pre-split value during an epoch so that mid-epoch checkpoints can replay the same
jax.random.splitto reproduce the shuffle permutation. Advanced to the post-split value at the end of the epoch.- Type:
jax._src.random.PRNGKey
- num_graphs_processed¶
Number of graphs yielded so far in the current epoch. Used to resume iteration from the correct position after a checkpoint.
- Type:
jax.Array
- replace(**updates)¶
Returns a new object replacing the specified fields with new values.
- class mlip.data.graph_dataset.GraphDataset(graphs: list[Graph], batch_size: int, max_n_node: int, max_n_edge: int, min_n_node: int = 1, min_n_edge: int = 1, min_n_graph: int = 1, max_n_edge_long_range: int | None = None, shuffle: bool = True, shuffle_between_epochs: bool = True, skip_last_batch: bool = False, raise_exc_if_graphs_discarded: bool = False, graph_postprocessing: list[Callable[[Graph], Graph]] | None = None, seed: int = 0, homogenize: bool = False)¶
Class for holding a dataset consisting of graphs, i.e.,
Graph, and managing batching.- __init__(graphs: list[Graph], batch_size: int, max_n_node: int, max_n_edge: int, min_n_node: int = 1, min_n_edge: int = 1, min_n_graph: int = 1, max_n_edge_long_range: int | None = None, shuffle: bool = True, shuffle_between_epochs: bool = True, skip_last_batch: bool = False, raise_exc_if_graphs_discarded: bool = False, graph_postprocessing: list[Callable[[Graph], Graph]] | None = None, seed: int = 0, homogenize: bool = False)¶
Constructor.
- Parameters:
graphs – The graphs to store and manage in this class.
batch_size – The batch size.
max_n_node – The maximum number of nodes contributed by one graph in a batch.
max_n_edge – The maximum number of edges contributed by one graph in a batch.
min_n_node – The minimum number of nodes in a batch, defaults to 1.
min_n_edge – The minimum number of edges in a batch, defaults to 1.
min_n_graph – The minimum number of graphs in a batch, defaults to 1.
max_n_edge_long_range – The maximum number of long range edges contributed by one graph in a batch. If None, long range interactions are not considered.
shuffle – Whether to shuffle the graphs before iterating, defaults to True.
shuffle_between_epochs – If true, then reshuffle data between epochs but only if should_shuffle is also true.
skip_last_batch – Whether to skip the last batch. The default is false.
raise_exc_if_graphs_discarded – Whether to raise an exception if there are graphs that must be discarded due to size constraints. Default is False, which means only a warning is logged.
graph_postprocessing – Optional list of functions applied to each batched graph when iterating. Default is None.
seed – The random seed to use for shuffling. Default is 0.
homogenize – If True, pad missing
Prediction-targeted optional fields (e.g.stress,forces) with NaN so graphs from heterogeneous datasets share the same pytree structure and can be batched. If False, the dataset instead validates that the provided graphs are already batch-compatible and raises a clear error otherwise. Defaults to False.
- __iter__()¶
- __len__()¶
Returns the number of batches but does not recompute them each time.
- subset(i: slice | int | list | float) GraphDataset¶
Constructs and returns a new graph dataset containing a subset of graphs of the current one with given slicing information
i.- Parameters:
i – The slicing information. See source code for options.
- Returns:
A new graph dataset containing only a subset of the graphs of the current one.