Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
593 changes: 593 additions & 0 deletions examples/tutorials/nudged_elastic_band.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ include = ["docs/**/*.py", "docs/**/*.ipynb", "examples/**/*.py"]
[tool.ty.overrides.rules]
invalid-argument-type = "ignore"
invalid-assignment = "ignore"
invalid-attribute-override = "ignore"
not-iterable = "ignore"
not-subscriptable = "ignore"
unresolved-attribute = "ignore"
Expand Down
293 changes: 293 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,269 @@ def test_binning_auto_batcher_restore_order_with_split_states(
assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers)


def test_binning_auto_batcher_keeps_group_together(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""A multi-system group is treated as one unit and stays in a single batch."""
grouped = ts.concatenate_states([si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.zeros(
grouped.n_systems, device=grouped.device, dtype=torch.long
)

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(grouped.n_atoms),
)
batcher.load_states(grouped)

# one group => one unit whose scaler is the sum over its systems
assert len(batcher.memory_scalers) == 1
assert batcher.memory_scalers[0] == grouped.n_atoms

batches = [batch for batch, _ in batcher]
assert len(batches) == 1
assert batches[0].n_systems == grouped.n_systems
assert batches[0].n_groups == 1

restored = batcher.restore_original_order(batches)
assert len(restored) == 1
assert restored[0].n_systems == grouped.n_systems


def test_binning_auto_batcher_packs_multiple_groups(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Multiple groups pack into one bin when memory allows (no throttling)."""
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)
group0 = 2 * si_sim_state.n_atoms
group1 = fe_supercell_sim_state.n_atoms

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(group0 + group1),
)
batcher.load_states(grouped)

assert batcher.memory_scalers == [group0, group1]

batches = [batch for batch, _ in batcher]
assert len(batches) == 1
assert batches[0].n_systems == 3
assert batches[0].n_groups == 2

restored = batcher.restore_original_order(batches)
assert len(restored) == 2
assert restored[0].n_systems == 2
assert restored[1].n_systems == 1


def test_binning_auto_batcher_restores_unsorted_group_bins(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
monkeypatch: pytest.MonkeyPatch,
) -> None:
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)

def unsorted_bins(
index_to_scaler: dict[int, float], *, max_volume: float
) -> list[dict[int, float]]:
assert max_volume > 0
return [{1: index_to_scaler[1], 0: index_to_scaler[0]}]

monkeypatch.setattr("torch_sim.autobatching.to_constant_volume_bins", unsorted_bins)
batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(grouped.n_atoms),
)
batcher.load_states(grouped)

batches = [batch for batch, _ in batcher]
restored = batcher.restore_original_order(batches)

assert batcher.index_bins == [[0, 1]]
assert len(restored) == 2
assert restored[0].n_systems == 2
assert restored[1].n_systems == 1


