Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"h5py>=3.15.1",
"numpy>=1.26,<3; python_version < '3.13'",
"numpy>=2.3.2,<3; python_version >= '3.13'",
"nvalchemi-toolkit-ops[torch]>=0.3.0",
"nvalchemi-toolkit-ops[torch]>=0.3.1",
"tables>=3.11.1",
"torch>=2",
"tqdm>=4.67",
Expand Down
105 changes: 105 additions & 0 deletions tests/models/test_dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test

Expand Down Expand Up @@ -69,6 +70,110 @@ def d3_model_r2scan() -> D3DispersionModel:
)


def test_d3_stress_matches_finite_strain_sign() -> None:
"""Stress should match dE/dstrain/V, not the opposite virial sign."""
row_cell = torch.tensor(
[[4.2, 0.3, 0.1], [0.2, 4.8, 0.4], [0.15, 0.35, 5.1]],
dtype=DTYPE,
device=DEVICE,
)
positions = torch.tensor(
[[0.4, 0.5, 0.6], [1.9, 1.4, 2.3], [3.1, 2.6, 1.7]],
dtype=DTYPE,
device=DEVICE,
)
state = ts.SimState(
positions=positions,
masses=torch.ones(3, dtype=DTYPE, device=DEVICE),
cell=row_cell.mT.unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([6, 8, 14], dtype=torch.int64, device=DEVICE),
)

gen = torch.Generator(device=DEVICE)
gen.manual_seed(1234)
max_z = 14
mesh = 5
rcov = torch.rand(max_z + 1, generator=gen, device=DEVICE) + 0.5
r4r2 = torch.rand(max_z + 1, generator=gen, device=DEVICE) + 0.5
c6ab = 20.0 * (
torch.rand(
max_z + 1,
max_z + 1,
mesh,
mesh,
generator=gen,
device=DEVICE,
)
+ 0.1
)
c6ab = 0.5 * (c6ab + c6ab.permute(1, 0, 3, 2))
cn_ref = 4.0 * torch.rand(
max_z + 1,
max_z + 1,
mesh,
mesh,
generator=gen,
device=DEVICE,
)
cn_ref = 0.5 * (cn_ref + cn_ref.permute(1, 0, 3, 2))

model = D3DispersionModel(
**PBE_BJ,
d3_params=D3Parameters(rcov=rcov, r4r2=r4r2, c6ab=c6ab, cn_ref=cn_ref),
cutoff=6.0,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)

stress = model(state)["stress"][0]
volume = state.volume[0]
frac_positions = torch.linalg.solve(row_cell.mT, positions.mT).mT
identity = torch.eye(3, dtype=DTYPE, device=DEVICE)

def strained_energy(strain: torch.Tensor) -> torch.Tensor:
strained_row_cell = row_cell @ (identity + strain)
strained_state = ts.SimState(
positions=frac_positions @ strained_row_cell,
masses=state.masses,
cell=strained_row_cell.mT.unsqueeze(0),
pbc=state.pbc,
atomic_numbers=state.atomic_numbers,
system_idx=state.system_idx,
)
return model(strained_state)["energy"][0]

step = 1e-3
finite_diff_stress = torch.zeros((3, 3), dtype=DTYPE, device=DEVICE)
for idx_i in range(3):
for idx_j in range(idx_i, 3):
strain = torch.zeros((3, 3), dtype=DTYPE, device=DEVICE)
if idx_i == idx_j:
strain[idx_i, idx_i] = step
energy_plus = strained_energy(strain)
strain[idx_i, idx_i] = -step
energy_minus = strained_energy(strain)
else:
strain[idx_i, idx_j] = 0.5 * step
strain[idx_j, idx_i] = 0.5 * step
energy_plus = strained_energy(strain)
strain[idx_i, idx_j] = -0.5 * step
strain[idx_j, idx_i] = -0.5 * step
energy_minus = strained_energy(strain)
stress_component = (energy_plus - energy_minus) / (2 * step * volume)
finite_diff_stress[idx_i, idx_j] = stress_component
finite_diff_stress[idx_j, idx_i] = stress_component

torch.testing.assert_close(
stress,
finite_diff_stress,
rtol=5e-3,
atol=5e-7,
)


test_d3_pbe_outputs = make_validate_model_outputs_test(
model_fixture_name="d3_model_pbe", device=DEVICE, dtype=DTYPE
)
Expand Down
92 changes: 92 additions & 0 deletions tests/models/test_electrostatics.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,98 @@ def test_dsf_nonzero_energy() -> None:
assert out["energy"].abs().item() > 0


