CUDA-parity rollback for the all-MLX fused loop (keep accepted K/V, trim only rejected) — +33% on code, ~AR parity#115
Draft
FluffyAIcode wants to merge 4 commits into
Conversation
All-MLX fused but WITHOUT --ignore-turn-stop, so generation ends at the real answer. For comparing mean_accept_len (natural-stop) vs the forced over-generation of k3-step2-fused-allmlx, to confirm on the real Mac that the low '2.13' accept is a forced-over-gen artifact, not a drafter/quant/restoration deficiency. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
…preset Honest spec-decode throughput probe: all-MLX fused on naturally-long, predictable code-completion prompts (the spec-decode sweet spot), natural stop. Reports decode-only tok/s (fused vs oracle AR) + acceptance. --code-prompts skips the NIAH recall gate (recall N/A by design). Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
…pted, drop rejected) Eliminates the v3 carry re-forward. Root cause: RotatingKVCache not trimmable once wrapped (is_trimmable -> offset<max_size), so v3 rolls the block back + re-forwards carried accepted tokens. Fix: prefill all-KVCache layout (sliding on full KVCache too -- byte-exact, window mask applies regardless of capacity) -> trim_prompt_cache is a sound O(1) slice on every layer. - restored_prefill_cache: +cache_factory; fused_specdecode.make_full_kv_prompt_cache; fused_specdecode_generate_mlx_trim (forward L, keep accepted, trim L-k, no carry); adapter.prefill +full_kv; harness --cuda-trim; manifest k3-fused-allmlx-code-trim. Linux: compiles; +1 UT; 4 pre-existing b876 failures unchanged. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
fused_specdecode_generate_mlx_trim(single_fused=True): skip the two-phase eval so drafter+26B fuse into ONE graph (the b876-pathological path); report per-block eval times (first8/max/mean). Harness --single-fused + preset k3-fused-singlefused-probe (n=2,gen=16 so a pathological block is bounded). Classifies fundamental command-buffer limit (eval scales w/ graph) vs fixable SDPA fallback (eval huge even at small scale). Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
This was referenced Jun 13, 2026
Closed
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Replaces the v3 carry re-forward rollback with a CUDA-
DynamicCache-parity rollback: keep the accepted tokens' K/V, trim only the rejected tail — no re-forward. Adds a code-completion workload + acceptance/instability probes.Why (root cause)
RotatingKVCache.is_trimmable()→offset < max_size, so once the sliding ring wraps it can't be trimmed; v3 rolls the whole block back and re-forwards the carried accepted tokens every partial-accept block (~2 verifier forwards/block). That re-forward is a real, removable penalty — acceptance/quantization/length all ruled out by GPU+Mac experiments.Fix
make_full_kv_prompt_cache: fullKVCachefor every layer (sliding too). Byte-exact (per-layer window mask applies regardless of cache capacity); cost is O(T) sliding KV during decode — fine for short-context code/agent.fused_specdecode_generate_mlx_trim: forward[bonus+drafts], accept the in-graph cumprodk,trim_prompt_cache(L−k)(sound O(1) slice on all-KVCache), advance_past_len += k. Accepted K/V stay cached — never recomputed. Levers ①②③ retained.single_fused=Trueprobe path.restored_prefill_cache(+cache_factory),prefill(+full_kv), harness--cuda-trim/--single-fused/--code-prompts, presetsk3-fused-allmlx-code[-trim],k3-fused-allmlx-natural,k3-fused-singlefused-probe.Evidence (real Mac,
kakeya-mac-m4, all-MLX fused, code workload)Rollback fix (block-4): v3 carry 13.94 tok/s (0.51× AR) → CUDA-trim 17.31 tok/s (0.68× AR), +33%; long-completion samples ~0.85–0.96×; output bit-identical to v3 (6/8 byte-exact vs oracle; the 2/8 are pre-existing bf16 drift, identical in v3).
Block-size sweep (CUDA-trim, long-completion best sample):
verify(16)cost wasted; accept plateausAcceptance gap: our all-MLX drafter's accept-len plateaus at ~4.5 regardless of block size; z-lab's reference reaches ~7.7 at block-16. That 4.5-vs-7.7 gap is the port-fidelity / alignment residual — so z-lab's block-16 choice only pays off with a reference-quality drafter; on our port block-8 is optimal and block-16 just wastes
verify(16)compute.Metal two-phase-sync probe (
single_fused): at the short-context code workload, single-fused runs stably (~0.16s/block, no 143s blowup) → the b876 "single-fused-graph pathology" is scale-dependent (large-cache / ctx280), not fundamental to fusing the graphs. But collapsing to one sync is not the throughput lever:build_s1.4s→0.15s yeteval_s3.5→4.8s → net ≈ equal (~1.05× best). The binding constraint is the 26Bverify(L)compute per block, not rollback (fixed) or sync count (≈equal).Net
CUDA-trim + block-8 brings Mac long-code spec-decode from 0.26–0.51× up to ~AR parity (best ~1.05×) — a real, correct improvement. >AR meaningfully remains CUDA-favored (H200 1.27×) because the limit is the verify compute (H200's verify-batch is far cheaper relative to its decode); on Mac neither the rollback fix nor a one-sync fused graph overcomes it. The two open levers are higher drafter acceptance (close the 4.5→7.7 gap) and cheaper verify.
Note (separate, pre-existing)
All-MLX fused is not fully lossless (6/8) vs greedy AR due to bf16 drift (present in v3 too); worth fp32 verify accumulation / tie handling to restore byte-exactness.
Testing
test_fused_specdecode.pyfixture failures unchanged (stash-verified).evidence_violations: [].