Skip to content
Merged
73 changes: 72 additions & 1 deletion inference_engine/backends/mlx/fused_specdecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
restored_prefill_cache,
)


# --------------------------------------------------------------------------- #
# Component A: capture verifier aux-layer hidden states (no transformers
# `output_hidden_states` on MLX → patch the decoder-layer __call__).
Expand Down Expand Up @@ -374,6 +373,7 @@ def fused_specdecode_generate_mlx_trim(
block_size: int,
eos_ids: Sequence[int] = (),
single_fused: bool = False,
stop_on_runaway: bool = True,
) -> Dict[str, Any]:
"""CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected
tail (no rollback, no carry re-forward). Requires the adapter to be
Expand All @@ -399,6 +399,7 @@ def fused_specdecode_generate_mlx_trim(
generated: List[int] = []
accepts: List[int] = []
block_evals: List[float] = []
stopped_on_runaway = False
ctx_len = C
try:
while len(generated) < gen_tokens:
Expand Down Expand Up @@ -460,6 +461,12 @@ def fused_specdecode_generate_mlx_trim(
timing["extend_s"] += time.perf_counter() - t_extend
if any(t in eos for t in commit):
break
if stop_on_runaway:
drop = _trailing_runaway_drop(generated)
if drop > 0:
del generated[len(generated) - drop:]
stopped_on_runaway = True
break
finally:
adapter._capture_aux = False
generated = generated[:gen_tokens]
Expand All @@ -469,6 +476,7 @@ def fused_specdecode_generate_mlx_trim(
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
if accepts else 0.0),
"decode_tokens": len(generated),
"stopped_on_runaway": stopped_on_runaway,
"loop": ("mlx_trim_single_fused_probe" if single_fused
else "mlx_trim_keep_accepted_cuda_parity"),
"single_fused": bool(single_fused),
Expand All @@ -490,6 +498,7 @@ def fused_specdecode_generate_mlx(
gen_tokens: int,
block_size: int,
eos_ids: Sequence[int] = (),
stop_on_runaway: bool = True,
) -> Dict[str, Any]:
"""All-MLX fused spec decode with ONE host sync per block.

