Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion inference_engine/backends/mlx/cross_model_dlm_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def restored_prefill_cache(
restored_v_per_layer: Dict[int, Any],
evicted_positions: Sequence[int],
prefill_chunk_size: int = 0,
cache_factory: Optional[Callable[[Any], Any]] = None,
):
"""Prefill ONCE with restoration, capturing the restored K/V into a
persistent mlx_lm prompt cache; return ``(cache, last_logits)``.
Expand All @@ -397,7 +398,15 @@ def restored_prefill_cache(
text_model = resolve_mlx_text_model(mlx_model)
T = len(list(input_ids))
evicted = set(int(p) for p in evicted_positions if 0 <= int(p) < T)
cache = make_prompt_cache(mlx_model)
# cache_factory lets the caller swap the model's native hybrid cache for an
# all-`KVCache` layout (full store for sliding layers too) so that the
# spec-decode accept/reject rollback can use mlx_lm's native, SOUND
# `trim_prompt_cache` (keep accepted K/V, drop only rejected) instead of the
# full re-forward carry — `RotatingKVCache` is not trimmable once wrapped.
# Sliding attention stays byte-exact: the window mask is applied regardless
# of cache capacity. (Costs O(T) sliding KV during decode; fine for the
# short-context code/agent workloads this targets.)
cache = (cache_factory or make_prompt_cache)(mlx_model)

def _slice_restored(a, start: int, end: int):
if a is None:
Expand Down
144 changes: 144 additions & 0 deletions inference_engine/backends/mlx/fused_specdecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
self._last_aux_mx: Optional[List[Any]] = None
self._capture_aux = False
self._block_snapshot: Optional[List[Dict[str, Any]]] = None
self._full_kv = False

def reset(self) -> None:
self._cache = None
Expand All @@ -174,16 +175,22 @@ def prefill(
restored_v_per_layer: Dict[int, Any],
evicted_positions: Sequence[int],
prefill_chunk_size: int = 0,
full_kv: bool = False,
) -> None:
if not prompt_ids:
raise ValueError("prompt_ids must be non-empty")
self.reset()
# full_kv=True → all-`KVCache` layout so accept/reject rollback can use
# SOUND native trim (keep accepted, drop rejected) with no re-forward.
self._full_kv = bool(full_kv)
factory = make_full_kv_prompt_cache if full_kv else None
self._cache, self.next_token_logits = restored_prefill_cache(
self.mlx_model, list(prompt_ids),
restored_k_per_layer=restored_k_per_layer,
restored_v_per_layer=restored_v_per_layer,
evicted_positions=evicted_positions,
prefill_chunk_size=prefill_chunk_size,
cache_factory=factory,
)
self._past_len = len(prompt_ids)

Expand Down Expand Up @@ -336,6 +343,143 @@ def lm_head_fn(h: Any) -> Any:
# Single-sync all-MLX fused loop (levers ① ② ③ of the Step-2 throughput plan;
# docs/mlx-port-lessons.md "Step-2 rescue status").
# --------------------------------------------------------------------------- #
def make_full_kv_prompt_cache(mlx_model: Any) -> List[Any]:
"""Build a prompt cache that uses a full append-only ``KVCache`` for EVERY
layer (including the sliding-attention ones, which the model's native
``make_cache`` would give a ``RotatingKVCache``).

Why: ``RotatingKVCache`` is not trimmable once the ring has wrapped
(``is_trimmable`` → ``offset < max_size``), so spec-decode accept/reject
rollback cannot keep the accepted K/V via a cheap trim — it must re-forward
(the v3 carry penalty). With an all-``KVCache`` layout, ``trim_prompt_cache``
is a sound O(1) slice on every layer, so the loop keeps accepted K/V and
drops only the rejected tail (CUDA `DynamicCache` parity). Sliding attention
remains byte-exact because the per-layer window mask is applied regardless
of cache capacity; the only cost is O(T) sliding KV during decode.
"""
from mlx_lm.models.cache import make_prompt_cache, KVCache # type: ignore

n = len(make_prompt_cache(mlx_model))
return [KVCache() for _ in range(n)]


def fused_specdecode_generate_mlx_trim(
adapter: "MLXRestoredIncrementalVerifier",
drafter: Any,
*,
aux_prompt: Sequence[Any],
embed_fn: Callable[[Any], Any],
lm_head_fn: Callable[[Any], Any],
gen_tokens: int,
block_size: int,
eos_ids: Sequence[int] = (),
single_fused: bool = False,
) -> Dict[str, Any]:
"""CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected
tail (no rollback, no carry re-forward). Requires the adapter to be
prefilled with ``full_kv=True`` (all-``KVCache`` layout) so the native
``trim_prompt_cache`` is sound. Levers ①②③ retained (lazy draft+verify
single graph, in-graph cumprod acceptance, carried correction).

