Training state¶
- class mlip.training.training_state.TrainingState(params: dict[str, dict[str, ~jax.Array | dict]], optimizer_state: ~jax.Array | ~numpy.ndarray | ~numpy.bool_ | ~numpy.number | ~typing.Iterable[ArrayTree] | ~typing.Mapping[~typing.Any, ArrayTree], ema_state: ~mlip.training.ema.EMAState, num_steps: ~jax.Array, acc_steps: ~jax.Array, key: ~jax.Array, extras: dict | None = <factory>)¶
Represents the state of training.
- params¶
Model parameters.
- Type:
dict[str, dict[str, jax.Array | dict]]
- optimizer_state¶
State of the optimizer.
- Type:
jax.Array | numpy.ndarray | numpy.bool_ | numpy.number | Iterable[ArrayTree] | Mapping[Any, ArrayTree]
- ema_state¶
Exponentially weighted average state.
- num_steps¶
The number of training steps taken.
- Type:
jax.Array
- acc_steps¶
The number of gradient accumulation steps taken; resets to 0 after each optimizer step.
- Type:
jax.Array
- key¶
Pseudo-random number generator key.
- Type:
jax.Array
- extras¶
Additional auxiliary information in form of a dictionary.
- Type:
dict | None