Model fine-tuningΒΆ

This guide describes how to run multi-head fine-tuning (MHFT) with the mlip library.

Note

As of version 0.2.0, the fine-tuning methodology has been revised and is now available for all implemented architectures: MACE, NequIP, ViSNet, and eSEN.

There are the following steps to the process:

  • Load the pre-trained force field.

  • Process the dataset consisting of replay data (from the original dataset) and fine-tuning dataset(s). Important: reuse the DatasetInfo from the pre-trained model for the replay dataset.

  • Initialize a new force field with multiple readout heads and transfer pre-trained parameters into it.

  • Train the model.

  • Run inference with the model with a specific readout head.

Below, the details for each step are explained.

First, we load the pre-trained force field, e.g., MACE:

from mlip.models import Mace
from mlip.models.model_io import load_model_from_zip

pretrained_force_field = load_model_from_zip(Mace, "path/to/model.zip")

Then, we process the dataset with the BuilderMode.MULTI mode as described in the data processing guide here. Make sure to pass the pretrained_force_field.dataset_info of the pre-trained model to the graph dataset builder, because for the replay dataset, the dataset info will be reused from the pre-trained model.

Next, we instantiate the new force field with multiple readout heads. The rule is one head per dataset key passed to the MULTI builder (i.e. len(readers)): for example, two heads if we fine-tune on one dataset while still having the replay data. To inspect how many heads a pretrained model has, use count_readout_heads().

from mlip.models import Mace, ForceField

dataset_info = graph_dataset_builder.dataset_info

mace = Mace(Mace.Config(num_readout_heads=2), dataset_info)
force_field = ForceField.from_mlip_network(mace)

Transferring the pre-trained parameters to this new force field can be done like this:

from dataclasses import replace
from mlip.models.params_transfer import transfer_params

transferred_params = transfer_params(
    pretrained_force_field.params,
    force_field.params,
)

force_field = replace(force_field, params=transferred_params)

The above example uses the function transfer_params(), see its API reference for details. By default, newly added readout heads are warm-started by deep-copying the pretrained head-0 weights, so fine-tuning begins from pretrained readout values rather than a random init. Pass scale_factor=0.0 to fall back to a scaled random init for new blocks instead.

As a next step, we train the model as is described in the model training user guide.

Running inference with the trained model, for example, after it has been saved to zip format, is straightforward via the InferenceContext concept. In the example below, we assume that the name given to the fine-tuning dataset was "ft" and the model was saved to "path/to/model.zip". Note that the dataset_name value must match one of the keys passed to readers when building the dataset (it is resolved against DatasetInfo.dataset_name):

from mlip.models import InferenceContext, Mace
from mlip.models.model_io import load_model_from_zip

force_field_ft = load_model_from_zip(
    Mace,
    "path/to/model.zip",
    inference_context=InferenceContext(dataset_name="ft"),
)

graph = _get_graph_from_somewhere()  # placeholder
prediction = force_field_ft(graph)

Of course, technically, the force field can also be loaded with the context dataset_name="replay". After being loaded like above, in addition to single graph inference, this model can also be used in simulations and batched inference.