JAX-MD Simulation Engine¶
- class mlip.simulation.jax_md.jax_md_simulation_engine.JaxMDSimulationEngine(atoms: Atoms, force_field: ForceField, config: JaxMDSimulationConfig)¶
Simulation engine handling simulations with the JAX-MD backend.
For MD, the NVT-Langevin algorithm is used (see here). For energy minimization, the FIRE algorithm is used (see here).
- __init__(atoms: Atoms, force_field: ForceField, config: JaxMDSimulationConfig) None ¶
Constructor that initializes the simulation state and an empty list of loggers. Engine-specific initialization is then delegated to
._initialize()
- Parameters:
atoms – The atoms of the system to simulate.
force_field – The force field to use in the simulation.
config – The configuration/settings of the simulation.
- run() None ¶
See documentation of abstract parent class.
For the JAX-MD backend, the simulation run is divided into episodes to ensure usage of jitting of MD/minimization steps for optimal performance.
Important: The state of the simulation is updated and the loggers are called during this function.
- attach_logger(logger: Callable[[SimulationState], None]) None ¶
Adds a logger to the list of loggers of the simulation engine.
The logger function must only take in a single argument, the simulation state, and it shall not return anything.
- Parameters:
logger – The logger to add.