def test_binning_auto_batcher_does_not_split_group(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""A group never spans bins; tight memory packs one group per batch."""
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)
group0 = 2 * si_sim_state.n_atoms
group1 = fe_supercell_sim_state.n_atoms

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(max(group0, group1)),
)
batcher.load_states(grouped)

batches = [batch for batch, _ in batcher]
assert len(batches) == 2
assert sorted(batch.n_systems for batch in batches) == [1, 2]

multi = next(batch for batch in batches if batch.n_systems == 2)
assert multi.n_groups == 1


def test_in_flight_auto_batcher_keeps_group_together(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
grouped = ts.concatenate_states([si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.zeros(
grouped.n_systems, device=grouped.device, dtype=torch.long
)
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(grouped.n_atoms),
)
batcher.load_states(grouped)

state, [] = batcher.next_batch(None, None)
assert state is not None
assert state.n_systems == grouped.n_systems
assert state.n_groups == 1

next_state, completed_states = batcher.next_batch(
state, torch.ones(state.n_systems, dtype=torch.bool)
)
assert next_state is None
assert len(completed_states) == 1
assert completed_states[0].n_systems == grouped.n_systems


def test_in_flight_auto_batcher_packs_multiple_groups(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)
group0 = 2 * si_sim_state.n_atoms
group1 = fe_supercell_sim_state.n_atoms
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(group0 + group1),
)
batcher.load_states(grouped)

assert batcher.memory_scalers == [group0, group1]
state, [] = batcher.next_batch(None, None)
assert state is not None
assert state.n_systems == 3
assert state.n_groups == 2


def test_in_flight_auto_batcher_does_not_split_group(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)
group0 = 2 * si_sim_state.n_atoms
group1 = fe_supercell_sim_state.n_atoms
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(max(group0, group1)),
)
batcher.load_states(grouped)

first_batch, [] = batcher.next_batch(None, None)
assert first_batch is not None
assert first_batch.n_systems == 2
assert first_batch.n_groups == 1

second_batch, completed_states = batcher.next_batch(
first_batch, torch.ones(first_batch.n_systems, dtype=torch.bool)
)
assert second_batch is not None
assert len(completed_states) == 1
assert completed_states[0].n_systems == 2
assert second_batch.n_systems == 1
assert second_batch.n_groups == 1


def test_in_flight_auto_batcher_restore_order_with_grouped_state(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(grouped.n_atoms),
)
batcher.load_states(grouped)
state, [] = batcher.next_batch(None, None)
assert state is not None

convergence = torch.tensor([False, False, True], dtype=torch.bool)
state, completed_states = batcher.next_batch(state, convergence)
all_completed = [*completed_states]
assert state is not None
assert state.n_systems == 2

state, completed_states = batcher.next_batch(
state, torch.ones(state.n_systems, dtype=torch.bool)
)
assert state is None
all_completed.extend(completed_states)

restored = batcher.restore_original_order(all_completed)
assert len(restored) == 2
assert restored[0].n_systems == 2
assert restored[1].n_systems == 1


def test_in_flight_auto_batcher_loads_batched_state_without_split_groups(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
monkeypatch: pytest.MonkeyPatch,
) -> None:
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)

def fail_split_groups(_self: ts.SimState) -> list[ts.SimState]:
raise AssertionError("split_groups should not be called while loading")

monkeypatch.setattr(ts.SimState, "split_groups", fail_split_groups)
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=float(grouped.n_atoms),
)
batcher.load_states(grouped)

state, [] = batcher.next_batch(None, None)
assert state is not None
assert state.n_groups == 2


def test_in_flight_max_metric_too_small(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
Expand Down Expand Up @@ -715,6 +978,36 @@ def test_in_flight_max_iterations(
assert batcher.iteration_count[idx] == max_iterations


def test_in_flight_max_iterations_completes_whole_group(
si_double_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
grouped_state = si_double_sim_state.clone()
grouped_state.group_idx = torch.zeros(
grouped_state.n_systems, device=grouped_state.device, dtype=torch.long
)
batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=800.0,
max_iterations=1,
)
batcher.load_states(grouped_state)

state, [] = batcher.next_batch(None, None)
assert state is not None
assert state.n_systems == grouped_state.n_systems
assert state.n_groups == 1

convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool)
next_state, completed_states = batcher.next_batch(state, convergence_tensor)

assert next_state is None
assert len(completed_states) == 1
assert completed_states[0].n_systems == grouped_state.n_systems
assert completed_states[0].n_groups == 1


@pytest.mark.parametrize(
"num_steps_per_batch",
[
Expand Down
59 changes: 59 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,65 @@ def test_fire_optimization(
)


@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor))
def test_fire_uses_group_scoped_adaptive_state(
ar_double_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor
) -> None:
ar_double_sim_state.group_idx = torch.zeros(
ar_double_sim_state.n_systems,
device=ar_double_sim_state.device,
dtype=torch.int64,
)

state = ts.fire_init(
ar_double_sim_state,
lj_model,
fire_flavor=fire_flavor,
dt_start=0.1,
alpha_start=0.1,
)

assert state.n_groups == 1
assert state.dt.shape == (1,)
assert state.alpha.shape == (1,)
assert state.n_pos.shape == (1,)

updated = ts.fire_step(state=state, model=lj_model, dt_max=0.3)

assert updated.dt.shape == (1,)
assert updated.alpha.shape == (1,)
assert updated.n_pos.shape == (1,)


def test_fire_group_attributes_roundtrip_through_split_groups(
si_sim_state: SimState,
fe_supercell_sim_state: SimState,
lj_model: ModelInterface,
) -> None:
grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state])
grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long)
dt_start = torch.tensor([0.1, 0.2], device=grouped.device)
alpha_start = torch.tensor([0.3, 0.4], device=grouped.device)
state = ts.fire_init(
grouped,
lj_model,
dt_start=dt_start,
alpha_start=alpha_start,
)
state.n_pos = torch.tensor([3, 4], device=grouped.device, dtype=torch.int32)

split_groups = state.split_groups()
roundtrip = ts.concatenate_states(split_groups)

assert len(split_groups) == 2
assert split_groups[0].n_systems == 2
assert split_groups[1].n_systems == 1
assert torch.allclose(roundtrip.dt, state.dt)
assert torch.allclose(roundtrip.alpha, state.alpha)
assert torch.equal(roundtrip.n_pos, state.n_pos)
assert torch.equal(roundtrip.group_idx, grouped.group_idx)


def test_bfgs_optimization(
ar_supercell_sim_state: SimState, lj_model: ModelInterface
) -> None:
Expand Down
Loading
Loading