Optimizer¶
- mlip.training.optimizer.get_default_mlip_optimizer(config: OptimizerConfig | None = None) GradientTransformation ¶
Get a default optimizer for training MLIP models.
This is a specialized optimizer setup that originated in the MACE torch repo: https://github.com/ACEsuit/mace. It is customizable to an extent via the
OptimizerConfig
.This optimizer is based on the
optax.amsgrad
base optimizer and adds a weight decay transform to it optionally (only should be done for MACE), a gradient clipping step, and a scheduling of the learning rate with possible warm up period. Furthermore, we allow for gradient accumulation if requested.The learning rate schedule works as follows: First there is a period of warmup steps where the learning rate linearly increases from the “initial” to the “peak” learning rate. After that, we have a linearly increasing learning rate from “peak” to “final” learning rate.
See the optimizer config’s documentation for how to customize this default MLIP optimizer.
This function internally uses
get_mlip_optimizer_chain_with_flexible_base_optimizer()
which can also be used directly to build an analogous optimizer chain but with a different base optimizer.- Parameters:
config – The optimizer config. Default is
None
which leads to the pure default config being used.- Returns:
The default optimizer.
- mlip.training.optimizer.get_mlip_optimizer_chain_with_flexible_base_optimizer(base_optimizer_factory_fun: Callable[[float], GradientTransformation], config: OptimizerConfig) GradientTransformation ¶
Initializes an optimizer (based on optax) as a chain that is derived from a base optimizer class, e.g., optax.amsgrad.
The initialization happens from a base optimizer function, for example,
optax.adam
. This base optimizer function must be able to take in the learning rate as a single parameter.The return value of this is a full optimizer pipeline consisting of gradient clipping, warm-up, etc.
- Parameters:
base_optimizer_factory_fun – The base optimizer function which must be able to take in the learning rate as a single parameter.
config – The optimizer pydantic config.
- Returns:
The full optimizer pipeline constructed based on the provided base optimizer function.
- class mlip.training.optimizer_config.OptimizerConfig(*, apply_weight_decay_mask: bool = True, weight_decay: Annotated[float, Ge(ge=0)] = 0.0, grad_norm: Annotated[float, Ge(ge=0)] = 500, num_gradient_accumulation_steps: Annotated[int, Gt(gt=0)] = 1, init_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, peak_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, final_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, warmup_steps: Annotated[int, Ge(ge=0)] = 4000, transition_steps: Annotated[int, Ge(ge=0)] = 360000)¶
Pydantic config holding all settings that are relevant for the optimizer.
- apply_weight_decay_mask¶
Whether to apply a weight decay mask. If set to
False
, a weight decay is applied to all parameters. If set toTrue
(default), only the parameters of model blocks “linear_down” and “SymmetricContraction” are assigned a weight decay. These blocks only exist for MACE models, and it is recommended for MACE to set this setting toTrue
. If it is set toTrue
but neither of these blocks exist in the model (like for ViSNet or NequIP), we apply weight decay to all parameters.- Type:
bool
- weight_decay¶
The weight decay with a default of zero.
- Type:
float
- grad_norm¶
Gradient norm used for gradient clipping.
- Type:
float
- num_gradient_accumulation_steps¶
Number of gradient steps to accumulate before taking an optimizer step. Default is 1.
- Type:
int
- init_learning_rate¶
Initial learning rate (default is 0.01).
- Type:
float
- peak_learning_rate¶
Peak learning rate (default is 0.01).
- Type:
float
- final_learning_rate¶
Final learning rate (default is 0.01).
- Type:
float
- warmup_steps¶
Number of optimizer warm-up steps (default is 4000). Check optax’s
linear_schedule()
function for more info.- Type:
int
- transition_steps¶
Number of optimizer transition steps (default is 360000). Check optax’s
linear_schedule()
function for more info.- Type:
int
- __init__(**data: Any) None ¶
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError
][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.self
is explicitly positional-only to allowself
as a field name.