Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion inference_engine/bridge/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,8 @@ def _harness_preset(
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
"--chat", "--chat-scripted", "根据pow的机制,给出完整的c代码实现",
"--chat", "--chat-stream-stdout",
"--chat-scripted", "根据pow的机制,给出完整的c代码实现",
"--output", "results/research/chat_stream_probe_2815.json",
),
),
Expand Down
29 changes: 23 additions & 6 deletions scripts/research/k3_integrated_niah_eval_mac.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def parse_args() -> argparse.Namespace:
ap.add_argument("--chat-scripted", default=None,
help="Non-interactive chat: '||'-separated user turns "
"(for Mac-bridge verification); writes a transcript.")
ap.add_argument("--chat-stream-stdout", action="store_true",
help="In scripted chat, stream the clean answer delta to "
"stdout (as the interactive CLI does) instead of the "
"per-block [stream] timing lines — lets a non-tty bridge "
"run capture the exact live output format.")
ap.add_argument("--chat-native-ref", action="store_true",
help="DIAGNOSTIC opt-in: before each chat turn, also run a "
"plain NATIVE greedy AR decode of the SAME prompt for "
Expand Down Expand Up @@ -870,15 +875,21 @@ def cb(toks: List[int]) -> None:
except TypeError:
txt = tokenizer.decode(toks)
if to_stdout:
# Interactive: emit ONLY the clean answer delta to stdout.
# (No per-block progress line here — stderr would interleave
# with the streamed text in the terminal and mangle it.)
delta = txt[st["chars"]:]
if delta:
sys.stdout.write(delta)
sys.stdout.flush()
st["chars"] = len(txt)
sys.stderr.write(
f"[stream] blk={st['blk']} tok={len(toks)} "
f"t={time.perf_counter() - st['t0']:.1f}s\n")
sys.stderr.flush()
else:
# Non-interactive (bridge/scripted): timing-only progress to
# stderr (proves streaming / liveness in the captured log).
sys.stderr.write(
f"[stream] blk={st['blk']} tok={len(toks)} "
f"t={time.perf_counter() - st['t0']:.1f}s\n")
sys.stderr.flush()
return cb

print(f"[chat] FULL fused engine: verifier={args.verifier_path} "
Expand All @@ -893,8 +904,14 @@ def cb(toks: List[int]) -> None:
transcript = []
for u in turns:
history.append({"role": "user", "content": u})
res = _gen_turn(_encode_chat(history),
on_commit=_make_stream_cb(to_stdout=False))
if args.chat_stream_stdout:
sys.stdout.write(f"\ngemma-4 [{u[:24]}]> ")
sys.stdout.flush()
res = _gen_turn(_encode_chat(history), on_commit=_make_stream_cb(
to_stdout=args.chat_stream_stdout))
if args.chat_stream_stdout:
sys.stdout.write("\n")
sys.stdout.flush()
history.append({"role": "assistant", "content": res["text"]})
tps = (res["decode_tokens"] / res["decode_s"]
if res["decode_s"] > 0 else 0.0)
Expand Down
Loading