diff --git a/inference_engine/backends/mlx/fused_specdecode.py b/inference_engine/backends/mlx/fused_specdecode.py index c14cfbc..9b6b1cf 100644 --- a/inference_engine/backends/mlx/fused_specdecode.py +++ b/inference_engine/backends/mlx/fused_specdecode.py @@ -363,6 +363,18 @@ def make_full_kv_prompt_cache(mlx_model: Any) -> List[Any]: return [KVCache() for _ in range(n)] +def _emit(on_commit: Optional[Callable[[List[int]], None]], + generated: List[int]) -> None: + """Invoke a streaming callback with the tokens committed so far, swallowing + any error so token streaming can never break generation.""" + if on_commit is None: + return + try: + on_commit(list(generated)) + except Exception: # pragma: no cover - streaming must never break decode + pass + + def fused_specdecode_generate_mlx_trim( adapter: "MLXRestoredIncrementalVerifier", drafter: Any, @@ -374,6 +386,7 @@ def fused_specdecode_generate_mlx_trim( block_size: int, eos_ids: Sequence[int] = (), single_fused: bool = False, + on_commit: Optional[Callable[[List[int]], None]] = None, ) -> Dict[str, Any]: """CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected tail (no rollback, no carry re-forward). Requires the adapter to be @@ -440,6 +453,7 @@ def fused_specdecode_generate_mlx_trim( commit = check[:accepted] generated += commit accepts.append(accepted) + _emit(on_commit, generated) adapter.next_token_logits = next_row aux_rows = adapter._last_aux_mx # KEEP accepted (positions base..base+accepted-1), TRIM rejected. @@ -490,6 +504,7 @@ def fused_specdecode_generate_mlx( gen_tokens: int, block_size: int, eos_ids: Sequence[int] = (), + on_commit: Optional[Callable[[List[int]], None]] = None, ) -> Dict[str, Any]: """All-MLX fused spec decode with ONE host sync per block. @@ -587,6 +602,7 @@ def fused_specdecode_generate_mlx( commit = check[:accepted] generated += commit accepts.append(accepted) + _emit(on_commit, generated) tail_logits = next_row adapter.next_token_logits = next_row aux_rows = adapter._last_aux_mx # rows for positions base_fwd..base_fwd+k+L @@ -672,6 +688,7 @@ def fused_specdecode_generate( arange_fn: Callable[[int, int], Any], cat_aux_fn: Callable[[Sequence[Any]], Any], allow_greedy_fallback: bool = True, + on_commit: Optional[Callable[[List[int]], None]] = None, ) -> Dict[str, Any]: """Run the fused engine. ``adapter`` must already be prefilled. Per block: draft from the cached drafter context (B), verify+capture-aux incrementally @@ -772,6 +789,7 @@ def fused_specdecode_generate( commit = candidate[:accepted] + [correction] generated += commit accepts.append(accepted) + _emit(on_commit, generated) if any(t in eos for t in commit): break if (allow_greedy_fallback and len(accepts) >= 2 @@ -789,6 +807,7 @@ def fused_specdecode_generate( tok = int(argmax_fn(adapter.next_token_logits)) adapter.append_token(tok) generated.append(tok) + _emit(on_commit, generated) if tok in eos: break timing["fallback_greedy_s"] += time.perf_counter() - t_fb diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py index 8b67da7..a167efb 100644 --- a/inference_engine/bridge/manifest.py +++ b/inference_engine/bridge/manifest.py @@ -749,6 +749,32 @@ def _harness_preset( }, validate_reports=True, # §4 liveness gate: asserts f_theta_ran on-device ), + Preset( + name="mlx-kakeya-chat-stream-probe", + description="Reproduce + validate the 'CLI looks frozen on a code " + "prompt' report: full f_θ chat on the user's exact prompt " + "(根据pow的机制,给出完整的c代码实现). With token streaming the log " + "shows incremental '[stream] blk=.. t=..s' lines as tokens " + "commit (proving the engine is generating, not deadlocked) " + "and the answer text builds up over time rather than after " + "a long silence.", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta", + "--sink-size", "4", "--window-size", "64", "--block-size", "4", + "--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop", + "--chat", "--chat-scripted", "根据pow的机制,给出完整的c代码实现", + "--output", "results/research/chat_stream_probe_2815.json", + ), + ), + timeout_minutes=90, + params={"max_new_tokens": ("int:max_new_tokens", "200")}, + validate_reports=False, + ), Preset( name="mlx-kakeya-launcher-smoke", description="Verify the one-command local launcher " diff --git a/scripts/research/k3_integrated_niah_eval_mac.py b/scripts/research/k3_integrated_niah_eval_mac.py index ecde21d..a2e1651 100644 --- a/scripts/research/k3_integrated_niah_eval_mac.py +++ b/scripts/research/k3_integrated_niah_eval_mac.py @@ -769,7 +769,7 @@ def _encode_chat(history: List[Dict[str, str]]) -> List[int]: history, add_generation_prompt=True) return list(cids.tolist() if hasattr(cids, "tolist") else cids) - def _gen_turn(pid: List[int]) -> Dict[str, Any]: + def _gen_turn(pid: List[int], on_commit=None) -> Dict[str, Any]: # Opt-in A/B control (--chat-native-ref): a plain NATIVE greedy # AR decode of the SAME prompt for --max-new-tokens. Captured as # res["native_ref_text"] so the fused answer can be compared @@ -815,20 +815,22 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]: adapter, active_drafter, aux_prompt=aux_prompt, embed_fn=embed_fn, lm_head_fn=lm_head_fn, gen_tokens=args.max_new_tokens, block_size=args.block_size, - eos_ids=chat_eos, single_fused=args.single_fused) + eos_ids=chat_eos, single_fused=args.single_fused, + on_commit=on_commit) elif mlx_drafter is not None: res = fused_specdecode_generate_mlx( adapter, active_drafter, aux_prompt=aux_prompt, embed_fn=embed_fn, lm_head_fn=lm_head_fn, gen_tokens=args.max_new_tokens, block_size=args.block_size, - eos_ids=chat_eos) + eos_ids=chat_eos, on_commit=on_commit) else: res = fused_specdecode_generate( adapter, active_drafter, aux_prompt=aux_prompt, embed_fn=embed_fn, lm_head_fn=lm_head_fn, gen_tokens=args.max_new_tokens, block_size=args.block_size, eos_ids=chat_eos, argmax_fn=argmax_fn, arange_fn=arange_fn, - cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False) + cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False, + on_commit=on_commit) res["decode_s"] = round(time.perf_counter() - t0, 3) res["f_theta_ran"] = f_theta_ran res["f_theta_layers"] = sorted(rk.keys()) if rk else [] @@ -853,6 +855,32 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]: sum(int(getattr(c, "nbytes", 0)) for c in (adapter._cache or []))) return res + def _make_stream_cb(to_stdout: bool): + """on_commit callback: incrementally decode the committed tokens + and (interactive) print the new delta to stdout so the user sees + the answer build LIVE instead of waiting for the whole generation. + Always logs a per-block timing line to stderr (proves streaming / + rules out a hang).""" + st = {"chars": 0, "blk": 0, "t0": time.perf_counter()} + + def cb(toks: List[int]) -> None: + st["blk"] += 1 + try: + txt = tokenizer.decode(toks, skip_special_tokens=True) + except TypeError: + txt = tokenizer.decode(toks) + if to_stdout: + delta = txt[st["chars"]:] + if delta: + sys.stdout.write(delta) + sys.stdout.flush() + st["chars"] = len(txt) + sys.stderr.write( + f"[stream] blk={st['blk']} tok={len(toks)} " + f"t={time.perf_counter() - st['t0']:.1f}s\n") + sys.stderr.flush() + return cb + print(f"[chat] FULL fused engine: verifier={args.verifier_path} " f"drafter={args.drafter_id} f_theta={args.f_theta_dir} " f"S5 sink={args.sink_size} window={args.window_size} " @@ -865,7 +893,8 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]: transcript = [] for u in turns: history.append({"role": "user", "content": u}) - res = _gen_turn(_encode_chat(history)) + res = _gen_turn(_encode_chat(history), + on_commit=_make_stream_cb(to_stdout=False)) history.append({"role": "assistant", "content": res["text"]}) tps = (res["decode_tokens"] / res["decode_s"] if res["decode_s"] > 0 else 0.0) @@ -926,11 +955,17 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]: if not u: break history.append({"role": "user", "content": u}) - res = _gen_turn(_encode_chat(history)) + # Stream the answer LIVE so the terminal shows progress as tokens + # commit (the f_θ path is slow; without this the CLI looks frozen + # for minutes on long answers like code generation). + sys.stdout.write("gemma-4> ") + sys.stdout.flush() + res = _gen_turn(_encode_chat(history), + on_commit=_make_stream_cb(to_stdout=True)) history.append({"role": "assistant", "content": res["text"]}) tps = (res["decode_tokens"] / res["decode_s"] if res["decode_s"] > 0 else 0.0) - sys.stdout.write("gemma-4> " + res["text"] + "\n") + sys.stdout.write("\n") sys.stdout.flush() print(f"[chat] blocks={res['blocks']} accept_len=" f"{res['mean_accept_len']} {round(tps,2)} tok/s " diff --git a/tests/inference_engine/bridge/test_manifest.py b/tests/inference_engine/bridge/test_manifest.py index 31ce0ec..090f189 100644 --- a/tests/inference_engine/bridge/test_manifest.py +++ b/tests/inference_engine/bridge/test_manifest.py @@ -81,6 +81,7 @@ def test_allowlist_contains_exactly_the_documented_presets(): "mlx-batched-pad-decode", "mlx-env-probe", "mlx-kakeya-chat-smoke", + "mlx-kakeya-chat-stream-probe", "mlx-kakeya-degen-probe", "mlx-kakeya-fused-chat-ftheta", "mlx-kakeya-fused-chat-smoke",