Models¶
Create a model and force field¶
This section discusses how to initialize an MLIP model for subsequent training. If you are just interested in loading a pre-trained model for application in simulations, please see the dedicated section below.
Our MLIP models exist in two abstraction levels:
On the one hand, we have the pure neural networks, which are classes derived from
MLIPNetwork
. As a general rule, these raw models take in as input a graph’s edge vectors and node representations and output a vector of node energies.On the other hand, we wrap these models into force fields which take care of computing properties such as total energy, forces, or stress from the MLIP network’s output and themselves take a
jraph.GraphsTuple
object from the jraph library as input. The flax module that implements this isForceFieldPredictor
, however, we recommend to mostly interact with the classForceField
which makes handling of a force field as one object (that is aware of its parameters) easier and is the main class for passing a model between training and simulation.
The library currently interfaces three MLIP model architectures, i.e., MLIP network implementations:
These networks can be created from their configuration
(MaceConfig
,
NequipConfig
, or
VisnetConfig
) and a
DatasetInfo
object
that one obtained after the data processing step. For the
sake of simplified usage, the config objects can be directly accessed from the network
classes via their .Config
attribute (see example below).
For example, to create a force field that uses MACE, one can simply execute:
from mlip.models import Mace, ForceField
dataset_info = _get_from_data_processing() # placeholder
# with default config
mace = Mace(Mace.Config(), dataset_info)
force_field = ForceField.from_mlip_network(mace)
# with modified config
mace = Mace(Mace.Config(num_channels=64), dataset_info)
force_field = ForceField.from_mlip_network(mace)
The ForceField
class stores the
parameters of the model (random parameters after initialization) and acts as the input
to all downstream tasks. However, it is also possible for advanced users to interact
with the underlying flax modules directly.
We recommend to visit the flax documentation
for more details on how to work with
flax modules.
Make predictions¶
We can run a prediction with an MLIP force field like this:
graph = _get_jraph_graph_from_somewhere() # placeholder
prediction = force_field(graph)
The prediction
includes several properties and is a dataclass of type
Prediction
. The properties other than
energy and forces are only predicted optionally
(see predict_stress
argument of ForceField.from_mlip_network
).
If the input graph
object (type: jraph.GraphsTuple
) contains multiple subgraphs,
for example, if it represents a batch, we can get the energy and forces of the i
-th
subgraph like this:
# For i-th energy
energy_i = float(prediction.energy[i])
# For i-th forces
num_nodes_before_i = sum(graph.n_node[j] for j in range(0, i))
forces_i = prediction.forces[num_nodes_before_i : num_nodes_before_i + graph.n_node[i]]
Important caveat:
A ForceField
can only process
graphs (of type jraph.GraphsTuple
) that have at least two subgraphs in them.
Calling the force field on a graph that is not formally a batch will result in a
ValueError
. This means that if you are working with these graph objects directly,
make sure a single graph of interest is always batched with a minimal dummy graph.
We recommend to use the function
create_graph_from_chemical_system()
to prepare graphs as this allows to pass the argument
batch_it_with_minimal_dummy=True
for convenience. An example is shown below:
import numpy as np
from mlip.data import ChemicalSystem
from mlip.data.helpers import create_graph_from_chemical_system
# Example H2O molecule:
# - H (Z=1) has specie index 0
# - O (Z=8) has specie index 3 (H, C, N come first)
system = ChemicalSystem(
atomic_numbers = np.array([1, 8, 1]),
atomic_species = np.array([0, 3, 0]),
positions = np.array([[-.5, .0, .0], [.0, .2, .0], [.5, .0, .0]]),
)
graph = create_graph_from_chemical_system(
chemical_system=system,
distance_cutoff_angstrom=5,
batch_it_with_minimal_dummy=True,
)
Load a model from a zip archive¶
To load a model (e.g., MACE) from our lightweight zip format that we ship our
pre-trained models with, you can use the function
load_model_from_zip
:
from mlip.models import Mace
from mlip.models.model_io import load_model_from_zip
force_field = load_model_from_zip(Mace, "path/to/model.zip")
Subsequently, you can use the returned force field
(type: ForceField
) for
any downstream tasks.
Load a trained model from an Orbax checkpoint¶
To load a trained model from an orbax
checkpoint, one can use the
load_parameters_from_checkpoint()
helper function:
from mlip.models import ForceField
from mlip.models.params_loading import load_parameters_from_checkpoint
initial_force_field = _create_initial_force_field() # placeholder
# Load parameters
loaded_params = load_parameters_from_checkpoint(
local_checkpoint_dir="path/to/checkpoint/directory", # must be local
initial_params=initial_force_field.params,
epoch_to_load=157,
load_ema_params=False,
)
# Create new force field with those loaded parameters
force_field = ForceField(initial_force_field.predictor, loaded_params)