Optimizer

mlip.training.optimizer.get_default_mlip_optimizer(config: OptimizerConfig | None = None) GradientTransformation

Get a default optimizer for training MLIP models.

This is a specialized optimizer setup that originated in the MACE torch repo: https://github.com/ACEsuit/mace. It is customizable to an extent via the OptimizerConfig.

This optimizer is based on the optax.amsgrad base optimizer and adds a weight decay transform to it optionally (only should be done for MACE), a gradient clipping step, and a scheduling of the learning rate with possible warm up period. Furthermore, we allow for gradient accumulation if requested.

The learning rate schedule works as follows: First there is a period of warmup steps where the learning rate linearly increases from the “initial” to the “peak” learning rate. After that, we have a linearly increasing learning rate from “peak” to “final” learning rate.

See the optimizer config’s documentation for how to customize this default MLIP optimizer.

This function internally uses get_mlip_optimizer_chain_with_flexible_base_optimizer() which can also be used directly to build an analogous optimizer chain but with a different base optimizer.

Parameters:

config – The optimizer config. Default is None which leads to the pure default config being used.

Returns:

The default optimizer.

mlip.training.optimizer.get_mlip_optimizer_chain_with_flexible_base_optimizer(base_optimizer_factory_fun: Callable[[float], GradientTransformation], config: OptimizerConfig) GradientTransformation

Initializes an optimizer (based on optax) as a chain that is derived from a base optimizer class, e.g., optax.amsgrad.

The initialization happens from a base optimizer function, for example, optax.adam. This base optimizer function must be able to take in the learning rate as a single parameter.

The return value of this is a full optimizer pipeline consisting of gradient clipping, warm-up, etc.

Parameters:
  • base_optimizer_factory_fun – The base optimizer function which must be able to take in the learning rate as a single parameter.

  • config – The optimizer pydantic config.

Returns:

The full optimizer pipeline constructed based on the provided base optimizer function.

class mlip.training.optimizer_config.OptimizerConfig(*, apply_weight_decay_mask: bool = True, weight_decay: Annotated[float, Ge(ge=0)] = 0.0, grad_norm: Annotated[float, Ge(ge=0)] = 500, num_gradient_accumulation_steps: Annotated[int, Gt(gt=0)] = 1, init_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, peak_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, final_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, warmup_steps: Annotated[int, Ge(ge=0)] = 4000, transition_steps: Annotated[int, Ge(ge=0)] = 360000)

Pydantic config holding all settings that are relevant for the optimizer.

apply_weight_decay_mask

Whether to apply a weight decay mask. If set to False, a weight decay is applied to all parameters. If set to True (default), only the parameters of model blocks “linear_down” and “SymmetricContraction” are assigned a weight decay. These blocks only exist for MACE models, and it is recommended for MACE to set this setting to True. If it is set to True but neither of these blocks exist in the model (like for ViSNet or NequIP), we apply weight decay to all parameters.

Type:

bool

weight_decay

The weight decay with a default of zero.

Type:

float

grad_norm

Gradient norm used for gradient clipping.

Type:

float

num_gradient_accumulation_steps

Number of gradient steps to accumulate before taking an optimizer step. Default is 1.

Type:

int

init_learning_rate

Initial learning rate (default is 0.01).

Type:

float

peak_learning_rate

Peak learning rate (default is 0.01).

Type:

float

final_learning_rate

Final learning rate (default is 0.01).

Type:

float

warmup_steps

Number of optimizer warm-up steps (default is 4000). Check optax’s linear_schedule() function for more info.

Type:

int

transition_steps

Number of optimizer transition steps (default is 360000). Check optax’s linear_schedule() function for more info.

Type:

int

__init__(**data: Any) None

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.