@pytest.mark.parametrize(
("model_cls", "kwargs"),
[
pytest.param(DSFCoulombModel, {"cutoff": 8.0, "alpha": 0.2}, id="dsf"),
pytest.param(EwaldModel, {"cutoff": 8.0, "accuracy": 1e-6}, id="ewald"),
pytest.param(
PMEModel,
{"cutoff": 8.0, "accuracy": 1e-6, "mesh_spacing": 1.0},
id="pme",
),
],
)
def test_electrostatics_stress_matches_finite_strain_sign(
model_cls: type[DSFCoulombModel | EwaldModel | PMEModel],
kwargs: dict[str, float],
) -> None:
"""Electrostatic stress should match dE/dstrain/V, not the virial sign."""
row_cell = torch.tensor(
[[5.2, 0.3, 0.1], [0.2, 5.6, 0.4], [0.15, 0.35, 6.1]],
dtype=DTYPE,
device=DEVICE,
)
positions = torch.tensor(
[[0.4, 0.5, 0.6], [1.9, 1.4, 2.3], [3.1, 2.6, 1.7], [4.0, 3.4, 4.2]],
dtype=DTYPE,
device=DEVICE,
)
charges = torch.tensor([0.8, -0.7, 0.4, -0.5], dtype=DTYPE, device=DEVICE)
state = ts.SimState(
positions=positions,
masses=torch.ones(4, dtype=DTYPE, device=DEVICE),
cell=row_cell.mT.unsqueeze(0),
pbc=True,
atomic_numbers=torch.tensor([11, 17, 11, 17], dtype=torch.int64, device=DEVICE),
)
state._atom_extras["partial_charges"] = charges # noqa: SLF001

model = model_cls(
**kwargs,
device=DEVICE,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
)

stress = model(state)["stress"][0]
volume = state.volume[0]
frac_positions = torch.linalg.solve(row_cell.mT, positions.mT).mT
identity = torch.eye(3, dtype=DTYPE, device=DEVICE)

def strained_energy(strain: torch.Tensor) -> torch.Tensor:
strained_row_cell = row_cell @ (identity + strain)
strained_state = ts.SimState(
positions=frac_positions @ strained_row_cell,
masses=state.masses,
cell=strained_row_cell.mT.unsqueeze(0),
pbc=state.pbc,
atomic_numbers=state.atomic_numbers,
system_idx=state.system_idx,
)
strained_state._atom_extras["partial_charges"] = charges # noqa: SLF001
return model(strained_state)["energy"][0]

step = 1e-3
finite_diff_stress = torch.zeros((3, 3), dtype=DTYPE, device=DEVICE)
for idx_i in range(3):
for idx_j in range(idx_i, 3):
strain = torch.zeros((3, 3), dtype=DTYPE, device=DEVICE)
if idx_i == idx_j:
strain[idx_i, idx_i] = step
energy_plus = strained_energy(strain)
strain[idx_i, idx_i] = -step
energy_minus = strained_energy(strain)
else:
strain[idx_i, idx_j] = 0.5 * step
strain[idx_j, idx_i] = 0.5 * step
energy_plus = strained_energy(strain)
strain[idx_i, idx_j] = -0.5 * step
strain[idx_j, idx_i] = -0.5 * step
energy_minus = strained_energy(strain)
stress_component = (energy_plus - energy_minus) / (2 * step * volume)
finite_diff_stress[idx_i, idx_j] = stress_component
finite_diff_stress[idx_j, idx_i] = stress_component

torch.testing.assert_close(
stress,
finite_diff_stress,
rtol=5e-3,
atol=5e-7,
)


def test_ewald_pme_energy_agreement() -> None:
"""Ewald and PME should give the same converged Coulomb energy."""
state = _make_charged_state()
Expand Down
4 changes: 3 additions & 1 deletion torch_sim/models/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
# d3_out[3] is only defined if compute_virial is True
# we use [-1] to index it to avoid typing errors.
volumes = state.volume.unsqueeze(-1).unsqueeze(-1)
stress = (d3_out[-1] * UnitConversion.Hartree_to_eV) / volumes
# nvalchemiops returns the negative strain-gradient virial; TorchSim stress
# follows dE / dstrain / volume, matching ASE and other TorchSim models.
stress = -(d3_out[-1] * UnitConversion.Hartree_to_eV) / volumes
results["stress"] = stress.to(self._dtype).detach()
return results
12 changes: 9 additions & 3 deletions torch_sim/models/electrostatics.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
results["forces"] = forces.to(self._dtype).detach()
if self._compute_stress:
volumes = state.volume.unsqueeze(-1).unsqueeze(-1)
stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes
# nvalchemiops returns the negative strain-gradient virial; TorchSim stress
# follows dE / dstrain / volume, matching ASE and other TorchSim models.
stress = -(out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes
results["stress"] = stress.to(self._dtype).detach()
return results

Expand Down Expand Up @@ -258,7 +260,9 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
results["forces"] = forces.to(self._dtype).detach()
if self._compute_stress:
volumes = state.volume.unsqueeze(-1).unsqueeze(-1)
stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes
# nvalchemiops returns the negative strain-gradient virial; TorchSim stress
# follows dE / dstrain / volume, matching ASE and other TorchSim models.
stress = -(out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes
results["stress"] = stress.to(self._dtype).detach()
return results

Expand Down Expand Up @@ -387,6 +391,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
results["forces"] = forces.to(self._dtype).detach()
if self._compute_stress:
volumes = state.volume.unsqueeze(-1).unsqueeze(-1)
stress = (out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes
# nvalchemiops returns the negative strain-gradient virial; TorchSim stress
# follows dE / dstrain / volume, matching ASE and other TorchSim models.
stress = -(out[-1] * UnitConversion.e2_per_Ang_to_eV) / volumes
results["stress"] = stress.to(self._dtype).detach()
return results
Loading