Simulations¶
This library supports two types of simulations,
MD and energy minimizations, with
two types of backends, JAX-MD and ASE. Simulations are
handled with simulation engine classes, which are implementations of the abstract
base class
SimulationEngine
.
One can either use our two implemented engines
(JaxMDSimulationEngine
and
ASESimulationEngine
),
or implemented custom ones. Each engine comes with its own pydantic config that
inherits from
SimulationConfig
.
Important note on units: The system of units for the inputs and outputs of all simulation types is the ASE unit system.
Important note on logging: There is a subtle difference in which steps the JAX-MD and ASE backends log. While both engines run for n steps, JAX-MD logs N snapshots, the first of which corresponds to the initial (zero-th) state and the last snapshot corresponds to the N-1-th logging step. In contrast, ASE logs N+1 snapshots, the first of which corresponds to the initial (zero-th) state and the last snapshot corresponds to the N-th logging step.
Simulations with JAX-MD¶
To run a simulation (for example, an MD) with the JAX-MD backend, one can use the following code:
from ase.io import read as ase_read
from mlip.simulation.jax_md import JaxMDSimulationEngine
atoms = ase_read("/path/to/xyz/or/pdb/file")
force_field = _get_a_trained_force_field_from_somewhere() # placeholder
md_config = JaxMDSimulationEngine.Config(**config_kwargs)
md_engine = JaxMDSimulationEngine(atoms, force_field, md_config)
md_engine.run()
Note that in the example above, _get_a_trained_force_field_from_somewhere()
is a
placeholder for a function that loads a trained force field, as described either
here (Option 1) or here (Option 2).
The config class for JAX-MD simulations is
JaxMDSimulationConfig
and can also be accessed via JaxMDSimulationEngine.Config
for the sake of needing
fewer imports. The format for the input structure is the commonly used ase.Atoms
class (see the ASE docs here).
The result of the simulation is stored in the
SimulationState
, which can
be accessed like this:
md_state = md_engine.state
# Print some data from the simulation:
print(md_state.positions)
print(md_state.temperature)
print(md_state.compute_time_seconds)
Also, we recommend that you take note of the units
of the computed properties as described in the
SimulationState
reference. See
our Jupyter notebook on simulations here for
more information on how to convert these raw numpy arrays into file
formats that can be read by popular MD visualization tools.
Energy minimizations can be run in exactly the same way, possibly using slightly
different settings. See the documentation of the
JaxMDSimulationConfig
class for more details. Most importantly, the simulation_type
needs to be set to
SimulationType.MINIMIZATION
(see
SimulationType
).
Algorithms: For MD, the NVT-Langevin algorithm is used (see here). For energy minimization, the FIRE algorithm is used (see here). We plan to provide more options in future versions of the library.
Note
A special feature of the JAX-MD backend is that a simulation is divided into multiple episodes. Within one episode, the simulation runs in a fully jitted way. After each episode, the neighbor lists can be reallocated, the simulation state can be populated and loggers can be called.
Simulations with ASE¶
With ASE, running MD simulations and energy minimizations works in an analogous way as described above. The following code can be used:
from ase.io import read as ase_read
from mlip.simulation.ase.ase_simulation_engine import ASESimulationEngine
atoms = ase_read("/path/to/xyz/or/pdb/file")
force_field = _get_a_trained_force_field_from_somewhere() # placeholder
md_config = ASESimulationEngine.Config(**config_kwargs)
md_engine = ASESimulationEngine(atoms, force_field, md_config)
md_engine.run()
The config class for ASE simulations is
ASESimulationConfig
(accessible via ASESimulationEngine.Config
).
As in the JAX-MD case, the format for the input structure is the ase.Atoms
class
(see the ASE docs here).
The results of the simulation are stored in the
SimulationState
object as
described in the JAX-MD case above. Also, we recommend that you take note of the units
of the computed properties as described in the
SimulationState
reference.
For the settings required for energy minimizations, check out the documentation of the
ASESimulationConfig
class. Most importantly, the simulation_type
needs to be set to
SimulationType.MINIMIZATION
(see
SimulationType
).
Algorithms: For MD, the NVT-Langevin algorithm is used (see here). For energy minimization, the BFGS algorithm is used (see here). We plan to provide more options in future versions of the library.
Temperature Scheduling¶
It is also possible to add a temperature schedule to both simulation engines,
check out the documentation of the
TemperatureScheduleConfig
class for more details. This is done by creating an instance of
TemperatureScheduleConfig
and passing it under the variable name temperature_schedule_config
to either
ASESimulationConfig
or JaxMDSimulationConfig
.
By default, the method is CONSTANT
, which means the target temperature is set at the
start of the simulation and kept constant throughout its entirety.
However, other methods are available: LINEAR
and TRIANGLE
.
If you want to use a temperature schedule, you can set the method
attribute to an instance of the
TemperatureScheduleMethod
class and ensure that any other required parameters for the different methods
have been set appropriately.
The temperature schedule methods
are described here for more information.
Below we provide an example of how to use a linear schedule that will heat the system from 300 K to 600 K when using the JAX-MD simulation backend:
from mlip.simulation.configs import TemperatureScheduleConfig
from mlip.simulation.jax_md import JaxMDSimulationEngine
from mlip.simulation.enums import TemperatureScheduleMethod
temp_schedule_config = TemperatureScheduleConfig(
method=TemperatureScheduleMethod.LINEAR,
start_temperature=300.0,
end_temperature=600.0
)
md_config = JaxMDSimulationEngine.Config(
temperature_schedule_config=temp_schedule_config,
**config_kwargs
)
# Go on to initialize a simulation with this config
Advanced logging¶
The SimulationEngine
allows to attach custom loggers to a simulation:
from mlip.simulation.state import SimulationState
def logging_fun(state: SimulationState) -> None:
"""You can do anything with the given state here"""
_log_something() # placeholder
md_engine.attach_logger(logging_fun)
The logger must be attached before starting the simulation. In ASE, this logging function will be called depending on the logging interval set, and in JAX-MD, it will be called after every episode.
Batched inference¶
Instead of running MD simulations or energy minimizations,
we also provide the function
run_batched_inference()
that allows to input a list of ase.Atoms
objects and returns a list of
Prediction
objects like this:
from mlip.inference import run_batched_inference
structures = _get_list_of_ase_atoms_from_somewhere() # placeholder
force_field = _get_a_trained_force_field_from_somewhere() # placeholder
predictions = run_batched_inference(structures, force_field, batch_size=8)
# Example: Get energy and forces for 7-th structure (indexing starts at 0)
energy = predictions[7].energy
forces = predictions[7].forces