diff --git a/tests/integration/model_bridge/test_cohere_adapter.py b/tests/integration/model_bridge/test_cohere_adapter.py index 5e92c1616..3a29c196a 100644 --- a/tests/integration/model_bridge/test_cohere_adapter.py +++ b/tests/integration/model_bridge/test_cohere_adapter.py @@ -207,20 +207,36 @@ def test_embed_weight_equals_hf_embed_tokens( max_diff < 1e-6 ), f"embed.W_E was corrupted (possibly by logit_scale fold): max_diff={max_diff:.6f}" - def test_embed_and_unembed_weights_differ( - self, cohere_bridge_processed: TransformerBridge - ) -> None: - # After the logit_scale fold, embed.W_E and unembed.weight must NOT be identical. - # If they are, the untie or fold did not take effect. - logit_scale = getattr(cohere_bridge_processed.cfg, "logit_scale") - if logit_scale == 1.0: - pytest.skip("logit_scale=1.0 — fold is a no-op, skip this check") - tl_embed = cohere_bridge_processed.embed.W_E - tl_unembed = cohere_bridge_processed.unembed.original_component.weight - assert not torch.allclose(tl_embed, tl_unembed), ( - "embed.W_E and unembed.weight are identical — " - "logit_scale fold may not have been applied or untied correctly" + @pytest.mark.parametrize("logit_scale", [0.0625, 1.0]) + def test_embed_and_unembed_weights_differ(self, logit_scale: float) -> None: + # After the logit_scale fold, embed.W_E and unembed.weight must NOT be + # identical for a non-trivial scale. logit_scale=1.0 is kept as a regression + # guard for the no-op case, where the two weights stay tied. + # + # cfg.logit_scale is set before process_weights so the fold (which reads it + # inside preprocess_weights) runs with the parametrized value. + bridge = TransformerBridge.boot_transformers(MODEL, device="cpu") + bridge.cfg.logit_scale = logit_scale # type: ignore[attr-defined] + bridge.process_weights( + fold_ln=False, + center_writing_weights=False, + center_unembed=False, + fold_value_biases=False, + refactor_factored_attn_matrices=False, ) + tl_embed = bridge.embed.W_E + tl_unembed = bridge.unembed.original_component.weight + weights_identical = torch.allclose(tl_embed, tl_unembed) + if logit_scale == 1.0: + assert weights_identical, ( + "embed.W_E and unembed.weight should remain tied when logit_scale=1.0 " + "(the fold is a no-op)" + ) + else: + assert not weights_identical, ( + "embed.W_E and unembed.weight are identical — " + "logit_scale fold may not have been applied or untied correctly" + ) # ---------------------------------------------------------------------------