Fine-tuning the PET-MAD universal potential

Authors:

Davide Tisi @DavideTisi, Zhiyi Wang @0WellyWang0, Cesare Malosso @cesaremalosso, Sofiia Chorna @sofiia-chorna

This example demonstrates how to fine-tune the PET-MAD universal ML potential on a new dataset using metatrain. This allows adapting the model to a specialized task by retraining it on a more focused, domain-specific dataset.

PET-MAD is a universal machine-learning forcefield trained on the MAD dataset that aims to incorporate a very high degree of structural diversity. The model itself is the Point-Edge Transformer (PET), an unconstrained architecture that achieves symmetry compliance through data augmentation during training. You can see an overview of its usage in this introductory example.

The goal of this recipe is to demonstrate the process, not to reach high accuracy. Adjust the dataset size and hyperparameters accordingly if adapting this for an actual application.

# sphinx_gallery_thumbnail_number = 2
import os
import subprocess
from collections import Counter
from glob import glob
from urllib.request import urlretrieve

import ase.io
import matplotlib.pyplot as plt
import numpy as np
import torch
from metatrain.pet import PET
from sklearn.linear_model import LinearRegression


if hasattr(__import__("builtins"), "get_ipython"):
    get_ipython().run_line_magic("matplotlib", "inline")  # noqa: F821

Preparing for fine-tuning

While PET-MAD is trained as a universal model capable of handling a broad range of atomic environments, fine-tuning it allows to adapt this general-purpose model to a more specialized task. First, we need to get a checkpoint of the pre-trained PET-MAD model to start training from it. The checkpoint is stored in the lab-codmo Hugging Face repository and can be fetched using wget or curl:

wget https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt # noqa: E501

We’ll download it directly:

url = (
    "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt"
)
checkpoint_path = "pet-mad-v1.1.0.ckpt"

if not os.path.exists(checkpoint_path):
    urlretrieve(url, checkpoint_path)

Applying an atomic energy correction

DFT-calculated energies often contain systematic shifts due to the choice of functional, basis set, or pseudopotentials. If left uncorrected, such shifts can mislead the fine-tuning process.

On this example we use the sampled subset of ethanol structures from rMD17 dataset with PBE/def2-SVP level of theory which is different from the MAD which uses PBEsol and a plane-waves basis set. We apply a linear correction based on atomic compositions to align our fine-tuning dataset with PET-MAD energy reference. First, we define a helper function to load reference energies from PET-MAD.

def load_reference_energies(checkpoint_path):
    """
    Extract atomic reference energies from the PET-MAD checkpoint.

    It returns a mapping of elements to their reference energies (eV), e.g.: {'1':
    -1.23, '2': -5.67, ...}
    """
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    pet_model = PET.load_checkpoint(checkpoint, "finetune")

    energy_values = pet_model.additive_models[0].weights["energy"].block().values
    atomic_numbers = checkpoint["model_data"]["dataset_info"].atomic_types

    return dict(zip(atomic_numbers, energy_values))

For demonstration, the dataset is composed only of 100 structures of ethanol. We fit a linear model based on atomic compositions that we use as the energy correction.

dataset = ase.io.read("data/ethanol.xyz", index=":", format="extxyz")

# Extract DFT energies and compositions
dft_energies = [atoms.get_potential_energy() for atoms in dataset]
compositions = [Counter(atoms.get_atomic_numbers()) for atoms in dataset]
elements = sorted({element for composition in compositions for element in composition})

X = np.array([[comp.get(elem, 0) for elem in elements] for comp in compositions])
y = np.array(dft_energies)

# Fit linear model to estimate DFT per-element energy
correction_model = LinearRegression()
correction_model.fit(X, y)

coeffs = dict(zip(elements, correction_model.coef_))

Apply a correction to each structure

def get_compositional_energy(atoms, energy_per_atom):
    """Calculates total energy from atomic composition and per-atom energies"""
    counts = Counter(atoms.get_atomic_numbers())
    return sum(energy_per_atom[Z] * count for Z, count in counts.items())


# Get reference energies from PET-MAD
ref_energies = load_reference_energies(checkpoint_path)

