Skip to content

CUDA-parity rollback for the all-MLX fused loop (keep accepted K/V, trim only rejected) — +33% on code, ~AR parity#115

Draft
FluffyAIcode wants to merge 4 commits into
AgentMemory/k3-all-mlx-drafter-b876from
AgentMemory/mac-niah-accept-compare-2815
Draft

CUDA-parity rollback for the all-MLX fused loop (keep accepted K/V, trim only rejected) — +33% on code, ~AR parity#115
FluffyAIcode wants to merge 4 commits into
AgentMemory/k3-all-mlx-drafter-b876from
AgentMemory/mac-niah-accept-compare-2815

Conversation

@FluffyAIcode

@FluffyAIcode FluffyAIcode commented Jun 13, 2026

Copy link
Copy Markdown
Owner

What

Replaces the v3 carry re-forward rollback with a CUDA-DynamicCache-parity rollback: keep the accepted tokens' K/V, trim only the rejected tail — no re-forward. Adds a code-completion workload + acceptance/instability probes.

Why (root cause)

RotatingKVCache.is_trimmable()offset < max_size, so once the sliding ring wraps it can't be trimmed; v3 rolls the whole block back and re-forwards the carried accepted tokens every partial-accept block (~2 verifier forwards/block). That re-forward is a real, removable penalty — acceptance/quantization/length all ruled out by GPU+Mac experiments.

Fix

  • make_full_kv_prompt_cache: full KVCache for every layer (sliding too). Byte-exact (per-layer window mask applies regardless of cache capacity); cost is O(T) sliding KV during decode — fine for short-context code/agent.
  • fused_specdecode_generate_mlx_trim: forward [bonus+drafts], accept the in-graph cumprod k, trim_prompt_cache(L−k) (sound O(1) slice on all-KVCache), advance _past_len += k. Accepted K/V stay cached — never recomputed. Levers ①②③ retained. single_fused=True probe path.
  • restored_prefill_cache(+cache_factory), prefill(+full_kv), harness --cuda-trim / --single-fused / --code-prompts, presets k3-fused-allmlx-code[-trim], k3-fused-allmlx-natural, k3-fused-singlefused-probe.

Evidence (real Mac, kakeya-mac-m4, all-MLX fused, code workload)

Rollback fix (block-4): v3 carry 13.94 tok/s (0.51× AR) → CUDA-trim 17.31 tok/s (0.68× AR), +33%; long-completion samples ~0.85–0.96×; output bit-identical to v3 (6/8 byte-exact vs oracle; the 2/8 are pre-existing bf16 drift, identical in v3).

Block-size sweep (CUDA-trim, long-completion best sample):

block accept-len best long-sample note
4 3.1 ~0.96× AR
8 4.5 ~1.02–1.06× (>AR) sweet spot
16 4.5 ~0.74–0.89× worse — verify(16) cost wasted; accept plateaus

Acceptance gap: our all-MLX drafter's accept-len plateaus at ~4.5 regardless of block size; z-lab's reference reaches ~7.7 at block-16. That 4.5-vs-7.7 gap is the port-fidelity / alignment residual — so z-lab's block-16 choice only pays off with a reference-quality drafter; on our port block-8 is optimal and block-16 just wastes verify(16) compute.

Metal two-phase-sync probe (single_fused): at the short-context code workload, single-fused runs stably (~0.16s/block, no 143s blowup) → the b876 "single-fused-graph pathology" is scale-dependent (large-cache / ctx280), not fundamental to fusing the graphs. But collapsing to one sync is not the throughput lever: build_s 1.4s→0.15s yet eval_s 3.5→4.8s → net ≈ equal (~1.05× best). The binding constraint is the 26B verify(L) compute per block, not rollback (fixed) or sync count (≈equal).

Net

CUDA-trim + block-8 brings Mac long-code spec-decode from 0.26–0.51× up to ~AR parity (best ~1.05×) — a real, correct improvement. >AR meaningfully remains CUDA-favored (H200 1.27×) because the limit is the verify compute (H200's verify-batch is far cheaper relative to its decode); on Mac neither the rollback fix nor a one-sync fused graph overcomes it. The two open levers are higher drafter acceptance (close the 4.5→7.7 gap) and cheaper verify.

Note (separate, pre-existing)

All-MLX fused is not fully lossless (6/8) vs greedy AR due to bf16 drift (present in v3 too); worth fp32 verify accumulation / tie handling to restore byte-exactness.

Testing

  • ✅ Linux: compiles; +1 UT; 4 pre-existing test_fused_specdecode.py fixture failures unchanged (stash-verified).
  • ✅ Real Mac (bridge): CUDA-trim block-4 (0.68×, +33% vs v3), block-8 (~1.05× best long), block-16 (worse); single-fused probe (stable ~0.16s/block); all recall 1.0, output bit-identical to v3, evidence_violations: [].
Open in Web Open in Cursor 

cursoragent and others added 4 commits June 13, 2026 08:36
All-MLX fused but WITHOUT --ignore-turn-stop, so generation ends at the real
answer. For comparing mean_accept_len (natural-stop) vs the forced over-generation
of k3-step2-fused-allmlx, to confirm on the real Mac that the low '2.13' accept is
a forced-over-gen artifact, not a drafter/quant/restoration deficiency.

Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
…preset

Honest spec-decode throughput probe: all-MLX fused on naturally-long, predictable
code-completion prompts (the spec-decode sweet spot), natural stop. Reports
decode-only tok/s (fused vs oracle AR) + acceptance. --code-prompts skips the
NIAH recall gate (recall N/A by design).

Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
…pted, drop rejected)

Eliminates the v3 carry re-forward. Root cause: RotatingKVCache not trimmable once
wrapped (is_trimmable -> offset<max_size), so v3 rolls the block back + re-forwards
carried accepted tokens. Fix: prefill all-KVCache layout (sliding on full KVCache too
-- byte-exact, window mask applies regardless of capacity) -> trim_prompt_cache is a
sound O(1) slice on every layer.

- restored_prefill_cache: +cache_factory; fused_specdecode.make_full_kv_prompt_cache;
  fused_specdecode_generate_mlx_trim (forward L, keep accepted, trim L-k, no carry);
  adapter.prefill +full_kv; harness --cuda-trim; manifest k3-fused-allmlx-code-trim.
Linux: compiles; +1 UT; 4 pre-existing b876 failures unchanged.

Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
fused_specdecode_generate_mlx_trim(single_fused=True): skip the two-phase eval so
drafter+26B fuse into ONE graph (the b876-pathological path); report per-block eval
times (first8/max/mean). Harness --single-fused + preset k3-fused-singlefused-probe
(n=2,gen=16 so a pathological block is bounded). Classifies fundamental command-buffer
limit (eval scales w/ graph) vs fixable SDPA fallback (eval huge even at small scale).

Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants