Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 1 addition & 15 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class first, including the examples, and then skimming the available methods. Yo
from __future__ import annotations

import logging
import warnings
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -220,7 +219,7 @@ def __len__(self) -> int:
"""
return len(self.cache_dict)

def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache:
def to(self, device: Union[str, torch.device]) -> ActivationCache:
"""Move the Cache to a Device.

Mostly useful for moving the cache to the CPU after model computation finishes to save GPU
Expand All @@ -231,23 +230,10 @@ def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCa
Args:
device:
The device to move the cache to (e.g. `torch.device.cpu`).
move_model:
Whether to also move the model to the same device. @deprecated

"""
# Move model is deprecated as we plan on de-coupling the classes
if move_model is not None:
warnings.warn(
"The 'move_model' parameter is deprecated.",
DeprecationWarning,
)

warn_if_mps(device)
self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()}

if move_model:
self.model.to(device)

return self

def toggle_autodiff(self, mode: bool = False):
Expand Down
Loading