# Apply correction
for atoms, E_dft in zip(dataset, dft_energies):
    E_comp_dft = get_compositional_energy(atoms, coeffs)
    E_comp_ref = get_compositional_energy(atoms, ref_energies)

    corrected_energy = E_dft - E_comp_dft + E_comp_ref - correction_model.intercept_

    atoms.info["energy-corrected"] = corrected_energy.item()


# Split corrected dataset and save it
np.random.seed(42)
indices = np.random.permutation(len(dataset))
n = len(dataset)
n_val = n_test = int(0.1 * n)
n_train = n - n_val - n_test

train = [dataset[i] for i in indices[:n_train]]
val = [dataset[i] for i in indices[n_train : n_train + n_val]]
test = [dataset[i] for i in indices[n_train + n_val :]]

ase.io.write("data/ethanol_train.xyz", train, format="extxyz")
ase.io.write("data/ethanol_val.xyz", val, format="extxyz")
ase.io.write("data/ethanol_test.xyz", test, format="extxyz")
/home/runner/work/atomistic-cookbook/atomistic-cookbook/.nox/pet-finetuning/lib/python3.12/site-packages/metatrain/pet/model.py:99: UserWarning: PET assumes that Cartesian tensors of rank 2 are stress-like, meaning that they are symmetric and intensive. If this is not the case, please use a different model.
  self._add_output(target_name, target_info)

Defines some helper functions

We also define a few helper functions to visualize the training results. Each training run generates a log, stored in CSV format in the outputs folder.

def parse_training_log(csv_file):
    with open(csv_file, encoding="utf-8") as f:
        headers = f.readline().strip().split(",")

    cleaned_names = [h.strip().replace(" ", "_") for h in headers]

    train_log = np.genfromtxt(
        csv_file,
        delimiter=",",
        skip_header=2,
        names=cleaned_names,
    )

    return train_log


def display_training_curves(train_log, ax=None, style="-", label=""):
    """Plots training and validation losses from the training log"""

    if ax is None:
        _, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
    else:
        ax1, ax2 = ax

    ax1.loglog(
        train_log["Epoch"],
        train_log["training_energy_MAE_per_atom"],
        f"r{style}",
        label=f"Train. {label}",
    )
    ax1.loglog(
        train_log["Epoch"],
        train_log["validation_energy_MAE_per_atom"],
        f"b{style}",
        label=f"Valid. {label}",
    )

    ax2.loglog(
        train_log["Epoch"],
        train_log["training_forces_MAE"],
        f"r{style}",
        label="Training, F",
    )
    ax2.loglog(
        train_log["Epoch"],
        train_log["validation_forces_MAE"],
        f"b{style}",
        label="Validation, F",
    )

    ax1.set_xlabel("Epoch")
    ax2.set_xlabel("Epoch")
    ax1.set_ylabel("Energy MAE (meV)")
    ax2.set_ylabel("Force MAE (meV/Å)")
    ax1.legend()

    return ax1, ax2

Train a model from scratch

We first fit a PET model to the corrected dataset to establish a baseline. We use the metatrain utility to train the model. The training is configured in a YAML file, which specifies training, validation and test set, model architecture, optimizer settings, etc. You can learn more about the different settings in the metatrain documentation.

seed: 42

architecture:
  name: pet

  training:
    batch_size: 16
    num_epochs: 50
    num_epochs_warmup: 10
    learning_rate: 1e-4
    

# this needs specifying based on the specific dataset
training_set:
  systems: 
    read_from: "data/ethanol_train.xyz"  # path to the finetuning dataset
    length_unit: angstrom
  targets:
    energy:
      key: "energy-corrected"  # name of the target value
      unit: "eV"

validation_set:
  systems: 
    read_from: "data/ethanol_val.xyz"
    length_unit: angstrom
  targets:
    energy:
      key: "energy-corrected"
      unit: "eV"

test_set:
  systems: 
    read_from: "data/ethanol_test.xyz"
    length_unit: angstrom
  targets:
    energy:
      key: "energy-corrected"
      unit: "eV"

To launch training, you just need to run the following command in the terminal:

mtt train <options.yaml> [-o <output.pt>]

Or from Python:

subprocess.run(
    ["mtt", "train", "from_scratch_options.yaml", "-o", "from_scratch-model.pt"],
    check=True,
)
CompletedProcess(args=['mtt', 'train', 'from_scratch_options.yaml', '-o', 'from_scratch-model.pt'], returncode=0)

