Exponential Moving Average (EMA)

mlip.training.ema.exponentially_moving_average(decay: float = 0.99) EMAParameterTransformation

Creates an exponentially moving average (EMA) transformation.

Parameters:

decay – The decay factor for the EMA. Defaults to 0.99.

Returns:

A named tuple containing init and update functions for EMA.

class mlip.training.ema.EMAState(params_ema: dict[str, dict[str, Array | dict]], step: int)

Container for Exponentially Weighted Average state.

params_ema

Exponentially Weighted Average of the parameters.

Type:

dict[str, dict[str, jax.Array | dict]]

step

The current step.

Type:

int

class mlip.training.ema.EMAParameterTransformation(init: Callable, update: Callable)

Container for parameter transformation functions.

init

Function to initialize state.

Type:

Callable

update

Function to update state.

Type:

Callable

mlip.training.ema.get_debiased_params(state: EMAState, decay: float) dict[str, dict[str, Array | dict]]

Gets the debiased parameters from the EMAState.

Parameters:
  • state – The current state.

  • decay – The decay factor for the EMA.

Returns:

The debiased parameters.