Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions tests/integration/model_bridge/test_cohere_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,20 +207,36 @@ def test_embed_weight_equals_hf_embed_tokens(
max_diff < 1e-6
), f"embed.W_E was corrupted (possibly by logit_scale fold): max_diff={max_diff:.6f}"

def test_embed_and_unembed_weights_differ(
self, cohere_bridge_processed: TransformerBridge
) -> None:
# After the logit_scale fold, embed.W_E and unembed.weight must NOT be identical.
# If they are, the untie or fold did not take effect.
logit_scale = getattr(cohere_bridge_processed.cfg, "logit_scale")
if logit_scale == 1.0:
pytest.skip("logit_scale=1.0 — fold is a no-op, skip this check")
tl_embed = cohere_bridge_processed.embed.W_E
tl_unembed = cohere_bridge_processed.unembed.original_component.weight
assert not torch.allclose(tl_embed, tl_unembed), (
"embed.W_E and unembed.weight are identical — "
"logit_scale fold may not have been applied or untied correctly"
@pytest.mark.parametrize("logit_scale", [0.0625, 1.0])
def test_embed_and_unembed_weights_differ(self, logit_scale: float) -> None:
# After the logit_scale fold, embed.W_E and unembed.weight must NOT be
# identical for a non-trivial scale. logit_scale=1.0 is kept as a regression
# guard for the no-op case, where the two weights stay tied.
#
# cfg.logit_scale is set before process_weights so the fold (which reads it
# inside preprocess_weights) runs with the parametrized value.
bridge = TransformerBridge.boot_transformers(MODEL, device="cpu")
bridge.cfg.logit_scale = logit_scale # type: ignore[attr-defined]
bridge.process_weights(
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
fold_value_biases=False,
refactor_factored_attn_matrices=False,
)
tl_embed = bridge.embed.W_E
tl_unembed = bridge.unembed.original_component.weight
weights_identical = torch.allclose(tl_embed, tl_unembed)
if logit_scale == 1.0:
assert weights_identical, (
"embed.W_E and unembed.weight should remain tied when logit_scale=1.0 "
"(the fold is a no-op)"
)
else:
assert not weights_identical, (
"embed.W_E and unembed.weight are identical — "
"logit_scale fold may not have been applied or untied correctly"
)


# ---------------------------------------------------------------------------
Expand Down
Loading