Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions tests/integration/model_bridge/test_bridge_generate_return_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
this does not change what is being tested.
"""

import warnings

import pytest
import torch

Expand Down Expand Up @@ -149,25 +147,20 @@ def test_names_filter_scopes_cache(self, bridge):
assert set(cache.cache_dict) == set(ref.cache_dict)
assert len(cache.cache_dict) < 20

def test_device_offload_no_spurious_warning(self, bridge):
"""device= offloads cache tensors (cpu here) without ActivationCache.to's move_model warning."""
def test_device_offload_lands_on_requested_device(self, bridge):
"""device= offloads cache tensors to the requested device."""
tokens = bridge.to_tokens("The quick brown")
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
with torch.no_grad():
_, cache = bridge.generate(
tokens,
max_new_tokens=4,
do_sample=False,
return_type="tokens",
return_cache=True,
device="cpu",
use_past_kv_cache=False,
)
with torch.no_grad():
_, cache = bridge.generate(
tokens,
max_new_tokens=4,
do_sample=False,
return_type="tokens",
return_cache=True,
device="cpu",
use_past_kv_cache=False,
)
assert str(cache["blocks.0.hook_resid_post"].device) == "cpu"
assert not any("move_model" in str(w.message) for w in caught), [
str(w.message) for w in caught
]


class TestGenerateReturnCacheGuards:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Regression coverage for `TransformerBridge.run_with_cache(device=...)`.

The `device=` kwarg is a cache-offload knob: cached activations are stored on
that device, but the underlying model and inputs must stay where the caller
put them, matching `ActivationCache.to` and the legacy `get_caching_hooks`
("device to store on") contract.
"""

from unittest.mock import patch

import pytest


@pytest.fixture()
def bridge(distilgpt2_bridge):
"""Alias the session fixture for concise test signatures."""
return distilgpt2_bridge


def test_run_with_cache_device_does_not_move_model(bridge):
"""`run_with_cache(device=...)` must not relocate the underlying model.

CPU runners cannot reproduce the original cross-device crash directly
(`to('cpu')` is a no-op there), so we spy on `original_model.to` with
`Mock(wraps=...)` and assert it isn't invoked during the call. That
catches the regression on any platform.
"""
with patch.object(bridge.original_model, "to", wraps=bridge.original_model.to) as to_spy:
_, cache = bridge.run_with_cache(bridge.to_tokens("hello"), device="cpu")

assert to_spy.call_count == 0, (
f"run_with_cache(device=...) moved the underlying model "
f"({to_spy.call_count} call(s): {to_spy.call_args_list})."
)
# And the cache itself still lands on the requested device.
assert next(iter(cache.values())).device.type == "cpu"
41 changes: 16 additions & 25 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,8 @@ def run_with_cache(
remove_batch_dim: Whether to remove batch dimension
names_filter: Filter for which activations to cache (str, list of str, or callable)
stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop)
device: Where to store cached activations (matches ActivationCache.to;
does not move the model). Defaults to per-layer storage.
**kwargs: Additional arguments
# type: ignore[name-defined]
Returns:
Expand Down Expand Up @@ -2075,25 +2077,19 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor:
hook_dict[block_hook_name].add_hook(stop_hook)
hooks.append((hook_dict[block_hook_name], block_hook_name))
filtered_kwargs = kwargs.copy()
if cache_device is not None:
if getattr(self.cfg, "n_devices", 1) > 1:
# Moving a dispatched model to a single device collapses accelerate's
# split and breaks its routing hooks. The cache will stay spread across
# the per-layer devices; callers can .to(cache_device) on cache entries
# after the fact if they need a single-device cache.
warnings.warn(
f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
f"across {self.cfg.n_devices} devices via device_map. Cached activations "
"will remain on their per-layer devices.",
stacklevel=2,
)
else:
self.original_model = self.original_model.to(cache_device)
if processed_args and isinstance(processed_args[0], torch.Tensor):
processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:])
for key, value in filtered_kwargs.items():
if isinstance(value, torch.Tensor):
filtered_kwargs[key] = value.to(cache_device)
# `cache_device` is honored by `make_cache_hook` above (`tensor.detach().to(cache_device)`);
# the model and inputs stay where the caller put them, matching `ActivationCache.to`.
if cache_device is not None and getattr(self.cfg, "n_devices", 1) > 1:
# Moving a dispatched model to a single device collapses accelerate's
# split and breaks its routing hooks. The cache will stay spread across
# the per-layer devices; callers can .to(cache_device) on cache entries
# after the fact if they need a single-device cache.
warnings.warn(
f"run_with_cache(device={cache_device!r}) ignored: model is dispatched "
f"across {self.cfg.n_devices} devices via device_map. Cached activations "
"will remain on their per-layer devices.",
stacklevel=2,
)
try:
if "output_attentions" not in filtered_kwargs:
filtered_kwargs["output_attentions"] = True
Expand Down Expand Up @@ -2858,12 +2854,7 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ...
# cache is identical to run_with_cache(output_tokens) - all hook points, including
# attention patterns. The guards above restrict this to single-sequence, decoder-only
# text generation (see issue #697).
_, cache = self.run_with_cache(output_tokens, names_filter=names_filter)
if device is not None:
# Offload the cached activations to `device`. We move cache_dict directly rather
# than calling ActivationCache.to(device), which currently emits a spurious
# move_model DeprecationWarning.
cache.cache_dict = {key: value.to(device) for key, value in cache.cache_dict.items()}
_, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device)
return result, cache

@torch.no_grad()
Expand Down
Loading