From e7117a926d5d439dc95d7405192dfc65dae23976 Mon Sep 17 00:00:00 2001 From: axelray-dev Date: Tue, 2 Jun 2026 04:43:03 +0800 Subject: [PATCH] test: use supports_generation flag for generation compatibility test Replace the GENERATING_MODELS allowlist with a supports_generation flag on ArchitectureAdapter. The flag defaults to True and is set to False on encoder-only adapters (BERT, HuBERT) that don't support text generation. The test now checks model.adapter.supports_generation at runtime and skips via pytest.skip instead of maintaining a static model list. Fixes #1328 --- tests/unit/model_bridge/compatibility/test_utils.py | 12 ++++-------- .../model_bridge/architecture_adapter.py | 4 ++++ .../model_bridge/supported_architectures/bert.py | 2 ++ .../model_bridge/supported_architectures/hubert.py | 2 ++ 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/tests/unit/model_bridge/compatibility/test_utils.py b/tests/unit/model_bridge/compatibility/test_utils.py index 956c4f44f..9aa997c9a 100644 --- a/tests/unit/model_bridge/compatibility/test_utils.py +++ b/tests/unit/model_bridge/compatibility/test_utils.py @@ -93,15 +93,11 @@ def test_device_compatibility(self, model): def test_generation_compatibility(self, model): """Test that generation works correctly with TransformerBridge.""" + if not model.adapter.supports_generation: + pytest.skip("Generation not supported for this architecture") prompt = "Once upon a time" - - # Test basic generation if supported - try: - generated = model.generate(prompt, max_new_tokens=5) - assert isinstance(generated, (str, list, torch.Tensor)) - except (AttributeError, RuntimeError): - # Generation might not be implemented yet for all bridge models - pytest.skip("Generation not supported for this TransformerBridge model") + generated = model.generate(prompt, max_new_tokens=5) + assert isinstance(generated, (str, list, torch.Tensor)) @pytest.mark.parametrize("method", ["to_tokens", "to_string", "to_str_tokens"]) def test_tokenization_methods(self, model, method): diff --git a/transformer_lens/model_bridge/architecture_adapter.py b/transformer_lens/model_bridge/architecture_adapter.py index f6cdbafb3..6ed183c80 100644 --- a/transformer_lens/model_bridge/architecture_adapter.py +++ b/transformer_lens/model_bridge/architecture_adapter.py @@ -44,6 +44,10 @@ class ArchitectureAdapter: # documented in ~/.claude/plans/ssm-verification-compatibility.md. applicable_phases: list[int] = [1, 2, 3, 4] + # Whether this architecture supports text generation via generate(). + # Encoder-only models (e.g. BERT, HuBERT) should set this to False. + supports_generation: bool = True + def __init__(self, cfg: TransformerBridgeConfig) -> None: """Initialize the architecture adapter. diff --git a/transformer_lens/model_bridge/supported_architectures/bert.py b/transformer_lens/model_bridge/supported_architectures/bert.py index bcd4a4877..4a12e1c47 100644 --- a/transformer_lens/model_bridge/supported_architectures/bert.py +++ b/transformer_lens/model_bridge/supported_architectures/bert.py @@ -25,6 +25,8 @@ class BertArchitectureAdapter(ArchitectureAdapter): """Architecture adapter for BERT models.""" + supports_generation: bool = False + def __init__(self, cfg: Any) -> None: """Initialize the BERT architecture adapter. diff --git a/transformer_lens/model_bridge/supported_architectures/hubert.py b/transformer_lens/model_bridge/supported_architectures/hubert.py index 2f73b311a..0e11e8682 100644 --- a/transformer_lens/model_bridge/supported_architectures/hubert.py +++ b/transformer_lens/model_bridge/supported_architectures/hubert.py @@ -38,6 +38,8 @@ class HubertArchitectureAdapter(ArchitectureAdapter): prepare_model() detects this and adjusts component paths. """ + supports_generation: bool = False + def __init__(self, cfg: Any) -> None: super().__init__(cfg)