Expand Down Expand Up @@ -531,6 +540,7 @@ def fused_specdecode_generate_mlx(

generated: List[int] = []
accepts: List[int] = []
stopped_on_runaway = False
# Rollback-carry state: rejected blocks roll the WHOLE forward back
# (rollback_block — see its docstring for why trim is unsound on the
# wrapped sliding ring) and carry the stream-committed-but-not-cached
Expand Down Expand Up @@ -614,6 +624,12 @@ def fused_specdecode_generate_mlx(
timing["extend_s"] += time.perf_counter() - t_extend
if any(t in eos for t in commit):
break
if stop_on_runaway:
drop = _trailing_runaway_drop(generated)
if drop > 0:
del generated[len(generated) - drop:]
stopped_on_runaway = True
break
finally:
adapter._capture_aux = False
generated = generated[:gen_tokens]
Expand All @@ -623,6 +639,7 @@ def fused_specdecode_generate_mlx(
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
if accepts else 0.0),
"decode_tokens": len(generated),
"stopped_on_runaway": stopped_on_runaway,
"loop": "mlx_rollback_carry_v3",
"time_breakdown_s": {k: round(v, 3) for k, v in timing.items()},
}
Expand Down Expand Up @@ -655,6 +672,40 @@ def _sliding_ring_would_wrap(cache: Any, n_new: int) -> bool:
return False


def _trailing_runaway_drop(
ids: Sequence[int],
*,
max_period: int = 8,
min_reps: int = 12,
keep_reps: int = 3,
) -> int:
"""Return how many TRAILING tokens to drop if ``ids`` ends in a runaway
short-period loop, else 0.

A runaway loop is a unit of ``1..max_period`` tokens repeated ``>= min_reps``
times back-to-back at the tail (e.g. the ``**``/``.2``/``*`` markdown-marker
collapse greedy decoding falls into on code prompts). When found, we keep
``keep_reps`` instances and drop the rest, so callers can stop generation
with a clean tail instead of emitting an unbounded wall of repeats.

Deliberately CONSERVATIVE (>= 12 back-to-back repeats of a <= 8-token unit)
so legitimately repetitive text — numbered lists, ``矿工 A/B/C`` enumerations,
structured code — is never trimmed. Returns 0 when no runaway is present."""
n = len(ids)
for p in range(1, max_period + 1):
if n < p * min_reps:
continue
unit = list(ids[n - p:])
reps = 0
i = n
while i - p >= 0 and list(ids[i - p:i]) == unit:
reps += 1
i -= p
if reps >= min_reps:
return max((reps - keep_reps) * p, 0)
return 0


# --------------------------------------------------------------------------- #
# The fused spec-decode loop (control flow; MLX/torch ops via injected fns).
# --------------------------------------------------------------------------- #
Expand All @@ -672,6 +723,7 @@ def fused_specdecode_generate(
arange_fn: Callable[[int, int], Any],
cat_aux_fn: Callable[[Sequence[Any]], Any],
allow_greedy_fallback: bool = True,
stop_on_runaway: bool = True,
) -> Dict[str, Any]:
"""Run the fused engine. ``adapter`` must already be prefilled. Per block:
draft from the cached drafter context (B), verify+capture-aux incrementally
Expand Down Expand Up @@ -700,6 +752,7 @@ def fused_specdecode_generate(
generated: List[int] = []
accepts: List[int] = []
fallback_to_greedy = False
stopped_on_runaway = False
try:
while len(generated) < gen_tokens:
L = min(block_size, gen_tokens - len(generated))
Expand Down Expand Up @@ -774,6 +827,17 @@ def fused_specdecode_generate(
accepts.append(accepted)
if any(t in eos for t in commit):
break
# Greedy decoding can collapse into a runaway short-period loop (e.g.
# the **/.2/* markdown-marker wall on code prompts); the drafter then
# trivially predicts the repeats and the greedy verifier accepts them,
# so acceptance stays HIGH while the output is garbage. Stop on it
# instead of emitting an unbounded wall (keeps a short clean tail).
if stop_on_runaway:
drop = _trailing_runaway_drop(generated)
if drop > 0:
del generated[len(generated) - drop:]
stopped_on_runaway = True
break
if (allow_greedy_fallback and len(accepts) >= 2
and (sum(accepts) / len(accepts)) < 1.5):
fallback_to_greedy = True
Expand All @@ -791,6 +855,12 @@ def fused_specdecode_generate(
generated.append(tok)
if tok in eos:
break
if stop_on_runaway:
drop = _trailing_runaway_drop(generated)
if drop > 0:
del generated[len(generated) - drop:]
stopped_on_runaway = True
break
timing["fallback_greedy_s"] += time.perf_counter() - t_fb
finally:
adapter._capture_aux = False
Expand All @@ -801,5 +871,6 @@ def fused_specdecode_generate(
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
if accepts else 0.0),
"decode_tokens": len(generated),
"stopped_on_runaway": stopped_on_runaway,
"time_breakdown_s": {k: round(v, 3) for k, v in timing.items()},
}
57 changes: 57 additions & 0 deletions inference_engine/bridge/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,63 @@ def _harness_preset(
params={"max_new_tokens": ("int:max_new_tokens", "64")},
validate_reports=True, # §4 liveness gate on-device
),
Preset(
name="mlx-kakeya-codegen-degen-probe",
description="Regression probe (guard DISABLED): full f_θ fused engine "
"on the multi-turn 'explain PoW || write PoW in C' chat "
"that originally degenerated, with --fused-no-loop-guard so "
"any greedy markdown-marker collapse is observable. Pairs "
"with mlx-kakeya-codegen-guard-validate (guard ENABLED) to "
"show the guard is what keeps the answer clean. On current "
"code (post wrap-fix) both turns stay coherent.",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
"--chat", "--fused-no-loop-guard",
"--chat-scripted",
"请详细解释POW的工作原理||实现一个PoW的代码,用c语言完成",
"--output", "results/research/codegen_degen_2815_longprompt.json",
),
),
timeout_minutes=120,
params={"max_new_tokens": ("int:max_new_tokens", "900")},
validate_reports=False,
),
Preset(
name="mlx-kakeya-codegen-guard-validate",
description="Validate the runaway-loop guard end-to-end: full f_θ fused "
"engine on the multi-turn 'explain PoW || write PoW in C' "
"chat with the guard ENABLED (production default). The "
"answer must stay coherent and never collapse into a marker "
"wall — if a runaway starts, the guard stops it "
"(stopped_on_runaway) leaving a clean tail. Confirmed "
"coherent on current code; byte-identical to the guard-off "
"probe (the guard is inert on healthy output).",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
"--chat",
"--chat-scripted",
"请详细解释POW的工作原理||实现一个PoW的代码,用c语言完成",
"--output", "results/research/codegen_guard_validate_2815.json",
),
),
timeout_minutes=120,
params={"max_new_tokens": ("int:max_new_tokens", "900")},
validate_reports=False,
),
Preset(
name="mlx-kakeya-degen-probe",
description="Long-decode regression probe: full f_θ fused engine on a "
Expand Down
26 changes: 21 additions & 5 deletions scripts/research/k3_integrated_niah_eval_mac.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ def parse_args() -> argparse.Namespace:
ap.add_argument("--chat-scripted", default=None,
help="Non-interactive chat: '||'-separated user turns "
"(for Mac-bridge verification); writes a transcript.")
ap.add_argument("--chat-scripted-file", default=None,
help="Like --chat-scripted but reads the (possibly long, "
"'||'-separated) scripted prompt from a UTF-8 file. Lets "
"a long context be a committed fixture instead of a giant "
"manifest argv. Overrides --chat-scripted when set.")
ap.add_argument("--fused-no-loop-guard", action="store_true",
help="DIAGNOSTIC: disable the fused engine's runaway-loop stop "
"(default ON) so a degeneration probe can observe the full "
"collapse. Production chat keeps the guard enabled.")
ap.add_argument("--chat-native-ref", action="store_true",
help="DIAGNOSTIC opt-in: before each chat turn, also run a "
"plain NATIVE greedy AR decode of the SAME prompt for "
Expand Down Expand Up @@ -810,25 +819,28 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
evicted_positions=evicted,
prefill_chunk_size=args.prefill_chunk_size, full_kv=args.cuda_trim)
t0 = time.perf_counter()
_guard = not args.fused_no_loop_guard
if mlx_drafter is not None and args.cuda_trim:
res = fused_specdecode_generate_mlx_trim(
adapter, active_drafter, aux_prompt=aux_prompt,
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
gen_tokens=args.max_new_tokens, block_size=args.block_size,
eos_ids=chat_eos, single_fused=args.single_fused)
eos_ids=chat_eos, single_fused=args.single_fused,
stop_on_runaway=_guard)
elif mlx_drafter is not None:
res = fused_specdecode_generate_mlx(
adapter, active_drafter, aux_prompt=aux_prompt,
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
gen_tokens=args.max_new_tokens, block_size=args.block_size,
eos_ids=chat_eos)
eos_ids=chat_eos, stop_on_runaway=_guard)
else:
res = fused_specdecode_generate(
adapter, active_drafter, aux_prompt=aux_prompt,
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
gen_tokens=args.max_new_tokens, block_size=args.block_size,
eos_ids=chat_eos, argmax_fn=argmax_fn, arange_fn=arange_fn,
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False)
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False,
stop_on_runaway=_guard)
res["decode_s"] = round(time.perf_counter() - t0, 3)
res["f_theta_ran"] = f_theta_ran
res["f_theta_layers"] = sorted(rk.keys()) if rk else []
Expand Down Expand Up @@ -860,8 +872,12 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
file=sys.stderr, flush=True)

history: List[Dict[str, str]] = []
if args.chat_scripted is not None:
turns = [t for t in args.chat_scripted.split("||") if t.strip()]
scripted = args.chat_scripted
if args.chat_scripted_file is not None:
with open(args.chat_scripted_file, encoding="utf-8") as _f:
scripted = _f.read()
if scripted is not None:
turns = [t for t in scripted.split("||") if t.strip()]
transcript = []
for u in turns:
history.append({"role": "user", "content": u})
Expand Down
62 changes: 62 additions & 0 deletions tests/backends/mlx/test_fused_specdecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,68 @@ def __init__(self, offset):
self.max_size = None


def test_trailing_runaway_drop_detects_and_trims_loops():
# 1-token unit repeated 20x -> drop all but keep_reps (default 3).
ids = [1, 2, 3] + [9] * 20
drop = fsd._trailing_runaway_drop(ids)
assert drop == 17 # 20 - 3 kept
# multi-token unit (period 3) repeated 12x -> drop (12-3)*3 = 27.
ids2 = [5, 6] + [7, 8, 9] * 12
assert fsd._trailing_runaway_drop(ids2) == 27


def test_trailing_runaway_drop_is_conservative():
# fewer than min_reps (12) back-to-back -> no trim.
assert fsd._trailing_runaway_drop([9] * 11) == 0
# legitimate non-repeating tail -> no trim.
assert fsd._trailing_runaway_drop(list(range(40))) == 0
# a period that does not tile the very tail -> no trim.
assert fsd._trailing_runaway_drop([1, 2] * 10 + [3]) == 0
# empty / short -> no trim.
assert fsd._trailing_runaway_drop([]) == 0


def test_fused_loop_stops_on_runaway_repeat():
# Drafter keeps proposing the same token; the fake verifier's "+1" truth is
# defeated by making the bonus re-loop: we feed a drafter that always drafts
# the marker token and a verifier that greedily agrees, so the committed
# stream becomes a runaway single-token loop the guard must cut.
class _LoopAdapter(_FakeAdapter):
def forward_block(self, candidate):
# verifier greedily predicts the SAME marker token (42) forever.
if self._capture_aux:
L = len(candidate)
self._last_aux = [torch.zeros(L, self.hidden)]
return [42 for _ in candidate]

adapter = _LoopAdapter(prompt_len=5, first_token=42)
drafter = _FakeDrafter(drafts=[[42, 42, 42]] * 60)
res = fsd.fused_specdecode_generate(
adapter, drafter, gen_tokens=400, block_size=4, eos_ids=(),
allow_greedy_fallback=False, **_loop_kwargs(drafter))
assert res["stopped_on_runaway"] is True
# stopped early with a short clean tail, nowhere near the 400 budget.
assert len(res["tokens"]) < 40
assert set(res["tokens"]) == {42}


def test_fused_loop_runaway_guard_can_be_disabled():
class _LoopAdapter(_FakeAdapter):
def forward_block(self, candidate):
if self._capture_aux:
self._last_aux = [torch.zeros(len(candidate), self.hidden)]
return [42 for _ in candidate]

adapter = _LoopAdapter(prompt_len=5, first_token=42)
drafter = _FakeDrafter(drafts=[[42, 42, 42]] * 200)
res = fsd.fused_specdecode_generate(
adapter, drafter, gen_tokens=120, block_size=4, eos_ids=(),
allow_greedy_fallback=False, stop_on_runaway=False,
**_loop_kwargs(drafter))
assert res["stopped_on_runaway"] is False
assert len(res["tokens"]) == 120 # ran to the full budget


def test_sliding_ring_would_wrap_detects_wrap():
# offset + n_new >= max_size -> the rotating ring becomes non-trimmable.
cache = [_FakeRotating(offset=1022, max_size=1024)]
Expand Down
2 changes: 2 additions & 0 deletions tests/inference_engine/bridge/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def test_allowlist_contains_exactly_the_documented_presets():
"mlx-batched-pad-decode",
"mlx-env-probe",
"mlx-kakeya-chat-smoke",
"mlx-kakeya-codegen-degen-probe",
"mlx-kakeya-codegen-guard-validate",
"mlx-kakeya-degen-probe",
"mlx-kakeya-fused-chat-ftheta",
"mlx-kakeya-fused-chat-smoke",
Expand Down
Loading