The training logs are stored in the outputs/ directory, with a subdirectory named by the date and time of the training run. The model checkpoint is saved as model.ckpt and the exported model as model.pt, unless specified otherwise with the -o option.

We can load the latest training log and visualize the training curves - we will use them later to compare the fine-tuning results. It is clear that training is not converged, and the learning rate is not optimal – you can try to adjust the parameters and run for longer.

csv_file = sorted(glob("outputs/*/*/train.csv"))[-1]
from_scratch_log = parse_training_log(csv_file)
display_training_curves(from_scratch_log)
pet ft
(<Axes: xlabel='Epoch', ylabel='Energy MAE (meV)'>, <Axes: xlabel='Epoch', ylabel='Force MAE (meV/Å)'>)

Simple model fine-tuning

Having prepared the dataset and fitted a baseline model “from scratch”, we proceed with the training of a fine-tuned model. To this end, we also use the metatrain package. There are multiple strategies to apply fine-tuning, each described in the documentation. In this example we demostrate a basic full fine-tuning strategy, which adapts all model weights to the new dataset starting from the pre-trained PET-MAD checkpoint. The process is configured by setting appropriate settings in the YAML options file.

seed: 42

architecture:
  name: "pet"
  training:
    num_epochs: 50  # very short period for demostration
    num_epochs_warmup: 1
    learning_rate: 1e-5  # small learning rate to stabilize training
    finetune:
      method: "full"  # use fine-tuning strategy
      read_from: pet-mad-v1.1.0.ckpt  # path to the pretrained checkpoint to start from

training_set:
  systems: 
    read_from: "data/ethanol_train.xyz"  # path to the finetuning dataset
    length_unit: angstrom
  targets:
    energy:
      key: "energy-corrected"  # name of the target value
      unit: "eV"

validation_set:
  systems: 
    read_from: "data/ethanol_val.xyz"
    length_unit: angstrom
  targets:
    energy:
      key: "energy-corrected"
      unit: "eV"

test_set:
  systems: 
    read_from: "data/ethanol_test.xyz"
    length_unit: angstrom
  targets:
    energy:
      key: "energy-corrected"
      unit: "eV"
subprocess.run(
    ["mtt", "train", "full_ft_options.yaml", "-o", "fine_tune-model.pt"], check=True
)
CompletedProcess(args=['mtt', 'train', 'full_ft_options.yaml', '-o', 'fine_tune-model.pt'], returncode=0)

Comparing the model tranined from scratch (dashed lines) and the fine-tuned one (full lines), it is clear that fine-tuning from PET-MAD weights leads to much better zero-shot accuracy, and more consistent learning dynamics. Obviously it may be possible to tweak differently, and it is not unlikely that a large single-purpose training set and long training time might lead to better validation error than performing fine tuning.

csv_file = sorted(glob("outputs/*/*/train.csv"))[-1]
fine_tune_log = parse_training_log(csv_file)

ax = display_training_curves(fine_tune_log, label="Fine tuning")
display_training_curves(from_scratch_log, ax=ax, style="--", label="From scratch")
ax[0].set_ylim(1, 1000)
pet ft
(1, 1000)

Model evaluation

After the training, mtt train outputs the fine_tune-model.ckpt and fine_tune-model.pt (exported fine-tuned model) files in both the current directory and in output/YYYY-MM-DD/HH-MM-SS/.

These can be used together with metatrain to evaluate the model on a (potentially different) dataset. The evaluation is configured in a YAML file, which specifies the dataset to use, and the metrics to compute.

systems: 
  read_from: data/ethanol_test.xyz
targets: 
  energy:
    key: energy-corrected
    unit: eV

The evaluation can be run from the command line:

mtt eval fine_tune-model.pt model_eval.yaml

Or from Python:

subprocess.run(["mtt", "eval", "fine_tune-model.pt", "model_eval.yaml"], check=True)
CompletedProcess(args=['mtt', 'eval', 'fine_tune-model.pt', 'model_eval.yaml'], returncode=0)

Total running time of the script: (2 minutes 50.531 seconds)

Gallery generated by Sphinx-Gallery