Exponential Moving Average (EMA)¶
- mlip.training.ema.exponentially_moving_average(decay: float = 0.99) EMAParameterTransformation ¶
Creates an exponentially moving average (EMA) transformation.
- Parameters:
decay – The decay factor for the EMA. Defaults to 0.99.
- Returns:
A named tuple containing init and update functions for EMA.
- class mlip.training.ema.EMAState(params_ema: dict[str, dict[str, Array | dict]], step: int)¶
Container for Exponentially Weighted Average state.
- params_ema¶
Exponentially Weighted Average of the parameters.
- Type:
dict[str, dict[str, jax.Array | dict]]
- step¶
The current step.
- Type:
int