diff --git a/tests/backends/mlx/test_fused_specdecode.py b/tests/backends/mlx/test_fused_specdecode.py index de92aa7a..dd860a43 100644 --- a/tests/backends/mlx/test_fused_specdecode.py +++ b/tests/backends/mlx/test_fused_specdecode.py @@ -51,6 +51,12 @@ def append_token(self, token_id): self.appends.append(token_id) return self.next_token_logits + def last_aux_torch_slice(self, start=0, end=None): + # Mirror MLXRestoredIncrementalVerifier.last_aux_torch_slice: per-aux-layer + # torch rows of the most recent forward_block, sliced [start:end]. + aux = self._last_aux or [torch.zeros(1, self.hidden)] + return [a[start:end] for a in aux] + class _FakeDrafter: def __init__(self, drafts): @@ -89,13 +95,16 @@ def test_fused_loop_full_acceptance(): res = fsd.fused_specdecode_generate( adapter, drafter, gen_tokens=5, block_size=4, eos_ids=(), **_loop_kwargs(drafter)) - # Block1: candidate=[100,101,102] all accepted (3) + correction 103. - # Block2: candidate=[104] accepted (1) + correction 105 -> truncated to 5. + # Block1: candidate=[100,101,102] fully accepted (3). On FULL acceptance the + # loop reuses block_logits[-1] (=103) as the next distribution and does NOT + # append a correction token. next=103. + # Block2: L=2 -> candidate=[103,200]; accept 103 (1), reject 200, correction + # =104 appended -> commit [103,104]; total 5 tokens. assert res["tokens"] == [100, 101, 102, 103, 104] assert res["blocks"] == 2 assert res["mean_accept_len"] == 2.0 # (3 + 1) / 2 assert adapter.commits[0] == (3, 3) # block1 verify-commit - assert adapter.appends == [103, 105] # one correction per block + assert adapter.appends == [104] # only block2's correction # capture flag toggled on during loop, off after. assert adapter._capture_aux is False # context K/V extended once per block. @@ -122,9 +131,27 @@ def test_fused_loop_stops_on_eos(): res = fsd.fused_specdecode_generate( adapter, drafter, gen_tokens=50, block_size=4, eos_ids=(103,), **_loop_kwargs(drafter)) - # correction 103 is EOS -> stop after first block. + # Block1 fully accepts [100,101,102] (no correction appended on full accept), + # leaving next=103. Block2's bonus is then 103 (EOS), committed and stopped. assert res["tokens"] == [100, 101, 102, 103] - assert res["blocks"] == 1 + assert res["blocks"] == 2 + + +def test_fused_loop_greedy_fallback_on_low_acceptance(): + adapter = _FakeAdapter(prompt_len=5, first_token=100) + # Each block accepts only the bonus (drafts mismatch the verifier), so after + # 2 blocks mean acceptance = 1.0 < 1.5 and the loop switches to plain greedy + # to finish the budget (no aux capture, no drafter extension past the blocks). + drafter = _FakeDrafter(drafts=[[999, 999, 999], [999, 999, 999]]) + res = fsd.fused_specdecode_generate( + adapter, drafter, gen_tokens=10, block_size=4, eos_ids=(), + **_loop_kwargs(drafter)) + # blocks 1-2 commit [100,101] then [102,103]; greedy fallback adds 104..109. + assert res["tokens"] == [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] + assert res["blocks"] == 2 # only the speculative blocks are counted + assert res["mean_accept_len"] == 1.0 # (1 + 1) / 2 + assert adapter._capture_aux is False # turned off for the greedy tail + assert drafter.extend_calls == 2 # extended only during the spec blocks # =========================================================================== # @@ -273,8 +300,10 @@ def as_linear(self, h): return "L" adapter._capture_aux = True logits = adapter.forward_block([7, 8]) assert logits == "ROW" # _Model returns row at [0] - # aux = [hs[1]] = [layer-0 output]; bridged after stripping batch ([0]). - assert adapter._last_aux == [("torch", ("row", 0))] + # aux = [hs[1]] = [layer-0 output], captured LAZILY in MX (_last_aux_mx); + # _last_aux stays None and the torch bridge happens on demand. + assert adapter._last_aux is None + assert adapter.last_aux_torch_slice() == [("torch", ("row", 0))] # commit_or_truncate trims by (forwarded - accepted) and advances _past_len adapter.commit_or_truncate(forwarded=2, accepted=1)