Simulations

This library supports two types of simulations, MD and energy minimizations, with two types of backends, JAX-MD and ASE. Simulations are handled with simulation engine classes, which are implementations of the abstract base class SimulationEngine. One can either use our two implemented engines (JaxMDSimulationEngine and ASESimulationEngine), or implemented custom ones. Each engine comes with its own pydantic config that inherits from SimulationConfig.

Important note on units: The system of units for the inputs and outputs of all simulation types is the ASE unit system.

Important note on logging: There is a subtle difference in which steps the JAX-MD and ASE backends log. While both engines run for n steps, JAX-MD logs N snapshots, the first of which corresponds to the initial (zero-th) state and the last snapshot corresponds to the N-1-th logging step. In contrast, ASE logs N+1 snapshots, the first of which corresponds to the initial (zero-th) state and the last snapshot corresponds to the N-th logging step.

Simulations with JAX-MD

To run a simulation (for example, an MD) with the JAX-MD backend, one can use the following code:

from ase.io import read as ase_read
from mlip.simulation.jax_md import JaxMDSimulationEngine

atoms = ase_read("/path/to/xyz/or/pdb/file")
force_field = _get_a_trained_force_field_from_somewhere()  # placeholder
md_config = JaxMDSimulationEngine.Config(**config_kwargs)

md_engine = JaxMDSimulationEngine(atoms, force_field, md_config)
md_engine.run()

Note that in the example above, _get_a_trained_force_field_from_somewhere() is a placeholder for a function that loads a trained force field, as described either here (Option 1) or here (Option 2). The config class for JAX-MD simulations is JaxMDSimulationConfig and can also be accessed via JaxMDSimulationEngine.Config for the sake of needing fewer imports. The format for the input structure is the commonly used ase.Atoms class (see the ASE docs here).

The result of the simulation is stored in the SimulationState, which can be accessed like this:

md_state = md_engine.state

# Print some data from the simulation:
print(md_state.positions)
print(md_state.temperature)
print(md_state.compute_time_seconds)

Also, we recommend that you take note of the units of the computed properties as described in the SimulationState reference. See our Jupyter notebook on simulations here for more information on how to convert these raw numpy arrays into file formats that can be read by popular MD visualization tools.

Energy minimizations can be run in exactly the same way, possibly using slightly different settings. See the documentation of the JaxMDSimulationConfig class for more details. Most importantly, the simulation_type needs to be set to SimulationType.MINIMIZATION (see SimulationType).

Algorithms: For MD, the NVT-Langevin algorithm is used (see here). For energy minimization, the FIRE algorithm is used (see here). We plan to provide more options in future versions of the library.

Note

A special feature of the JAX-MD backend is that a simulation is divided into multiple episodes. Within one episode, the simulation runs in a fully jitted way. After each episode, the neighbor lists can be reallocated, the simulation state can be populated and loggers can be called.

Simulations with ASE

With ASE, running MD simulations and energy minimizations works in an analogous way as described above. The following code can be used:

from ase.io import read as ase_read
from mlip.simulation.ase.ase_simulation_engine import ASESimulationEngine

atoms = ase_read("/path/to/xyz/or/pdb/file")
force_field = _get_a_trained_force_field_from_somewhere()  # placeholder
md_config = ASESimulationEngine.Config(**config_kwargs)

md_engine = ASESimulationEngine(atoms, force_field, md_config)
md_engine.run()

The config class for ASE simulations is ASESimulationConfig (accessible via ASESimulationEngine.Config). As in the JAX-MD case, the format for the input structure is the ase.Atoms class (see the ASE docs here).

The results of the simulation are stored in the SimulationState object as described in the JAX-MD case above. Also, we recommend that you take note of the units of the computed properties as described in the SimulationState reference.

For the settings required for energy minimizations, check out the documentation of the ASESimulationConfig class. Most importantly, the simulation_type needs to be set to SimulationType.MINIMIZATION (see SimulationType).

Algorithms: For MD, the NVT-Langevin algorithm is used (see here). For energy minimization, the BFGS algorithm is used (see here). We plan to provide more options in future versions of the library.

Temperature Scheduling

It is also possible to add a temperature schedule to both simulation engines, check out the documentation of the TemperatureScheduleConfig class for more details. This is done by creating an instance of TemperatureScheduleConfig and passing it under the variable name temperature_schedule_config to either ASESimulationConfig or JaxMDSimulationConfig. By default, the method is CONSTANT, which means the target temperature is set at the start of the simulation and kept constant throughout its entirety. However, other methods are available: LINEAR and TRIANGLE. If you want to use a temperature schedule, you can set the method attribute to an instance of the TemperatureScheduleMethod class and ensure that any other required parameters for the different methods have been set appropriately. The temperature schedule methods are described here for more information.

Below we provide an example of how to use a linear schedule that will heat the system from 300 K to 600 K when using the JAX-MD simulation backend:

from mlip.simulation.configs import TemperatureScheduleConfig
from mlip.simulation.jax_md import JaxMDSimulationEngine
from mlip.simulation.enums import TemperatureScheduleMethod

temp_schedule_config = TemperatureScheduleConfig(
    method=TemperatureScheduleMethod.LINEAR,
    start_temperature=300.0,
    end_temperature=600.0
)
md_config = JaxMDSimulationEngine.Config(
    temperature_schedule_config=temp_schedule_config,
    **config_kwargs
)

# Go on to initialize a simulation with this config

Advanced logging

The SimulationEngine allows to attach custom loggers to a simulation:

from mlip.simulation.state import SimulationState

def logging_fun(state: SimulationState) -> None:
    """You can do anything with the given state here"""
    _log_something()  # placeholder

md_engine.attach_logger(logging_fun)

The logger must be attached before starting the simulation. In ASE, this logging function will be called depending on the logging interval set, and in JAX-MD, it will be called after every episode.

Batched inference

Instead of running MD simulations or energy minimizations, we also provide the function run_batched_inference() that allows to input a list of ase.Atoms objects and returns a list of Prediction objects like this:

from mlip.inference import run_batched_inference

structures = _get_list_of_ase_atoms_from_somewhere()  # placeholder
force_field = _get_a_trained_force_field_from_somewhere()  # placeholder
predictions = run_batched_inference(structures, force_field, batch_size=8)

# Example: Get energy and forces for 7-th structure (indexing starts at 0)
energy = predictions[7].energy
forces = predictions[7].forces