Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions tests/backends/mlx/test_fused_specdecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def append_token(self, token_id):
self.appends.append(token_id)
return self.next_token_logits

def last_aux_torch_slice(self, start=0, end=None):
# Mirror MLXRestoredIncrementalVerifier.last_aux_torch_slice: per-aux-layer
# torch rows of the most recent forward_block, sliced [start:end].
aux = self._last_aux or [torch.zeros(1, self.hidden)]
return [a[start:end] for a in aux]


class _FakeDrafter:
def __init__(self, drafts):
Expand Down Expand Up @@ -89,13 +95,16 @@ def test_fused_loop_full_acceptance():
res = fsd.fused_specdecode_generate(
adapter, drafter, gen_tokens=5, block_size=4, eos_ids=(),
**_loop_kwargs(drafter))
# Block1: candidate=[100,101,102] all accepted (3) + correction 103.
# Block2: candidate=[104] accepted (1) + correction 105 -> truncated to 5.
# Block1: candidate=[100,101,102] fully accepted (3). On FULL acceptance the
# loop reuses block_logits[-1] (=103) as the next distribution and does NOT
# append a correction token. next=103.
# Block2: L=2 -> candidate=[103,200]; accept 103 (1), reject 200, correction
# =104 appended -> commit [103,104]; total 5 tokens.
assert res["tokens"] == [100, 101, 102, 103, 104]
assert res["blocks"] == 2
assert res["mean_accept_len"] == 2.0 # (3 + 1) / 2
assert adapter.commits[0] == (3, 3) # block1 verify-commit
assert adapter.appends == [103, 105] # one correction per block
assert adapter.appends == [104] # only block2's correction
# capture flag toggled on during loop, off after.
assert adapter._capture_aux is False
# context K/V extended once per block.
Expand All @@ -122,9 +131,27 @@ def test_fused_loop_stops_on_eos():
res = fsd.fused_specdecode_generate(
adapter, drafter, gen_tokens=50, block_size=4, eos_ids=(103,),
**_loop_kwargs(drafter))
# correction 103 is EOS -> stop after first block.
# Block1 fully accepts [100,101,102] (no correction appended on full accept),
# leaving next=103. Block2's bonus is then 103 (EOS), committed and stopped.
assert res["tokens"] == [100, 101, 102, 103]
assert res["blocks"] == 1
assert res["blocks"] == 2


def test_fused_loop_greedy_fallback_on_low_acceptance():
adapter = _FakeAdapter(prompt_len=5, first_token=100)
# Each block accepts only the bonus (drafts mismatch the verifier), so after
# 2 blocks mean acceptance = 1.0 < 1.5 and the loop switches to plain greedy
# to finish the budget (no aux capture, no drafter extension past the blocks).
drafter = _FakeDrafter(drafts=[[999, 999, 999], [999, 999, 999]])
res = fsd.fused_specdecode_generate(
adapter, drafter, gen_tokens=10, block_size=4, eos_ids=(),
**_loop_kwargs(drafter))
# blocks 1-2 commit [100,101] then [102,103]; greedy fallback adds 104..109.
assert res["tokens"] == [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
assert res["blocks"] == 2 # only the speculative blocks are counted
assert res["mean_accept_len"] == 1.0 # (1 + 1) / 2
assert adapter._capture_aux is False # turned off for the greedy tail
assert drafter.extend_calls == 2 # extended only during the spec blocks


# =========================================================================== #
Expand Down Expand Up @@ -273,8 +300,10 @@ def as_linear(self, h): return "L"
adapter._capture_aux = True
logits = adapter.forward_block([7, 8])
assert logits == "ROW" # _Model returns row at [0]
# aux = [hs[1]] = [layer-0 output]; bridged after stripping batch ([0]).
assert adapter._last_aux == [("torch", ("row", 0))]
# aux = [hs[1]] = [layer-0 output], captured LAZILY in MX (_last_aux_mx);
# _last_aux stays None and the torch bridge happens on demand.
assert adapter._last_aux is None
assert adapter.last_aux_torch_slice() == [("torch", ("row", 0))]

# commit_or_truncate trims by (forwarded - accepted) and advances _past_len
adapter.commit_or_truncate(forwarded=2, accepted=1)
Expand Down
Loading