JAX-MD Simulation Engine

class mlip.simulation.jax_md.jax_md_simulation_engine.JaxMDSimulationEngine(atoms: Atoms, force_field: ForceField, config: JaxMDSimulationConfig)

Simulation engine handling simulations with the JAX-MD backend.

For MD, the NVT-Langevin algorithm is used (see here). For energy minimization, the FIRE algorithm is used (see here).

__init__(atoms: Atoms, force_field: ForceField, config: JaxMDSimulationConfig) None

Constructor that initializes the simulation state and an empty list of loggers. Engine-specific initialization is then delegated to ._initialize()

Parameters:
  • atoms – The atoms of the system to simulate.

  • force_field – The force field to use in the simulation.

  • config – The configuration/settings of the simulation.

run() None

See documentation of abstract parent class.

For the JAX-MD backend, the simulation run is divided into episodes to ensure usage of jitting of MD/minimization steps for optimal performance.

Important: The state of the simulation is updated and the loggers are called during this function.

attach_logger(logger: Callable[[SimulationState], None]) None

Adds a logger to the list of loggers of the simulation engine.

The logger function must only take in a single argument, the simulation state, and it shall not return anything.

Parameters:

logger – The logger to add.