Note
Go to the end to download the full example code.
Learning Capabilities with torchpme¶
- Authors:
Egor Rumiantsev @E-Rum; Philip Loche @PicoCentauri
This example demonstrates the capabilities of the torchpme package, focusing on
learning target charges and utilizing the CombinedPotential
class to evaluate
potentials that combine multiple pairwise interactions with optimizable weights
.
The weights
are optimized to reproduce the energy of a system interacting purely
through Coulomb forces.
from typing import Dict
import ase.io
import ase.visualize.plot
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchpme import CombinedPotential, EwaldCalculator, InversePowerLawPotential
from vesin import NeighborList
Select computation device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
dtype = torch.float32
prefactor = 0.5292 # Unit conversion prefactor.
Load the dataset¶
The dataset consists of atomic configurations with reference energies.
frames = ase.io.read("data/point_charges_Training_set.xyz", ":10")
Define model parameters
cell = frames[0].get_cell().array
cell_dimensions = np.linalg.norm(cell, axis=1)
cutoff = np.min(cell_dimensions) / 2 - 1e-6 # Define the cutoff distance.
smearing = cutoff / 6.0 # Smearing parameter for interaction potentials.
lr_wavelength = 0.5 * smearing # Wavelength for long-range interactions.
params = {"lr_wavelength": lr_wavelength}
Build the neighbor list¶
The neighbor list is used to identify interacting pairs within the cutoff distance.
nl = NeighborList(cutoff=cutoff, full_list=False)
l_positions = []
l_cell = []
l_neighbor_indices = []
l_neighbor_distances = []
l_ref_energy = torch.zeros(len(frames), device=device, dtype=dtype)
for i_atoms, atoms in enumerate(frames):
# Compute neighbor indices and distances
i, j, d = nl.compute(
points=atoms.positions, box=atoms.cell.array, periodic=True, quantities="ijd"
)
i = torch.from_numpy(i.astype(int))
j = torch.from_numpy(j.astype(int))
# Store atom positions, cell information, neighbor indices, and distances
l_positions.append(torch.tensor(atoms.positions, device=device, dtype=dtype))
l_cell.append(torch.tensor(atoms.cell.array, device=device, dtype=dtype))
l_neighbor_indices.append(torch.vstack([i, j]).to(device=device).T)
l_neighbor_distances.append(torch.from_numpy(d).to(device=device, dtype=dtype))
# Store reference energy
l_ref_energy[i_atoms] = torch.tensor(
atoms.get_potential_energy(), device=device, dtype=dtype
)
Function to assign charges to atoms
def assign_charges(atoms, charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Assign charges to atoms based on their chemical symbols."""
chemical_symbols = np.array(atoms.get_chemical_symbols())
charges = torch.zeros(len(atoms), dtype=dtype, device=device)
for chemical_symbol, charge in charge_dict.items():
charges[chemical_symbols == chemical_symbol] = charge
return charges.reshape(-1, 1)
Define the energy computation
def compute_energy(charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute the total energy based on assigned charges and potentials."""
energy = torch.zeros(len(frames), device=device, dtype=dtype)
for i_atoms, atoms in enumerate(frames):
charges = assign_charges(atoms, charge_dict)
potential = calculator(
charges=charges,
cell=l_cell[i_atoms],
positions=l_positions[i_atoms],
neighbor_indices=l_neighbor_indices[i_atoms],
neighbor_distances=l_neighbor_distances[i_atoms],
)
energy[i_atoms] = (charges * potential).sum()
return energy
Define the loss function
def loss(charge_dict: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Calculate the loss as the mean squared error between computed and reference
energies."""
energy = compute_energy(charge_dict)
mse = torch.sum((energy - l_ref_energy) ** 2)
return mse.sum() # Optionally add charge_penalty for strict neutrality enforcement.
Fit charge model¶
Set initial values for the potential
potential = CombinedPotential(
potentials=[InversePowerLawPotential(exponent=1.0, smearing=smearing)],
smearing=smearing,
)
calculator = EwaldCalculator(potential=potential, **params, prefactor=prefactor)
calculator.to(device=device, dtype=dtype)
q_Na = torch.tensor(1e-5).to(device=device, dtype=dtype)
q_Na.requires_grad = True
q_Cl = -torch.tensor(1e-5 + 0.2).to(device=device, dtype=dtype)
q_Cl.requires_grad = True
charge_dict = {"Na": q_Na, "Cl": q_Cl}
Learning loop: optimize charges to minimize the loss function
optimizer = torch.optim.Adam([q_Na, q_Cl], lr=0.1)
q_Na_timeseries = []
q_Cl_timeseries = []
loss_timeseries = []
for step in range(1000):
optimizer.zero_grad()
charge_dict = {"Na": q_Na, "Cl": q_Cl}
loss_value = loss(charge_dict)
loss_value.backward()
optimizer.step()
if step % 10 == 0:
print(
f"Step: {step:>5}, Loss: {loss_value.item():>5.2e}, "
+ ", ".join([f"q_{k}: {v:>5.2f}" for k, v in charge_dict.items()]),
end="\r",
)
loss_timeseries.append(float(loss_value.detach().cpu()))
q_Na_timeseries.append(float(q_Na.detach().cpu()))
q_Cl_timeseries.append(float(q_Cl.detach().cpu()))
if loss_value < 1e-10:
break
Step: 0, Loss: 4.78e+01, q_Na: -0.10, q_Cl: -0.30
Step: 10, Loss: 8.50e+00, q_Na: -0.48, q_Cl: -0.73
Step: 20, Loss: 6.67e+00, q_Na: -0.42, q_Cl: -0.75
Step: 30, Loss: 6.34e+00, q_Na: -0.35, q_Cl: -0.76
Step: 40, Loss: 6.28e+00, q_Na: -0.34, q_Cl: -0.85
Step: 50, Loss: 5.86e+00, q_Na: -0.23, q_Cl: -0.87
Step: 60, Loss: 5.22e+00, q_Na: -0.13, q_Cl: -0.95
Step: 70, Loss: 4.31e+00, q_Na: 0.02, q_Cl: -1.04
Step: 80, Loss: 2.95e+00, q_Na: 0.24, q_Cl: -1.12
Step: 90, Loss: 1.46e+00, q_Na: 0.50, q_Cl: -1.18
Step: 100, Loss: 4.43e-01, q_Na: 0.74, q_Cl: -1.16
Step: 110, Loss: 8.37e-02, q_Na: 0.90, q_Cl: -1.09
Step: 120, Loss: 6.79e-03, q_Na: 0.97, q_Cl: -1.03
Step: 130, Loss: 7.23e-05, q_Na: 1.00, q_Cl: -0.99
Step: 140, Loss: 1.50e-03, q_Na: 1.01, q_Cl: -0.98
Step: 150, Loss: 1.61e-03, q_Na: 1.01, q_Cl: -0.98
Step: 160, Loss: 9.30e-04, q_Na: 1.01, q_Cl: -0.99
Step: 170, Loss: 3.59e-04, q_Na: 1.01, q_Cl: -0.99
Step: 180, Loss: 9.25e-05, q_Na: 1.00, q_Cl: -1.00
Step: 190, Loss: 1.25e-05, q_Na: 1.00, q_Cl: -1.00
Step: 200, Loss: 1.08e-07, q_Na: 1.00, q_Cl: -1.00
Step: 210, Loss: 8.56e-07, q_Na: 1.00, q_Cl: -1.00
Step: 220, Loss: 1.18e-06, q_Na: 1.00, q_Cl: -1.00
Step: 230, Loss: 6.39e-07, q_Na: 1.00, q_Cl: -1.00
Step: 240, Loss: 1.93e-07, q_Na: 1.00, q_Cl: -1.00
Step: 250, Loss: 2.82e-08, q_Na: 1.00, q_Cl: -1.00
Step: 260, Loss: 4.25e-10, q_Na: 1.00, q_Cl: -1.00
Step: 270, Loss: 2.14e-09, q_Na: 1.00, q_Cl: -1.00
Step: 280, Loss: 2.73e-09, q_Na: 1.00, q_Cl: -1.00
Step: 290, Loss: 1.38e-09, q_Na: 1.00, q_Cl: -1.00
Step: 300, Loss: 4.85e-10, q_Na: 1.00, q_Cl: -1.00
Step: 310, Loss: 2.14e-10, q_Na: 1.00, q_Cl: -1.00
Step: 320, Loss: 2.00e-10, q_Na: 1.00, q_Cl: -1.00
Step: 330, Loss: 2.03e-10, q_Na: 1.00, q_Cl: -1.00
Step: 340, Loss: 1.95e-10, q_Na: 1.00, q_Cl: -1.00
Step: 350, Loss: 1.96e-10, q_Na: 1.00, q_Cl: -1.00
Step: 360, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 370, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 380, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 390, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 400, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 410, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 420, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 430, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 440, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 450, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 460, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 470, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 480, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 490, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 500, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 510, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 520, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 530, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 540, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 550, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 560, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 570, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 580, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 590, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 600, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 610, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 620, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 630, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 640, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 650, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 660, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 670, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 680, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 690, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 700, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 710, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 720, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 730, Loss: 1.90e-10, q_Na: 1.00, q_Cl: -1.00
Step: 740, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 750, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 760, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 770, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 780, Loss: 1.90e-10, q_Na: 1.00, q_Cl: -1.00
Step: 790, Loss: 1.90e-10, q_Na: 1.00, q_Cl: -1.00
Step: 800, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 810, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 820, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 830, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 840, Loss: 1.90e-10, q_Na: 1.00, q_Cl: -1.00
Step: 850, Loss: 1.90e-10, q_Na: 1.00, q_Cl: -1.00
Step: 860, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 870, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 880, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 890, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 900, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 910, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 920, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Step: 930, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 940, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 950, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 960, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 970, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 980, Loss: 1.92e-10, q_Na: 1.00, q_Cl: -1.00
Step: 990, Loss: 1.87e-10, q_Na: 1.00, q_Cl: -1.00
Fit kernel model¶
The second phase involves optimizing the weights of the combined potential kernels.
Set initial values for the kernel model
potential = CombinedPotential(
[
InversePowerLawPotential(exponent=1.0, smearing=smearing),
InversePowerLawPotential(exponent=2.0, smearing=smearing),
],
smearing=smearing,
)
calculator = EwaldCalculator(potential=potential, **params, prefactor=prefactor)
calculator.to(device=device, dtype=dtype)
EwaldCalculator(
(potential): CombinedPotential(
(potentials): ModuleList(
(0-1): 2 x InversePowerLawPotential()
)
)
)
Kernel optimization loop: optimize kernel weights to minimize the loss function
optimizer = torch.optim.Adam(calculator.parameters(), lr=0.1)
weights_timeseries = []
loss_timeseries = []
for step in range(1000):
optimizer.zero_grad()
# Fix charges to their ideal values for this phase
loss_value = loss({"Na": 1.0, "Cl": -1.0})
loss_value.backward()
optimizer.step()
if step % 10 == 0:
print(
f"Step: {step:>5}, Loss: {loss_value.item():>5.2e} "
+ ", ".join(
[
f"w_{i}: {float(v):>5.2f}"
for i, v in enumerate(
calculator.potential.weights.detach().cpu().tolist()
)
]
),
end="\r",
)
loss_timeseries.append(float(loss_value.detach().cpu()))
weights_timeseries.append(calculator.potential.weights.detach().cpu().tolist())
if loss_value < 1e-10:
break
Step: 0, Loss: 1.50e+00 w_0: 0.90, w_1: 0.90
Step: 10, Loss: 3.62e-01 w_0: 0.96, w_1: 0.60
Step: 20, Loss: 2.67e-02 w_0: 0.98, w_1: 0.25
Step: 30, Loss: 2.93e-02 w_0: 0.98, w_1: 0.00
Step: 40, Loss: 5.33e-03 w_0: 1.00, w_1: -0.09
Step: 50, Loss: 2.46e-03 w_0: 1.01, w_1: -0.07
Step: 60, Loss: 1.44e-03 w_0: 1.01, w_1: -0.03
Step: 70, Loss: 5.38e-04 w_0: 1.00, w_1: 0.01
Step: 80, Loss: 2.73e-04 w_0: 1.00, w_1: 0.02
Step: 90, Loss: 9.66e-05 w_0: 1.00, w_1: 0.01
Step: 100, Loss: 2.46e-05 w_0: 1.00, w_1: -0.00
Step: 110, Loss: 9.57e-06 w_0: 1.00, w_1: -0.00
Step: 120, Loss: 1.53e-06 w_0: 1.00, w_1: -0.00
Step: 130, Loss: 1.32e-07 w_0: 1.00, w_1: 0.00
Step: 140, Loss: 5.43e-07 w_0: 1.00, w_1: 0.00
Step: 150, Loss: 1.18e-07 w_0: 1.00, w_1: 0.00
Step: 160, Loss: 1.05e-08 w_0: 1.00, w_1: -0.00
Step: 170, Loss: 2.65e-08 w_0: 1.00, w_1: -0.00
Step: 180, Loss: 2.66e-10 w_0: 1.00, w_1: 0.00
Step: 190, Loss: 3.08e-09 w_0: 1.00, w_1: 0.00
Step: 200, Loss: 1.49e-10 w_0: 1.00, w_1: 0.00
Plot results¶
Visualize the learning process for charges and kernel weights.
palette = [
"#EE7733", # Orange
"#0077BB", # Blue
"#33BBEE", # Light Blue
"#EE3377", # Pink
"#CC3311", # Red
"#009988", # Teal
"#BBBBBB", # Grey
"#000000", # Black
]
def plot_results(fname=None, show_snapshot=True):
"""
Plot the learning process for charges and kernel weights.
Args:
fname (str): File name to save the plot. If None, the plot is not saved.
show_snapshot (bool): Whether to show a snapshot of the atomic configuration.
"""
fig, ax = plt.subplots(
2,
sharex=True,
layout="constrained",
dpi=200,
)
if show_snapshot:
ax_in = fig.add_axes([0.12, 0.14, 0.27, 0.27])
ase.visualize.plot.plot_atoms(atoms, ax=ax_in, radii=0.75)
ax_in.set_axis_off()
# Plot charge learning
ax[0].plot(q_Na_timeseries, c=palette[0], label=r"Na")
ax[0].plot(np.array(q_Cl_timeseries), c=palette[1], label=r"Cl")
ax[0].set_ylim(-1.3, 1.3)
ax[0].axhline(1, ls="dotted", c=palette[0])
ax[0].axhline(-1, ls="dotted", c=palette[1])
ax[0].legend()
ax[0].set_ylabel(r"Charge / e")
# Plot kernel weight learning
ax[1].axhline(1, c=palette[2], ls="dotted")
ax[1].axhline(0, c=palette[3], ls="dotted")
weights_timeseries_array = np.array(weights_timeseries)
ax[1].plot(weights_timeseries_array[:, 0], label="p=1", c=palette[2])
ax[1].plot(weights_timeseries_array[:, 1], label="p=2", c=palette[3])
ax[1].set_ylim(-0.2, 1.2)
ax[1].legend()
ax[1].set_ylabel("Kernel weight")
for a in ax:
a.set_xscale("log")
ax[1].set_xlabel("Learning epoch")
fig.align_labels()
if fname is not None:
fig.savefig(fname, transparent=True, bbox_inches="tight")
plt.show()
# Call the plot function to visualize results
plot_results("toy_model_learning.pdf", show_snapshot=True)
data:image/s3,"s3://crabby-images/fc931/fc931acc5b727a55f36c299219371eab95527414" alt="torchpme learning"
Total running time of the script: (1 minutes 26.515 seconds)