Training state

class mlip.training.training_state.TrainingState(params: dict[str, dict[str, ~jax.Array | dict]], optimizer_state: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | ~typing.Iterable[ArrayTree] | ~typing.Mapping[~typing.Any, ArrayTree], ema_state: ~mlip.training.ema.EMAState, num_steps: ~jax.Array, acc_steps: ~jax.Array, key: ~jax.Array, extras: dict | None = <factory>)

Represents the state of training.

params

Model parameters.

Type:

dict[str, dict[str, jax.Array | dict]]

optimizer_state

State of the optimizer.

Type:

jax.Array | numpy.ndarray | numpy.bool_ | numpy.number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]

ema_state

Exponentially weighted average state.

Type:

mlip.training.ema.EMAState

num_steps

The number of training steps taken.

Type:

jax.Array

acc_steps

The number of gradient accumulation steps taken; resets to 0 after each optimizer step.

Type:

jax.Array

key

Pseudo-random number generator key.

Type:

jax.Array

extras

Additional auxiliary information in form of a dictionary.

Type:

dict | None