Fine-tuning utilities¶
- mlip.models.params_transfer.transfer_params(params_source: dict[str, dict[str, Array | dict]], params_destination: dict[str, dict[str, Array | dict]], scale_factor: float = 1.0) dict[str, dict[str, Array | dict]]¶
Transfer parameters from a source to a destination.
Typically, the destination will be some newly initialized parameters that have some additional blocks in them compared to a source, which is an already trained model. This function will raise an exception if the two parameters deviate more than this from one another.
For new readout heads (
*ReadoutBlock_Nwith N >= 1 that don’t exist inparams_source), the params are warm-started by deep-copying*ReadoutBlock_0from the source, so fine-tuning begins from the pretrained readout weights rather than a random init. Passscale_factor=0.0to restore the original “scaled-init” behaviour (useful for tests that want reproducible zero-initialised new blocks).- Parameters:
params_source – The parameters to transfer into the destination.
params_destination – The destination parameters that may contain additional blocks compared to the source.
scale_factor – Scale factor applied to the destination’s random init for any new block that’s neither in the source nor an added readout head. Default is 1.0.
- Returns:
The updated destination parameters.
- Raises:
ParameterTransferImpossibleError – if the source and destination parameters are incompatible with each other.
- mlip.models.params_transfer.count_readout_heads(params: dict[str, dict[str, Array | dict]]) int¶
Count the number of distinct readout heads in a parameters tree.
A readout head is any
*ReadoutBlock_Nentry (e.g.LinearReadoutBlock_0,NonLinearReadoutBlock_1); two blocks with the sameNbut different prefixes (e.g. linear + non-linear at different layers) are one head.- Parameters:
params – A parameters pytree, typically nested dicts of arrays.
- Returns:
The number of distinct head indices found anywhere in
params.
- class mlip.models.params_transfer.ParameterTransferImpossibleError¶
Exception to be raised if the destination and source parameters deviate more in their structures than just having some missing blocks in the source.