Optimizer¶
- mlip.training.optimizer.get_default_mlip_optimizer(config: OptimizerConfig | None = None) GradientTransformation¶
Get a default optimizer for training MLIP models.
This is a specialized optimizer setup that originated in the MACE torch repo: https://github.com/ACEsuit/mace. It is customizable to an extent via the
OptimizerConfig.This optimizer is based on the
optax.amsgradbase optimizer and adds a weight decay transform to it optionally (only should be done for MACE), a gradient clipping step, and a scheduling of the learning rate with possible warm up period. Furthermore, we allow for gradient accumulation if requested.The learning rate schedule works as follows: First there is a period of warmup steps where the learning rate linearly increases from the “initial” to the “peak” learning rate. After that, we have a linearly increasing learning rate from “peak” to “final” learning rate.
See the optimizer config’s documentation for how to customize this default MLIP optimizer.
This function internally uses
get_mlip_optimizer_chain_with_flexible_base_optimizer()which can also be used directly to build an analogous optimizer chain but with a different base optimizer.- Parameters:
config – The optimizer config. Default is
Nonewhich leads to the pure default config being used.- Returns:
The default optimizer.
- mlip.training.optimizer.get_mlip_optimizer_chain_with_flexible_base_optimizer(base_optimizer_factory_fun: Callable[[float], GradientTransformation], config: OptimizerConfig) GradientTransformation¶
Initializes an optimizer (based on optax) as a chain that is derived from a base optimizer class, e.g., optax.amsgrad.
The initialization happens from a base optimizer function, for example,
optax.adam. This base optimizer function must be able to take in the learning rate as a single parameter.The return value of this is a full optimizer pipeline consisting of gradient clipping, warm-up, etc.
- Parameters:
base_optimizer_factory_fun – The base optimizer function which must be able to take in the learning rate as a single parameter.
config – The optimizer pydantic config.
- Returns:
The full optimizer pipeline constructed based on the provided base optimizer function.
- class mlip.training.optimizer_config.OptimizerConfig(*, apply_weight_decay_mask: bool = True, weight_decay: Annotated[float, Ge(ge=0)] = 0.0, grad_norm: Annotated[float, Ge(ge=0)] = 500, num_gradient_accumulation_steps: Annotated[int, Gt(gt=0)] = 1, init_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, peak_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, final_learning_rate: Annotated[float, Gt(gt=0)] = 0.01, warmup_steps: Annotated[int, Ge(ge=0)] = 4000, transition_steps: Annotated[int, Ge(ge=0)] = 360000)¶
Pydantic config holding all settings that are relevant for the optimizer.
- apply_weight_decay_mask¶
Whether to apply a weight decay mask. If set to
False, a weight decay is applied to all parameters. If set toTrue(default), only the parameters of model blocks “linear_down” and “SymmetricContraction” are assigned a weight decay. These blocks only exist for MACE models, and it is recommended for MACE to set this setting toTrue. If it is set toTruebut neither of these blocks exist in the model (like for ViSNet or NequIP), we apply weight decay to all parameters.- Type:
bool
- weight_decay¶
The weight decay with a default of zero.
- Type:
float
- grad_norm¶
Gradient norm used for gradient clipping.
- Type:
float
- num_gradient_accumulation_steps¶
Number of gradient steps to accumulate before taking an optimizer step. Default is 1.
- Type:
int
- init_learning_rate¶
Initial learning rate (default is 0.01).
- Type:
float
- peak_learning_rate¶
Peak learning rate (default is 0.01).
- Type:
float
- final_learning_rate¶
Final learning rate (default is 0.01).
- Type:
float
- warmup_steps¶
Number of optimizer warm-up steps (default is 4000). Check optax’s
linear_schedule()function for more info.- Type:
int
- transition_steps¶
Number of optimizer transition steps (default is 360000). Check optax’s
linear_schedule()function for more info.- Type:
int
- __init__(**data: Any) None¶
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.