From bfcd7a0c404a93b4bc0f8b58a56005feba57890d Mon Sep 17 00:00:00 2001 From: Ephraiem Sarabamoun Date: Sat, 30 May 2026 06:50:42 -0700 Subject: [PATCH] Parametrize Cohere logit_scale fold test with a non-trivial scale The test_embed_and_unembed_weights_differ test guarded the fold assertion behind an if logit_scale == 1.0: pytest.skip(...) branch. The fixture model trl-internal-testing/tiny-CohereForCausalLM ships logit_scale=0.0625, so the guard was dead code and the no-op case was never covered. Parametrize logit_scale over a non-trivial value (0.0625) and 1.0. The non-trivial case sets cfg.logit_scale before process_weights so the fold runs and the assertion that embed.W_E and unembed.weight differ fires. The 1.0 case is kept as a regression guard that asserts the weights stay tied when the fold is a no-op. Closes #1325. --- .../model_bridge/test_cohere_adapter.py | 42 +++++++++++++------ 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/tests/integration/model_bridge/test_cohere_adapter.py b/tests/integration/model_bridge/test_cohere_adapter.py index 5e92c1616..3a29c196a 100644 --- a/tests/integration/model_bridge/test_cohere_adapter.py +++ b/tests/integration/model_bridge/test_cohere_adapter.py @@ -207,20 +207,36 @@ def test_embed_weight_equals_hf_embed_tokens( max_diff < 1e-6 ), f"embed.W_E was corrupted (possibly by logit_scale fold): max_diff={max_diff:.6f}" - def test_embed_and_unembed_weights_differ( - self, cohere_bridge_processed: TransformerBridge - ) -> None: - # After the logit_scale fold, embed.W_E and unembed.weight must NOT be identical. - # If they are, the untie or fold did not take effect. - logit_scale = getattr(cohere_bridge_processed.cfg, "logit_scale") - if logit_scale == 1.0: - pytest.skip("logit_scale=1.0 — fold is a no-op, skip this check") - tl_embed = cohere_bridge_processed.embed.W_E - tl_unembed = cohere_bridge_processed.unembed.original_component.weight - assert not torch.allclose(tl_embed, tl_unembed), ( - "embed.W_E and unembed.weight are identical — " - "logit_scale fold may not have been applied or untied correctly" + @pytest.mark.parametrize("logit_scale", [0.0625, 1.0]) + def test_embed_and_unembed_weights_differ(self, logit_scale: float) -> None: + # After the logit_scale fold, embed.W_E and unembed.weight must NOT be + # identical for a non-trivial scale. logit_scale=1.0 is kept as a regression + # guard for the no-op case, where the two weights stay tied. + # + # cfg.logit_scale is set before process_weights so the fold (which reads it + # inside preprocess_weights) runs with the parametrized value. + bridge = TransformerBridge.boot_transformers(MODEL, device="cpu") + bridge.cfg.logit_scale = logit_scale # type: ignore[attr-defined] + bridge.process_weights( + fold_ln=False, + center_writing_weights=False, + center_unembed=False, + fold_value_biases=False, + refactor_factored_attn_matrices=False, ) + tl_embed = bridge.embed.W_E + tl_unembed = bridge.unembed.original_component.weight + weights_identical = torch.allclose(tl_embed, tl_unembed) + if logit_scale == 1.0: + assert weights_identical, ( + "embed.W_E and unembed.weight should remain tied when logit_scale=1.0 " + "(the fold is a no-op)" + ) + else: + assert not weights_identical, ( + "embed.W_E and unembed.weight are identical — " + "logit_scale fold may not have been applied or untied correctly" + ) # ---------------------------------------------------------------------------