From 3968493f91f57fb505481c5dd608d5a3352c0021 Mon Sep 17 00:00:00 2001 From: Nirbhai Date: Fri, 29 May 2026 16:18:58 +0530 Subject: [PATCH 1/2] Fix run_with_cache(device=...) permanently moving the model The single-device branch moved the model and inputs to cache_device with no restore, leaving non-CPU models silently migrated and cfg.device stale. The move was redundant since make_cache_hook already offloads each captured activation, matching ActivationCache.to and the legacy get_caching_hooks contract. Flatten the conditional, add a regression test asserting original_model.to is not invoked, and document the device kwarg. --- .../test_bridge_run_with_cache_device.py | 36 +++++++++++++++++++ transformer_lens/model_bridge/bridge.py | 34 ++++++++---------- 2 files changed, 51 insertions(+), 19 deletions(-) create mode 100644 tests/integration/model_bridge/test_bridge_run_with_cache_device.py diff --git a/tests/integration/model_bridge/test_bridge_run_with_cache_device.py b/tests/integration/model_bridge/test_bridge_run_with_cache_device.py new file mode 100644 index 000000000..d0da4f51b --- /dev/null +++ b/tests/integration/model_bridge/test_bridge_run_with_cache_device.py @@ -0,0 +1,36 @@ +"""Regression coverage for `TransformerBridge.run_with_cache(device=...)`. + +The `device=` kwarg is a cache-offload knob: cached activations are stored on +that device, but the underlying model and inputs must stay where the caller +put them, matching `ActivationCache.to` and the legacy `get_caching_hooks` +("device to store on") contract. +""" + +from unittest.mock import patch + +import pytest + + +@pytest.fixture() +def bridge(distilgpt2_bridge): + """Alias the session fixture for concise test signatures.""" + return distilgpt2_bridge + + +def test_run_with_cache_device_does_not_move_model(bridge): + """`run_with_cache(device=...)` must not relocate the underlying model. + + CPU runners cannot reproduce the original cross-device crash directly + (`to('cpu')` is a no-op there), so we spy on `original_model.to` with + `Mock(wraps=...)` and assert it isn't invoked during the call. That + catches the regression on any platform. + """ + with patch.object(bridge.original_model, "to", wraps=bridge.original_model.to) as to_spy: + _, cache = bridge.run_with_cache(bridge.to_tokens("hello"), device="cpu") + + assert to_spy.call_count == 0, ( + f"run_with_cache(device=...) moved the underlying model " + f"({to_spy.call_count} call(s): {to_spy.call_args_list})." + ) + # And the cache itself still lands on the requested device. + assert next(iter(cache.values())).device.type == "cpu" diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index e0ff8b3c0..938906123 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -1968,6 +1968,8 @@ def run_with_cache( remove_batch_dim: Whether to remove batch dimension names_filter: Filter for which activations to cache (str, list of str, or callable) stop_at_layer: Layer to stop forward pass at (uses StopAtLayerException; cleans up KV cache on stop) + device: Where to store cached activations (matches ActivationCache.to; + does not move the model). Defaults to per-layer storage. **kwargs: Additional arguments # type: ignore[name-defined] Returns: @@ -2075,25 +2077,19 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: hook_dict[block_hook_name].add_hook(stop_hook) hooks.append((hook_dict[block_hook_name], block_hook_name)) filtered_kwargs = kwargs.copy() - if cache_device is not None: - if getattr(self.cfg, "n_devices", 1) > 1: - # Moving a dispatched model to a single device collapses accelerate's - # split and breaks its routing hooks. The cache will stay spread across - # the per-layer devices; callers can .to(cache_device) on cache entries - # after the fact if they need a single-device cache. - warnings.warn( - f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " - f"across {self.cfg.n_devices} devices via device_map. Cached activations " - "will remain on their per-layer devices.", - stacklevel=2, - ) - else: - self.original_model = self.original_model.to(cache_device) - if processed_args and isinstance(processed_args[0], torch.Tensor): - processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) - for key, value in filtered_kwargs.items(): - if isinstance(value, torch.Tensor): - filtered_kwargs[key] = value.to(cache_device) + # `cache_device` is honored by `make_cache_hook` above (`tensor.detach().to(cache_device)`); + # the model and inputs stay where the caller put them, matching `ActivationCache.to`. + if cache_device is not None and getattr(self.cfg, "n_devices", 1) > 1: + # Moving a dispatched model to a single device collapses accelerate's + # split and breaks its routing hooks. The cache will stay spread across + # the per-layer devices; callers can .to(cache_device) on cache entries + # after the fact if they need a single-device cache. + warnings.warn( + f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Cached activations " + "will remain on their per-layer devices.", + stacklevel=2, + ) try: if "output_attentions" not in filtered_kwargs: filtered_kwargs["output_attentions"] = True From 480c26e4ae560c3e0f45917b7e1b5d0f4d9d3136 Mon Sep 17 00:00:00 2001 From: Nirbhai Date: Fri, 29 May 2026 16:32:48 +0530 Subject: [PATCH 2/2] Retire cache_dict workaround in return_cache device offload With the run_with_cache model-move fixed, TransformerBridge.generate return_cache device offload can use a run_with_cache(device=device) passthrough. The offload now happens at capture time, reducing peak memory. Drop the cache_dict direct-write and its justifying comment, simplify the offload test to a device-landing check. --- .../test_bridge_generate_return_cache.py | 31 +++++++------------ transformer_lens/model_bridge/bridge.py | 7 +---- 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/tests/integration/model_bridge/test_bridge_generate_return_cache.py b/tests/integration/model_bridge/test_bridge_generate_return_cache.py index dacf28f89..c6323818b 100644 --- a/tests/integration/model_bridge/test_bridge_generate_return_cache.py +++ b/tests/integration/model_bridge/test_bridge_generate_return_cache.py @@ -10,8 +10,6 @@ this does not change what is being tested. """ -import warnings - import pytest import torch @@ -149,25 +147,20 @@ def test_names_filter_scopes_cache(self, bridge): assert set(cache.cache_dict) == set(ref.cache_dict) assert len(cache.cache_dict) < 20 - def test_device_offload_no_spurious_warning(self, bridge): - """device= offloads cache tensors (cpu here) without ActivationCache.to's move_model warning.""" + def test_device_offload_lands_on_requested_device(self, bridge): + """device= offloads cache tensors to the requested device.""" tokens = bridge.to_tokens("The quick brown") - with warnings.catch_warnings(record=True) as caught: - warnings.simplefilter("always") - with torch.no_grad(): - _, cache = bridge.generate( - tokens, - max_new_tokens=4, - do_sample=False, - return_type="tokens", - return_cache=True, - device="cpu", - use_past_kv_cache=False, - ) + with torch.no_grad(): + _, cache = bridge.generate( + tokens, + max_new_tokens=4, + do_sample=False, + return_type="tokens", + return_cache=True, + device="cpu", + use_past_kv_cache=False, + ) assert str(cache["blocks.0.hook_resid_post"].device) == "cpu" - assert not any("move_model" in str(w.message) for w in caught), [ - str(w.message) for w in caught - ] class TestGenerateReturnCacheGuards: diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 938906123..acf5c74a0 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -2854,12 +2854,7 @@ def _logits_to_tuple(logits_list: list[torch.Tensor]) -> tuple[torch.Tensor, ... # cache is identical to run_with_cache(output_tokens) - all hook points, including # attention patterns. The guards above restrict this to single-sequence, decoder-only # text generation (see issue #697). - _, cache = self.run_with_cache(output_tokens, names_filter=names_filter) - if device is not None: - # Offload the cached activations to `device`. We move cache_dict directly rather - # than calling ActivationCache.to(device), which currently emits a spurious - # move_model DeprecationWarning. - cache.cache_dict = {key: value.to(device) for key, value in cache.cache_dict.items()} + _, cache = self.run_with_cache(output_tokens, names_filter=names_filter, device=device) return result, cache @torch.no_grad()