Per block: forward ``[bonus + drafts]`` (L tokens) → cache = base+L; accept
the leading match count ``k`` (bonus always accepts); ``trim_prompt_cache``
drops the L−k rejected tokens; advance ``_past_len`` by ``k``. The accepted
tokens' K/V (computed in this forward) stay in the cache — never recomputed.
"""
import mlx.core as mx # type: ignore
from mlx_lm.models.cache import trim_prompt_cache # type: ignore

eos = set(int(t) for t in eos_ids)
C = adapter._past_len
ctx_kv = drafter.make_context_kv(list(aux_prompt), mx.arange(0, C))
mx.async_eval([t for kv in ctx_kv for t in kv])
timing = {"ctx_kv_build_s": 0.0, "build_s": 0.0, "eval_s": 0.0, "extend_s": 0.0}
adapter._capture_aux = True

generated: List[int] = []
accepts: List[int] = []
block_evals: List[float] = []
ctx_len = C
try:
while len(generated) < gen_tokens:
L = min(block_size, gen_tokens - len(generated))
base = adapter._past_len
t_build = time.perf_counter()
bonus_id = mx.argmax(adapter.next_token_logits) # lazy scalar
n_draft = max(L - 1, 0)
if n_draft:
drafts = drafter.draft_block_ids(
ctx_kv, bonus_id, embed_fn, lm_head_fn,
n_masks=n_draft, context_len=base)
check_ids = mx.concatenate([bonus_id[None], drafts]) # [L]
if not single_fused:
mx.eval(check_ids) # two-phase (drafter graph before 26B)
# single_fused=True → leave check_ids LAZY so the drafter and
# 26B verify fuse into ONE graph (the path b876 found Metal-
# pathological); this probe times it to classify the instability.
else:
check_ids = bonus_id[None]
block_logits = adapter.forward_block_lazy(check_ids[None]) # [L, V]
# in-graph greedy acceptance over the check region
pred_rows = mx.concatenate(
[adapter.next_token_logits[None], block_logits[:max(L - 1, 0)]],
axis=0)
matches = (mx.argmax(pred_rows, axis=-1) == check_ids)
accepted_mx = mx.sum(mx.cumprod(matches.astype(mx.int32)))
rows = mx.concatenate(
[adapter.next_token_logits[None], block_logits], axis=0) # [L+1,V]
next_row = mx.take(rows, accepted_mx[None], axis=0)[0] # [V]
timing["build_s"] += time.perf_counter() - t_build
t_eval = time.perf_counter()
mx.eval(accepted_mx, check_ids)
blk_eval = time.perf_counter() - t_eval
timing["eval_s"] += blk_eval
block_evals.append(round(blk_eval, 4))
accepted = int(accepted_mx.item())
check = [int(x) for x in check_ids.tolist()]
commit = check[:accepted]
generated += commit
accepts.append(accepted)
adapter.next_token_logits = next_row
aux_rows = adapter._last_aux_mx
# KEEP accepted (positions base..base+accepted-1), TRIM rejected.
drop = L - accepted
if drop > 0:
trim_prompt_cache(adapter._cache, drop)
adapter._past_len = base + accepted
S_new = adapter._past_len
lo, hi = ctx_len - base, S_new - base
if hi > lo and aux_rows is not None:
t_extend = time.perf_counter()
new_aux = [a[lo:hi][None] for a in aux_rows]
ctx_kv = drafter.extend_context_kv(
ctx_kv,
drafter.make_context_kv(new_aux, mx.arange(ctx_len, S_new)))
mx.async_eval([t for kv in ctx_kv for t in kv])
ctx_len = S_new
timing["extend_s"] += time.perf_counter() - t_extend
if any(t in eos for t in commit):
break
finally:
adapter._capture_aux = False
generated = generated[:gen_tokens]
return {
"tokens": generated,
"blocks": len(accepts),
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
if accepts else 0.0),
"decode_tokens": len(generated),
"loop": ("mlx_trim_single_fused_probe" if single_fused
else "mlx_trim_keep_accepted_cuda_parity"),
"single_fused": bool(single_fused),
"block_eval_s_first8": block_evals[:8],
"block_eval_s_max": (round(max(block_evals), 4) if block_evals else None),
"block_eval_s_mean": (round(sum(block_evals) / len(block_evals), 4)
if block_evals else None),
"time_breakdown_s": {k: round(v, 3) for k, v in timing.items()},
}


def fused_specdecode_generate_mlx(
adapter: "MLXRestoredIncrementalVerifier",
drafter: Any,
Expand Down
123 changes: 123 additions & 0 deletions inference_engine/bridge/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,129 @@ def _harness_preset(
timeout_minutes=45,
params={"path": ("path:tests", None)},
),
Preset(
name="k3-fused-singlefused-probe",
description="PROBE: single-fused (one drafter+26B graph) vs two-phase, "
"to classify the Metal instability. Small (n=2, gen=16) so a "
"pathological per-block eval is bounded. Compare block_eval_s "
"vs k3-fused-allmlx-code-trim (two-phase).",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode",
"--all-mlx-drafter", "--code-prompts", "--cuda-trim",
"--single-fused",
"--n-samples", "{n_samples}",
"--max-new-tokens", "{max_new_tokens}",
"--block-size", "{block_size}",
"--prefill-chunk-size", "512",
"--output",
"results/research/k3_mac_bridge_k3_fused_singlefused_probe.json",
),
),
timeout_minutes=60,
params={
"n_samples": ("int:n_samples", "2"),
"max_new_tokens": ("int:max_new_tokens", "16"),
"block_size": ("int:block_size", "4"),
},
validate_reports=False,
),
Preset(
name="k3-fused-allmlx-code-trim",
description="CUDA-parity rollback test: all-MLX fused + --cuda-trim "
"(all-KVCache + native trim, keep accepted / drop rejected, "
"no re-forward) on the code-completion workload. Compare "
"decode-only tok/s vs k3-fused-allmlx-code (v3 carry).",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode",
"--all-mlx-drafter", "--code-prompts", "--cuda-trim",
"--n-samples", "{n_samples}",
"--max-new-tokens", "{max_new_tokens}",
"--block-size", "{block_size}",
"--prefill-chunk-size", "512",
"--output",
"results/research/k3_mac_bridge_k3_fused_allmlx_code_trim.json",
),
),
timeout_minutes=120,
params={
"n_samples": ("int:n_samples", "8"),
"max_new_tokens": ("int:max_new_tokens", "128"),
"block_size": ("int:block_size", "4"),
},
validate_reports=False,
),
Preset(
name="k3-fused-allmlx-code",
description="HONEST spec-decode throughput probe: all-MLX fused on a "
"code-completion workload (naturally-long, predictable gen "
"= the spec-decode sweet spot), natural stop. Reports "
"decode-only tok/s (fused vs oracle AR) + acceptance.",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode",
"--all-mlx-drafter", "--code-prompts",
# natural stop (no --ignore-turn-stop); code finishes itself
"--n-samples", "{n_samples}",
"--max-new-tokens", "{max_new_tokens}",
"--block-size", "{block_size}",
"--prefill-chunk-size", "512",
"--output",
"results/research/k3_mac_bridge_k3_fused_allmlx_code.json",
),
),
timeout_minutes=120,
params={
"n_samples": ("int:n_samples", "8"),
"max_new_tokens": ("int:max_new_tokens", "128"),
"block_size": ("int:block_size", "4"),
},
validate_reports=False,
),
Preset(
name="k3-fused-allmlx-natural",
description="Acceptance probe: all-MLX fused, NATURAL stop (no "
"--ignore-turn-stop) so generation ends at the real "
"answer. Compare mean_accept_len vs the forced "
"k3-step2-fused-allmlx (which over-generates).",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode",
"--all-mlx-drafter",
# deliberately NO --ignore-turn-stop (natural stop)
"--n-samples", "{n_samples}",
"--max-new-tokens", "{max_new_tokens}",
"--block-size", "{block_size}",
"--prefill-chunk-size", "512",
"--output",
"results/research/k3_mac_bridge_k3_fused_allmlx_natural.json",
),
),
timeout_minutes=120,
params={
"n_samples": ("int:n_samples", "5"),
"max_new_tokens": ("int:max_new_tokens", "48"),
"block_size": ("int:block_size", "4"),
},
validate_reports=False,
),
)
}

Expand Down
Loading