Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions inference_engine/backends/mlx/fused_specdecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,18 @@ def make_full_kv_prompt_cache(mlx_model: Any) -> List[Any]:
return [KVCache() for _ in range(n)]


def _emit(on_commit: Optional[Callable[[List[int]], None]],
generated: List[int]) -> None:
"""Invoke a streaming callback with the tokens committed so far, swallowing
any error so token streaming can never break generation."""
if on_commit is None:
return
try:
on_commit(list(generated))
except Exception: # pragma: no cover - streaming must never break decode
pass


def fused_specdecode_generate_mlx_trim(
adapter: "MLXRestoredIncrementalVerifier",
drafter: Any,
Expand All @@ -374,6 +386,7 @@ def fused_specdecode_generate_mlx_trim(
block_size: int,
eos_ids: Sequence[int] = (),
single_fused: bool = False,
on_commit: Optional[Callable[[List[int]], None]] = None,
) -> Dict[str, Any]:
"""CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected
tail (no rollback, no carry re-forward). Requires the adapter to be
Expand Down Expand Up @@ -440,6 +453,7 @@ def fused_specdecode_generate_mlx_trim(
commit = check[:accepted]
generated += commit
accepts.append(accepted)
_emit(on_commit, generated)
adapter.next_token_logits = next_row
aux_rows = adapter._last_aux_mx
# KEEP accepted (positions base..base+accepted-1), TRIM rejected.
Expand Down Expand Up @@ -490,6 +504,7 @@ def fused_specdecode_generate_mlx(
gen_tokens: int,
block_size: int,
eos_ids: Sequence[int] = (),
on_commit: Optional[Callable[[List[int]], None]] = None,
) -> Dict[str, Any]:
"""All-MLX fused spec decode with ONE host sync per block.

Expand Down Expand Up @@ -587,6 +602,7 @@ def fused_specdecode_generate_mlx(
commit = check[:accepted]
generated += commit
accepts.append(accepted)
_emit(on_commit, generated)
tail_logits = next_row
adapter.next_token_logits = next_row
aux_rows = adapter._last_aux_mx # rows for positions base_fwd..base_fwd+k+L
Expand Down Expand Up @@ -672,6 +688,7 @@ def fused_specdecode_generate(
arange_fn: Callable[[int, int], Any],
cat_aux_fn: Callable[[Sequence[Any]], Any],
allow_greedy_fallback: bool = True,
on_commit: Optional[Callable[[List[int]], None]] = None,
) -> Dict[str, Any]:
"""Run the fused engine. ``adapter`` must already be prefilled. Per block:
draft from the cached drafter context (B), verify+capture-aux incrementally
Expand Down Expand Up @@ -772,6 +789,7 @@ def fused_specdecode_generate(
commit = candidate[:accepted] + [correction]
generated += commit
accepts.append(accepted)
_emit(on_commit, generated)
if any(t in eos for t in commit):
break
if (allow_greedy_fallback and len(accepts) >= 2
Expand All @@ -789,6 +807,7 @@ def fused_specdecode_generate(
tok = int(argmax_fn(adapter.next_token_logits))
adapter.append_token(tok)
generated.append(tok)
_emit(on_commit, generated)
if tok in eos:
break
timing["fallback_greedy_s"] += time.perf_counter() - t_fb
Expand Down
26 changes: 26 additions & 0 deletions inference_engine/bridge/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,32 @@ def _harness_preset(
},
validate_reports=True, # §4 liveness gate: asserts f_theta_ran on-device
),
Preset(
name="mlx-kakeya-chat-stream-probe",
description="Reproduce + validate the 'CLI looks frozen on a code "
"prompt' report: full f_θ chat on the user's exact prompt "
"(根据pow的机制,给出完整的c代码实现). With token streaming the log "
"shows incremental '[stream] blk=.. t=..s' lines as tokens "
"commit (proving the engine is generating, not deadlocked) "
"and the answer text builds up over time rather than after "
"a long silence.",
command_templates=(
(
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
"--chat", "--chat-scripted", "根据pow的机制,给出完整的c代码实现",
"--output", "results/research/chat_stream_probe_2815.json",
),
),
timeout_minutes=90,
params={"max_new_tokens": ("int:max_new_tokens", "200")},
validate_reports=False,
),
Preset(
name="mlx-kakeya-launcher-smoke",
description="Verify the one-command local launcher "
Expand Down
49 changes: 42 additions & 7 deletions scripts/research/k3_integrated_niah_eval_mac.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def _encode_chat(history: List[Dict[str, str]]) -> List[int]:
history, add_generation_prompt=True)
return list(cids.tolist() if hasattr(cids, "tolist") else cids)

def _gen_turn(pid: List[int]) -> Dict[str, Any]:
def _gen_turn(pid: List[int], on_commit=None) -> Dict[str, Any]:
# Opt-in A/B control (--chat-native-ref): a plain NATIVE greedy
# AR decode of the SAME prompt for --max-new-tokens. Captured as
# res["native_ref_text"] so the fused answer can be compared
Expand Down Expand Up @@ -815,20 +815,22 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
adapter, active_drafter, aux_prompt=aux_prompt,
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
gen_tokens=args.max_new_tokens, block_size=args.block_size,
eos_ids=chat_eos, single_fused=args.single_fused)
eos_ids=chat_eos, single_fused=args.single_fused,
on_commit=on_commit)
elif mlx_drafter is not None:
res = fused_specdecode_generate_mlx(
adapter, active_drafter, aux_prompt=aux_prompt,
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
gen_tokens=args.max_new_tokens, block_size=args.block_size,
eos_ids=chat_eos)
eos_ids=chat_eos, on_commit=on_commit)
else:
res = fused_specdecode_generate(
adapter, active_drafter, aux_prompt=aux_prompt,
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
gen_tokens=args.max_new_tokens, block_size=args.block_size,
eos_ids=chat_eos, argmax_fn=argmax_fn, arange_fn=arange_fn,
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False)
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False,
on_commit=on_commit)
res["decode_s"] = round(time.perf_counter() - t0, 3)
res["f_theta_ran"] = f_theta_ran
res["f_theta_layers"] = sorted(rk.keys()) if rk else []
Expand All @@ -853,6 +855,32 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
sum(int(getattr(c, "nbytes", 0)) for c in (adapter._cache or [])))
return res

def _make_stream_cb(to_stdout: bool):
"""on_commit callback: incrementally decode the committed tokens
and (interactive) print the new delta to stdout so the user sees
the answer build LIVE instead of waiting for the whole generation.
Always logs a per-block timing line to stderr (proves streaming /
rules out a hang)."""
st = {"chars": 0, "blk": 0, "t0": time.perf_counter()}

def cb(toks: List[int]) -> None:
st["blk"] += 1
try:
txt = tokenizer.decode(toks, skip_special_tokens=True)
except TypeError:
txt = tokenizer.decode(toks)
if to_stdout:
delta = txt[st["chars"]:]
if delta:
sys.stdout.write(delta)
sys.stdout.flush()
st["chars"] = len(txt)
sys.stderr.write(
f"[stream] blk={st['blk']} tok={len(toks)} "
f"t={time.perf_counter() - st['t0']:.1f}s\n")
sys.stderr.flush()
return cb

print(f"[chat] FULL fused engine: verifier={args.verifier_path} "
f"drafter={args.drafter_id} f_theta={args.f_theta_dir} "
f"S5 sink={args.sink_size} window={args.window_size} "
Expand All @@ -865,7 +893,8 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
transcript = []
for u in turns:
history.append({"role": "user", "content": u})
res = _gen_turn(_encode_chat(history))
res = _gen_turn(_encode_chat(history),
on_commit=_make_stream_cb(to_stdout=False))
history.append({"role": "assistant", "content": res["text"]})
tps = (res["decode_tokens"] / res["decode_s"]
if res["decode_s"] > 0 else 0.0)
Expand Down Expand Up @@ -926,11 +955,17 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
if not u:
break
history.append({"role": "user", "content": u})
res = _gen_turn(_encode_chat(history))
# Stream the answer LIVE so the terminal shows progress as tokens
# commit (the f_θ path is slow; without this the CLI looks frozen
# for minutes on long answers like code generation).
sys.stdout.write("gemma-4> ")
sys.stdout.flush()
res = _gen_turn(_encode_chat(history),
on_commit=_make_stream_cb(to_stdout=True))
history.append({"role": "assistant", "content": res["text"]})
tps = (res["decode_tokens"] / res["decode_s"]
if res["decode_s"] > 0 else 0.0)
sys.stdout.write("gemma-4> " + res["text"] + "\n")
sys.stdout.write("\n")
sys.stdout.flush()
print(f"[chat] blocks={res['blocks']} accept_len="
f"{res['mean_accept_len']} {round(tps,2)} tok/s "
Expand Down
1 change: 1 addition & 0 deletions tests/inference_engine/bridge/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_allowlist_contains_exactly_the_documented_presets():
"mlx-batched-pad-decode",
"mlx-env-probe",
"mlx-kakeya-chat-smoke",
"mlx-kakeya-chat-stream-probe",
"mlx-kakeya-degen-probe",
"mlx-kakeya-fused-chat-ftheta",
"mlx-kakeya-fused-chat-smoke",
Expand Down
Loading