Data processing¶
Set up graph dataset builder¶
In order to train a model or run batched inference, one needs to process the data
into objects of type
GraphDataset
.
This can be achieved by using the
GraphDatasetBuilder
class, which can be instantiated from its associated pydantic config and a
chemical systems reader that is derived from the
ChemicalSystemsReader
base class:
from mlip.data import GraphDatasetBuilder
reader = _get_chemical_systems_reader() # this is a placeholder for the moment
builder_config = GraphDatasetBuilder.Config(
graph_cutoff_angstrom=5.0,
max_n_node=None,
max_n_edge=None,
batch_size=16,
)
graph_dataset_builder = GraphDatasetBuilder(reader, builder_config)
In the example above, we set some example values for the settings in the
GraphDatasetBuilderConfig
.
For simpler code, we allow to access this config object directly via
GraphDatasetBuilder.Config
. Check out the API reference of the class to see the
full set of configurable values and for which values we have defaults available.
The chemical systems reader is an instance of a
ChemicalSystemsReader
class.
This class allows to read a dataset into lists of
ChemicalSystem
objects via
its load()
member function. You can either implement your own derived class to do
this for your custom dataset format, or you can employ one of the
built-in implementations, for example, the
ExtxyzReader
for datasets stored in extended XYZ format:
from mlip.data import ExtxyzReader
reader_config = ExtxyzReader.Config(
train_dataset_paths = "...",
valid_dataset_paths = "...",
test_dataset_paths = "...",
)
# If data is stored locally
reader = ExtxyzReader(reader_config)
# If data is on remote storage, one can also provide a data download function
reader = ExtxyzReader(reader_config, data_download_fun)
The configuration object used here is the
ChemicalSystemsReaderConfig
,
again accessible via ExtxyzReader.Config
to reduce the number of required imports.
In the example above, the data_download_fun
is a simple function that takes in
a source and a target path and performs the download operation. Our helper functions
for splitting a dataset are documented here.
If you have multiple datasets in different formats and would like to combine them,
you can do so by instead using the
CombinedReader
:
from mlip.data import CombinedReader
readers = _get_list_of_individual_chemical_system_readers() # placeholder
combined_reader = CombinedReader(readers)
This combined reader can then also be used as an input to the
GraphDatasetBuilder
.
Built-in graph dataset readers: data formats¶
As mentioned above, two built-in core readers are currently provided:
ExtxyzReader
and
Hdf5Reader
.
They each support their own data format. To train an MLIP model, we need a dataset of atomic systems with the following features per system with specific units:
the positions (i.e., coordinates) of the atoms in the structure in Angstrom
the element numbers of the atoms
the forces of the atoms in eV / Angstrom
the energy of the structure in eV
(optional) the stress of the structure in eV / Angstrom3
(optional) the periodic boundary conditions
For a detailed description of the data format that the
ExtxyzReader
requires, see
here.
For a detailed description of the data format that the
Hdf5Reader
.
requires, see
here.
Start preprocessing¶
Once you have the graph_dataset_builder
set up, you can start the preprocessing and
fetch the resulting datasets:
graph_dataset_builder.prepare_datasets()
splits = graph_dataset_builder.get_splits()
train_set, validation_set, test_set = splits
The resulting datasets are of type
GraphDataset
as mentioned above. For example, to process the batches in the training set, one
can execute:
num_graphs = len(train_set.graphs)
num_batches = len(train_set)
for batch in train_set:
_process_batch_in_some_way(batch)
Get sharded batches¶
If one wants to generate batches that are sharded across devices and prefetched, the
arguments to the get_splits()
member of the
GraphDatasetBuilder
must be set to the following:
splits = graph_dataset_builder.get_datasets(
prefetch=True, devices=jax.local_devices()
)
train_set, valid_set, test_set = splits
Now, the datasets are not of type
GraphDataset
anymore,
but of type
PrefetchIterator
instead which implements batch prefetching on top of the
ParallelGraphDataset
class. It can be iterated over to obtain the sharded batches in the same way, however,
note that it does not have a graphs
member that can be accessed directly.
Get dataset info¶
Furthermore, the builder class also populates a dataclass of type
DatasetInfo
, which contains
metadata about the dataset which are relevant to the models while training and must be
stored together with the models for these to be usable. The populated instance of this
dataclass can be accessed easily like this:
# Note: this will raise an exception if accessed
# before prepare_datasets() is run
dataset_info = graph_dataset_builder.dataset_info