diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 946c567c8..e820b64d3 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -14,7 +14,6 @@ class first, including the examples, and then skimming the available methods. Yo from __future__ import annotations import logging -import warnings from typing import ( TYPE_CHECKING, Any, @@ -220,7 +219,7 @@ def __len__(self) -> int: """ return len(self.cache_dict) - def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCache: + def to(self, device: Union[str, torch.device]) -> ActivationCache: """Move the Cache to a Device. Mostly useful for moving the cache to the CPU after model computation finishes to save GPU @@ -231,23 +230,10 @@ def to(self, device: Union[str, torch.device], move_model=False) -> ActivationCa Args: device: The device to move the cache to (e.g. `torch.device.cpu`). - move_model: - Whether to also move the model to the same device. @deprecated """ - # Move model is deprecated as we plan on de-coupling the classes - if move_model is not None: - warnings.warn( - "The 'move_model' parameter is deprecated.", - DeprecationWarning, - ) - warn_if_mps(device) self.cache_dict = {key: value.to(device) for key, value in self.cache_dict.items()} - - if move_model: - self.model.to(device) - return self def toggle_autodiff(self, mode: bool = False):