diff --git a/inference_engine/backends/mlx/fused_specdecode.py b/inference_engine/backends/mlx/fused_specdecode.py index c14cfbc..3d44d79 100644 --- a/inference_engine/backends/mlx/fused_specdecode.py +++ b/inference_engine/backends/mlx/fused_specdecode.py @@ -35,7 +35,6 @@ restored_prefill_cache, ) - # --------------------------------------------------------------------------- # # Component A: capture verifier aux-layer hidden states (no transformers # `output_hidden_states` on MLX → patch the decoder-layer __call__). @@ -374,6 +373,7 @@ def fused_specdecode_generate_mlx_trim( block_size: int, eos_ids: Sequence[int] = (), single_fused: bool = False, + stop_on_runaway: bool = True, ) -> Dict[str, Any]: """CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected tail (no rollback, no carry re-forward). Requires the adapter to be @@ -399,6 +399,7 @@ def fused_specdecode_generate_mlx_trim( generated: List[int] = [] accepts: List[int] = [] block_evals: List[float] = [] + stopped_on_runaway = False ctx_len = C try: while len(generated) < gen_tokens: @@ -460,6 +461,12 @@ def fused_specdecode_generate_mlx_trim( timing["extend_s"] += time.perf_counter() - t_extend if any(t in eos for t in commit): break + if stop_on_runaway: + drop = _trailing_runaway_drop(generated) + if drop > 0: + del generated[len(generated) - drop:] + stopped_on_runaway = True + break finally: adapter._capture_aux = False generated = generated[:gen_tokens] @@ -469,6 +476,7 @@ def fused_specdecode_generate_mlx_trim( "mean_accept_len": (round(sum(accepts) / len(accepts), 3) if accepts else 0.0), "decode_tokens": len(generated), + "stopped_on_runaway": stopped_on_runaway, "loop": ("mlx_trim_single_fused_probe" if single_fused else "mlx_trim_keep_accepted_cuda_parity"), "single_fused": bool(single_fused), @@ -490,6 +498,7 @@ def fused_specdecode_generate_mlx( gen_tokens: int, block_size: int, eos_ids: Sequence[int] = (), + stop_on_runaway: bool = True, ) -> Dict[str, Any]: """All-MLX fused spec decode with ONE host sync per block. @@ -531,6 +540,7 @@ def fused_specdecode_generate_mlx( generated: List[int] = [] accepts: List[int] = [] + stopped_on_runaway = False # Rollback-carry state: rejected blocks roll the WHOLE forward back # (rollback_block — see its docstring for why trim is unsound on the # wrapped sliding ring) and carry the stream-committed-but-not-cached @@ -614,6 +624,12 @@ def fused_specdecode_generate_mlx( timing["extend_s"] += time.perf_counter() - t_extend if any(t in eos for t in commit): break + if stop_on_runaway: + drop = _trailing_runaway_drop(generated) + if drop > 0: + del generated[len(generated) - drop:] + stopped_on_runaway = True + break finally: adapter._capture_aux = False generated = generated[:gen_tokens] @@ -623,6 +639,7 @@ def fused_specdecode_generate_mlx( "mean_accept_len": (round(sum(accepts) / len(accepts), 3) if accepts else 0.0), "decode_tokens": len(generated), + "stopped_on_runaway": stopped_on_runaway, "loop": "mlx_rollback_carry_v3", "time_breakdown_s": {k: round(v, 3) for k, v in timing.items()}, } @@ -655,6 +672,40 @@ def _sliding_ring_would_wrap(cache: Any, n_new: int) -> bool: return False +def _trailing_runaway_drop( + ids: Sequence[int], + *, + max_period: int = 8, + min_reps: int = 12, + keep_reps: int = 3, +) -> int: + """Return how many TRAILING tokens to drop if ``ids`` ends in a runaway + short-period loop, else 0. + + A runaway loop is a unit of ``1..max_period`` tokens repeated ``>= min_reps`` + times back-to-back at the tail (e.g. the ``**``/``.2``/``*`` markdown-marker + collapse greedy decoding falls into on code prompts). When found, we keep + ``keep_reps`` instances and drop the rest, so callers can stop generation + with a clean tail instead of emitting an unbounded wall of repeats. + + Deliberately CONSERVATIVE (>= 12 back-to-back repeats of a <= 8-token unit) + so legitimately repetitive text — numbered lists, ``矿工 A/B/C`` enumerations, + structured code — is never trimmed. Returns 0 when no runaway is present.""" + n = len(ids) + for p in range(1, max_period + 1): + if n < p * min_reps: + continue + unit = list(ids[n - p:]) + reps = 0 + i = n + while i - p >= 0 and list(ids[i - p:i]) == unit: + reps += 1 + i -= p + if reps >= min_reps: + return max((reps - keep_reps) * p, 0) + return 0 + + # --------------------------------------------------------------------------- # # The fused spec-decode loop (control flow; MLX/torch ops via injected fns). # --------------------------------------------------------------------------- # @@ -672,6 +723,7 @@ def fused_specdecode_generate( arange_fn: Callable[[int, int], Any], cat_aux_fn: Callable[[Sequence[Any]], Any], allow_greedy_fallback: bool = True, + stop_on_runaway: bool = True, ) -> Dict[str, Any]: """Run the fused engine. ``adapter`` must already be prefilled. Per block: draft from the cached drafter context (B), verify+capture-aux incrementally @@ -700,6 +752,7 @@ def fused_specdecode_generate( generated: List[int] = [] accepts: List[int] = [] fallback_to_greedy = False + stopped_on_runaway = False try: while len(generated) < gen_tokens: L = min(block_size, gen_tokens - len(generated)) @@ -774,6 +827,17 @@ def fused_specdecode_generate( accepts.append(accepted) if any(t in eos for t in commit): break + # Greedy decoding can collapse into a runaway short-period loop (e.g. + # the **/.2/* markdown-marker wall on code prompts); the drafter then + # trivially predicts the repeats and the greedy verifier accepts them, + # so acceptance stays HIGH while the output is garbage. Stop on it + # instead of emitting an unbounded wall (keeps a short clean tail). + if stop_on_runaway: + drop = _trailing_runaway_drop(generated) + if drop > 0: + del generated[len(generated) - drop:] + stopped_on_runaway = True + break if (allow_greedy_fallback and len(accepts) >= 2 and (sum(accepts) / len(accepts)) < 1.5): fallback_to_greedy = True @@ -791,6 +855,12 @@ def fused_specdecode_generate( generated.append(tok) if tok in eos: break + if stop_on_runaway: + drop = _trailing_runaway_drop(generated) + if drop > 0: + del generated[len(generated) - drop:] + stopped_on_runaway = True + break timing["fallback_greedy_s"] += time.perf_counter() - t_fb finally: adapter._capture_aux = False @@ -801,5 +871,6 @@ def fused_specdecode_generate( "mean_accept_len": (round(sum(accepts) / len(accepts), 3) if accepts else 0.0), "decode_tokens": len(generated), + "stopped_on_runaway": stopped_on_runaway, "time_breakdown_s": {k: round(v, 3) for k, v in timing.items()}, } diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py index 8b67da7..fbcdbb0 100644 --- a/inference_engine/bridge/manifest.py +++ b/inference_engine/bridge/manifest.py @@ -771,6 +771,63 @@ def _harness_preset( params={"max_new_tokens": ("int:max_new_tokens", "64")}, validate_reports=True, # §4 liveness gate on-device ), + Preset( + name="mlx-kakeya-codegen-degen-probe", + description="Regression probe (guard DISABLED): full f_θ fused engine " + "on the multi-turn 'explain PoW || write PoW in C' chat " + "that originally degenerated, with --fused-no-loop-guard so " + "any greedy markdown-marker collapse is observable. Pairs " + "with mlx-kakeya-codegen-guard-validate (guard ENABLED) to " + "show the guard is what keeps the answer clean. On current " + "code (post wrap-fix) both turns stay coherent.", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta", + "--sink-size", "4", "--window-size", "64", "--block-size", "4", + "--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop", + "--chat", "--fused-no-loop-guard", + "--chat-scripted", + "请详细解释POW的工作原理||实现一个PoW的代码,用c语言完成", + "--output", "results/research/codegen_degen_2815_longprompt.json", + ), + ), + timeout_minutes=120, + params={"max_new_tokens": ("int:max_new_tokens", "900")}, + validate_reports=False, + ), + Preset( + name="mlx-kakeya-codegen-guard-validate", + description="Validate the runaway-loop guard end-to-end: full f_θ fused " + "engine on the multi-turn 'explain PoW || write PoW in C' " + "chat with the guard ENABLED (production default). The " + "answer must stay coherent and never collapse into a marker " + "wall — if a runaway starts, the guard stops it " + "(stopped_on_runaway) leaving a clean tail. Confirmed " + "coherent on current code; byte-identical to the guard-off " + "probe (the guard is inert on healthy output).", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta", + "--sink-size", "4", "--window-size", "64", "--block-size", "4", + "--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop", + "--chat", + "--chat-scripted", + "请详细解释POW的工作原理||实现一个PoW的代码,用c语言完成", + "--output", "results/research/codegen_guard_validate_2815.json", + ), + ), + timeout_minutes=120, + params={"max_new_tokens": ("int:max_new_tokens", "900")}, + validate_reports=False, + ), Preset( name="mlx-kakeya-degen-probe", description="Long-decode regression probe: full f_θ fused engine on a " diff --git a/scripts/research/k3_integrated_niah_eval_mac.py b/scripts/research/k3_integrated_niah_eval_mac.py index ecde21d..a6fc2eb 100644 --- a/scripts/research/k3_integrated_niah_eval_mac.py +++ b/scripts/research/k3_integrated_niah_eval_mac.py @@ -180,6 +180,15 @@ def parse_args() -> argparse.Namespace: ap.add_argument("--chat-scripted", default=None, help="Non-interactive chat: '||'-separated user turns " "(for Mac-bridge verification); writes a transcript.") + ap.add_argument("--chat-scripted-file", default=None, + help="Like --chat-scripted but reads the (possibly long, " + "'||'-separated) scripted prompt from a UTF-8 file. Lets " + "a long context be a committed fixture instead of a giant " + "manifest argv. Overrides --chat-scripted when set.") + ap.add_argument("--fused-no-loop-guard", action="store_true", + help="DIAGNOSTIC: disable the fused engine's runaway-loop stop " + "(default ON) so a degeneration probe can observe the full " + "collapse. Production chat keeps the guard enabled.") ap.add_argument("--chat-native-ref", action="store_true", help="DIAGNOSTIC opt-in: before each chat turn, also run a " "plain NATIVE greedy AR decode of the SAME prompt for " @@ -810,25 +819,28 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]: evicted_positions=evicted, prefill_chunk_size=args.prefill_chunk_size, full_kv=args.cuda_trim) t0 = time.perf_counter() + _guard = not args.fused_no_loop_guard if mlx_drafter is not None and args.cuda_trim: res = fused_specdecode_generate_mlx_trim( adapter, active_drafter, aux_prompt=aux_prompt, embed_fn=embed_fn, lm_head_fn=lm_head_fn, gen_tokens=args.max_new_tokens, block_size=args.block_size, - eos_ids=chat_eos, single_fused=args.single_fused) + eos_ids=chat_eos, single_fused=args.single_fused, + stop_on_runaway=_guard) elif mlx_drafter is not None: res = fused_specdecode_generate_mlx( adapter, active_drafter, aux_prompt=aux_prompt, embed_fn=embed_fn, lm_head_fn=lm_head_fn, gen_tokens=args.max_new_tokens, block_size=args.block_size, - eos_ids=chat_eos) + eos_ids=chat_eos, stop_on_runaway=_guard) else: res = fused_specdecode_generate( adapter, active_drafter, aux_prompt=aux_prompt, embed_fn=embed_fn, lm_head_fn=lm_head_fn, gen_tokens=args.max_new_tokens, block_size=args.block_size, eos_ids=chat_eos, argmax_fn=argmax_fn, arange_fn=arange_fn, - cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False) + cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False, + stop_on_runaway=_guard) res["decode_s"] = round(time.perf_counter() - t0, 3) res["f_theta_ran"] = f_theta_ran res["f_theta_layers"] = sorted(rk.keys()) if rk else [] @@ -860,8 +872,12 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]: file=sys.stderr, flush=True) history: List[Dict[str, str]] = [] - if args.chat_scripted is not None: - turns = [t for t in args.chat_scripted.split("||") if t.strip()] + scripted = args.chat_scripted + if args.chat_scripted_file is not None: + with open(args.chat_scripted_file, encoding="utf-8") as _f: + scripted = _f.read() + if scripted is not None: + turns = [t for t in scripted.split("||") if t.strip()] transcript = [] for u in turns: history.append({"role": "user", "content": u}) diff --git a/tests/backends/mlx/test_fused_specdecode.py b/tests/backends/mlx/test_fused_specdecode.py index ddf099b..f9c37a4 100644 --- a/tests/backends/mlx/test_fused_specdecode.py +++ b/tests/backends/mlx/test_fused_specdecode.py @@ -170,6 +170,68 @@ def __init__(self, offset): self.max_size = None +def test_trailing_runaway_drop_detects_and_trims_loops(): + # 1-token unit repeated 20x -> drop all but keep_reps (default 3). + ids = [1, 2, 3] + [9] * 20 + drop = fsd._trailing_runaway_drop(ids) + assert drop == 17 # 20 - 3 kept + # multi-token unit (period 3) repeated 12x -> drop (12-3)*3 = 27. + ids2 = [5, 6] + [7, 8, 9] * 12 + assert fsd._trailing_runaway_drop(ids2) == 27 + + +def test_trailing_runaway_drop_is_conservative(): + # fewer than min_reps (12) back-to-back -> no trim. + assert fsd._trailing_runaway_drop([9] * 11) == 0 + # legitimate non-repeating tail -> no trim. + assert fsd._trailing_runaway_drop(list(range(40))) == 0 + # a period that does not tile the very tail -> no trim. + assert fsd._trailing_runaway_drop([1, 2] * 10 + [3]) == 0 + # empty / short -> no trim. + assert fsd._trailing_runaway_drop([]) == 0 + + +def test_fused_loop_stops_on_runaway_repeat(): + # Drafter keeps proposing the same token; the fake verifier's "+1" truth is + # defeated by making the bonus re-loop: we feed a drafter that always drafts + # the marker token and a verifier that greedily agrees, so the committed + # stream becomes a runaway single-token loop the guard must cut. + class _LoopAdapter(_FakeAdapter): + def forward_block(self, candidate): + # verifier greedily predicts the SAME marker token (42) forever. + if self._capture_aux: + L = len(candidate) + self._last_aux = [torch.zeros(L, self.hidden)] + return [42 for _ in candidate] + + adapter = _LoopAdapter(prompt_len=5, first_token=42) + drafter = _FakeDrafter(drafts=[[42, 42, 42]] * 60) + res = fsd.fused_specdecode_generate( + adapter, drafter, gen_tokens=400, block_size=4, eos_ids=(), + allow_greedy_fallback=False, **_loop_kwargs(drafter)) + assert res["stopped_on_runaway"] is True + # stopped early with a short clean tail, nowhere near the 400 budget. + assert len(res["tokens"]) < 40 + assert set(res["tokens"]) == {42} + + +def test_fused_loop_runaway_guard_can_be_disabled(): + class _LoopAdapter(_FakeAdapter): + def forward_block(self, candidate): + if self._capture_aux: + self._last_aux = [torch.zeros(len(candidate), self.hidden)] + return [42 for _ in candidate] + + adapter = _LoopAdapter(prompt_len=5, first_token=42) + drafter = _FakeDrafter(drafts=[[42, 42, 42]] * 200) + res = fsd.fused_specdecode_generate( + adapter, drafter, gen_tokens=120, block_size=4, eos_ids=(), + allow_greedy_fallback=False, stop_on_runaway=False, + **_loop_kwargs(drafter)) + assert res["stopped_on_runaway"] is False + assert len(res["tokens"]) == 120 # ran to the full budget + + def test_sliding_ring_would_wrap_detects_wrap(): # offset + n_new >= max_size -> the rotating ring becomes non-trimmable. cache = [_FakeRotating(offset=1022, max_size=1024)] diff --git a/tests/inference_engine/bridge/test_manifest.py b/tests/inference_engine/bridge/test_manifest.py index 31ce0ec..cfea538 100644 --- a/tests/inference_engine/bridge/test_manifest.py +++ b/tests/inference_engine/bridge/test_manifest.py @@ -81,6 +81,8 @@ def test_allowlist_contains_exactly_the_documented_presets(): "mlx-batched-pad-decode", "mlx-env-probe", "mlx-kakeya-chat-smoke", + "mlx-kakeya-codegen-degen-probe", + "mlx-kakeya-codegen-guard-validate", "mlx-kakeya-degen-probe", "mlx-kakeya-fused-chat-ftheta", "mlx-kakeya-fused-chat-smoke",