Fine-tuning utilities

mlip.models.params_transfer.transfer_params(params_source: dict[str, dict[str, Array | dict]], params_destination: dict[str, dict[str, Array | dict]], scale_factor: float = 1.0) dict[str, dict[str, Array | dict]]

Transfer parameters from a source to a destination.

Typically, the destination will be some newly initialized parameters that have some additional blocks in them compared to a source, which is an already trained model. This function will raise an exception if the two parameters deviate more than this from one another.

For new readout heads (*ReadoutBlock_N with N >= 1 that don’t exist in params_source), the params are warm-started by deep-copying *ReadoutBlock_0 from the source, so fine-tuning begins from the pretrained readout weights rather than a random init. Pass scale_factor=0.0 to restore the original “scaled-init” behaviour (useful for tests that want reproducible zero-initialised new blocks).

Parameters:
  • params_source – The parameters to transfer into the destination.

  • params_destination – The destination parameters that may contain additional blocks compared to the source.

  • scale_factor – Scale factor applied to the destination’s random init for any new block that’s neither in the source nor an added readout head. Default is 1.0.

Returns:

The updated destination parameters.

Raises:

ParameterTransferImpossibleError – if the source and destination parameters are incompatible with each other.

mlip.models.params_transfer.count_readout_heads(params: dict[str, dict[str, Array | dict]]) int

Count the number of distinct readout heads in a parameters tree.

A readout head is any *ReadoutBlock_N entry (e.g. LinearReadoutBlock_0, NonLinearReadoutBlock_1); two blocks with the same N but different prefixes (e.g. linear + non-linear at different layers) are one head.

Parameters:

params – A parameters pytree, typically nested dicts of arrays.

Returns:

The number of distinct head indices found anywhere in params.

class mlip.models.params_transfer.ParameterTransferImpossibleError

Exception to be raised if the destination and source parameters deviate more in their structures than just having some missing blocks in the source.