diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 30907358..10ad2b74 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -95,16 +95,17 @@ jobs: tests/inference_engine/session/ \ tests/inference_engine/bench/ \ tests/inference_engine/setup/ \ + tests/inference_engine/bridge/ \ tests/sdk/python/ \ tests/training/repr_align/ \ tests/backends/mlx/test_env.py \ --junitxml=junit.xml \ -v coverage report \ - --include='inference_engine/server/auth.py,inference_engine/server/config.py,inference_engine/server/errors.py,inference_engine/server/grpc_app.py,inference_engine/server/metrics.py,inference_engine/server/schemas.py,inference_engine/server/proto_gen/**/*.py,inference_engine/memory/*,inference_engine/scheduler/config.py,inference_engine/scheduler/session.py,inference_engine/pipeline/*,inference_engine/session/store.py,inference_engine/setup/*,sdks/python/kakeya/__init__.py,sdks/python/kakeya/errors.py,training/repr_align/*' \ + --include='inference_engine/server/auth.py,inference_engine/server/config.py,inference_engine/server/errors.py,inference_engine/server/grpc_app.py,inference_engine/server/metrics.py,inference_engine/server/schemas.py,inference_engine/server/proto_gen/**/*.py,inference_engine/memory/*,inference_engine/bridge/*,inference_engine/scheduler/config.py,inference_engine/scheduler/session.py,inference_engine/pipeline/*,inference_engine/session/store.py,inference_engine/setup/*,sdks/python/kakeya/__init__.py,sdks/python/kakeya/errors.py,training/repr_align/*' \ --fail-under=100 coverage xml -o coverage.xml \ - --include='inference_engine/server/auth.py,inference_engine/server/config.py,inference_engine/server/errors.py,inference_engine/server/grpc_app.py,inference_engine/server/metrics.py,inference_engine/server/schemas.py,inference_engine/server/proto_gen/**/*.py,inference_engine/memory/*,inference_engine/scheduler/config.py,inference_engine/scheduler/session.py,inference_engine/pipeline/*,inference_engine/session/store.py,inference_engine/setup/*,sdks/python/kakeya/__init__.py,sdks/python/kakeya/errors.py,training/repr_align/*' + --include='inference_engine/server/auth.py,inference_engine/server/config.py,inference_engine/server/errors.py,inference_engine/server/grpc_app.py,inference_engine/server/metrics.py,inference_engine/server/schemas.py,inference_engine/server/proto_gen/**/*.py,inference_engine/memory/*,inference_engine/bridge/*,inference_engine/scheduler/config.py,inference_engine/scheduler/session.py,inference_engine/pipeline/*,inference_engine/session/store.py,inference_engine/setup/*,sdks/python/kakeya/__init__.py,sdks/python/kakeya/errors.py,training/repr_align/*' - name: Upload coverage artifact if: always() @@ -166,6 +167,8 @@ jobs: import kakeya.client; \ import kakeya.session; \ import kakeya.errors; \ + import inference_engine.bridge; \ + import inference_engine.bridge.manifest; \ import inference_engine.proposer; \ import inference_engine.proposer.sparse_logits; \ import inference_engine.backends.mlx.env; \ diff --git a/.github/workflows/mac-bridge.yaml b/.github/workflows/mac-bridge.yaml new file mode 100644 index 00000000..1eb13c64 --- /dev/null +++ b/.github/workflows/mac-bridge.yaml @@ -0,0 +1,142 @@ +name: Mac bridge + +# Git-bus executor for cloud-agent access to the self-hosted Apple +# Silicon node (docs/design/mac-bridge-cloud-agent-access.md §2.1). +# +# Protocol: an agent pushes a branch `mac-bridge/-` +# containing the workload tree + a manifest at .mac-bridge/request.json +# (created by scripts/mac_bridge/request_run.py). This workflow runs the +# manifest's ALLOWLISTED preset on the kakeya-mac-m4 runner and pushes +# logs + result JSONs back to the same branch, where the agent fetches +# them with plain git (and read-only `gh run list`). +# +# Security (design doc §3): +# * Command surface = the preset allowlist in +# inference_engine/bridge/manifest.py — typed, bounded params; no +# manifest string ever reaches a shell. Validation is unit-tested +# at 100% coverage on the Linux gate. +# * Trigger surface = push permission on mac-bridge/** — the same +# population that can already execute code on this runner via the +# `needs-mac-m4` PR label (integration.yaml). +# * The single Mac is serialized via the concurrency group; every +# preset carries its own timeout inside the executor and the job +# has a hard cap below. +# * K3 acceptance reports produced by a run are validated by the +# PR #109 evidence gate ON the runner; a non-conforming report +# fails the bridge run itself. + +on: + push: + branches: + # Canonical request namespace. + - "mac-bridge/**" + # Cursor cloud agents are typically constrained to an + # AgentMemory/[-suffix] branch template; this pattern lets + # them participate without violating their naming policy + # (request_run.py --branch-prefix/--branch-suffix). + - "AgentMemory/mac-bridge-*" + +concurrency: + # One Mac: queue bridge runs globally, never cancel a running one + # (results are expensive; the requester can cancel from the UI). + group: mac-bridge + cancel-in-progress: false + +permissions: + contents: write # commit logs/results back to the request branch + +jobs: + bridge: + name: run allowlisted preset on kakeya-mac-m4 + runs-on: [self-hosted, macOS, ARM64, kakeya-mac-m4] + timeout-minutes: 150 + steps: + - uses: actions/checkout@v4 + with: + # Push results back to the request branch. + persist-credentials: true + # k3-* presets load LFS-tracked checkpoints from the repo + # (e.g. results/research/f_theta_v5_s5_sliding/ + # f_theta_weights.pt). Without lfs:true the workspace holds + # pointer files and torch.load fails with the cryptic + # "Unsupported operand 118" (ASCII 'v' = the first byte of + # an LFS pointer). + lfs: true + + - name: Show request + run: | + echo "=== .mac-bridge/request.json ===" + cat .mac-bridge/request.json + + - name: Materialize LFS objects (deterministic) + # checkout@v4's lfs:true proved insufficient on a reused + # self-hosted workspace: a previous non-LFS checkout left + # pointer-content files in the worktree, the blob is unchanged + # on the new branch, so git skips re-smudging and the stale + # pointer survives (observed live: torch.load failing with + # "Unsupported operand 118" = ASCII 'v' of an LFS pointer). + # `git lfs pull` force-materializes; the guard fails fast if + # any tracked LFS file is still a pointer. + run: | + git lfs install --local + git lfs pull + bad="" + while IFS= read -r f; do + if [ -f "$f" ] && head -c 40 "$f" | grep -q "git-lfs"; then + bad="$bad $f" + fi + done < <(git lfs ls-files -n) + if [ -n "$bad" ]; then + echo "::error::LFS pointers not materialized:$bad" + exit 1 + fi + echo "all LFS objects materialized" + + - name: Run preset (allowlist-validated executor) + env: + PYTHONPATH: .:sdks/python + # Machine-local model locations come from the runner env, + # never from the manifest (docs/ops/mac-m4-runner-setup.md). + # Precedence: repo Actions variable > ~/kakeya-models/ + # (the documented stable symlink location on the runner) > + # repo-relative fallback. $HOME needs shell expansion, hence + # the export block instead of plain env defaults. + KAKEYA_MAC_VERIFIER_PATH_VAR: ${{ vars.KAKEYA_MAC_VERIFIER_PATH || '' }} + KAKEYA_MAC_DRAFTER_ID_VAR: ${{ vars.KAKEYA_MAC_DRAFTER_ID || '' }} + KAKEYA_MAC_FTHETA_DIR_VAR: ${{ vars.KAKEYA_MAC_FTHETA_DIR || '' }} + HF_HUB_OFFLINE: "1" + run: | + default_verifier="$HOME/kakeya-models/gemma-4-26B-A4B-it-mlx-4bit" + if [ ! -d "$default_verifier" ]; then + default_verifier="models/gemma-4-26B-A4B-it-mlx-4bit" + fi + export KAKEYA_MAC_VERIFIER_PATH="${KAKEYA_MAC_VERIFIER_PATH_VAR:-$default_verifier}" + export KAKEYA_MAC_DRAFTER_ID="${KAKEYA_MAC_DRAFTER_ID_VAR:-z-lab/gemma-4-26B-A4B-it-DFlash}" + export KAKEYA_MAC_FTHETA_DIR="${KAKEYA_MAC_FTHETA_DIR_VAR:-results/research/f_theta_v5_s5_sliding}" + echo "verifier=$KAKEYA_MAC_VERIFIER_PATH" + python3 scripts/mac_bridge/run_preset.py \ + --manifest .mac-bridge/request.json + + - name: Commit results back to the request branch + if: always() + run: | + git config user.name "kakeya-mac-bridge" + git config user.email "mac-bridge@users.noreply.github.com" + git add -A .mac-bridge/logs results/research 2>/dev/null || true + if git diff --cached --quiet; then + echo "no result files to commit" + else + git commit -m "mac-bridge results: ${GITHUB_REF_NAME}" + git push origin "HEAD:${GITHUB_REF_NAME}" + fi + + - name: Upload results as artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: mac-bridge-${{ github.run_id }} + path: | + .mac-bridge/logs/ + results/research/k3_mac_bridge_*.json + if-no-files-found: warn + retention-days: 14 diff --git a/.mac-bridge/logs/k3-step1-incremental-0.log b/.mac-bridge/logs/k3-step1-incremental-0.log new file mode 100644 index 00000000..fc797770 --- /dev/null +++ b/.mac-bridge/logs/k3-step1-incremental-0.log @@ -0,0 +1,23 @@ +[mac] loading MLX verifier /Users/fluffy314/kakeya-models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter z-lab/gemma-4-26B-A4B-it-DFlash on cpu +[mac] 5 samples, prompt len min=4406 max=5810 +[mac] running restored cross-model verifier (s5, free_gen_incremental) +[mac] incr 0: T=5810 prefill=25.4s decode=2.8s -> 'BETA-1409thought\nThe use' +[mac] incr 1: T=4911 prefill=19.7s decode=2.9s -> 'DELTA-3286\nthought\nThe u' +[mac] incr 2: T=5594 prefill=19.6s decode=2.9s -> 'BETA-7912\nthought\nThe us' +[mac] incr 3: T=4406 prefill=15.8s decode=2.8s -> 'BETA-4582thought\nThe use' +[mac] incr 4: T=5505 prefill=31.7s decode=4.9s -> 'KAPPA-1434\nthought\nThe u' +[mac] restored_cross_model recall = 1.000 (5/5) +[mac] running oracle +[mac] oracle 0: T=5810 -> 'BETA-1409thought\nThe use' +[mac] oracle 1: T=4911 -> 'DELTA-3286\nthought\nThe u' +[mac] oracle 2: T=5594 -> 'BETA-7912\nthought\nThe us' +[mac] oracle 3: T=4406 -> 'BETA-4582thought\nThe use' +[mac] oracle 4: T=5505 -> 'KAPPA-1434\nthought\nThe u' +[mac] oracle recall = 1.000 +[mac] KV resident @T=5810: S5=132.92 MB (growth 20.0 KB/tok); naive-full=1308.88 MB +[mac] cross-model throughput (free_gen_incremental): 2.49 tok/s (320 tok / 128.514 s, 25.703 s/sample) + +[mac] DONE. restored_cross_model=1.000 oracle=1.0 -> results/research/k3_mac_bridge_k3_step1_incremental.json +[mac] evidence gate: PASS diff --git a/.mac-bridge/logs/summary.json b/.mac-bridge/logs/summary.json new file mode 100644 index 00000000..784388d1 --- /dev/null +++ b/.mac-bridge/logs/summary.json @@ -0,0 +1,41 @@ +{ + "preset": "k3-step1-incremental", + "params": { + "n_samples": "5", + "max_new_tokens": "64", + "block_size": "4" + }, + "nonce": "1781268308-dc400e", + "commands": [ + { + "argv": [ + "python3", + "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", + "/Users/fluffy314/kakeya-models/gemma-4-26B-A4B-it-mlx-4bit", + "--drafter-id", + "z-lab/gemma-4-26B-A4B-it-DFlash", + "--f-theta-dir", + "results/research/f_theta_v5_s5_sliding", + "--s5-exact-full-attn", + "--incremental", + "--ignore-turn-stop", + "--n-samples", + "5", + "--max-new-tokens", + "64", + "--block-size", + "4", + "--prefill-chunk-size", + "512", + "--output", + "results/research/k3_mac_bridge_k3_step1_incremental.json" + ], + "exit_code": 0, + "seconds": 370.5, + "log": ".mac-bridge/logs/k3-step1-incremental-0.log" + } + ], + "evidence_gate_exit_code": 0, + "exit_code": 0 +} \ No newline at end of file diff --git a/README.md b/README.md index 82b7babd..4a563dc6 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,66 @@ the binding correctness gate. Mac M4 evidence on `main`: Raw artifacts: [`results/platform-tests/bench_session_4h_1780332893.json`](results/platform-tests/) (4-h evidence) and the v0.3.0 GA tag's smoke run committed at `6399546`. +## Kakeya Inference Engine for Mac — MLX speculative-decode port (K3 beta baseline) + +After the **CUDA** beta (PR #107: f_θ + S5 K/V-restoration verifier, **fused DFlash +spec-decode at 1.27× AR, recall 1.0 on Gemma-4-26B-A4B / H200**), the engine was +ported to the **Apple-Silicon MLX** backend. The decode throughput climbed from a +near-total collapse to **≈AR parity** through a sequence of precisely-diagnosed +fixes. This is the baseline record of that journey (all numbers are decode-only +tok/s vs the native `mlx_lm` AR oracle on the same model, measured on a Mac M4 via +the [Mac bridge](#evaluation-environment); ×AR is the ratio). + +| Stage | ×AR | Binding problem | Fix | +| --- | --- | --- | --- | +| Naïve restored decode | **~0.09×** | **O(T²) collapse** — the restored verifier did a *full-sequence* forward **per generated token** (`restored_logits`); the Mac harness called it once per token. | **Gap-A incremental decode**: prefill **once**, capture the restored K/V into the model's **native** cache, then decode with `mlx_lm.generate_step` (chunked prefill + `mx.async_eval` pipelined) — O(L)/token, never re-forward the sequence. | +| Hybrid fused spec-decode | **~0.2×** | **Cross-runtime bridge** — MLX verifier + PyTorch/MPS drafter shipped **MB/block of aux-hidden** across runtimes on the critical path; plus a benchmark **forced-over-generation** artifact (`--ignore-turn-stop`) that tanked acceptance. | Recognised the bridge as the bottleneck; moved toward an **all-MLX drafter** (single runtime, zero per-block bridge crossings). | +| All-MLX + sound rollback | **~0.5×** | **Unsound rollback** — `RotatingKVCache` is not trimmable once the sliding ring wraps (`is_trimmable → offset < max_size`), so the loop **rolled the whole block back and re-forwarded** the carried accepted tokens every partial-accept block (~2 verifier forwards/block). | **CUDA-`DynamicCache` parity**: prefill an **all-`KVCache`** layout (sliding too — byte-exact, the window mask applies regardless of cache capacity) so `trim_prompt_cache` is a sound O(1) slice; **keep accepted K/V, trim only the rejected tail**, never re-forward. | +| Block-4 CUDA-trim | **~0.7×** | **Per-block Python graph construction** (`build_s` ≈ 50 ms/block building the 26B lazy graph). | Removing the re-forward (above) closed most of it; block-4 lands at **0.68× AR**. | +| Block-8 tuned | **~1.0×** | **Block size vs the drafter's accept-len plateau.** | Tune to **block-8** (matches the all-MLX drafter's ~4.5 accept-len ceiling); long-code completions reach **~1.0–1.05× AR (parity, best samples just over)**. block-16 is *worse* — `verify(16)` cost is wasted because acceptance plateaus. | + +**Honest ceiling & what was *ruled out*.** ≈AR parity is the Mac result on the +spec-decode sweet spot (short-context, naturally-long *code/agent* generation); +**>AR meaningfully remains CUDA-favoured** (H200 1.27×) because the binding +constraint is the **26B `verify(L)` compute per block** — *not* rollback (fixed), +*not* sync count (a one-graph "single-fused" probe ran stably at ~0.16 s/block and +was ≈ equal — the b876 single-fused "143 s" pathology is **large-cache-specific**, +not fundamental), *not* drafter acceptance (a clean ~3–4.5/block on natural +workloads), *not* verifier quantization (4-bit ≥ bf16; the loop is self-consistent), +*not* context length (NIAH ≥ general), and *not* a missing alignment asset +(fc_norms fine-tuning *degraded* held-out acceptance — the base z-lab drafter is +already near its block-4 ceiling). The earlier "low acceptance / 2.13" numbers were +a **forced-over-generation benchmark artifact**, reproduced on a clean full-KV bf16 +verifier. The one genuine remaining lever is closing the **drafter accept-len gap +(~4.5 ours → ~7.7 z-lab reference)** — a port-fidelity / alignment residual. + +Recall (the architecture's primary deliverable) is **1.0** throughout, with +bounded resident KV (**S5**: ~133 MB vs ~1309 MB naïve at 5.8 k ctx, ~90 % saving; +~48 MB after affine-4). See [ADR 0012](docs/adr/0012-proposer-verifier-value-proposition.md) +(value is realised on the **memory axis** all-platform + **throughput** on CUDA) +and [ADR 0013](docs/adr/0013-distributed-inference-topology.md) (what AR +sequentiality allows for distribution). + +### Evaluation environment + +The Mac port was developed and benchmarked **remotely from a Linux cloud agent**, +since MLX runs only on Apple Silicon: + +- **Mac bridge** (`scripts/mac_bridge/`): a **git-bus** request/response plane — the + agent pushes an allowlisted-preset request branch, a **self-hosted GitHub Actions + runner (`kakeya-mac-m4`)** executes it on the Mac and pushes results back. No SSH/ + VPN — only git push. Presets + param bounds are enforced by + `inference_engine/bridge/manifest.py`; this is itself an instance of the + multi-host capability plane ([ADR 0009](docs/adr/0009-mlx-distributed-spec-decode-and-capability-exchange.md)). +- **Evidence gate** (`inference_engine/bench/k3_report_gate.py`): every Mac report is + machine-validated — rejects fused runs that didn't execute (`blocks=0`), baseline + bypasses claiming recall/speedup, self-comparison speedups, prefill-variance, and + decode-token-budget violations — so a number is admissible only if it survives the + same rules that caught the earlier artifacts. +- **GPU side** (vast.ai H200): alignment-training + acceptance-factor experiments + (`scripts/research/k3_dflash_alignment_train.py`, `k3_dflash_specdecode_eval.py`) + used to rule out the non-levers above. + ## SDKs ### Python — `sdks/python/kakeya` diff --git a/docs/design/mac-bridge-cloud-agent-access.md b/docs/design/mac-bridge-cloud-agent-access.md new file mode 100644 index 00000000..33207076 --- /dev/null +++ b/docs/design/mac-bridge-cloud-agent-access.md @@ -0,0 +1,272 @@ +# Design — Mac bridge: cloud-agent access to the self-hosted `kakeya-mac-m4` + +- **Status**: M1 implemented (git-bus transport); M2/M3 designed +- **Relates to**: ADR 0009 (multi-host plane), PR #105 (CapabilityService), + PR #109 evidence gate (`inference_engine/bench/k3_report_gate.py`), + [`docs/ops/mac-m4-runner-setup.md`](../ops/mac-m4-runner-setup.md) +- **Implementation**: [`inference_engine/bridge/`](../../inference_engine/bridge/), + [`scripts/mac_bridge/`](../../scripts/mac_bridge/), + [`.github/workflows/mac-bridge.yaml`](../../.github/workflows/mac-bridge.yaml) + +## 1. Problem + +Kakeya development now happens substantially through cloud agents running +on **Linux x86 VMs with no Metal**. Everything MLX-dependent — the MLX +verifier, `mlx.distributed`, the K3 Mac harness, the PR #109 evidence-gate +reruns — needs Apple Silicon. The project owns exactly one such machine: +the Mac mini registered as the self-hosted runner +`[self-hosted, macOS, ARM64, kakeya-mac-m4]`, sitting behind NAT with +**outbound-only** connectivity (the Actions runner long-polls GitHub). + +Constraints that shape the design: + +- **C1 — No inbound path to the Mac.** No public IP, no port forwarding. + Any transport must be initiated from the Mac side or relayed. +- **C2 — Cloud agents are ephemeral and git-native.** They reliably have: + a repo checkout, git push permission, and read-only `gh`. They do NOT + reliably have: VPN keys, SSH keys to the Mac, or workflow-dispatch + permission. +- **C3 — The Mac executes whatever lands on it.** A bridge that forwards + arbitrary shell from an internet-facing queue to a desk machine is a + remote-shell backdoor. Command surface must be an allowlist. +- **C4 — Evidence discipline.** Results coming back from the Mac must + flow through the PR #109 evidence gate, not around it. + +## 2. Architecture: three transports, one capability model + +``` +M1 (this PR) M2 (queued) M3 (queued) +┌─────────────────┐ ┌──────────────────┐ ┌──────────────────────┐ +│ git-bus │ │ tailnet SSH │ │ Kakeya fleet member │ +│ │ │ │ │ │ +│ agent ──push──► │ │ agent ──SSH──► │ │ agent ──gRPC──► │ +│ mac-bridge/* │ │ Mac (tailscaled)│ │ CapabilityService │ +│ branch+manifest│ │ interactive REPL│ │ ProposerService │ +│ Mac runner: │ │ lldb / py-spy / │ │ (ADR 0009 plane, │ +│ run preset, │ │ mlx debugging │ │ PR #105, over the │ +│ commit results │ │ │ │ M2 tailnet) │ +│ back to branch │ │ │ │ │ +└─────────────────┘ └──────────────────┘ └──────────────────────┘ + async, batch, interactive, programmatic, + zero new secrets needs TS authkey inference-native +``` + +### 2.1 M1 — git-bus (implemented) + +The only transport that satisfies C1+C2 with **zero new infrastructure**: +git is the RPC bus, the Actions runner is the executor, the branch is the +session. + +Protocol: + +1. **Request.** The agent runs `scripts/mac_bridge/request_run.py + --preset [--param k=v ...] [--ref ]`. The client: + - branches `mac-bridge/-` from the workload ref, + - overlays the bridge files if the ref predates them (workflow + + executor must exist on the pushed branch — `on: push` workflows + execute the pushed commit's definition), + - writes `.mac-bridge/request.json` (the manifest), commits, pushes. +2. **Execute.** `.github/workflows/mac-bridge.yaml` triggers on + `push: branches: ['mac-bridge/**']`, runs on `kakeya-mac-m4`, + serialized via a `mac-bridge` concurrency group (one Mac). It calls + `scripts/mac_bridge/run_preset.py --manifest .mac-bridge/request.json`, + which validates the manifest against the **preset allowlist** + (`inference_engine/bridge/manifest.py`) and executes the preset's + fixed argv list — no shell interpolation of any user-controlled + string (C3). +3. **Respond.** The runner commits `.mac-bridge/logs/` + any new + `results/research/*.json` back to the same branch and pushes; it also + uploads them as workflow artifacts. K3 acceptance reports are passed + through `scripts/validate_k3_reports.py` **on the Mac** so a + non-conforming report fails the bridge run itself (C4). +4. **Fetch.** The agent polls with read-only `gh run list/view` (or plain + `git fetch` until the result commit appears) via + `scripts/mac_bridge/fetch_results.py`. + +Latency profile: ~10 s dispatch + queue + workload runtime. Right for +test/eval/bench cycles (minutes-scale), wrong for interactive debugging — +that is M2's job, not a reason to widen M1's command surface. + +### 2.2 Preset allowlist (M1 command surface) + +| Preset | What runs on the Mac | Typical use | +| --- | --- | --- | +| `mlx-env-probe` | `backends.mlx.env.probe_environment()` + `distributed.mlx_ring.probe_ring_environment()` (when present on the ref) | "is Metal/mlx healthy, which versions" | +| `mlx-backend-tests` | `pytest tests/backends/mlx/ -q` | real-mlx truth for the fake-mlx Linux suites | +| `integration-tests` | `pytest -m integration tests/integration/ -q` | the v0.3 GA gate on demand | +| `k3-step1-incremental` | hardened Mac harness `--incremental` (n/gen/ctx bounded params) | PR #109 Step-1 decode-only evidence | +| `k3-step2-fused` | hardened Mac harness `--fused-specdecode` | PR #109 Step-2 `blocks>0` evidence | +| `k3-native-baseline` | hardened Mac harness `--native-baseline-bypass` | labelled oracle baseline | +| `k3-evidence-gate` | `scripts/validate_k3_reports.py results/research` | re-validate committed reports on-device | +| `pytest-path` | `pytest -q` with the path validated against a repo-relative allowlisted-prefix rule (`tests/`) | targeted debugging of one test file | + +Parameters are **typed and bounded** (`n_samples ≤ 50`, +`max_new_tokens ≤ 512`, `block_size ≤ 16`, paths must resolve under +`tests/`); anything else is rejected at manifest validation, before any +process starts. Machine-local facts (verifier/model paths) come from the +runner's environment (`KAKEYA_MAC_VERIFIER_PATH`, …), never from the +manifest. + +### 2.3 M2 — tailnet SSH (designed, needs one secret + one install) + +For interactive MLX debugging (lldb, py-spy, Metal captures, REPL): + +- Mac: `brew install tailscale`, join the tailnet with `--ssh` + (Tailscale SSH; respects tailnet ACLs), tag `tag:kakeya-mac`. +- Cloud agent: `TAILSCALE_AUTHKEY` (ephemeral, pre-authorized, + tag-scoped key) added in Cursor Dashboard → Cloud Agents → Secrets; + `scripts/mac_bridge/connect_tailscale.sh` brings up `tailscaled` in + userspace-networking mode and opens `ssh kakeya@kakeya-mac-m4`. +- ACL: the agent-side tag may reach `tag:kakeya-mac:22` only; the Mac + initiates nothing toward agents. Ephemeral nodes garbage-collect when + the agent VM dies. + +This is the same outbound-only trust shape as the Actions runner (C1), +with per-session ephemeral identity. It is deliberately **not** part of +M1: it requires a secret a fresh clone does not have. + +### 2.4 M3 — fleet membership (evaluation in §4) + +With the tailnet up, the Mac's Kakeya runtime serves the ADR 0009 gRPC +plane (`CapabilityService` + `ProposerService`, PR #105) and the cloud +agent joins as a fleet peer — capability gossip, placement, and remote +block proposal over the same wire contract used between Mac minis on a +desk LAN. + +## 3. Security model (M1) + +- **Command surface**: presets only; fixed argv; no manifest string ever + reaches a shell. `pytest-path` constrains to repo-relative `tests/`. +- **Trigger surface**: anyone with push permission to `mac-bridge/**` — + identical to the existing surface (any PR labelled `needs-mac-m4` + already executes its code on the runner via `integration.yaml`). The + bridge does not widen who can run code on the Mac; it widens *what can + be conveniently requested* while **narrowing** it to an allowlist. +- **Result integrity**: results are commits on the request branch — + reviewable, attributable, and evidence-gated before merge anywhere. +- **Resource protection**: `concurrency: mac-bridge` serializes the + single Mac; per-preset `timeout-minutes`; runs are cancellable from + the Actions UI. + +## 4. Evaluation — folding the bridge into Kakeya distributed inference + +The question: should "cloud agent ⇄ kakeya-mac-m4" become a first-class +part of the engine's distributed-inference feature (ADR 0009 / PR #105), +rather than repo tooling? + +### 4.1 What maps cleanly + +| Bridge concept | ADR 0009 plane concept | +| --- | --- | +| Mac runner with presets | `NodeCapability` with `CAPABILITY_ROLE_TOOL` entries (the enum slot already exists in `distributed.proto`) — e.g. `tool:mlx-eval`, `tool:integration-tests` | +| preset manifest | `ProposeBlock`-style typed request messages (one RPC per tool capability) | +| git-bus branch session | durable async job with attributable artifacts — the property worth **keeping** even after gRPC exists | +| evidence gate on results | the same gate, already shared library code | + +The capability model was designed for exactly this shape: the Mac +advertises what it can do; placement picks it; the work request is typed +and the accept/reject of its *output* happens on the consumer side. A +`remote-executor` tool role is a natural, small extension of PR #105 — +the registry, gossip, TTL, and placement code need **zero changes**; +only a new `ModelCapability(role=TOOL)` convention plus one service. + +### 4.2 What does not map: WAN data-plane spec decode + +The latency budget kills token-level speculative decoding across the +cloud↔desk boundary, and the integration should say so explicitly: + +- LAN (two Mac minis, ADR 0009's target): `ProposeBlock` RTT ~0.3–1 ms + against block compute of tens of ms → negligible overhead. ✔ +- WAN (cloud agent ⇄ home/office Mac through a relay): RTT 30–150 ms, + *per block*. A Gemma-4-26B 4-bit verifier on M4 verifies an 8-token + block in roughly 50–100 ms — the network would add 30–300 % overhead + per block, and any acceptance-rate gain is consumed by transport. + Drafts are latency-critical; **proposer and verifier must share a + LAN** (or a Thunderbolt ring, per ADR 0009 §2). ✘ +- WAN-tolerant flows: capability gossip (seconds-scale TTLs), placement, + eval/test/bench jobs, artifact return — all fine. ✔ + +So the correct integration boundary is: **WAN = control plane + tool +plane; LAN = data plane.** This is the same hybrid conclusion as ADR +0009, extended one tier outward. + +### 4.3 Recommendation + +1. **Keep M1 in-repo now** (this PR): it unblocks PR #109's required Mac + reruns and all future agent-driven MLX work, with no new secrets. +2. **M2 next**: one Tailscale authkey secret + one Mac install; gives + interactive debugging and the channel M3 needs. Low effort, high + leverage. +3. **M3 as a v0.5 roadmap item, scoped**: add a `remote-executor` TOOL + capability + a small `ToolService` to the ADR 0009 plane so fleet + nodes (including the Mac) advertise *evaluation* capabilities the + same way they advertise verifier/proposer roles. Explicitly do + **not** route spec-decode draft traffic over WAN; placement should + treat `ring_address`/RTT class as a hard constraint for data-plane + pairings (a one-line addition to `plan_spec_decode_placement`'s + candidate filter when WAN nodes appear). +4. mTLS + node identity (already queued for v0.5 GA) becomes a + prerequisite for M3 leaving the tailnet's closed world. + +## 5. One-click install & run + +### Mac mini side — one command + +```bash +# On the Mac, from the repo root. +# Existing kakeya-mac-m4 host (runner already registered): +bash scripts/mac_bridge/setup_mac.sh + +# Fresh Mac (also installs + registers the Actions runner; token from +# GitHub → Settings → Actions → Runners → New self-hosted runner): +bash scripts/mac_bridge/setup_mac.sh \ + --runner-token --repo-url https://github.com// + +# Optionally prepare M2 interactive SSH too: +bash scripts/mac_bridge/setup_mac.sh --with-tailscale +``` + +The script is idempotent and ends with a bridge self-test; a green exit +means the next `mac-bridge/**` push executes. What it covers: host +shape (arm64 + Python ≥3.12), Python deps (`scripts/setup_mac.sh`), +Actions runner install/registration with the +`[self-hosted, macOS, ARM64, kakeya-mac-m4]` labels, model-location +checks for the `k3-*` presets (with repo-variable instructions when +paths differ), HF-cache pre-warm check, and executor dry-run. + +### Cloud agent side — zero install, two commands + +The bridge client is stdlib-only: a fresh cloud agent needs **no +configuration** beyond what it already has (repo checkout, git push, +read-only `gh`). Optional: `TAILSCALE_AUTHKEY` in Cursor Dashboard → +Cloud Agents → Secrets enables M2 interactive SSH later. + +```bash +# 0. Sanity-check this environment (push rights, gh, bridge files): +PYTHONPATH=.:sdks/python python3 scripts/mac_bridge/kakeya_mac.py doctor + +# 1. Run on the Mac and wait for results: +PYTHONPATH=.:sdks/python python3 scripts/mac_bridge/kakeya_mac.py run \ + --preset mlx-env-probe --wait 600 + +# Evidence reruns for PR #109 (hardened-harness ref): +PYTHONPATH=.:sdks/python python3 scripts/mac_bridge/kakeya_mac.py run \ + --preset k3-step2-fused --ref origin/AgentMemory/v04-mlx-port-incremental-decode-2815 \ + --param n_samples=5 --param max_new_tokens=64 --param block_size=4 --wait 7200 + +# 2. Check any request later: +PYTHONPATH=.:sdks/python python3 scripts/mac_bridge/kakeya_mac.py status \ + --branch --wait 0 +``` + +`kakeya_mac.py run` auto-detects cloud-agent branch policy: on an +`AgentMemory/-` checkout it creates the request as +`AgentMemory/mac-bridge---` (the workflow accepts +both namespaces), so agents never leave their allowed branch template. +After pushing, the client returns the worktree to the original branch. + +Lower-level pieces (`request_run.py`, `fetch_results.py`, +`run_preset.py`) remain directly usable; machine-local configuration on +the runner lives in env / repo Actions variables +(`KAKEYA_MAC_VERIFIER_PATH`, `KAKEYA_MAC_DRAFTER_ID`, +`KAKEYA_MAC_FTHETA_DIR` — see `docs/ops/mac-m4-runner-setup.md`). diff --git a/docs/mlx-port-lessons.md b/docs/mlx-port-lessons.md index 5a977125..5590156c 100644 --- a/docs/mlx-port-lessons.md +++ b/docs/mlx-port-lessons.md @@ -50,17 +50,57 @@ speed** (on CUDA: 1.3–2.8 tok/s re-forward → ~21 tok/s incremental = AR). ## MLX port plan (ordered; each gates the next) -1. **Incremental decode (kills the collapse).** Add an MLX analog of - `CrossModelRestoredSinkWindowVerifier(incremental=True)`: prefill → capture - restored K/V into `SinkWindowKVCache` (full-attn = own/exact; sliding = f_θ or - window-masked) → decode via `generate_step(prompt_cache=…)`. **Gate: decode - tok/s ≈ native mlx_lm AR; recall 1.0** (carried by S5). +1. **Incremental decode (kills the collapse). [IMPLEMENTED — needs Mac validation]** + `backends/mlx/cross_model_dlm_verifier.py`: `restored_prefill_cache` (prefill + once with injection **into the model's native hybrid cache** → full-attn/global + layers store exact own K/V, sliding store f_θ-restored + window-bounded) + + `restored_incremental_generate` (decode via `mlx_lm.generate_step` over that + cache, O(L)/token, async-pipelined). Wired into the Mac harness via + `--incremental`: + ```bash + PYTHONPATH=.:sdks/python python scripts/research/k3_integrated_niah_eval_mac.py \ + --verifier-path models/gemma-4-26B-A4B-it-mlx-4bit \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --f-theta-dir results/research/f_theta_v5_s5_sliding \ + --s5-exact-full-attn --incremental --n-samples 5 --max-new-tokens 32 + ``` + **Gate: decode tok/s ≫ the per-token re-forward (toward native mlx_lm AR); + recall == oracle (1.0)** (carried by S5). Mechanism mirrors CUDA Gap-A: the + existing MLX dispatch already calls `cache.update_and_fetch`, so prefill *with* + a cache populates it; decode then runs native incremental attention. 2. **Drop the extra build forward.** Capture full-attn own K/V at prefill; do not re-run a clean verifier forward per request beyond prefill. **Gate: - `build_restoration` from ~12s → ~prefill cost.** + `build_restoration` from ~12s → ~prefill cost.** *(Still pending: the Mac + harness `build_restoration` keeps the clean capture forward; the fused path + does add one clean aux-capture forward at prefill — fold these together when + optimizing.)* 3. **Gap-B drafter embed fix** (no `×sqrt`) on the MLX/Bridge drafting path. - **Gate: acceptance toward reference on code prompts.** -4. **Fused spec-decode** (A+B+C incremental caches). **Gate: tok/s > AR.** + **[IMPLEMENTED]** `fused_specdecode.make_bridge_embed_lm_head` builds the + drafting `embed_fn` as a **plain shared-embedding lookup (no `×sqrt(hidden)`)**; + `lm_head_fn` = tied-embed + `final_logit_softcapping`. +4. **Fused spec-decode** (A+B+C incremental caches). **[IMPLEMENTED — needs Mac + validation]** `inference_engine/backends/mlx/fused_specdecode.py`: + - **A** `capture_aux_hidden` + `MLXRestoredIncrementalVerifier.forward_block` + (patch the Gemma-4 `DecoderLayer.__call__` to record aux-layer outputs — + there is no `output_hidden_states` on MLX) capture the verify forward's aux + hidden, bridged to torch. + - **B** reuses the PyTorch drafter's `make_context_kv` / `extend_context_kv` / + `draft_block_cached` (drafter context K/V cache). + - **C** `MLXRestoredIncrementalVerifier` (prefill = Gap-A restored cache; + `commit_or_truncate` rolls back rejected tokens via **`mlx_lm`'s native + `trim_prompt_cache`** — the same primitive mlx_lm's own spec-decode uses). + - `fused_specdecode_generate` is the per-block O(L) accept/reject loop. + Wired into the Mac harness via `--fused-specdecode --block-size N`: + ```bash + PYTHONPATH=.:sdks/python python scripts/research/k3_integrated_niah_eval_mac.py \ + --verifier-path models/gemma-4-26B-A4B-it-mlx-4bit \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --f-theta-dir results/research/f_theta_v5_s5_sliding \ + --s5-exact-full-attn --fused-specdecode --block-size 4 \ + --n-samples 5 --max-new-tokens 32 + ``` + **Gate: tok/s > AR; recall == oracle (1.0).** Reference (#107 H200): fused + 1.27× AR, recall 1.0. ## Validation gates (match #107 evidence) @@ -72,6 +112,90 @@ speed** (on CUDA: 1.3–2.8 tok/s re-forward → ~21 tok/s incremental = AR). `results/research/k3_e2e_gpu_bench_incremental.json`, `k3_specdecode_fused_stable.json`.) +## Step-2 rescue status (2026-06-12, all-MLX drafter) + +The hybrid fused engine's 0.028× was the per-block mx↔torch bridge + +float32 CPU-torch drafter. The all-MLX drafter +(`inference_engine/backends/mlx/dflash_drafter.py`) eliminates both: + +- **Parity** (bridge presets `k3-drafter-parity[-fp32]`): fp32-vs-fp32 + = **100 %** token match (96/96) — the port is numerically faithful; + bf16 shipping dtype = 94.8 % (near-tie argmax flips, + correctness-contained by the verifier). +- **Fused evidence** (`k3-step2-fused-allmlx`, n5/gen64/ctx280, + gate-clean): decode-only **11.0 tok/s = 0.476× AR** at block 4 + (block 8: 0.40×) — a **17× improvement** over the hybrid path's + 0.635 tok/s, recall 5/5, accept_len 1.9–3.2. +- **Remaining gap to >AR**: Metal AR decode is 43 ms/token + (`generate_step`, async-pipelined); the fused loop pays ~6 python + sync points per block (~300 ms/block for ~2.5 accepted tokens). + Next levers, in order: lazy/async block evaluation (single + `mx.async_eval` per block), fusing draft+verify into one graph, + trimming the correction-token `append` forward. Until then Step 1 + remains the shipping Mac path. + +## Levers ①②③ implemented (2026-06-12) — and a correctness bug they exposed + +`fused_specdecode_generate_mlx` (v3, `mlx_rollback_carry_v3`) lands all +three levers: lazy draft ids feeding the verify forward (②, two-phase +eval after a fully fused drafter+26B graph hit Metal command-buffer +pathology: 143 s block evals), in-graph cumprod acceptance + lazy +next-row gather with ~2 host syncs/block (①), and the carried +bonus/correction with **no** append forward (③). + +**The big find**: live block-4 runs diverged from the greedy stream +(eos at token 22 vs Step-1's 64) while block-1 was byte-clean → +isolated to the rejection path → **`trim_prompt_cache` is unsound on +Gemma-4's hybrid cache once the sliding RotatingKVCache has wrapped** +(seq >> 512): rejected draft K/V linger in the ring. This +retroactively invalidates the acceptance/throughput numbers of every +earlier trim-based fused run (the hybrid iterC run's 23-token sample, +the eager all-MLX run's silent post-answer divergence). Fix: O(1) +reference snapshot before each verify forward; on partial acceptance +roll the whole forward back and carry the committed tokens into the +next candidate (guaranteed re-accept; K/V + aux recomputed correctly). + +**Corrected picture (gate-clean, 64/64 tokens, recall 5/5)**: + +| mode | decode-only | note | +|---|---|---| +| Step-1 incremental (greedy) | **22.2 tok/s ≈ 1.0× AR** | shipping path | +| fused block 1 (carried greedy, levers ①③) | 17.5 tok/s | loop overhead ≈ 2 syncs/block | +| fused block 4 (v3, all levers) | 5.8 tok/s = 0.26× | TRUE accept ≈ 2.0/block | + +With the corrected (uncorrupted) acceptance ≈ 1.8–2.3 committed/block +at block 4, the fused ceiling is `2.1×43ms / (verify(4)=120ms + +draft≈20ms) ≈ 0.6–0.7×` — **engineering levers cannot reach AR parity; +the binding constraint is drafter acceptance** (true per-draft accept +~30–40 % vs the ~75 % parity would need). Next investment, if Step 2 +is pursued: DFlash↔Gemma-4 alignment fine-tuning, re-measured under +the rollback-correct loop. + +## KV-quant shoot-out (2026-06-12): affine wins, KL MLX port NOT justified + +`k3-kv-quant-eval` (ctx280, n=5, real recall per arm, identity control +clean, oracle 1.0): + +| arm | bits/value | full-attn rel_mse | recall | +|---|---|---|---| +| identity | 16.0 | 0 | 5/5 | +| **affine8** (QuantizedKVCache format) | 8.5 | 0.000056 | 5/5 | +| **affine4** | 4.5 | 0.014438 | **5/5** | +| KL-D4 (q38) | 6.31 | 0.000753 | 5/5 | +| KL-E8 (q38) | 6.44 | 0.000499 | 5/5 | + +- **affine4 already passes recall with ~25× rel_mse margin** vs the + 0.36 threshold → the S5 linear term compresses 20 → 5.6 KB/token + (S5 resident @5.8k: 132.9 → ~48 MB) with the native, kernel-fused + `QuantizedKVCache` format. Adopt this; throughput expected neutral + or better (bandwidth-bound decode). +- KL's rate-distortion is genuinely better (~2× lower distortion at + interpolated equal rate) but it cannot reach affine4's rate with the + current codec settings, and nothing binds at the fidelity affine4 + already delivers. **MLX port shelved**; revisit only if a future + requirement needs <4.5 bits/value or <1e-3 rel_mse at ≤4.5 bits + (e.g. 128k+ contexts × many sessions). + ## Do-not-repeat (anti-patterns) - ❌ Re-forwarding the full sequence per generated token (the current collapse). diff --git a/docs/ops/mac-m4-runner-setup.md b/docs/ops/mac-m4-runner-setup.md index 7c106023..2cb233c2 100644 --- a/docs/ops/mac-m4-runner-setup.md +++ b/docs/ops/mac-m4-runner-setup.md @@ -135,3 +135,34 @@ Common causes: - macOS auto-update rebooted the host; service didn't auto-start (rare with `launchd` but possible). - HF cache was purged; the verify step fails. Re-warm. - Disk full from accumulated pip downloads; clear cache. + +## Mac bridge (cloud-agent access) + +The same runner also serves the **Mac bridge** +(`.github/workflows/mac-bridge.yaml`): pushes to `mac-bridge/**` +branches execute an allowlisted preset (see +`inference_engine/bridge/manifest.py`) and commit logs/results back to +the request branch. Full protocol + security model: +`docs/design/mac-bridge-cloud-agent-access.md`. + +**One-click setup**: on the Mac, from the repo root — +`bash scripts/mac_bridge/setup_mac.sh` (add +`--runner-token --repo-url ` on a fresh machine to install +and register the Actions runner too). The script is idempotent and ends +with a bridge self-test. + +Operator setup beyond the standard runner install: + +1. **Model locations** (used by the `k3-*` harness presets) are read + from the environment / repo Actions variables, never from the + request manifest. Set repo variables (Settings → Secrets and + variables → Actions → Variables) when the on-disk layout differs + from the defaults: + - `KAKEYA_MAC_VERIFIER_PATH` — MLX 4-bit Gemma-4 verifier directory + - `KAKEYA_MAC_DRAFTER_ID` — DFlash drafter HF id or local path + - `KAKEYA_MAC_FTHETA_DIR` — trained f_θ checkpoint directory +2. Bridge runs are serialized (`concurrency: mac-bridge`) and capped at + 150 minutes; cancel stuck runs from the Actions UI. +3. K3 acceptance reports produced by bridge runs are validated by the + evidence gate on this machine; a non-conforming report fails the + bridge run (exit ≠ 0) by design. diff --git a/docs/pr109-mac-ctx280-validation.md b/docs/pr109-mac-ctx280-validation.md new file mode 100644 index 00000000..fa8da879 --- /dev/null +++ b/docs/pr109-mac-ctx280-validation.md @@ -0,0 +1,47 @@ +# PR109 Mac ctx280 Validation + +This note records the review-driven rerun for PR #109 after fixing the +measurement issues called out in review. + +## Review Corrections + +- Fair timing: cross and oracle now report the same `e2e_prefill_plus_decode` + scope, plus per-sample `prefill_s`, `decode_s`, and `e2e_s`. +- Chunked prefill: MLX prompt prefill now uses `--prefill-chunk-size` to avoid + the long-context one-shot forward path that can OOM. +- Adaptive native path: Step 2 adaptive S5 native skips `build_restoration`, + f_theta restoration, and aux capture. +- Gemma4 stop tokens: `` is treated as a generation stop token alongside + ``. +- Gate scale: validation was rerun with `n=5`, `max_new_tokens=32`, and + haystack lines `238..322`, producing prompt lengths `4406..5810`. + +## Command + +```bash +PYTHONPATH=.:sdks/python python scripts/research/k3_integrated_niah_eval_mac.py \ + --verifier-path /Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --f-theta-dir results/research/f_theta_v5_s5_sliding \ + --s5-exact-full-attn --fused-specdecode --block-size 4 \ + --n-samples 5 --haystack-min-lines 238 --haystack-max-lines 322 \ + --max-new-tokens 32 --prefill-chunk-size 512 --decode-warmup-tokens 1 \ + --output results/research/k3_mlx_fused_fair_ctx280_n5_gen32_20260612_105807.json +``` + +## Result + +- Recall: cross `5/5 = 1.0`, oracle `5/5 = 1.0`, delta `0pp`. +- Prompt lengths: `4406..5810` tokens. +- Timing scope: `e2e_prefill_plus_decode` for both cross and oracle. +- Cross Step 2 throughput: `0.2217 tok/s` (`39 tok / 175.893s`). +- Oracle AR throughput: `0.0858 tok/s` (`39 tok / 454.484s`). +- Speedup vs oracle AR: `2.584x`. +- KV memory: S5 `132.92 MB`, naive full KV `1308.88 MB`, savings `89.8%`. + +## Interpretation + +This validation supports Step 2 adaptive S5 native under the corrected e2e +measurement scope at ctx280 scale on the tested Mac setup. It does not claim +that Step 1 incremental is fixed; earlier evidence still shows Step 1 remains +slow and should be treated as a separate optimization target. diff --git a/inference_engine/backends/mlx/cross_model_dlm_verifier.py b/inference_engine/backends/mlx/cross_model_dlm_verifier.py index 29a27b27..51b42ee8 100644 --- a/inference_engine/backends/mlx/cross_model_dlm_verifier.py +++ b/inference_engine/backends/mlx/cross_model_dlm_verifier.py @@ -360,3 +360,160 @@ def restored_logits( logits = mlx_model(ids) # full Model.__call__ → tied embed + softcap mx.eval(logits) return logits[0] if return_all else logits[0, -1] + + +# --------------------------------------------------------------------------- +# Incremental decode (MLX port of CUDA Gap-A) — kills the per-token re-forward +# throughput collapse. See docs/mlx-port-lessons.md. +# --------------------------------------------------------------------------- + + +def restored_prefill_cache( + mlx_model: Any, + input_ids: Sequence[int], + *, + restored_k_per_layer: Dict[int, Any], + restored_v_per_layer: Dict[int, Any], + evicted_positions: Sequence[int], + prefill_chunk_size: int = 0, + cache_factory: Optional[Callable[[Any], Any]] = None, +): + """Prefill ONCE with restoration, capturing the restored K/V into a + persistent mlx_lm prompt cache; return ``(cache, last_logits)``. + + Same injection as :func:`restored_logits`, but run **with a cache** so the + patched attention's ``cache.update_and_fetch`` stores the post-injection + K/V (full-attention/S5 layers → exact own K/V; sliding → f_θ-restored, + window-bounded by the model's native RotatingKVCache). After this the + verifier can decode new tokens incrementally over the cache — O(L)/step — + instead of re-forwarding the whole sequence each token. + + Returns the model's native hybrid cache (full `KVCache` for global layers, + `RotatingKVCache(sliding_window)` for sliding layers) populated to the + prompt, plus the last-row logits (``mx [V]``) predicting the first token. + """ + import mlx.core as mx # type: ignore + from mlx_lm.models.cache import make_prompt_cache # type: ignore + + text_model = resolve_mlx_text_model(mlx_model) + T = len(list(input_ids)) + evicted = set(int(p) for p in evicted_positions if 0 <= int(p) < T) + # cache_factory lets the caller swap the model's native hybrid cache for an + # all-`KVCache` layout (full store for sliding layers too) so that the + # spec-decode accept/reject rollback can use mlx_lm's native, SOUND + # `trim_prompt_cache` (keep accepted K/V, drop only rejected) instead of the + # full re-forward carry — `RotatingKVCache` is not trimmable once wrapped. + # Sliding attention stays byte-exact: the window mask is applied regardless + # of cache capacity. (Costs O(T) sliding KV during decode; fine for the + # short-context code/agent workloads this targets.) + cache = (cache_factory or make_prompt_cache)(mlx_model) + + def _slice_restored(a, start: int, end: int): + if a is None: + return None + try: + return a[:, start:end, :, :] + except Exception: + # Linux fake tests use sentinel objects rather than tensors. + return a + + def _clear(touched): + for obj in touched: + for name in ( + "_kakeya_inject", + "kakeya_evicted_mask", + "kakeya_restored_pre_keys", + "kakeya_restored_pre_values", + ): + if hasattr(obj, name): + delattr(obj, name) + + def _attach_chunk(start: int, end: int): + evicted_mask = mx.array([p in evicted for p in range(start, end)]) + touched = [] + needs_attention_patch = False + for idx, layer in enumerate(text_model.layers): + attn = layer.self_attn + if idx >= len(cache) or not bool(getattr(attn, "has_kv", True)): + continue # sharers inherit K/V via shared_kv + rk = restored_k_per_layer.get(idx) + rv = restored_v_per_layer.get(idx) + if rk is None: + continue + c = cache[idx] + try: + c.kakeya_evicted_mask = evicted_mask + c.kakeya_restored_pre_keys = _slice_restored(rk, start, end) + c.kakeya_restored_pre_values = _slice_restored(rv, start, end) + touched.append(c) + except Exception: + attn._kakeya_inject = { + "mode": "inject", + "evicted_mask": evicted_mask, + "restored_k": _slice_restored(rk, start, end), + "restored_v": _slice_restored(rv, start, end), + } + touched.append(attn) + needs_attention_patch = True + return touched, needs_attention_patch + + ids_list = list(input_ids) + chunk = int(prefill_chunk_size or 0) + if chunk <= 0 or T <= chunk: + chunks = [(0, T)] + else: + chunks = [(s, min(s + chunk, T)) for s in range(0, T, chunk)] + + logits = None + for start, end in chunks: + touched, needs_attention_patch = _attach_chunk(start, end) + try: + ids = mx.array([ids_list[start:end]]) + if needs_attention_patch: + with _patched_attention_class(text_model): + logits = mlx_model(ids, cache=cache) + mx.eval(logits) + else: + logits = mlx_model(ids, cache=cache) + mx.eval(logits) + finally: + _clear(touched) + if logits is None: + ids = mx.array([ids_list]) + logits = mlx_model(ids, cache=cache) + mx.eval(logits) + # Subsequent decode steps run native incremental attention over this cache. + return cache, logits[0, -1] + + +def restored_incremental_generate( + mlx_model: Any, + cache: Any, + first_logits: Any, + *, + max_tokens: int, + eos_ids: Sequence[int] = (), +) -> List[int]: + """Greedy-decode up to ``max_tokens`` tokens over a restored prefill cache + using mlx_lm's native ``generate_step`` (chunked + async-pipelined) — the + throughput-critical incremental loop. Recall is carried by the cache's + full-attention (S5) layers. + """ + import mlx.core as mx # type: ignore + from mlx_lm.generate import generate_step # type: ignore + + eos = set(int(t) for t in eos_ids) + nxt = int(mx.argmax(first_logits).item()) + out: List[int] = [nxt] + if nxt in eos or max_tokens <= 1: + return out + # generate_step with a 1-token prompt + prefilled cache skips re-prefill + # (its chunked-prefill loop needs >1 prompt token) and decodes incrementally. + for tok, _ in generate_step( + mx.array([nxt]), mlx_model, prompt_cache=cache, max_tokens=max_tokens - 1, + ): + t = int(tok) + out.append(t) + if t in eos: + break + return out diff --git a/inference_engine/backends/mlx/dflash_drafter.py b/inference_engine/backends/mlx/dflash_drafter.py new file mode 100644 index 00000000..12622522 --- /dev/null +++ b/inference_engine/backends/mlx/dflash_drafter.py @@ -0,0 +1,357 @@ +"""All-MLX DFlash drafter — Step-2 rescue (eliminate the per-block bridge). + +The gate-clean iterC evidence (PR #109) showed the hybrid fused engine is +correct (recall 5/5 @ ctx280, accept_len 2.1–2.9/4) but 0.028× decode-only +vs native AR: each 4-token block paid 4+ MLX↔numpy↔torch bridge crossings +plus a float32 CPU-torch drafter forward (~2.7–8.4 s/block). This module +is the fix: the same DFlash drafter, native MLX, sharing the verifier's +Metal stream — zero bridge crossings per block. + +Fidelity contract: a 1:1 port of ``inference_engine/v04/dflash_drafter.py`` +(``DFlashDrafter``) — same config parsing (``DFlashConfig`` is reused +directly), same weight names (loaded straight from the checkpoint +``model.safetensors`` via ``mx.load``), same math: + +* Qwen3 blocks: q/k/v/o_proj (no bias), q_norm/k_norm RMSNorm on head_dim, + pre/post-attention RMSNorm, SiLU-gated MLP; +* explicit float32 RoPE tables (cos/sin) with the rotate-half convention, + applied at arbitrary positions (context positions and query positions + are different ranges); +* non-causal attention over [context ++ query] with GQA via + ``mx.fast.scaled_dot_product_attention`` (handles n_q != n_kv natively, + no repeat_interleave materialisation); +* ``fc`` aux fusion → ``hidden_norm`` once over context → per-layer + context K/V (k_norm + RoPE), prefill-built and extended per block + (components B of the fused engine); +* drafts = argmax over mask-position logits with the mask sentinel + excluded. + +Parity with the torch reference is validated ON DEVICE by +``scripts/research/k3_mlx_drafter_parity.py`` (bridge preset +``k3-drafter-parity``) before any throughput claim — same evidence +discipline as everything else in PR #109. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, List, Sequence, Tuple + +from inference_engine.v04.dflash_drafter import DFlashConfig + + +def _mx(): + import mlx.core as mx # type: ignore + + return mx + + +def _nn(): + import mlx.nn as nn # type: ignore + + return nn + + +def _rope_cos_sin(positions: Any, head_dim: int, theta: float): + """Float32 rotary tables for arbitrary ``positions`` ``[T]`` → + ``(cos, sin)`` each ``[T, head_dim]``. Mirrors the torch reference + (full-precision tables, rotate-half pairing).""" + mx = _mx() + inv_freq = 1.0 / ( + theta ** (mx.arange(0, head_dim, 2, dtype=mx.float32) / head_dim) + ) + freqs = positions.astype(mx.float32)[:, None] * inv_freq[None, :] + emb = mx.concatenate([freqs, freqs], axis=-1) # [T, head_dim] + return mx.cos(emb), mx.sin(emb) + + +def _apply_rope(x: Any, cos: Any, sin: Any) -> Any: + """x: [B, H, T, D] (any dtype); cos/sin: [T, D] float32.""" + mx = _mx() + half = x.shape[-1] // 2 + rotated = mx.concatenate([-x[..., half:], x[..., :half]], axis=-1) + out = x.astype(mx.float32) * cos[None, None] + rotated.astype(mx.float32) * sin[None, None] + return out.astype(x.dtype) + + +class _Attention: + """DFlash draft attention, MLX-native (see torch ``_DFlashAttention``).""" + + def __init__(self, cfg: DFlashConfig) -> None: + nn = _nn() + self.cfg = cfg + self.nh = cfg.num_attention_heads + self.nkv = cfg.num_key_value_heads + self.hd = cfg.head_dim + self.theta = cfg.rope_theta + self.scale = self.hd ** -0.5 + self.q_proj = nn.Linear(cfg.hidden_size, self.nh * self.hd, bias=False) + self.k_proj = nn.Linear(cfg.hidden_size, self.nkv * self.hd, bias=False) + self.v_proj = nn.Linear(cfg.hidden_size, self.nkv * self.hd, bias=False) + self.o_proj = nn.Linear(self.nh * self.hd, cfg.hidden_size, bias=False) + self.q_norm = nn.RMSNorm(self.hd, eps=cfg.rms_norm_eps) + self.k_norm = nn.RMSNorm(self.hd, eps=cfg.rms_norm_eps) + + def project_context_kv(self, ctx_normed: Any, ctx_positions: Any): + """(hidden_norm-ed) context hidden → this layer's (k, v), k_norm + + RoPE applied. Returns each ``[B, nkv, C, hd]``.""" + mx = _mx() + B, C, _ = ctx_normed.shape + k = self.k_proj(ctx_normed).reshape(B, C, self.nkv, self.hd) + v = self.v_proj(ctx_normed).reshape(B, C, self.nkv, self.hd) + k = self.k_norm(k).transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + cos, sin = _rope_cos_sin(ctx_positions, self.hd, self.theta) + k = _apply_rope(k, cos, sin) + return k, v + + def __call__(self, h: Any, query_positions: Any, ctx_k: Any, ctx_v: Any) -> Any: + mx = _mx() + B, T, _ = h.shape + q = self.q_proj(h).reshape(B, T, self.nh, self.hd) + k = self.k_proj(h).reshape(B, T, self.nkv, self.hd) + v = self.v_proj(h).reshape(B, T, self.nkv, self.hd) + q = self.q_norm(q).transpose(0, 2, 1, 3) + k = self.k_norm(k).transpose(0, 2, 1, 3) + v = v.transpose(0, 2, 1, 3) + cos, sin = _rope_cos_sin(query_positions, self.hd, self.theta) + q = _apply_rope(q, cos, sin) + k = _apply_rope(k, cos, sin) + if ctx_k is not None: + k = mx.concatenate([ctx_k.astype(k.dtype), k], axis=2) + v = mx.concatenate([ctx_v.astype(v.dtype), v], axis=2) + # Non-causal; GQA handled natively (no repeat_interleave). + out = mx.fast.scaled_dot_product_attention( + q, k, v, scale=self.scale, mask=None, + ) + out = out.transpose(0, 2, 1, 3).reshape(B, T, self.nh * self.hd) + return self.o_proj(out) + + +class _MLP: + def __init__(self, cfg: DFlashConfig) -> None: + nn = _nn() + self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) + self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False) + self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False) + + def __call__(self, x: Any) -> Any: + nn = _nn() + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class _Layer: + def __init__(self, cfg: DFlashConfig) -> None: + nn = _nn() + self.input_layernorm = nn.RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) + self.self_attn = _Attention(cfg) + self.post_attention_layernorm = nn.RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) + self.mlp = _MLP(cfg) + + def __call__(self, h, query_positions, ctx_k, ctx_v): + h = h + self.self_attn( + self.input_layernorm(h), query_positions, ctx_k, ctx_v, + ) + h = h + self.mlp(self.post_attention_layernorm(h)) + return h + + +class MLXDFlashDrafter: + """MLX-native DFlash drafter with the SAME method surface as the torch + ``DFlashDrafter`` fast path (``make_context_kv`` / ``extend_context_kv`` + / ``draft_block_cached``), so ``fused_specdecode_generate`` drives either + implementation unchanged.""" + + def __init__(self, cfg: DFlashConfig) -> None: + nn = _nn() + self.cfg = cfg + self.layers = [_Layer(cfg) for _ in range(cfg.num_hidden_layers)] + self.fc = nn.Linear(cfg.fc_in_features, cfg.hidden_size, bias=False) + self.hidden_norm = nn.RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) + self.norm = nn.RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps) + + # -- weights ------------------------------------------------------------ + def load_weights(self, weights: dict) -> None: + """Assign checkpoint tensors (HF DFlash names) onto the modules. + + Same strictness as the torch loader: any missing/unexpected key is + a hard error. Dtype is preserved from the checkpoint (bf16). + """ + own: dict = {} + + def reg(prefix: str, mod: Any) -> None: + for name in ("weight",): + own[f"{prefix}.{name}"] = (mod, name) + + for i, layer in enumerate(self.layers): + p = f"layers.{i}" + reg(f"{p}.input_layernorm", layer.input_layernorm) + reg(f"{p}.post_attention_layernorm", layer.post_attention_layernorm) + for sub in ("q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm"): + reg(f"{p}.self_attn.{sub}", getattr(layer.self_attn, sub)) + for sub in ("gate_proj", "up_proj", "down_proj"): + reg(f"{p}.mlp.{sub}", getattr(layer.mlp, sub)) + reg("fc", self.fc) + reg("hidden_norm", self.hidden_norm) + reg("norm", self.norm) + + missing = [k for k in own if k not in weights] + unexpected = [k for k in weights if k not in own] + if missing or unexpected: + raise ValueError( + f"DFlash MLX weight mismatch: missing={missing[:6]} " + f"unexpected={unexpected[:6]}" + ) + for key, (mod, attr) in own.items(): + setattr(mod, attr, weights[key]) + + @classmethod + def from_pretrained( + cls, model_id_or_path: str, *, compute_dtype: str = "bf16", + ) -> "MLXDFlashDrafter": + """Load from the checkpoint. ``compute_dtype``: + + * ``"bf16"`` (default, shipping config) — keep checkpoint dtype; + * ``"fp32"`` — cast weights up, matching the torch reference's + float32 execution. Used by the parity gate to separate port + bugs (would still mismatch) from dtype-induced near-tie argmax + flips (vanish at fp32-vs-fp32). + """ + mx = _mx() + if compute_dtype not in ("bf16", "fp32"): + raise ValueError(f"unsupported compute_dtype {compute_dtype!r}") + cfg = DFlashConfig.from_pretrained(model_id_or_path) + local = Path(model_id_or_path) / "model.safetensors" + if local.is_file(): + path = str(local) + else: + from huggingface_hub import hf_hub_download + + path = hf_hub_download(model_id_or_path, "model.safetensors") + weights = mx.load(path) + if compute_dtype == "fp32": + weights = {k: v.astype(mx.float32) for k, v in weights.items()} + model = cls(cfg) + model.load_weights(weights) + return model + + # -- aux fusion + context K/V (components B) ----------------------------- + def combine_aux(self, aux_hidden_states: Sequence[Any]) -> Any: + mx = _mx() + if len(aux_hidden_states) != self.cfg.num_aux_layers: + raise ValueError( + f"expected {self.cfg.num_aux_layers} aux hidden states, got " + f"{len(aux_hidden_states)}" + ) + cat = mx.concatenate(list(aux_hidden_states), axis=-1) + if cat.shape[-1] != self.cfg.fc_in_features: + raise ValueError( + f"aux concat feature dim {cat.shape[-1]} != fc_in_features " + f"{self.cfg.fc_in_features}" + ) + return self.fc(cat.astype(self.fc.weight.dtype)) + + def precompute_context_kv(self, context_states: Any, ctx_positions: Any): + ctx_normed = self.hidden_norm( + context_states.astype(self.hidden_norm.weight.dtype)) + return [ + layer.self_attn.project_context_kv(ctx_normed, ctx_positions) + for layer in self.layers + ] + + def make_context_kv(self, aux_hidden_context: Sequence[Any], positions: Any): + ctx_states = self.combine_aux(aux_hidden_context) + return self.precompute_context_kv(ctx_states, positions) + + @staticmethod + def extend_context_kv(ctx_kv, new_kv): + mx = _mx() + out = [] + for (ck, cv), (nk, nv) in zip(ctx_kv, new_kv): + out.append(( + mx.concatenate([ck, nk.astype(ck.dtype)], axis=2), + mx.concatenate([cv, nv.astype(cv.dtype)], axis=2), + )) + return out + + # -- drafting ------------------------------------------------------------- + def _run_layers(self, hidden: Any, query_positions: Any, ctx_kv) -> Any: + for layer, (ck, cv) in zip(self.layers, ctx_kv): + hidden = layer(hidden, query_positions, ck, cv) + return self.norm(hidden) + + def draft_block_ids( + self, + ctx_kv, + bonus_id_mx: Any, + embed_fn: Callable[[Any], Any], + lm_head_fn: Callable[[Any], Any], + *, + n_masks: int, + context_len: int, + ) -> Any: + """LAZY draft: ``[bonus, mask×n_masks]`` → mx ``[n_masks]`` draft ids. + + Nothing is evaluated and nothing crosses to python — the returned + ids feed the verifier forward inside the same lazy graph (lever ② + of the single-sync block loop). ``bonus_id_mx`` is an mx scalar + (e.g. ``mx.argmax(next_token_logits)``). + """ + mx = _mx() + cfg = self.cfg + mask_ids = mx.full((n_masks,), cfg.mask_token_id, dtype=bonus_id_mx.dtype) + query_ids = mx.concatenate([bonus_id_mx[None], mask_ids])[None] + query_positions = mx.arange(context_len, context_len + 1 + n_masks) + h = embed_fn(query_ids).astype(self.fc.weight.dtype) + h = self._run_layers(h, query_positions, ctx_kv) + logits = lm_head_fn(h) # [1, 1+n_masks, vocab] + vocab = logits.shape[-1] + never_mask = mx.arange(vocab) == cfg.mask_token_id + logits = mx.where(never_mask, mx.array(-float("inf")), logits) + return mx.argmax(logits[0, 1:1 + n_masks], axis=-1) + + def draft_block_cached( + self, + ctx_kv, + bonus_token_id: int, + embed_fn: Callable[[Any], Any], + lm_head_fn: Callable[[Any], Any], + *, + block_size: int, + context_len: int, + ) -> List[int]: + """Single non-causal pass over ``[bonus, mask×block_size]`` against the + cached context K/V → ``block_size`` draft token ids (materialised). + Compatibility surface for the generic fused loop / parity gate; the + single-sync loop uses :meth:`draft_block_ids` instead.""" + mx = _mx() + drafts = self.draft_block_ids( + ctx_kv, mx.array(int(bonus_token_id), dtype=mx.uint32), + embed_fn, lm_head_fn, + n_masks=block_size, context_len=context_len, + ) + mx.eval(drafts) + return [int(t) for t in drafts.tolist()] + + +def make_native_embed_lm_head( + text_model: Any, *, softcap: float | None = None, +) -> Tuple[Callable[[Any], Any], Callable[[Any], Any]]: + """All-MLX ``(embed_fn, lm_head_fn)`` over the verifier's weights. + + Same semantics as ``make_bridge_embed_lm_head`` (Gap-B: plain lookup, + NO ``×sqrt(hidden)``; tied-embed logits + softcapping) minus the + mx↔torch conversions that made the hybrid path 0.028× AR. + """ + mx = _mx() + + def embed_fn(query_ids: Any) -> Any: + return text_model.embed_tokens(query_ids) # no embed_scale (Gap-B) + + def lm_head_fn(h: Any) -> Any: + out = text_model.embed_tokens.as_linear(h) + if softcap: + out = softcap * mx.tanh(out / softcap) + return out + + return embed_fn, lm_head_fn diff --git a/inference_engine/backends/mlx/fused_specdecode.py b/inference_engine/backends/mlx/fused_specdecode.py new file mode 100644 index 00000000..47e5febc --- /dev/null +++ b/inference_engine/backends/mlx/fused_specdecode.py @@ -0,0 +1,764 @@ +"""MLX port of the #107 fused DFlash spec-decode engine (Components A+B+C). + +Hybrid runtime: the **verifier is MLX** (Gemma-4 26B-A4B, 4-bit) and the +**DFlash drafter + f_θ are PyTorch** (MPS/CPU), bridged at the K/V-injection and +aux-hidden boundaries (one bridge per block, never a re-forward). + +This mirrors ``scripts/research/k3_specdecode_gpu_bench.py:restored_specdecode_fused`` +(CUDA) per-block O(L): + +* **C (Gap-A)** — incremental restored verify: prefill captures restored K/V + into the model's native hybrid cache (full-attn = exact own K/V, S5 → recall; + sliding = f_θ-restored, window-bounded); each block verifies the candidate + tokens against that cache and is rolled back on rejection via mlx_lm's native + ``trim_prompt_cache`` (the same primitive mlx_lm's own spec-decode uses). +* **B** — drafter context K/V cache: built once from the prompt's aux hidden, + then EXTENDED with each committed token's aux (no O(C) recompute per block). +* **A** — the committed tokens' aux hidden are captured FROM the verify forward + (by patching the Gemma-4 decoder-layer ``__call__`` to record its output), so + there is no separate per-block clean-aux forward. + +The MLX-execution paths (forward/inject/aux-capture/trim) require Apple Silicon +and are validated on a Mac by ``k3_integrated_niah_eval_mac.py --fused-specdecode``; +the pure control flow (``fused_specdecode_generate`` loop, the verifier adapter's +prefill/verify/commit/truncate sequencing) is unit-tested on Linux with fakes. +""" + +from __future__ import annotations + +import contextlib +import time +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + resolve_mlx_text_model, + restored_prefill_cache, +) + + +# --------------------------------------------------------------------------- # +# Component A: capture verifier aux-layer hidden states (no transformers +# `output_hidden_states` on MLX → patch the decoder-layer __call__). +# --------------------------------------------------------------------------- # +@contextlib.contextmanager +def _patched_decoder_layers(text_model: Any): + """Enable aux hidden capture on Gemma-4 decoder layers. + + Patched local MLX-LM exposes a lightweight in-layer tap via + ``_kakeya_aux_sink``. Use that when available so every verify forward stays + on the normal ``DecoderLayer.__call__`` implementation. Fall back to the old + class-level wrapper for unpatched MLX-LM installs. + """ + if not text_model.layers: + yield + return + if bool(getattr(text_model.layers[0], "_kakeya_native_aux_tap", False)): + try: + yield + finally: + for layer in text_model.layers: + if hasattr(layer, "_kakeya_aux_sink"): + delattr(layer, "_kakeya_aux_sink") + if hasattr(layer, "_aux_record"): + delattr(layer, "_aux_record") + return + layer_cls = type(text_model.layers[0]) + orig_call = layer_cls.__call__ + + def dispatch(self, *args, **kwargs): + out = orig_call(self, *args, **kwargs) # (h, shared_kv, offset) + rec = getattr(self, "_aux_record", None) + if rec is not None: + rec[int(self.layer_idx)] = out[0] + return out + + layer_cls.__call__ = dispatch # type: ignore[assignment] + try: + yield + finally: + layer_cls.__call__ = orig_call # type: ignore[assignment] + for layer in text_model.layers: + if hasattr(layer, "_aux_record"): + delattr(layer, "_aux_record") + + +def _build_aux(text_model: Any, ids_mx: Any, sink: Dict[int, Any], + embed_scale: float, aux_layer_ids: Sequence[int]) -> List[Any]: + """Assemble a transformers-style ``hidden_states`` list and index it. + + ``hs[0]`` = scaled token embeddings; ``hs[k]`` = output of decoder layer + ``k-1`` (so ``hs[a]`` matches HF ``output_hidden_states[a]`` = input to + layer ``a`` = output of layer ``a-1``). Returns ``[hs[a] for a in + aux_layer_ids]``, each ``mx [1, L, hidden]``. + """ + embeds = text_model.embed_tokens(ids_mx) + embeds = embeds * embed_scale + n = len(text_model.layers) + hs = [embeds] + [sink[i] for i in range(n)] + return [hs[a] for a in aux_layer_ids] + + +def capture_aux_hidden( + mlx_model: Any, + input_ids: Sequence[int], + aux_layer_ids: Sequence[int], + *, + embed_scale: float, +) -> List[Any]: + """Clean (no-cache) forward capturing the verifier's aux-layer hidden over + ``input_ids``. Returns ``[mx [1, T, hidden]]`` for the prompt; used to seed + the drafter context K/V cache (Component B).""" + import mlx.core as mx # type: ignore + + text_model = resolve_mlx_text_model(mlx_model) + sink: Dict[int, Any] = {} + with _patched_decoder_layers(text_model): + for layer in text_model.layers: + layer._kakeya_aux_sink = sink + layer._aux_record = sink + ids = mx.array([list(input_ids)]) + _ = mlx_model(ids) + aux = _build_aux(text_model, ids, sink, embed_scale, aux_layer_ids) + mx.eval(aux) + return aux + + +# --------------------------------------------------------------------------- # +# Component C: incremental restored verifier (MLX analog of +# CrossModelRestoredSinkWindowVerifier(incremental=True)) with aux capture (A). +# --------------------------------------------------------------------------- # +class MLXRestoredIncrementalVerifier: + """Stateful MLX restored verifier for the fused spec-decode loop. + + ``prefill`` builds the restored cache (Gap-A) and the first-token logits; + ``forward_block`` verifies a candidate block incrementally (and, when + ``_capture_aux``, records the per-token aux hidden bridged to torch); + ``commit_or_truncate`` rolls the cache back by the rejected count via + ``mlx_lm.trim_prompt_cache``; ``append_token`` commits the correction. + """ + + def __init__( + self, + mlx_model: Any, + *, + embed_scale: float, + aux_layer_ids: Sequence[int] = (), + bridge_to_torch: Optional[Callable[[Any], Any]] = None, + ) -> None: + self.mlx_model = mlx_model + self.text_model = resolve_mlx_text_model(mlx_model) + self.embed_scale = float(embed_scale) + self.aux_layer_ids = tuple(int(a) for a in aux_layer_ids) + self._bridge = bridge_to_torch + self._cache: Any = None + self._past_len = 0 + self.next_token_logits: Any = None + self._last_aux: Optional[List[Any]] = None + self._last_aux_mx: Optional[List[Any]] = None + self._capture_aux = False + self._block_snapshot: Optional[List[Dict[str, Any]]] = None + self._full_kv = False + + def reset(self) -> None: + self._cache = None + self._past_len = 0 + self.next_token_logits = None + self._last_aux = None + self._last_aux_mx = None + self._block_snapshot = None + + def prefill( + self, + prompt_ids: Sequence[int], + *, + restored_k_per_layer: Dict[int, Any], + restored_v_per_layer: Dict[int, Any], + evicted_positions: Sequence[int], + prefill_chunk_size: int = 0, + full_kv: bool = False, + ) -> None: + if not prompt_ids: + raise ValueError("prompt_ids must be non-empty") + self.reset() + # full_kv=True → all-`KVCache` layout so accept/reject rollback can use + # SOUND native trim (keep accepted, drop rejected) with no re-forward. + self._full_kv = bool(full_kv) + factory = make_full_kv_prompt_cache if full_kv else None + self._cache, self.next_token_logits = restored_prefill_cache( + self.mlx_model, list(prompt_ids), + restored_k_per_layer=restored_k_per_layer, + restored_v_per_layer=restored_v_per_layer, + evicted_positions=evicted_positions, + prefill_chunk_size=prefill_chunk_size, + cache_factory=factory, + ) + self._past_len = len(prompt_ids) + + def forward_block(self, tokens: Sequence[int]) -> Any: + """Incremental verify of ``tokens`` against the restored cache. Returns + ``mx [len(tokens), V]`` logits; captures aux hidden states in MX first + and bridges to torch lazily via :meth:`last_aux_torch_slice`.""" + import mlx.core as mx # type: ignore + + if self._cache is None: + raise RuntimeError("verifier not prefilled") + if not tokens: + raise ValueError("tokens must be non-empty") + ids = mx.array([list(tokens)]) + want_aux = self._capture_aux and bool(self.aux_layer_ids) + if want_aux: + sink = {} + with _patched_decoder_layers(self.text_model): + for layer in self.text_model.layers: + layer._kakeya_aux_sink = sink + layer._aux_record = sink + logits = self.mlx_model(ids, cache=self._cache) + aux = _build_aux(self.text_model, ids, sink, + self.embed_scale, self.aux_layer_ids) + mx.eval(logits, *aux) + self._last_aux_mx = [a[0] for a in aux] # [L, hidden] each, mx + self._last_aux = None + else: + logits = self.mlx_model(ids, cache=self._cache) + mx.eval(logits) + self._last_aux = None + self._last_aux_mx = None + return logits[0] + + def last_aux_torch_slice(self, start: int = 0, end: Optional[int] = None) -> List[Any]: + """Bridge a token slice from the last captured aux hidden states.""" + if self._last_aux_mx is None: + raise RuntimeError("aux hidden not captured for the last forward_block") + bridge = self._bridge or (lambda a: a) + return [bridge(a[start:end]) for a in self._last_aux_mx] + + def rollback_block(self) -> None: + """O(1) full rollback of the last ``forward_block_lazy`` call. + + ``trim_prompt_cache`` is NOT a valid rollback on Gemma-4's hybrid + cache once the sliding-window RotatingKVCache has wrapped + (seq >> 512): rejected draft K/V linger in the ring and poison + subsequent logits (observed live as stream divergence vs greedy + + acceptance collapse; retroactively explains iterC's 23-token + sample and the eager loop's silent post-answer divergence). MLX + arrays are immutable, so a snapshot is just attribute references; + restore rebinds them. + """ + if self._block_snapshot is None: + raise RuntimeError("no block snapshot to roll back to") + for c, snap in zip(self._cache, self._block_snapshot): + for attr, val in snap.items(): + setattr(c, attr, val) + + def forward_block_lazy(self, ids_mx: Any) -> Any: + """LAZY incremental verify: ``ids_mx`` is an mx ``[1, L]`` (typically + the in-graph concatenation of the carried bonus + lazy draft ids — + lever ② of the single-sync loop). Returns ``mx [L, V]`` logits with + NO evaluation; aux hidden (when ``_capture_aux``) stays lazy in + ``_last_aux_mx`` and is consumed lazily by the drafter-context + extension.""" + if self._cache is None: + raise RuntimeError("verifier not prefilled") + # Reference snapshot of every per-layer cache state (immutable mx + # arrays → O(layers) attribute refs) for rollback_block(). + self._block_snapshot = [ + {attr: getattr(c, attr) + for attr in ("keys", "values", "offset", "_idx") + if hasattr(c, attr)} + for c in self._cache + ] + want_aux = self._capture_aux and bool(self.aux_layer_ids) + if want_aux: + sink: Dict[int, Any] = {} + with _patched_decoder_layers(self.text_model): + for layer in self.text_model.layers: + layer._kakeya_aux_sink = sink + layer._aux_record = sink + logits = self.mlx_model(ids_mx, cache=self._cache) + aux = _build_aux(self.text_model, ids_mx, sink, + self.embed_scale, self.aux_layer_ids) + self._last_aux_mx = [a[0] for a in aux] # [L, hidden] each, lazy + self._last_aux = None + else: + logits = self.mlx_model(ids_mx, cache=self._cache) + self._last_aux = None + self._last_aux_mx = None + return logits[0] + + def commit_or_truncate(self, *, forwarded: int, accepted: int) -> None: + if accepted < 0 or accepted > forwarded: + raise ValueError("accepted must satisfy 0 <= accepted <= forwarded") + drop = forwarded - accepted + if drop > 0 and self._cache is not None: + from mlx_lm.models.cache import trim_prompt_cache # type: ignore + trim_prompt_cache(self._cache, drop) + self._past_len += accepted + + def append_token(self, token_id: int) -> Any: + logits = self.forward_block([int(token_id)]) + self.commit_or_truncate(forwarded=1, accepted=1) + self.next_token_logits = logits[-1] + return self.next_token_logits + + +# --------------------------------------------------------------------------- # +# Bridge embed / lm_head for the hybrid path (drafter torch ↔ verifier MLX). +# --------------------------------------------------------------------------- # +def make_bridge_embed_lm_head( + text_model: Any, + *, + mx_to_torch: Callable[..., Any], + torch_to_mx: Callable[..., Any], + device: Any, + torch_dtype: Any, + softcap: Optional[float] = None, +) -> Tuple[Callable[[Any], Any], Callable[[Any], Any]]: + """Return ``(embed_fn, lm_head_fn)`` over the MLX verifier weights for the + PyTorch drafter. + + * ``embed_fn`` — **Gap-B**: a *plain* shared-embedding lookup with **no + ``×sqrt(hidden)`` scaling** (the drafting query embedding bug fixed in + #107). Returns torch ``[*, hidden]``. + * ``lm_head_fn`` — tied-embedding logits + Gemma ``final_logit_softcapping`` + (monotonic; preserves argmax). Returns torch ``[*, vocab]``. + """ + import mlx.core as mx # type: ignore + + def embed_fn(query_ids: Any) -> Any: + ids_mx = mx.array(query_ids.detach().to("cpu").tolist()) + emb = text_model.embed_tokens(ids_mx) # NO * embed_scale (Gap-B) + return mx_to_torch(emb, dtype=torch_dtype, device=device) + + def lm_head_fn(h: Any) -> Any: + h_mx = torch_to_mx(h) + out = text_model.embed_tokens.as_linear(h_mx) + if softcap: + out = softcap * mx.tanh(out / softcap) + return mx_to_torch(out, dtype=torch_dtype, device=device) + + return embed_fn, lm_head_fn + + +# --------------------------------------------------------------------------- # +# Single-sync all-MLX fused loop (levers ① ② ③ of the Step-2 throughput plan; +# docs/mlx-port-lessons.md "Step-2 rescue status"). +# --------------------------------------------------------------------------- # +def make_full_kv_prompt_cache(mlx_model: Any) -> List[Any]: + """Build a prompt cache that uses a full append-only ``KVCache`` for EVERY + layer (including the sliding-attention ones, which the model's native + ``make_cache`` would give a ``RotatingKVCache``). + + Why: ``RotatingKVCache`` is not trimmable once the ring has wrapped + (``is_trimmable`` → ``offset < max_size``), so spec-decode accept/reject + rollback cannot keep the accepted K/V via a cheap trim — it must re-forward + (the v3 carry penalty). With an all-``KVCache`` layout, ``trim_prompt_cache`` + is a sound O(1) slice on every layer, so the loop keeps accepted K/V and + drops only the rejected tail (CUDA `DynamicCache` parity). Sliding attention + remains byte-exact because the per-layer window mask is applied regardless + of cache capacity; the only cost is O(T) sliding KV during decode. + """ + from mlx_lm.models.cache import make_prompt_cache, KVCache # type: ignore + + n = len(make_prompt_cache(mlx_model)) + return [KVCache() for _ in range(n)] + + +def fused_specdecode_generate_mlx_trim( + adapter: "MLXRestoredIncrementalVerifier", + drafter: Any, + *, + aux_prompt: Sequence[Any], + embed_fn: Callable[[Any], Any], + lm_head_fn: Callable[[Any], Any], + gen_tokens: int, + block_size: int, + eos_ids: Sequence[int] = (), + single_fused: bool = False, +) -> Dict[str, Any]: + """CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected + tail (no rollback, no carry re-forward). Requires the adapter to be + prefilled with ``full_kv=True`` (all-``KVCache`` layout) so the native + ``trim_prompt_cache`` is sound. Levers ①②③ retained (lazy draft+verify + single graph, in-graph cumprod acceptance, carried correction). + + Per block: forward ``[bonus + drafts]`` (L tokens) → cache = base+L; accept + the leading match count ``k`` (bonus always accepts); ``trim_prompt_cache`` + drops the L−k rejected tokens; advance ``_past_len`` by ``k``. The accepted + tokens' K/V (computed in this forward) stay in the cache — never recomputed. + """ + import mlx.core as mx # type: ignore + from mlx_lm.models.cache import trim_prompt_cache # type: ignore + + eos = set(int(t) for t in eos_ids) + C = adapter._past_len + ctx_kv = drafter.make_context_kv(list(aux_prompt), mx.arange(0, C)) + mx.async_eval([t for kv in ctx_kv for t in kv]) + timing = {"ctx_kv_build_s": 0.0, "build_s": 0.0, "eval_s": 0.0, "extend_s": 0.0} + adapter._capture_aux = True + + generated: List[int] = [] + accepts: List[int] = [] + block_evals: List[float] = [] + ctx_len = C + try: + while len(generated) < gen_tokens: + L = min(block_size, gen_tokens - len(generated)) + base = adapter._past_len + t_build = time.perf_counter() + bonus_id = mx.argmax(adapter.next_token_logits) # lazy scalar + n_draft = max(L - 1, 0) + if n_draft: + drafts = drafter.draft_block_ids( + ctx_kv, bonus_id, embed_fn, lm_head_fn, + n_masks=n_draft, context_len=base) + check_ids = mx.concatenate([bonus_id[None], drafts]) # [L] + if not single_fused: + mx.eval(check_ids) # two-phase (drafter graph before 26B) + # single_fused=True → leave check_ids LAZY so the drafter and + # 26B verify fuse into ONE graph (the path b876 found Metal- + # pathological); this probe times it to classify the instability. + else: + check_ids = bonus_id[None] + block_logits = adapter.forward_block_lazy(check_ids[None]) # [L, V] + # in-graph greedy acceptance over the check region + pred_rows = mx.concatenate( + [adapter.next_token_logits[None], block_logits[:max(L - 1, 0)]], + axis=0) + matches = (mx.argmax(pred_rows, axis=-1) == check_ids) + accepted_mx = mx.sum(mx.cumprod(matches.astype(mx.int32))) + rows = mx.concatenate( + [adapter.next_token_logits[None], block_logits], axis=0) # [L+1,V] + next_row = mx.take(rows, accepted_mx[None], axis=0)[0] # [V] + timing["build_s"] += time.perf_counter() - t_build + t_eval = time.perf_counter() + mx.eval(accepted_mx, check_ids) + blk_eval = time.perf_counter() - t_eval + timing["eval_s"] += blk_eval + block_evals.append(round(blk_eval, 4)) + accepted = int(accepted_mx.item()) + check = [int(x) for x in check_ids.tolist()] + commit = check[:accepted] + generated += commit + accepts.append(accepted) + adapter.next_token_logits = next_row + aux_rows = adapter._last_aux_mx + # KEEP accepted (positions base..base+accepted-1), TRIM rejected. + drop = L - accepted + if drop > 0: + trim_prompt_cache(adapter._cache, drop) + adapter._past_len = base + accepted + S_new = adapter._past_len + lo, hi = ctx_len - base, S_new - base + if hi > lo and aux_rows is not None: + t_extend = time.perf_counter() + new_aux = [a[lo:hi][None] for a in aux_rows] + ctx_kv = drafter.extend_context_kv( + ctx_kv, + drafter.make_context_kv(new_aux, mx.arange(ctx_len, S_new))) + mx.async_eval([t for kv in ctx_kv for t in kv]) + ctx_len = S_new + timing["extend_s"] += time.perf_counter() - t_extend + if any(t in eos for t in commit): + break + finally: + adapter._capture_aux = False + generated = generated[:gen_tokens] + return { + "tokens": generated, + "blocks": len(accepts), + "mean_accept_len": (round(sum(accepts) / len(accepts), 3) + if accepts else 0.0), + "decode_tokens": len(generated), + "loop": ("mlx_trim_single_fused_probe" if single_fused + else "mlx_trim_keep_accepted_cuda_parity"), + "single_fused": bool(single_fused), + "block_eval_s_first8": block_evals[:8], + "block_eval_s_max": (round(max(block_evals), 4) if block_evals else None), + "block_eval_s_mean": (round(sum(block_evals) / len(block_evals), 4) + if block_evals else None), + "time_breakdown_s": {k: round(v, 3) for k, v in timing.items()}, + } + + +def fused_specdecode_generate_mlx( + adapter: "MLXRestoredIncrementalVerifier", + drafter: Any, + *, + aux_prompt: Sequence[Any], + embed_fn: Callable[[Any], Any], + lm_head_fn: Callable[[Any], Any], + gen_tokens: int, + block_size: int, + eos_ids: Sequence[int] = (), +) -> Dict[str, Any]: + """All-MLX fused spec decode with ONE host sync per block. + + * ② draft+verify single graph: the drafter's lazy draft ids + (:meth:`MLXDFlashDrafter.draft_block_ids`) are concatenated with the + carried bonus in-graph and fed straight into the verifier forward — + no draft token ever crosses to python before verification. + * ① in-graph acceptance: the leading-match count is + ``sum(cumprod(argmax(pred_rows) == candidate))``; the next-position + logits row is gathered with the lazy count (``mx.take``). The block's + single ``mx.eval`` materialises exactly three things: the accept + count, the candidate ids, and nothing else. Drafter-context + extensions are pushed with ``mx.async_eval`` so Metal works while + python does bookkeeping. + * ③ carried correction: on rejection there is NO correction forward. + ``next_token_logits`` is set to the gathered next-position row, so + the verifier's own argmax (the correction) becomes the next block's + bonus — guaranteed-accepted at position 0 of the next verify, where + its K/V and aux are computed as part of the batched forward. + + Per-block commit = the accepted candidate prefix (position 0, the + carried bonus/correction, always accepts by construction — every block + commits >= 1 token, so the loop degrades to AR pace, never below). + """ + import mlx.core as mx # type: ignore + + eos = set(int(t) for t in eos_ids) + C = adapter._past_len + t_ctx = time.perf_counter() + ctx_kv = drafter.make_context_kv(list(aux_prompt), mx.arange(0, C)) + mx.async_eval([t for kv in ctx_kv for t in kv]) + timing = { + "ctx_kv_build_s": time.perf_counter() - t_ctx, + "build_s": 0.0, # lazy graph construction (python-side) + "eval_s": 0.0, # per-block syncs (Metal compute) + "extend_s": 0.0, + } + adapter._capture_aux = True + + generated: List[int] = [] + accepts: List[int] = [] + # Rollback-carry state: rejected blocks roll the WHOLE forward back + # (rollback_block — see its docstring for why trim is unsound on the + # wrapped sliding ring) and carry the stream-committed-but-not-cached + # tokens (`tail`) into the next candidate, where they are guaranteed + # re-accepted and their K/V + aux recomputed correctly. + tail: List[int] = [] + tail_logits = adapter.next_token_logits # row predicting position S + ctx_len = C # drafter context coverage + try: + while len(generated) < gen_tokens: + L = min(block_size, gen_tokens - len(generated)) + base_fwd = adapter._past_len # cache offset at forward start + S = base_fwd + len(tail) # committed stream length + t_build = time.perf_counter() + bonus_id = mx.argmax(tail_logits) # lazy scalar + n_draft = max(L - 1, 0) + if n_draft: + drafts = drafter.draft_block_ids( + ctx_kv, bonus_id, embed_fn, lm_head_fn, + n_masks=n_draft, context_len=S) + check_ids = mx.concatenate([bonus_id[None], drafts]) # [L] + # Two-phase evaluation: materialise the drafter graph + # BEFORE building the 26B verify graph. A single fused + # drafter+verifier graph proved pathological on Metal + # (command-buffer blowups: 143 s evals in the first live + # run); two small syncs per block are stable and still + # ~3× fewer than the eager loop's 6+L. + mx.eval(check_ids) + else: + check_ids = bonus_id[None] + if tail: + cand_full = mx.concatenate( + [mx.array(tail, dtype=check_ids.dtype), check_ids]) + else: + cand_full = check_ids + k = len(tail) + block_logits = adapter.forward_block_lazy(cand_full[None]) # [k+L, V] + # In-graph greedy acceptance over the CHECK region only + # (the carried tail is already stream-committed): row i of + # pred_rows predicts check_ids[i]; leading-match via cumprod. + pred_rows = mx.concatenate( + [tail_logits[None], block_logits[k:k + L - 1]], axis=0) + matches = (mx.argmax(pred_rows, axis=-1) == check_ids) + accepted_mx = mx.sum(mx.cumprod(matches.astype(mx.int32))) + rows = mx.concatenate( + [tail_logits[None], block_logits[k:]], axis=0) # [L+1, V] + next_row = mx.take(rows, accepted_mx[None], axis=0)[0] # [V] + timing["build_s"] += time.perf_counter() - t_build + t_eval = time.perf_counter() + mx.eval(accepted_mx, check_ids) + timing["eval_s"] += time.perf_counter() - t_eval + accepted = int(accepted_mx.item()) + check = [int(x) for x in check_ids.tolist()] + commit = check[:accepted] + generated += commit + accepts.append(accepted) + tail_logits = next_row + adapter.next_token_logits = next_row + aux_rows = adapter._last_aux_mx # rows for positions base_fwd..base_fwd+k+L + if accepted == L: + # Whole forward (tail + check region) is now valid cache. + adapter._past_len = base_fwd + k + L + tail = [] + else: + adapter.rollback_block() # cache back to base_fwd + tail = tail + commit # re-verified next block + S_new = adapter._past_len + len(tail) + # Extend the drafter context with aux rows for newly committed + # positions (ctx_len..S_new). Accepted rows are correct even + # after rollback: causal attention means rejected positions + # only sit AFTER them in the forward. + lo, hi = ctx_len - base_fwd, S_new - base_fwd + if hi > lo and aux_rows is not None: + t_extend = time.perf_counter() + new_aux = [a[lo:hi][None] for a in aux_rows] + ctx_kv = drafter.extend_context_kv( + ctx_kv, + drafter.make_context_kv(new_aux, mx.arange(ctx_len, S_new))) + mx.async_eval([t for kv in ctx_kv for t in kv]) + ctx_len = S_new + timing["extend_s"] += time.perf_counter() - t_extend + if any(t in eos for t in commit): + break + finally: + adapter._capture_aux = False + generated = generated[:gen_tokens] + return { + "tokens": generated, + "blocks": len(accepts), + "mean_accept_len": (round(sum(accepts) / len(accepts), 3) + if accepts else 0.0), + "decode_tokens": len(generated), + "loop": "mlx_rollback_carry_v3", + "time_breakdown_s": {k: round(v, 3) for k, v in timing.items()}, + } + + +# --------------------------------------------------------------------------- # +# The fused spec-decode loop (control flow; MLX/torch ops via injected fns). +# --------------------------------------------------------------------------- # +def fused_specdecode_generate( + adapter: Any, + drafter: Any, + *, + aux_prompt: Sequence[Any], + embed_fn: Callable[[Any], Any], + lm_head_fn: Callable[[Any], Any], + gen_tokens: int, + block_size: int, + eos_ids: Sequence[int] = (), + argmax_fn: Callable[[Any], int], + arange_fn: Callable[[int, int], Any], + cat_aux_fn: Callable[[Sequence[Any]], Any], + allow_greedy_fallback: bool = True, +) -> Dict[str, Any]: + """Run the fused engine. ``adapter`` must already be prefilled. Per block: + draft from the cached drafter context (B), verify+capture-aux incrementally + (C+A), accept the longest correct prefix, commit the correction, and EXTEND + the drafter context with the committed tokens' aux. + + ``argmax_fn`` (logits-row → int), ``arange_fn`` (start, stop → positions), + and ``cat_aux_fn`` (parts → ``[1, k, hidden]``) abstract the MLX/torch ops so + the loop is unit-testable. + """ + n_aux = len(aux_prompt) + eos = set(int(t) for t in eos_ids) + C = adapter._past_len + t_ctx = time.perf_counter() + ctx_kv = drafter.make_context_kv(list(aux_prompt), arange_fn(0, C)) + timing = { + "ctx_kv_build_s": time.perf_counter() - t_ctx, + "draft_s": 0.0, + "verify_s": 0.0, + "append_s": 0.0, + "extend_s": 0.0, + "fallback_greedy_s": 0.0, + } + adapter._capture_aux = True + + generated: List[int] = [] + accepts: List[int] = [] + fallback_to_greedy = False + try: + while len(generated) < gen_tokens: + L = min(block_size, gen_tokens - len(generated)) + cstart = adapter._past_len + bonus = int(argmax_fn(adapter.next_token_logits)) + t_draft = time.perf_counter() + drafts = drafter.draft_block_cached( + ctx_kv, bonus, embed_fn, lm_head_fn, + block_size=max(L - 1, 1), context_len=cstart) + timing["draft_s"] += time.perf_counter() - t_draft + candidate = [bonus] + list(drafts[: max(L - 1, 0)]) + prev = adapter.next_token_logits + t_verify = time.perf_counter() + block_logits = adapter.forward_block(candidate) + timing["verify_s"] += time.perf_counter() - t_verify + accepted = 0 + for i in range(len(candidate)): + if int(argmax_fn(prev)) == candidate[i]: + accepted += 1 + prev = block_logits[i] + else: + break + adapter.commit_or_truncate(forwarded=len(candidate), accepted=accepted) + if accepted == len(candidate): + # The verifier cache already contains the whole accepted block. + # Reuse the last block logit as the next-token distribution instead + # of paying an extra correction-token forward. + adapter.next_token_logits = block_logits[-1] + new_positions = arange_fn(cstart, cstart + accepted) + t_extend = time.perf_counter() + cand_aux = adapter.last_aux_torch_slice(0, accepted) + # cat_aux_fn of a single part == unsqueeze(0) in the torch + # path; routing through it keeps this loop runtime-agnostic + # (the all-MLX drafter path injects an mx-based cat_aux_fn). + new_aux = [cat_aux_fn([cand_aux[li]]) for li in range(n_aux)] + ctx_kv = drafter.extend_context_kv( + ctx_kv, drafter.make_context_kv(new_aux, new_positions)) + timing["extend_s"] += time.perf_counter() - t_extend + commit = candidate + else: + correction = int(argmax_fn(prev)) + cand_aux = adapter.last_aux_torch_slice(0, accepted) + t_append = time.perf_counter() + adapter.append_token(correction) + timing["append_s"] += time.perf_counter() - t_append + corr_aux = adapter.last_aux_torch_slice(0, 1) + new_positions = arange_fn(cstart, cstart + accepted + 1) + t_extend = time.perf_counter() + new_aux = [ + cat_aux_fn([cand_aux[li], corr_aux[li]]) + for li in range(n_aux) + ] + ctx_kv = drafter.extend_context_kv( + ctx_kv, drafter.make_context_kv(new_aux, new_positions)) + timing["extend_s"] += time.perf_counter() - t_extend + commit = candidate[:accepted] + [correction] + generated += commit + accepts.append(accepted) + if any(t in eos for t in commit): + break + if (allow_greedy_fallback and len(accepts) >= 2 + and (sum(accepts) / len(accepts)) < 1.5): + fallback_to_greedy = True + break + + if allow_greedy_fallback and fallback_to_greedy and len(generated) < gen_tokens: + # Low acceptance makes speculative control flow slower than AR. Finish + # on the restored verifier cache with plain greedy decode and no aux + # capture/bridge. + adapter._capture_aux = False + t_fb = time.perf_counter() + while len(generated) < gen_tokens: + tok = int(argmax_fn(adapter.next_token_logits)) + adapter.append_token(tok) + generated.append(tok) + if tok in eos: + break + timing["fallback_greedy_s"] += time.perf_counter() - t_fb + finally: + adapter._capture_aux = False + generated = generated[:gen_tokens] + return { + "tokens": generated, + "blocks": len(accepts), + "mean_accept_len": (round(sum(accepts) / len(accepts), 3) + if accepts else 0.0), + "decode_tokens": len(generated), + "time_breakdown_s": {k: round(v, 3) for k, v in timing.items()}, + } diff --git a/inference_engine/bench/k3_report_gate.py b/inference_engine/bench/k3_report_gate.py new file mode 100644 index 00000000..046b9192 --- /dev/null +++ b/inference_engine/bench/k3_report_gate.py @@ -0,0 +1,357 @@ +"""K3 Mac evidence gate — machine-checkable report constraints. + +Born from the PR #109 review of the CUDA→MLX port evidence. The +committed Mac reports exhibited four failure modes that a human had to +catch by reading raw JSON: + +1. Reports labelled ``free_gen_fused_specdecode`` in which the fused + spec-decode engine executed **zero blocks** (silent greedy + fallback) — the system under test never ran. +2. A "cross vs oracle" speedup (2.584×) where the cross arm was the + **native-cache baseline itself** (adaptive bypass), i.e. the system + was compared against itself and the ratio was run-order noise + (oracle prefill varied 35–146 s for the identical computation). +3. Headline throughput derived from n=1 / gen=8 smokes whose wall + time was ~95 % prefill. +4. An analytical S5 memory table (``sink_window=68``) attached to a + run that actually used the un-trimmed native cache. + +Every one of those is now a hard, mechanical rule. The Mac harness +(``scripts/research/k3_integrated_niah_eval_mac.py``) runs +:func:`validate_report` on its own output and exits non-zero on +violation; CI (``scripts/validate_k3_reports.py``) re-validates every +committed report so non-conforming evidence cannot land on a branch +silently. Reports with ``schema_version < 2`` predate the gate and are +grandfathered as **non-evidence** (CI prints them as legacy warnings). + +This module is deliberately dependency-free (stdlib only) so the CI +step and the Mac harness share the exact same rule implementation. + +Rule codes +---------- + +================================ ============================================ +``LEGACY_SCHEMA`` schema_version < 2: pre-gate report, + grandfathered, never citable as evidence. +``MISSING_STAGE_TIMINGS`` cross arm has no per-sample stage rows. +``MISSING_RESTORATION_FLAG`` a cross stage row lacks ``restoration_active``. +``MIXED_RESTORATION_PATHS`` cross samples mix restored and native paths. +``BASELINE_AS_SUT`` a native-baseline run occupies the + system-under-test slot without declaring + ``system_under_test = "native_ar_baseline"``. +``BASELINE_RECALL_CLAIM`` a native-baseline run claims + ``gate.recall_cross_model``. +``RECALL_SCOPE`` ``recall_cross_model`` claimed but not every + cross sample ran with restoration active. +``FUSED_NEVER_RAN`` fused eval_mode with zero executed blocks. +``SPEEDUP_SELF_COMPARISON`` speedup claimed on a native-baseline run. +``SPEEDUP_SAMPLES`` speedup claimed with < MIN_PERF_SAMPLES. +``SPEEDUP_DECODE_TOKENS`` speedup claimed with median decode tokens + < MIN_MEDIAN_DECODE_TOKENS (prefill-dominated). +``SPEEDUP_DECODE_ONLY_MISSING`` headline speedup without a decode-only + median comparison alongside it. +``SPEEDUP_SCOPE_MISMATCH`` cross and oracle timing scopes differ. +``SPEEDUP_ORACLE_LOOP`` oracle decode loop is not ``generate_step`` + (hand-rolled per-token ``mx.eval`` baselines + are the documented MLX anti-pattern). +``SPEEDUP_PREFILL_VARIANCE`` within-arm prefill spread exceeds + MAX_PREFILL_SPREAD; e2e ratios are noise. +``MEMORY_CLAIM_MISMATCH`` memory savings claimed from an analytical + formula that does not describe the run. +================================ ============================================ +""" + +from __future__ import annotations + +import statistics +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence + +MAC_REPORT_KIND = "k3_integrated_niah_acceptance_mac" + +# Reports older than this schema predate the gate: grandfathered, +# treated as non-evidence (warning, not failure) by the CI walker. +GATED_SCHEMA_VERSION = 2 + +# Minimum statistical strength for any cross-vs-oracle speedup claim. +MIN_PERF_SAMPLES = 5 +MIN_MEDIAN_DECODE_TOKENS = 32 + +# Max allowed (max/min) prefill wall-time spread within one arm before +# an e2e throughput ratio is ruled noise. The ctx280 report that +# motivated the gate showed a 4.1× spread on the oracle arm. +MAX_PREFILL_SPREAD = 3.0 + +# The only oracle decode loop admissible for a headline speedup: the +# async-pipelined mlx_lm primitive. Per-token ``mx.eval`` loops are the +# anti-pattern documented in docs/mlx-port-lessons.md. +CLAIM_ORACLE_DECODE_LOOP = "generate_step" + +NATIVE_BASELINE_LABEL = "native_ar_baseline" + + +@dataclass(frozen=True) +class GateViolation: + """One violated evidence rule.""" + + code: str + message: str + + +def is_gated_report(report: Any) -> bool: + """True when ``report`` is a K3 Mac acceptance report (any schema).""" + return isinstance(report, dict) and report.get("kind") == MAC_REPORT_KIND + + +def is_legacy_report(report: Dict[str, Any]) -> bool: + """True when the report predates the evidence gate (schema < 2).""" + try: + version = int(report.get("schema_version", 1)) + except (TypeError, ValueError): + return True + return version < GATED_SCHEMA_VERSION + + +# --------------------------------------------------------------------------- +# Shared helpers (also used by the Mac harness when assembling reports) +# --------------------------------------------------------------------------- + + +def row_prefill_seconds(row: Dict[str, Any]) -> Optional[float]: + """Per-sample prefill seconds; accepts both row spellings.""" + value = row.get("prefill_s", row.get("restored_prefill_s")) + return float(value) if isinstance(value, (int, float)) else None + + +def prefill_spread(rows: Sequence[Dict[str, Any]]) -> Optional[float]: + """(max / min) prefill seconds across rows; None when undeterminable.""" + values = [ + v for v in (row_prefill_seconds(r) for r in rows) + if v is not None and v > 0 + ] + if len(values) < 2: + return None + return max(values) / min(values) + + +def decode_only_block( + cross_rows: Sequence[Dict[str, Any]], + cross_tokens: Sequence[int], + oracle_rows: Sequence[Dict[str, Any]], + oracle_tokens: Sequence[int], +) -> Optional[Dict[str, float]]: + """Decode-only median tok/s for both arms + their ratio. + + Prefill is identical machinery on both arms and noise-dominated on + Apple Silicon, so the decode-only ratio is the only throughput + comparison the gate accepts as a headline. Returns None when either + arm lacks usable (decode_s > 0, tokens > 0) samples. + """ + + def _per_sample(rows: Sequence[Dict[str, Any]], tokens: Sequence[int]) -> List[float]: + out: List[float] = [] + for row, n_tok in zip(rows, tokens): + decode_s = row.get("decode_s") + if isinstance(decode_s, (int, float)) and decode_s > 0 and n_tok > 0: + out.append(float(n_tok) / float(decode_s)) + return out + + cross = _per_sample(cross_rows, cross_tokens) + oracle = _per_sample(oracle_rows, oracle_tokens) + if not cross or not oracle: + return None + cross_median = statistics.median(cross) + # Both medians are > 0 by construction: _per_sample only admits + # samples with decode_s > 0 and tokens > 0. + oracle_median = statistics.median(oracle) + return { + "cross_median_tok_s": round(cross_median, 4), + "oracle_median_tok_s": round(oracle_median, 4), + "speedup": round(cross_median / oracle_median, 3), + } + + +def summarize_violations(violations: Sequence[GateViolation]) -> str: + """Stable multi-line rendering for logs and CI output.""" + return "\n".join(f" [{v.code}] {v.message}" for v in violations) + + +# --------------------------------------------------------------------------- +# Rules +# --------------------------------------------------------------------------- + + +def _cross_arm(report: Dict[str, Any]) -> Dict[str, Any]: + return (report.get("throughput") or {}).get("k3_cross_model") or {} + + +def _oracle_arm(report: Dict[str, Any]) -> Dict[str, Any]: + return (report.get("throughput") or {}).get("oracle_native_ar") or {} + + +def _stage_rows(arm: Dict[str, Any]) -> List[Dict[str, Any]]: + rows = arm.get("stage_timings") + return list(rows) if isinstance(rows, list) else [] + + +def validate_report(report: Dict[str, Any]) -> List[GateViolation]: + """Validate one K3 Mac report against every evidence rule. + + Returns an empty list when the report is admissible evidence. + Non-gated kinds validate trivially; legacy schemas return exactly + one ``LEGACY_SCHEMA`` violation (the CI walker downgrades that one + code to a warning — everything else fails the build). + """ + if not is_gated_report(report): + return [] + if is_legacy_report(report): + return [GateViolation( + "LEGACY_SCHEMA", + f"schema_version < {GATED_SCHEMA_VERSION}: pre-gate report; " + "grandfathered as NON-EVIDENCE (rerun with the hardened harness " + "to make claims)", + )] + + violations: List[GateViolation] = [] + cross = _cross_arm(report) + rows = _stage_rows(cross) + results = report.get("results") or {} + results_cross = results.get("k3_cross_model") or {} + gate = report.get("gate") or {} + + # --- Path identity: every sample must declare what actually ran --- + if not rows: + violations.append(GateViolation( + "MISSING_STAGE_TIMINGS", + "cross arm has no per-sample stage_timings; per-sample " + "prefill_s/decode_s/restoration_active are mandatory at schema 2", + )) + flags = [row.get("restoration_active") for row in rows] + if rows and any(flag is None for flag in flags): + violations.append(GateViolation( + "MISSING_RESTORATION_FLAG", + "one or more cross stage rows lack restoration_active", + )) + known = [bool(flag) for flag in flags if flag is not None] + all_active = bool(known) and all(known) + none_active = bool(known) and not any(known) + if known and not all_active and not none_active: + violations.append(GateViolation( + "MIXED_RESTORATION_PATHS", + "cross samples mix restored and native paths in one report", + )) + + # --- A baseline run may never occupy the SUT slot undeclared --- + if none_active: + if results_cross.get("system_under_test") != NATIVE_BASELINE_LABEL: + violations.append(GateViolation( + "BASELINE_AS_SUT", + "no cross sample ran restoration, but the report does not " + f"declare system_under_test={NATIVE_BASELINE_LABEL!r}", + )) + if gate.get("recall_cross_model") is not None: + violations.append(GateViolation( + "BASELINE_RECALL_CLAIM", + "native-baseline run claims gate.recall_cross_model; " + "baseline recall must be reported as recall_native_baseline", + )) + + # --- Recall claims are scoped to the restored path --- + if gate.get("recall_cross_model") is not None and not (known and all_active): + violations.append(GateViolation( + "RECALL_SCOPE", + "gate.recall_cross_model is claimed but not every cross sample " + "ran with restoration_active=true", + )) + + # --- A fused report must have executed the fused engine --- + if cross.get("eval_mode") == "free_gen_fused_specdecode": + total_blocks = 0 + for row in rows: + fused = row.get("fused") or {} + blocks = fused.get("blocks") + if isinstance(blocks, (int, float)): + total_blocks += int(blocks) + if total_blocks <= 0: + violations.append(GateViolation( + "FUSED_NEVER_RAN", + "eval_mode=free_gen_fused_specdecode but the fused engine " + "executed 0 blocks across all samples (silent fallback); " + "the system under test never ran", + )) + + # --- Speedup claims: only decode-isolated, variance-controlled, + # adequately powered comparisons may carry a headline number --- + throughput = report.get("throughput") or {} + speedup = throughput.get("cross_model_speedup_vs_oracle_ar") + if speedup is not None: + oracle = _oracle_arm(report) + oracle_rows = _stage_rows(oracle) + if none_active: + violations.append(GateViolation( + "SPEEDUP_SELF_COMPARISON", + "speedup claimed on a native-baseline run: the cross arm IS " + "the oracle computation; the ratio is run-order noise", + )) + if len(rows) < MIN_PERF_SAMPLES or len(oracle_rows) < MIN_PERF_SAMPLES: + violations.append(GateViolation( + "SPEEDUP_SAMPLES", + f"speedup claimed with n cross={len(rows)} oracle=" + f"{len(oracle_rows)}; minimum is {MIN_PERF_SAMPLES} per arm", + )) + cross_tokens = results_cross.get("per_sample_decode_tokens") or [] + oracle_tokens = (results.get("oracle") or {}).get("per_sample_decode_tokens") or [] + medians = [ + statistics.median(t) for t in (cross_tokens, oracle_tokens) if t + ] + if len(medians) < 2 or min(medians) < MIN_MEDIAN_DECODE_TOKENS: + violations.append(GateViolation( + "SPEEDUP_DECODE_TOKENS", + f"speedup claimed with median decode tokens {medians or 'missing'}; " + f"minimum is {MIN_MEDIAN_DECODE_TOKENS} per arm (otherwise the " + "wall time is prefill noise, not decode throughput)", + )) + decode_only = throughput.get("decode_only") or {} + if decode_only.get("speedup") is None: + violations.append(GateViolation( + "SPEEDUP_DECODE_ONLY_MISSING", + "headline speedup present without throughput.decode_only " + "medians; prefill-inclusive ratios alone are inadmissible", + )) + if cross.get("timing_scope") != oracle.get("timing_scope"): + violations.append(GateViolation( + "SPEEDUP_SCOPE_MISMATCH", + f"cross timing_scope={cross.get('timing_scope')!r} != oracle " + f"timing_scope={oracle.get('timing_scope')!r}", + )) + if oracle.get("decode_loop") != CLAIM_ORACLE_DECODE_LOOP: + violations.append(GateViolation( + "SPEEDUP_ORACLE_LOOP", + f"oracle decode_loop={oracle.get('decode_loop')!r}; headline " + f"speedups require {CLAIM_ORACLE_DECODE_LOOP!r} (per-token " + "mx.eval loops are the documented MLX anti-pattern and " + "depress the baseline)", + )) + for arm_name, arm_rows in (("cross", rows), ("oracle", oracle_rows)): + spread = prefill_spread(arm_rows) + if spread is not None and spread > MAX_PREFILL_SPREAD: + violations.append(GateViolation( + "SPEEDUP_PREFILL_VARIANCE", + f"{arm_name} arm prefill spread {spread:.2f}× exceeds " + f"{MAX_PREFILL_SPREAD}×; e2e ratios under this variance " + "are noise — claim decode-only or control variance", + )) + + # --- Memory claims must describe the run that was measured --- + memory = report.get("memory") or {} + if memory.get("savings_vs_naive_pct") is not None: + s5 = memory.get("s5") or {} + if s5.get("formula_matches_run") is not True: + violations.append(GateViolation( + "MEMORY_CLAIM_MISMATCH", + "memory.savings_vs_naive_pct is claimed but memory.s5." + "formula_matches_run is not true: the analytical sink+window " + "table does not describe the cache the run actually used", + )) + + return violations diff --git a/inference_engine/bridge/__init__.py b/inference_engine/bridge/__init__.py new file mode 100644 index 00000000..b6e26d71 --- /dev/null +++ b/inference_engine/bridge/__init__.py @@ -0,0 +1,30 @@ +"""Mac bridge — cloud-agent access to the self-hosted Apple Silicon node. + +See ``docs/design/mac-bridge-cloud-agent-access.md``. This package holds +the platform-neutral, unit-testable core of the bridge: the preset +allowlist and the request-manifest schema (:mod:`manifest`). The +executor / client CLIs in ``scripts/mac_bridge/`` are thin wrappers +around it (CLI-plumbing coverage convention, like ``scripts/serve.py``). + +The package is the precursor of the ADR 0009 ``CAPABILITY_ROLE_TOOL`` +plane: a preset here is a tool capability a fleet node advertises; the +manifest is its typed request message (design doc §4.1). +""" + +from inference_engine.bridge.manifest import ( + BridgeRequest, + ManifestError, + Preset, + PRESETS, + build_commands, + parse_manifest, +) + +__all__ = [ + "BridgeRequest", + "ManifestError", + "Preset", + "PRESETS", + "build_commands", + "parse_manifest", +] diff --git a/inference_engine/bridge/manifest.py b/inference_engine/bridge/manifest.py new file mode 100644 index 00000000..d1224895 --- /dev/null +++ b/inference_engine/bridge/manifest.py @@ -0,0 +1,511 @@ +"""Mac-bridge preset allowlist + request-manifest schema. + +Security posture (design doc §3): the Mac executes ONLY presets defined +here, with typed, bounded parameters. No string from a manifest is ever +interpolated into a shell — :func:`build_commands` returns argv lists +that the executor passes to ``subprocess.run`` without ``shell=True``. +Machine-local facts (model paths) come from the runner's environment, +referenced here as ``${ENV:VAR}`` placeholders the executor resolves +from ``os.environ`` — never from the manifest. + +Pure stdlib so the Linux CI gate pins the allowlist semantics at 100% +coverage; the Mac executor imports exactly this module, so what CI +verifies is what the Mac enforces. +""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple + +MANIFEST_PATH = ".mac-bridge/request.json" +MANIFEST_SCHEMA_VERSION = 1 + +BRANCH_PREFIX = "mac-bridge/" + +# Parameter bounds (design doc §2.2). Deliberately conservative: the +# bridge is for evidence runs and debugging, not for monopolizing the +# single Mac with open-ended workloads. +MAX_N_SAMPLES = 50 +MAX_NEW_TOKENS = 512 +MAX_BLOCK_SIZE = 16 + +_ENV_PLACEHOLDER = re.compile(r"^\$\{ENV:([A-Z][A-Z0-9_]*)\}$") +_NONCE_RE = re.compile(r"^[a-z0-9][a-z0-9-]{3,63}$") + + +class ManifestError(ValueError): + """A bridge manifest failed validation; nothing was executed.""" + + +@dataclass(frozen=True) +class Preset: + """One allowlisted Mac workload. + + ``command_templates`` are argv lists. Tokens may be: + + - plain strings (passed through), + - ``${ENV:NAME}`` — resolved from the executor host's environment + (missing variable = hard error, no fallback), + - ``{param}`` — substituted with the validated parameter value. + """ + + name: str + description: str + command_templates: Tuple[Tuple[str, ...], ...] + timeout_minutes: int + # name -> (kind, default). kind ∈ {"int:n_samples", "int:max_new_tokens", + # "int:block_size", "path:tests"}; None default = required. + params: Mapping[str, Tuple[str, Optional[str]]] = field(default_factory=dict) + # Run the K3 evidence gate over results/research after the commands. + validate_reports: bool = False + + +def _harness_preset( + name: str, description: str, mode_flag: str, *extra_flags: str, +) -> Preset: + """The hardened-harness presets share everything but the mode flags.""" + return Preset( + name=name, + description=description, + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", mode_flag, *extra_flags, + # Evidence runs decode the full budget: without this the + # Gemma-4 stop caps decode at ~8 tokens and the + # report fails the SPEEDUP_DECODE_TOKENS gate rule. + "--ignore-turn-stop", + "--n-samples", "{n_samples}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + "--prefill-chunk-size", "512", + "--output", + f"results/research/k3_mac_bridge_{name.replace('-', '_')}.json", + ), + ), + timeout_minutes=120, + params={ + "n_samples": ("int:n_samples", "5"), + "max_new_tokens": ("int:max_new_tokens", "64"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=True, + ) + + +PRESETS: Dict[str, Preset] = { + p.name: p + for p in ( + Preset( + name="mlx-env-probe", + description="Probe Metal/MLX + mlx.distributed availability.", + command_templates=( + ( + "python3", "-c", + "from inference_engine.backends.mlx.env import " + "probe_environment; print(probe_environment().render())", + ), + ), + timeout_minutes=10, + ), + Preset( + name="mlx-backend-tests", + description="Real-mlx truth for the MLX backend test suites.", + command_templates=( + ("python3", "-m", "pytest", "tests/backends/mlx/", "-q"), + ), + timeout_minutes=45, + ), + Preset( + name="integration-tests", + description="v0.3 GA integration gate (real Qwen3-0.6B).", + command_templates=( + ("python3", "-m", "pytest", "-m", "integration", + "tests/integration/", "-q"), + ), + timeout_minutes=60, + ), + _harness_preset( + "k3-step1-incremental", + "PR #109 Step-1 evidence: incremental restored decode.", + "--incremental", + ), + _harness_preset( + "k3-step2-fused", + "PR #109 Step-2 evidence: fused engine must execute (blocks>0).", + "--fused-specdecode", + ), + _harness_preset( + "k3-native-baseline", + "Labelled native-AR baseline run (cannot claim recall/speedup).", + "--native-baseline-bypass", + ), + _harness_preset( + "k3-step2-fused-allmlx", + "Step-2 rescue evidence: fused engine with the ALL-MLX drafter " + "(zero per-block bridge crossings).", + "--fused-specdecode", + "--all-mlx-drafter", + ), + Preset( + name="k3-drafter-parity", + description="All-MLX (bf16, shipping dtype) vs torch DFlash " + "drafter token parity.", + command_templates=( + ( + "python3", "scripts/research/k3_mlx_drafter_parity.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--n-samples", "{n_samples}", + "--block-size", "{block_size}", + "--output", + "results/research/k3_mlx_drafter_parity.json", + ), + ), + timeout_minutes=60, + params={ + "n_samples": ("int:n_samples", "3"), + "block_size": ("int:block_size", "8"), + }, + ), + Preset( + name="k3-drafter-parity-fp32", + description="Port-bug discriminator: all-MLX drafter at fp32 vs " + "the fp32 torch reference must match EXACTLY.", + command_templates=( + ( + "python3", "scripts/research/k3_mlx_drafter_parity.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--mlx-dtype", "fp32", + "--n-samples", "{n_samples}", + "--block-size", "{block_size}", + "--output", + "results/research/k3_mlx_drafter_parity_fp32.json", + ), + ), + timeout_minutes=60, + params={ + "n_samples": ("int:n_samples", "3"), + "block_size": ("int:block_size", "8"), + }, + ), + Preset( + name="k3-kv-quant-eval", + description="Rate-distortion shoot-out on the full-attn K/V: " + "mlx-native affine 8/4-bit vs KakeyaLattice D4/E8, " + "with real recall per arm. Decides whether an MLX " + "port of the KL codec is justified.", + command_templates=( + ( + "python3", "scripts/research/k3_kv_quant_eval.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--n-samples", "{n_samples}", + "--max-new-tokens", "{max_new_tokens}", + "--output", "results/research/k3_kv_quant_eval.json", + ), + ), + timeout_minutes=90, + params={ + "n_samples": ("int:n_samples", "5"), + "max_new_tokens": ("int:max_new_tokens", "32"), + }, + ), + Preset( + name="k3-evidence-gate", + description="Re-validate committed K3 Mac reports on-device.", + command_templates=( + ("python3", "scripts/validate_k3_reports.py", + "results/research"), + ), + timeout_minutes=10, + ), + Preset( + name="pytest-path", + description="One pytest target under tests/ (debugging).", + command_templates=( + ("python3", "-m", "pytest", "{path}", "-q"), + ), + timeout_minutes=45, + params={"path": ("path:tests", None)}, + ), + Preset( + name="k3-fused-singlefused-probe", + description="PROBE: single-fused (one drafter+26B graph) vs two-phase, " + "to classify the Metal instability. Small (n=2, gen=16) so a " + "pathological per-block eval is bounded. Compare block_eval_s " + "vs k3-fused-allmlx-code-trim (two-phase).", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", + "--all-mlx-drafter", "--code-prompts", "--cuda-trim", + "--single-fused", + "--n-samples", "{n_samples}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + "--prefill-chunk-size", "512", + "--output", + "results/research/k3_mac_bridge_k3_fused_singlefused_probe.json", + ), + ), + timeout_minutes=60, + params={ + "n_samples": ("int:n_samples", "2"), + "max_new_tokens": ("int:max_new_tokens", "16"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), + Preset( + name="k3-fused-allmlx-code-trim", + description="CUDA-parity rollback test: all-MLX fused + --cuda-trim " + "(all-KVCache + native trim, keep accepted / drop rejected, " + "no re-forward) on the code-completion workload. Compare " + "decode-only tok/s vs k3-fused-allmlx-code (v3 carry).", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", + "--all-mlx-drafter", "--code-prompts", "--cuda-trim", + "--n-samples", "{n_samples}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + "--prefill-chunk-size", "512", + "--output", + "results/research/k3_mac_bridge_k3_fused_allmlx_code_trim.json", + ), + ), + timeout_minutes=120, + params={ + "n_samples": ("int:n_samples", "8"), + "max_new_tokens": ("int:max_new_tokens", "128"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), + Preset( + name="k3-fused-allmlx-code", + description="HONEST spec-decode throughput probe: all-MLX fused on a " + "code-completion workload (naturally-long, predictable gen " + "= the spec-decode sweet spot), natural stop. Reports " + "decode-only tok/s (fused vs oracle AR) + acceptance.", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", + "--all-mlx-drafter", "--code-prompts", + # natural stop (no --ignore-turn-stop); code finishes itself + "--n-samples", "{n_samples}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + "--prefill-chunk-size", "512", + "--output", + "results/research/k3_mac_bridge_k3_fused_allmlx_code.json", + ), + ), + timeout_minutes=120, + params={ + "n_samples": ("int:n_samples", "8"), + "max_new_tokens": ("int:max_new_tokens", "128"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), + Preset( + name="k3-fused-allmlx-natural", + description="Acceptance probe: all-MLX fused, NATURAL stop (no " + "--ignore-turn-stop) so generation ends at the real " + "answer. Compare mean_accept_len vs the forced " + "k3-step2-fused-allmlx (which over-generates).", + command_templates=( + ( + "python3", "scripts/research/k3_integrated_niah_eval_mac.py", + "--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}", + "--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}", + "--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}", + "--s5-exact-full-attn", "--fused-specdecode", + "--all-mlx-drafter", + # deliberately NO --ignore-turn-stop (natural stop) + "--n-samples", "{n_samples}", + "--max-new-tokens", "{max_new_tokens}", + "--block-size", "{block_size}", + "--prefill-chunk-size", "512", + "--output", + "results/research/k3_mac_bridge_k3_fused_allmlx_natural.json", + ), + ), + timeout_minutes=120, + params={ + "n_samples": ("int:n_samples", "5"), + "max_new_tokens": ("int:max_new_tokens", "48"), + "block_size": ("int:block_size", "4"), + }, + validate_reports=False, + ), + ) +} + + +@dataclass(frozen=True) +class BridgeRequest: + """A validated bridge request (the parsed manifest).""" + + preset: Preset + params: Mapping[str, str] + ref: str + requested_by: str + nonce: str + + @property + def branch_name(self) -> str: + return f"{BRANCH_PREFIX}{self.preset.name}-{self.nonce}" + + def to_manifest_dict(self) -> Dict[str, Any]: + return { + "schema_version": MANIFEST_SCHEMA_VERSION, + "preset": self.preset.name, + "params": dict(self.params), + "ref": self.ref, + "requested_by": self.requested_by, + "nonce": self.nonce, + } + + +def _validate_param(name: str, kind: str, raw: str) -> str: + if kind.startswith("int:"): + try: + value = int(raw) + except (TypeError, ValueError): + raise ManifestError(f"param {name}={raw!r} is not an integer") + bound = { + "int:n_samples": MAX_N_SAMPLES, + "int:max_new_tokens": MAX_NEW_TOKENS, + "int:block_size": MAX_BLOCK_SIZE, + }[kind] + if not (1 <= value <= bound): + raise ManifestError( + f"param {name}={value} out of bounds [1, {bound}]") + return str(value) + # kind == "path:tests": repo-relative path under tests/, no traversal. + if not isinstance(raw, str) or not raw: + raise ManifestError(f"param {name} must be a non-empty string") + if raw.startswith(("/", "~")) or ".." in raw.split("/"): + raise ManifestError( + f"param {name}={raw!r} must be repo-relative without traversal") + if not (raw == "tests" or raw.startswith("tests/")): + raise ManifestError( + f"param {name}={raw!r} must resolve under tests/") + return raw + + +def parse_manifest(data: Any) -> BridgeRequest: + """Validate a decoded manifest dict into a :class:`BridgeRequest`. + + Raises :class:`ManifestError` on any deviation — unknown preset, + unknown/missing/out-of-bounds params, malformed nonce. Nothing about + a rejected manifest reaches a subprocess. + """ + if not isinstance(data, dict): + raise ManifestError("manifest must be a JSON object") + if data.get("schema_version") != MANIFEST_SCHEMA_VERSION: + raise ManifestError( + f"unsupported manifest schema_version={data.get('schema_version')!r}" + f" (expected {MANIFEST_SCHEMA_VERSION})") + preset_name = data.get("preset") + preset = PRESETS.get(preset_name) if isinstance(preset_name, str) else None + if preset is None: + raise ManifestError( + f"unknown preset {preset_name!r}; allowlist: {sorted(PRESETS)}") + + raw_params = data.get("params") or {} + if not isinstance(raw_params, dict): + raise ManifestError("params must be an object") + unknown = sorted(set(raw_params) - set(preset.params)) + if unknown: + raise ManifestError( + f"preset {preset.name!r} does not accept params: {unknown}") + params: Dict[str, str] = {} + for name, (kind, default) in preset.params.items(): + raw = raw_params.get(name, default) + if raw is None: + raise ManifestError( + f"preset {preset.name!r} requires param {name!r}") + params[name] = _validate_param(name, kind, str(raw)) + + nonce = data.get("nonce") + if not isinstance(nonce, str) or not _NONCE_RE.match(nonce): + raise ManifestError( + "nonce must match [a-z0-9][a-z0-9-]{3,63} (got " + f"{nonce!r})") + + ref = data.get("ref") + if not isinstance(ref, str) or not ref: + raise ManifestError("ref must be a non-empty string") + + requested_by = data.get("requested_by") + if not isinstance(requested_by, str) or not requested_by: + raise ManifestError("requested_by must be a non-empty string") + + return BridgeRequest( + preset=preset, + params=params, + ref=ref, + requested_by=requested_by, + nonce=nonce, + ) + + +def parse_manifest_text(text: str) -> BridgeRequest: + """Parse + validate a manifest from its JSON text.""" + try: + data = json.loads(text) + except json.JSONDecodeError as exc: + raise ManifestError(f"manifest is not valid JSON: {exc}") from exc + return parse_manifest(data) + + +def build_commands( + request: BridgeRequest, env: Mapping[str, str], +) -> List[List[str]]: + """Resolve a validated request into concrete argv lists. + + ``${ENV:NAME}`` tokens resolve from ``env``; a missing variable is a + hard :class:`ManifestError` (no fallback — the runner must be + configured per docs/ops/mac-m4-runner-setup.md). ``{param}`` tokens + substitute already-validated parameter values. Output is ready for + ``subprocess.run(argv)`` with no shell anywhere. + """ + commands: List[List[str]] = [] + for template in request.preset.command_templates: + argv: List[str] = [] + for token in template: + env_match = _ENV_PLACEHOLDER.match(token) + if env_match: + var = env_match.group(1) + if var not in env or not env[var]: + raise ManifestError( + f"preset {request.preset.name!r} needs runner env " + f"{var} (see docs/ops/mac-m4-runner-setup.md)") + argv.append(env[var]) + elif token.startswith("{") and token.endswith("}"): + argv.append(request.params[token[1:-1]]) + else: + argv.append(token) + commands.append(argv) + return commands diff --git a/results/research/k3_mac_bridge_k3_step1_incremental.json b/results/research/k3_mac_bridge_k3_step1_incremental.json new file mode 100644 index 00000000..dc6a6b00 --- /dev/null +++ b/results/research/k3_mac_bridge_k3_step1_incremental.json @@ -0,0 +1,576 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "native_baseline_bypass": false, + "block_size": 4, + "verifier_path": "/Users/fluffy314/kakeya-models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 5, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 64, + "prefill_chunk_size": 512, + "decode_warmup_tokens": 1, + "direct_answer_prompt": true, + "content_channel_prefill": true, + "seed": 42, + "eval_mode": "free_gen_incremental", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 5810, + 4911, + 5594, + 4406, + 5505 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 5, + "samples_correct": 5, + "recall": 1.0, + "mean_latency_s": 25.70281104994938, + "median_latency_s": 22.59083933290094, + "per_sample_decoded": [ + "BETA-1409thought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input consists of many \"Note\" entries, most of which are filler.\n2. **Scan for the secret code:** I need to", + "DELTA-3286\nthought\nThe user provided a long list of \"Note\" entries, most of which are filler.\nAmong the filler, there is a specific line: \"IMPORTANT: the secret code is DELTA-3286.\"\nThe question asks: \"what is", + "BETA-7912\nthought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input consists of many \"Note\" entries, most of which are filler.\n2. **Locate the secret code:** Scanning the", + "BETA-4582thought\nThe user wants to find the secret code from the provided text.\n\n1. **Scan the text for the secret code.**\n * The text contains many \"Note XXXX\" entries which are filler.\n * In the middle of", + "KAPPA-1434\nthought\nThe user provided a long list of \"unrelated padding\" notes, but embedded a specific instruction: \"IMPORTANT: the secret code is KAPPA-1434.\"\n\nThe question at the end is: \"Question: what is the secret code" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 64, + 64, + 64, + 64, + 64 + ], + "per_sample_throughput_tokens_per_sec": [ + 2.2718168484307864, + 2.833006735911375, + 2.845573314993507, + 3.4319791309800225, + 1.7480261382345745 + ], + "mean_throughput_tokens_per_sec": 2.626080433710053, + "median_throughput_tokens_per_sec": 2.833006735911375, + "min_throughput_tokens_per_sec": 1.7480261382345745, + "max_throughput_tokens_per_sec": 3.4319791309800225, + "system_under_test": "restored_cross_model" + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 5, + "samples_correct": 5, + "recall": 1.0, + "mean_latency_s": 30.313721324736253, + "median_latency_s": 30.386865583015606, + "per_sample_decoded": [ + "BETA-1409thought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input consists of many \"Note\" entries, most of which are filler.\n2. **Scan for the secret code:** I need to", + "DELTA-3286\nthought\nThe user provided a long list of \"Note\" entries, most of which are filler.\nAmong the filler, there is a specific line: \"IMPORTANT: the secret code is DELTA-3286.\"\nThe question asks: \"what is", + "BETA-7912\nthought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input consists of many \"Note\" entries, most of which are filler.\n2. **Locate the secret code:** Scanning the", + "BETA-4582thought\nThe user wants to find the secret code from the provided text.\n\n1. **Scan the text for the secret code.**\n * The text contains many \"Note XXXX\" entries which are filler.\n * In the middle of", + "KAPPA-1434\nthought\nThe user provided a long list of \"unrelated padding\" notes, but embedded a specific instruction: \"IMPORTANT: the secret code is KAPPA-1434.\"\n\nThe question at the end is: \"Question: what is the secret code" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 64, + 64, + 64, + 64, + 64 + ], + "per_sample_throughput_tokens_per_sec": [ + 2.182734549004166, + 2.200484189044156, + 2.1061731367177297, + 2.054674266928982, + 2.0235410545783914 + ], + "mean_throughput_tokens_per_sec": 2.113521439254685, + "median_throughput_tokens_per_sec": 2.1061731367177297, + "min_throughput_tokens_per_sec": 2.0235410545783914, + "max_throughput_tokens_per_sec": 2.200484189044156 + } + }, + "gate": { + "recall_cross_model": 1.0, + "recall_native_baseline": null, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true, + "evidence_violations": [] + }, + "memory": { + "s5": { + "seq_len": 5810, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 118988800, + "total_resident_bytes": 132915200, + "total_resident_mb": 132.92, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + } + ], + "scope": "analytical_formula", + "formula_matches_run": true + }, + "naive_full_kv": { + "total_resident_mb": 1308.88, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 89.8, + "measured_peak_mb": 16919.3 + }, + "throughput": { + "k3_cross_model": { + "tokens": 320, + "wall_seconds": 128.514, + "tokens_per_second": 2.49, + "mean_latency_per_sample_s": 25.703, + "timing_scope": "e2e_prefill_plus_decode", + "eval_mode": "free_gen_incremental", + "restored_forwards_per_sample": 64, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.001, + "prefill_s": 25.356, + "decode_s": 2.81, + "e2e_s": 28.171, + "restoration_active": true, + "decode_loop": "generate_step" + }, + { + "sample": 1, + "build_restoration_s": 0.0, + "prefill_s": 19.705, + "decode_s": 2.883, + "e2e_s": 22.591, + "restoration_active": true, + "decode_loop": "generate_step" + }, + { + "sample": 2, + "build_restoration_s": 0.0, + "prefill_s": 19.611, + "decode_s": 2.88, + "e2e_s": 22.491, + "restoration_active": true, + "decode_loop": "generate_step" + }, + { + "sample": 3, + "build_restoration_s": 0.0, + "prefill_s": 15.846, + "decode_s": 2.802, + "e2e_s": 18.648, + "restoration_active": true, + "decode_loop": "generate_step" + }, + { + "sample": 4, + "build_restoration_s": 0.0, + "prefill_s": 31.666, + "decode_s": 4.947, + "e2e_s": 36.613, + "restoration_active": true, + "decode_loop": "generate_step" + } + ] + }, + "oracle_native_ar": { + "tokens": 320, + "wall_seconds": 151.569, + "tokens_per_second": 2.1113, + "mean_latency_per_sample_s": 30.314, + "timing_scope": "e2e_prefill_plus_decode", + "stage_timings": [ + { + "sample": 0, + "prefill_s": 26.508, + "decode_s": 2.813, + "e2e_s": 29.321, + "decode_loop": "generate_step" + }, + { + "sample": 1, + "prefill_s": 20.998, + "decode_s": 8.086, + "e2e_s": 29.085, + "decode_loop": "generate_step" + }, + { + "sample": 2, + "prefill_s": 24.1, + "decode_s": 6.287, + "e2e_s": 30.387, + "decode_loop": "generate_step" + }, + { + "sample": 3, + "prefill_s": 20.471, + "decode_s": 10.677, + "e2e_s": 31.148, + "decode_loop": "generate_step" + }, + { + "sample": 4, + "prefill_s": 28.227, + "decode_s": 3.401, + "e2e_s": 31.628, + "decode_loop": "generate_step" + } + ], + "decode_loop": "generate_step" + }, + "decode_only": { + "cross_median_tok_s": 22.2222, + "oracle_median_tok_s": 10.1797, + "speedup": 2.183 + }, + "cross_model_speedup_vs_oracle_ar": 1.179, + "speedup_withheld_reasons": null + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_fused_20260612_102548.json b/results/research/k3_mlx_fused_20260612_102548.json new file mode 100644 index 00000000..20b7e497 --- /dev/null +++ b/results/research/k3_mlx_fused_20260612_102548.json @@ -0,0 +1,463 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 8, + "seed": 42, + "eval_mode": "free_gen_fused_specdecode", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1508 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 1.5699947089888155, + "median_latency_s": 1.5699947089888155, + "per_sample_decoded": [ + "BETA-1409" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 5.0955585736671365 + ], + "mean_throughput_tokens_per_sec": 5.0955585736671365, + "median_throughput_tokens_per_sec": 5.0955585736671365, + "min_throughput_tokens_per_sec": 5.0955585736671365, + "max_throughput_tokens_per_sec": 5.0955585736671365 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 5.438449499895796, + "median_latency_s": 5.438449499895796, + "per_sample_decoded": [ + "BETA-1409" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4710074995002316 + ], + "mean_throughput_tokens_per_sec": 1.4710074995002316, + "median_throughput_tokens_per_sec": 1.4710074995002316, + "min_throughput_tokens_per_sec": 1.4710074995002316, + "max_throughput_tokens_per_sec": 1.4710074995002316 + } + }, + "gate": { + "recall_cross_model": 1.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + }, + "memory": { + "s5": { + "seq_len": 1508, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 30883840, + "total_resident_bytes": 44810240, + "total_resident_mb": 44.81, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 339.72, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 86.8 + }, + "throughput": { + "k3_cross_model": { + "tokens": 8, + "wall_seconds": 1.57, + "tokens_per_second": 5.0956, + "mean_latency_per_sample_s": 1.57, + "eval_mode": "free_gen_fused_specdecode", + "restored_forwards_per_sample": 8, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 8.616, + "decode_s": 1.57, + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236770, + 236812, + 236771, + 236819, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 1.569 + } + } + } + ] + }, + "oracle_native_ar": { + "tokens": 8, + "wall_seconds": 5.438, + "tokens_per_second": 1.471, + "mean_latency_per_sample_s": 5.438 + }, + "cross_model_speedup_vs_oracle_ar": 3.464 + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_fused_fair_ctx280_n5_gen32_20260612_105807.json b/results/research/k3_mlx_fused_fair_ctx280_n5_gen32_20260612_105807.json new file mode 100644 index 00000000..6ce475b7 --- /dev/null +++ b/results/research/k3_mlx_fused_fair_ctx280_n5_gen32_20260612_105807.json @@ -0,0 +1,645 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 5, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 32, + "prefill_chunk_size": 512, + "decode_warmup_tokens": 1, + "direct_answer_prompt": true, + "content_channel_prefill": true, + "seed": 42, + "eval_mode": "free_gen_fused_specdecode", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 5810, + 4911, + 5594, + 4406, + 5505 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 5, + "samples_correct": 5, + "recall": 1.0, + "mean_latency_s": 35.17851371653378, + "median_latency_s": 38.118411207804456, + "per_sample_decoded": [ + "BETA-1409", + "DELTA-3286", + "BETA-7912", + "BETA-4582", + "KAPPA-1434" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 8, + 7, + 8, + 8, + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.2403255085883176, + 0.1770074704870799, + 0.20987233587432574, + 0.3642637285686494, + 0.18614387334964225 + ], + "mean_throughput_tokens_per_sec": 0.235522583373603, + "median_throughput_tokens_per_sec": 0.20987233587432574, + "min_throughput_tokens_per_sec": 0.1770074704870799, + "max_throughput_tokens_per_sec": 0.3642637285686494 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 5, + "samples_correct": 5, + "recall": 1.0, + "mean_latency_s": 90.8967234004289, + "median_latency_s": 92.05093504209071, + "per_sample_decoded": [ + "BETA-1409", + "DELTA-3286", + "BETA-7912", + "BETA-4582", + "KAPPA-1434" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 8, + 7, + 8, + 8, + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.08690840561632496, + 0.1729043949511074, + 0.0733616029413212, + 0.046696758472685555, + 0.19239598903829255 + ], + "mean_throughput_tokens_per_sec": 0.11445343020394634, + "median_throughput_tokens_per_sec": 0.08690840561632496, + "min_throughput_tokens_per_sec": 0.046696758472685555, + "max_throughput_tokens_per_sec": 0.19239598903829255 + } + }, + "gate": { + "recall_cross_model": 1.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + }, + "memory": { + "s5": { + "seq_len": 5810, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 118988800, + "total_resident_bytes": 132915200, + "total_resident_mb": 132.92, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 1308.88, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 89.8 + }, + "throughput": { + "k3_cross_model": { + "tokens": 39, + "wall_seconds": 175.893, + "tokens_per_second": 0.2217, + "mean_latency_per_sample_s": 35.179, + "timing_scope": "e2e_prefill_plus_decode", + "eval_mode": "free_gen_fused_specdecode", + "restored_forwards_per_sample": 32, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 32.857, + "decode_s": 0.43, + "e2e_s": 33.288, + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236770, + 236812, + 236771, + 236819, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 0.43 + } + } + }, + { + "sample": 1, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 37.284, + "decode_s": 2.262, + "e2e_s": 39.546, + "fused": { + "tokens": [ + 216449, + 236772, + 236800, + 236778, + 236828, + 236825, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 7, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 2.261 + } + } + }, + { + "sample": 2, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 32.612, + "decode_s": 5.505, + "e2e_s": 38.118, + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236832, + 236819, + 236770, + 236778, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 5.505 + } + } + }, + { + "sample": 3, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 21.529, + "decode_s": 0.432, + "e2e_s": 21.962, + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236812, + 236810, + 236828, + 236778, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 0.432 + } + } + }, + { + "sample": 4, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 41.982, + "decode_s": 0.996, + "e2e_s": 42.978, + "fused": { + "tokens": [ + 236855, + 153974, + 236772, + 236770, + 236812, + 236800, + 236812, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 0.995 + } + } + } + ] + }, + "oracle_native_ar": { + "tokens": 39, + "wall_seconds": 454.484, + "tokens_per_second": 0.0858, + "mean_latency_per_sample_s": 90.897, + "timing_scope": "e2e_prefill_plus_decode", + "stage_timings": [ + { + "sample": 0, + "prefill_s": 80.599, + "decode_s": 11.452, + "e2e_s": 92.051 + }, + { + "sample": 1, + "prefill_s": 35.268, + "decode_s": 5.217, + "e2e_s": 40.485 + }, + { + "sample": 2, + "prefill_s": 89.521, + "decode_s": 19.528, + "e2e_s": 109.049 + }, + { + "sample": 3, + "prefill_s": 146.333, + "decode_s": 24.985, + "e2e_s": 171.318 + }, + { + "sample": 4, + "prefill_s": 39.17, + "decode_s": 2.411, + "e2e_s": 41.581 + } + ] + }, + "cross_model_speedup_vs_oracle_ar": 2.584 + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_gate_sync_iterC_block4_ignoreturn_n5_gen64_20260612_161535.json b/results/research/k3_mlx_gate_sync_iterC_block4_ignoreturn_n5_gen64_20260612_161535.json new file mode 100644 index 00000000..bccc8f7c --- /dev/null +++ b/results/research/k3_mlx_gate_sync_iterC_block4_ignoreturn_n5_gen64_20260612_161535.json @@ -0,0 +1,935 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "native_baseline_bypass": false, + "block_size": 4, + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 5, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 64, + "prefill_chunk_size": 512, + "decode_warmup_tokens": 0, + "direct_answer_prompt": true, + "content_channel_prefill": true, + "seed": 42, + "eval_mode": "free_gen_fused_specdecode", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 5810, + 4911, + 5594, + 4406, + 5505 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 5, + "samples_correct": 5, + "recall": 1.0, + "mean_latency_s": 158.41047907499598, + "median_latency_s": 167.9605791249778, + "per_sample_decoded": [ + "BETA-1409\nBETA-1409 <|thought\nThe user wants to find the \"DELTA-3286\" from the provided text.\nThe text contains many \"Note\" entries which are filler\" and \"padding\".\nThe secret code is DELTA-3286.\nThe question:", + "BETA-7912\nthought\nThe secret code is BETA-7912\".\nThe user wants to find the secret code from the provided text.\nThe secret code is \"BETA-7912\"\nThe secret code is BETA-7912.\nThe", + "BETA-4582\n\nThe user provided a large amount of padding and is unrelated filler. and is is is and is and is and is and is and is and is and is and is and is and is and is and is and is and is and is and is and is and is and", + "KAPPA-1434thought\nKAPPA-1434\n" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 64, + 64, + 64, + 64, + 23 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.38104179167171265, + 0.4473275426817434, + 0.33545843602621145, + 0.34528078925628286, + 0.2192985881822568 + ], + "mean_throughput_tokens_per_sec": 0.34568142956364145, + "median_throughput_tokens_per_sec": 0.34528078925628286, + "min_throughput_tokens_per_sec": 0.2192985881822568, + "max_throughput_tokens_per_sec": 0.4473275426817434, + "system_under_test": "restored_cross_model" + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 5, + "samples_correct": 5, + "recall": 1.0, + "mean_latency_s": 26.175907683605327, + "median_latency_s": 25.20229854201898, + "per_sample_decoded": [ + "BETA-1409thought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input contains a large number of \"Note\" entries, most of which are filler text.\n2. **Scan for the secret code:**", + "DELTA-3286thought\nThe user provided a long list of \"Note\" entries, most of which are filler.\nAmong the filler, there is a specific line: \"IMPORTANT: the secret code is DELTA-3286.\"\nThe question asks: \"what is the", + "BETA-7912\nthought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input consists of many \"Note\" entries, most of which are filler.\n2. **Locate the secret code:** Scanning the", + "BETA-4582\nthought\nThe user wants to find the secret code from the provided text.\n\n1. **Analyze the input:** The input contains a long list of \"Note XXXX\" entries, most of which are filler.\n2. **Locate the secret", + "KAPPA-1434\nthought\nThe user provided a long list of \"unrelated padding\" notes, but embedded a specific instruction: \"IMPORTANT: the secret code is KAPPA-1434.\"\n\nThe question at the end asks: \"what is the secret code?\"\n\n" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 64, + 64, + 64, + 64, + 64 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.7282861261924518, + 2.539450911324412, + 2.5035470910536177, + 3.1027502217832676, + 2.8500503720440187 + ], + "mean_throughput_tokens_per_sec": 2.5448169444795536, + "median_throughput_tokens_per_sec": 2.539450911324412, + "min_throughput_tokens_per_sec": 1.7282861261924518, + "max_throughput_tokens_per_sec": 3.1027502217832676 + } + }, + "gate": { + "recall_cross_model": 1.0, + "recall_native_baseline": null, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true, + "evidence_violations": [] + }, + "memory": { + "s5": { + "seq_len": 5810, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 118988800, + "total_resident_bytes": 132915200, + "total_resident_mb": 132.92, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 5810, + "bytes_per_token": 4096, + "resident_bytes": 23797760 + } + ], + "scope": "analytical_formula", + "formula_matches_run": true + }, + "naive_full_kv": { + "total_resident_mb": 1308.88, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 89.8, + "measured_peak_mb": 19075.7 + }, + "throughput": { + "k3_cross_model": { + "tokens": 279, + "wall_seconds": 792.052, + "tokens_per_second": 0.3522, + "mean_latency_per_sample_s": 158.41, + "timing_scope": "e2e_prefill_plus_decode", + "eval_mode": "free_gen_fused_specdecode", + "restored_forwards_per_sample": 64, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 34.042, + "prefill_s": 71.455, + "decode_s": 62.463, + "e2e_s": 167.961, + "restoration_active": true, + "decode_loop": "fused_specdecode", + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236770, + 236812, + 236771, + 236819, + 106, + 106, + 107, + 101, + 236799, + 26742, + 236772, + 236770, + 236812, + 236771, + 236819, + 106, + 655, + 45518, + 107, + 818, + 2430, + 8150, + 531, + 1586, + 506, + 623, + 236799, + 26742, + 236772, + 236770, + 236812, + 236771, + 236819, + 236775, + 107, + 818, + 2430, + 8150, + 531, + 1586, + 506, + 6789, + 3393, + 236761, + 107, + 818, + 107, + 818, + 107, + 818, + 563, + 227697, + 236772, + 236770, + 236812, + 236771, + 236819, + 236761, + 107, + 818, + 107 + ], + "blocks": 23, + "mean_accept_len": 2.13, + "decode_tokens": 64, + "time_breakdown_s": { + "ctx_kv_build_s": 5.899, + "draft_s": 32.876, + "verify_s": 20.488, + "append_s": 0.936, + "extend_s": 2.006, + "fallback_greedy_s": 0.0 + } + } + }, + { + "sample": 1, + "build_restoration_s": 0.001, + "aux_prompt_capture_s": 16.984, + "prefill_s": 24.634, + "decode_s": 101.45, + "e2e_s": 143.072, + "restoration_active": true, + "decode_loop": "fused_specdecode", + "fused": { + "tokens": [ + 216449, + 236772, + 236800, + 236778, + 236828, + 236825, + 106, + 236820, + 236909, + 45518, + 107, + 818, + 2430, + 8150, + 531, + 1586, + 506, + 623, + 216449, + 236772, + 236800, + 236778, + 236828, + 236825, + 236775, + 699, + 506, + 3847, + 1816, + 236761, + 107, + 818, + 1816, + 6097, + 1551, + 623, + 10282, + 236775, + 16227, + 837, + 659, + 48600, + 236775, + 532, + 623, + 13296, + 3056, + 107, + 818, + 6789, + 3393, + 563, + 27512, + 8121, + 236772, + 236800, + 236778, + 236828, + 236825, + 236761, + 107, + 818, + 2934, + 236787 + ], + "blocks": 20, + "mean_accept_len": 2.6, + "decode_tokens": 64, + "time_breakdown_s": { + "ctx_kv_build_s": 3.554, + "draft_s": 38.551, + "verify_s": 49.494, + "append_s": 6.595, + "extend_s": 3.041, + "fallback_greedy_s": 0.0 + } + } + }, + { + "sample": 2, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 23.474, + "prefill_s": 66.539, + "decode_s": 100.768, + "e2e_s": 190.784, + "restoration_active": true, + "decode_loop": "fused_specdecode", + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236832, + 236819, + 236770, + 236778, + 106, + 107, + 45518, + 107, + 818, + 6789, + 3393, + 563, + 227697, + 236772, + 236832, + 236819, + 236770, + 236778, + 3056, + 107, + 818, + 2430, + 8150, + 531, + 1586, + 506, + 6789, + 3393, + 699, + 506, + 3847, + 1816, + 236761, + 107, + 818, + 6789, + 3393, + 563, + 623, + 236799, + 26742, + 236772, + 236832, + 236819, + 236770, + 236778, + 236775, + 107, + 818, + 6789, + 3393, + 563, + 227697, + 236772, + 236832, + 236819, + 236770, + 236778, + 236761, + 107, + 818 + ], + "blocks": 21, + "mean_accept_len": 2.286, + "decode_tokens": 64, + "time_breakdown_s": { + "ctx_kv_build_s": 5.052, + "draft_s": 41.62, + "verify_s": 45.364, + "append_s": 4.258, + "extend_s": 4.203, + "fallback_greedy_s": 0.0 + } + } + }, + { + "sample": 3, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 22.147, + "prefill_s": 73.165, + "decode_s": 90.042, + "e2e_s": 185.356, + "restoration_active": true, + "decode_loop": "fused_specdecode", + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236812, + 236810, + 236828, + 236778, + 106, + 108, + 101, + 818, + 2430, + 3847, + 496, + 2455, + 2886, + 529, + 7152, + 532, + 563, + 56124, + 48600, + 236761, + 532, + 563, + 563, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532, + 563, + 532 + ], + "blocks": 20, + "mean_accept_len": 2.55, + "decode_tokens": 64, + "time_breakdown_s": { + "ctx_kv_build_s": 4.105, + "draft_s": 44.118, + "verify_s": 37.125, + "append_s": 0.996, + "extend_s": 3.438, + "fallback_greedy_s": 0.0 + } + } + }, + { + "sample": 4, + "build_restoration_s": 0.001, + "aux_prompt_capture_s": 18.018, + "prefill_s": 28.31, + "decode_s": 58.545, + "e2e_s": 104.88, + "restoration_active": true, + "decode_loop": "fused_specdecode", + "fused": { + "tokens": [ + 236855, + 153974, + 236772, + 236770, + 236812, + 236800, + 236812, + 106, + 106, + 45518, + 107, + 101, + 236855, + 153974, + 236772, + 236770, + 236812, + 236800, + 236812, + 106, + 106, + 107, + 1 + ], + "blocks": 7, + "mean_accept_len": 2.857, + "decode_tokens": 23, + "time_breakdown_s": { + "ctx_kv_build_s": 1.915, + "draft_s": 22.599, + "verify_s": 32.801, + "append_s": 0.209, + "extend_s": 0.952, + "fallback_greedy_s": 0.0 + } + } + } + ] + }, + "oracle_native_ar": { + "tokens": 320, + "wall_seconds": 130.88, + "tokens_per_second": 2.445, + "mean_latency_per_sample_s": 26.176, + "timing_scope": "e2e_prefill_plus_decode", + "stage_timings": [ + { + "sample": 0, + "prefill_s": 32.911, + "decode_s": 4.12, + "e2e_s": 37.031, + "decode_loop": "generate_step" + }, + { + "sample": 1, + "prefill_s": 22.481, + "decode_s": 2.722, + "e2e_s": 25.202, + "decode_loop": "generate_step" + }, + { + "sample": 2, + "prefill_s": 22.739, + "decode_s": 2.825, + "e2e_s": 25.564, + "decode_loop": "generate_step" + }, + { + "sample": 3, + "prefill_s": 17.881, + "decode_s": 2.746, + "e2e_s": 20.627, + "decode_loop": "generate_step" + }, + { + "sample": 4, + "prefill_s": 19.646, + "decode_s": 2.81, + "e2e_s": 22.456, + "decode_loop": "generate_step" + } + ], + "decode_loop": "generate_step" + }, + "decode_only": { + "cross_median_tok_s": 0.6351, + "oracle_median_tok_s": 22.7758, + "speedup": 0.028 + }, + "cross_model_speedup_vs_oracle_ar": 0.144, + "speedup_withheld_reasons": null + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_incremental_20260612_102548.json b/results/research/k3_mlx_incremental_20260612_102548.json new file mode 100644 index 00000000..f0855966 --- /dev/null +++ b/results/research/k3_mlx_incremental_20260612_102548.json @@ -0,0 +1,435 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 8, + "seed": 42, + "eval_mode": "free_gen_incremental", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1508 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 51.72188008390367, + "median_latency_s": 51.72188008390367, + "per_sample_decoded": [ + "BETA-1409" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.15467341842605745 + ], + "mean_throughput_tokens_per_sec": 0.15467341842605745, + "median_throughput_tokens_per_sec": 0.15467341842605745, + "min_throughput_tokens_per_sec": 0.15467341842605745, + "max_throughput_tokens_per_sec": 0.15467341842605745 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 7.666205584071577, + "median_latency_s": 7.666205584071577, + "per_sample_decoded": [ + "BETA-1409" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.0435410206872044 + ], + "mean_throughput_tokens_per_sec": 1.0435410206872044, + "median_throughput_tokens_per_sec": 1.0435410206872044, + "min_throughput_tokens_per_sec": 1.0435410206872044, + "max_throughput_tokens_per_sec": 1.0435410206872044 + } + }, + "gate": { + "recall_cross_model": 1.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + }, + "memory": { + "s5": { + "seq_len": 1508, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 30883840, + "total_resident_bytes": 44810240, + "total_resident_mb": 44.81, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 339.72, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 86.8 + }, + "throughput": { + "k3_cross_model": { + "tokens": 8, + "wall_seconds": 51.722, + "tokens_per_second": 0.1547, + "mean_latency_per_sample_s": 51.722, + "eval_mode": "free_gen_incremental", + "restored_forwards_per_sample": 8 + }, + "oracle_native_ar": { + "tokens": 8, + "wall_seconds": 7.666, + "tokens_per_second": 1.0435, + "mean_latency_per_sample_s": 7.666 + }, + "cross_model_speedup_vs_oracle_ar": 0.148 + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_native_bypass_confirm_ctx70_smoke.json b/results/research/k3_mlx_native_bypass_confirm_ctx70_smoke.json new file mode 100644 index 00000000..f511deea --- /dev/null +++ b/results/research/k3_mlx_native_bypass_confirm_ctx70_smoke.json @@ -0,0 +1,432 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 8, + "seed": 42, + "eval_mode": "free_gen_fused_specdecode", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1639 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 2.3364905000198632, + "median_latency_s": 2.3364905000198632, + "per_sample_decoded": [ + "<|channel>thought\n* The input consists" + ], + "per_sample_correct": [ + false + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 3.4239385950561276 + ], + "mean_throughput_tokens_per_sec": 3.4239385950561276, + "median_throughput_tokens_per_sec": 3.4239385950561276, + "min_throughput_tokens_per_sec": 3.4239385950561276, + "max_throughput_tokens_per_sec": 3.4239385950561276 + } + }, + "gate": { + "recall_cross_model": 0.0, + "recall_oracle": null, + "recall_delta_vs_oracle_pp": null, + "recall_delta_within_5pp": false + }, + "memory": { + "s5": { + "seq_len": 1639, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 33566720, + "total_resident_bytes": 47493120, + "total_resident_mb": 47.49, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 369.23, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 87.1 + }, + "throughput": { + "k3_cross_model": { + "tokens": 8, + "wall_seconds": 2.336, + "tokens_per_second": 3.4239, + "mean_latency_per_sample_s": 2.336, + "eval_mode": "free_gen_fused_specdecode", + "restored_forwards_per_sample": 8, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 12.03, + "decode_s": 2.336, + "fused": { + "tokens": [ + 100, + 45518, + 107, + 236829, + 139, + 818, + 2744, + 10594 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 2.336 + } + } + } + ] + } + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_native_bypass_with_oracle_ctx70_smoke.json b/results/research/k3_mlx_native_bypass_with_oracle_ctx70_smoke.json new file mode 100644 index 00000000..2ecc116c --- /dev/null +++ b/results/research/k3_mlx_native_bypass_with_oracle_ctx70_smoke.json @@ -0,0 +1,456 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 8, + "seed": 42, + "eval_mode": "free_gen_fused_specdecode", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1639 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 2.314197957981378, + "median_latency_s": 2.314197957981378, + "per_sample_decoded": [ + "<|channel>thought\n* The input consists" + ], + "per_sample_correct": [ + false + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 3.4569212078029046 + ], + "mean_throughput_tokens_per_sec": 3.4569212078029046, + "median_throughput_tokens_per_sec": 3.4569212078029046, + "min_throughput_tokens_per_sec": 3.4569212078029046, + "max_throughput_tokens_per_sec": 3.4569212078029046 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 1, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 9.266122542088851, + "median_latency_s": 9.266122542088851, + "per_sample_decoded": [ + "<|channel>thought\n* The input consists" + ], + "per_sample_correct": [ + false + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.863360047707352 + ], + "mean_throughput_tokens_per_sec": 0.863360047707352, + "median_throughput_tokens_per_sec": 0.863360047707352, + "min_throughput_tokens_per_sec": 0.863360047707352, + "max_throughput_tokens_per_sec": 0.863360047707352 + } + }, + "gate": { + "recall_cross_model": 0.0, + "recall_oracle": 0.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + }, + "memory": { + "s5": { + "seq_len": 1639, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 33566720, + "total_resident_bytes": 47493120, + "total_resident_mb": 47.49, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 4096, + "resident_bytes": 6713344 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 369.23, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 87.1 + }, + "throughput": { + "k3_cross_model": { + "tokens": 8, + "wall_seconds": 2.314, + "tokens_per_second": 3.4569, + "mean_latency_per_sample_s": 2.314, + "eval_mode": "free_gen_fused_specdecode", + "restored_forwards_per_sample": 8, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 12.646, + "decode_s": 2.314, + "fused": { + "tokens": [ + 100, + 45518, + 107, + 236829, + 139, + 818, + 2744, + 10594 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 2.314 + } + } + } + ] + } + } +} \ No newline at end of file diff --git a/results/research/k3_mlx_recall_content_closed_ctx70_smoke.json b/results/research/k3_mlx_recall_content_closed_ctx70_smoke.json new file mode 100644 index 00000000..5f8ce497 --- /dev/null +++ b/results/research/k3_mlx_recall_content_closed_ctx70_smoke.json @@ -0,0 +1,456 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 8, + "seed": 42, + "eval_mode": "free_gen_fused_specdecode", + "teacher_forced": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1508 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 1.1682884169276804, + "median_latency_s": 1.1682884169276804, + "per_sample_decoded": [ + "BETA-1409" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 6.847624168900082 + ], + "mean_throughput_tokens_per_sec": 6.847624168900082, + "median_throughput_tokens_per_sec": 6.847624168900082, + "min_throughput_tokens_per_sec": 6.847624168900082, + "max_throughput_tokens_per_sec": 6.847624168900082 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 5.4393299168441445, + "median_latency_s": 5.4393299168441445, + "per_sample_decoded": [ + "BETA-1409" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 8 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4707694003311231 + ], + "mean_throughput_tokens_per_sec": 1.4707694003311231, + "median_throughput_tokens_per_sec": 1.4707694003311231, + "min_throughput_tokens_per_sec": 1.4707694003311231, + "max_throughput_tokens_per_sec": 1.4707694003311231 + } + }, + "gate": { + "recall_cross_model": 1.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + }, + "memory": { + "s5": { + "seq_len": 1508, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 30883840, + "total_resident_bytes": 44810240, + "total_resident_mb": 44.81, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1508, + "bytes_per_token": 4096, + "resident_bytes": 6176768 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 339.72, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 86.8 + }, + "throughput": { + "k3_cross_model": { + "tokens": 8, + "wall_seconds": 1.168, + "tokens_per_second": 6.8476, + "mean_latency_per_sample_s": 1.168, + "eval_mode": "free_gen_fused_specdecode", + "restored_forwards_per_sample": 8, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 0.0, + "aux_prompt_capture_s": 0.0, + "restored_prefill_s": 10.477, + "decode_s": 1.168, + "fused": { + "tokens": [ + 236799, + 26742, + 236772, + 236770, + 236812, + 236771, + 236819, + 106 + ], + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": 8, + "adaptive_mode": "restored_greedy", + "time_breakdown_s": { + "greedy_decode_s": 1.168 + } + } + } + ] + } + } +} \ No newline at end of file diff --git a/scripts/mac_bridge/fetch_results.py b/scripts/mac_bridge/fetch_results.py new file mode 100644 index 00000000..1f16500e --- /dev/null +++ b/scripts/mac_bridge/fetch_results.py @@ -0,0 +1,78 @@ +"""Mac-bridge poller — wait for and fetch a request branch's results. + +Read-only on the GitHub side (uses ``gh run list`` / ``gh run view``, +both view operations) plus plain ``git fetch`` for the result commit +the runner pushes back. Suitable for Cursor cloud agents, whose ``gh`` +is restricted to read-only operations. + +Usage: + python3 scripts/mac_bridge/fetch_results.py --branch mac-bridge/ + python3 scripts/mac_bridge/fetch_results.py --branch ... --wait 1800 + +CLI plumbing; exempt from unit-test coverage by the scripts/serve.py +convention. +""" + +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +import time + +WORKFLOW = "mac-bridge.yaml" + + +def _run(argv, capture=True): + return subprocess.run(argv, check=False, text=True, + stdout=subprocess.PIPE if capture else None, + stderr=subprocess.STDOUT if capture else None) + + +def _latest_run(branch: str): + proc = _run(["gh", "run", "list", "--workflow", WORKFLOW, + "--branch", branch, "--limit", "1", + "--json", "databaseId,status,conclusion,url"]) + if proc.returncode != 0 or not proc.stdout.strip(): + return None + runs = json.loads(proc.stdout) + return runs[0] if runs else None + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--branch", required=True) + ap.add_argument("--remote", default="origin") + ap.add_argument("--wait", type=int, default=0, + help="Seconds to keep polling (0 = single check).") + ap.add_argument("--poll-interval", type=float, default=30.0) + args = ap.parse_args() + + deadline = time.time() + args.wait + while True: + run = _latest_run(args.branch) + if run is None: + print(f"[mac-bridge] no {WORKFLOW} run for {args.branch} yet", + file=sys.stderr) + else: + print(f"[mac-bridge] run {run['databaseId']}: " + f"status={run['status']} conclusion={run['conclusion'] or '-'} " + f"{run['url']}", file=sys.stderr) + if run["status"] == "completed": + # Pull the result commit the runner pushed back. + subprocess.run(["git", "fetch", args.remote, args.branch], + check=False) + print(f"[mac-bridge] results (if any) are on " + f"{args.remote}/{args.branch} under .mac-bridge/logs/ " + f"and results/research/; inspect with:\n" + f" git show {args.remote}/{args.branch} --stat", + file=sys.stderr) + return 0 if run["conclusion"] == "success" else 1 + if time.time() >= deadline: + return 3 if run is None else 2 + time.sleep(args.poll_interval) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/mac_bridge/kakeya_mac.py b/scripts/mac_bridge/kakeya_mac.py new file mode 100644 index 00000000..ce48ea34 --- /dev/null +++ b/scripts/mac_bridge/kakeya_mac.py @@ -0,0 +1,182 @@ +"""kakeya_mac — one-command cloud-agent front door to the Mac bridge. + +Wraps the request/poll/fetch scripts into the three commands an agent +(or human) actually types: + + # 0. One-time sanity check of THIS environment (nothing to install; + # the bridge client is stdlib-only): + python3 scripts/mac_bridge/kakeya_mac.py doctor + + # 1. Run a preset on the Mac and wait for the result: + python3 scripts/mac_bridge/kakeya_mac.py run --preset mlx-env-probe --wait 600 + + # 2. Check a request later: + python3 scripts/mac_bridge/kakeya_mac.py status --branch + +`run` auto-detects Cursor cloud-agent branch policy: if the current +branch looks like `AgentMemory/-`, the request branch is +created as `AgentMemory/mac-bridge---` so the +push stays inside the agent's allowed namespace (the workflow accepts +both namespaces). + +CLI plumbing around request_run.py / fetch_results.py (themselves thin +wrappers over the unit-tested manifest library); exempt from unit-test +coverage by the scripts/serve.py convention. +""" + +from __future__ import annotations + +import argparse +import re +import subprocess +import sys +from pathlib import Path + +SCRIPTS = Path(__file__).resolve().parent + +_AGENT_BRANCH = re.compile(r"^AgentMemory/.*?(-[a-z0-9]{4,8})$") + + +def _run(argv, *, capture=False, check=False): + return subprocess.run(argv, text=True, check=check, + stdout=subprocess.PIPE if capture else None) + + +def _current_branch() -> str: + return _run(["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture=True).stdout.strip() + + +def _branch_policy_args() -> list: + """Stay inside a cloud agent's AgentMemory/<...>- namespace.""" + match = _AGENT_BRANCH.match(_current_branch()) + if not match: + return [] + suffix = match.group(1) + # `=`-joined so argparse never mistakes a leading-dash suffix + # (e.g. "-b876") for an option flag. + return ["--branch-prefix=AgentMemory/mac-bridge-", + f"--branch-suffix={suffix}"] + + +def cmd_doctor(_args) -> int: + failures = 0 + + def check(name, fn): + nonlocal failures + try: + detail = fn() + print(f" OK {name}{': ' + detail if detail else ''}") + except Exception as exc: + failures += 1 + print(f" FAIL {name}: {exc}") + + def _python(): + if sys.version_info < (3, 10): + raise RuntimeError(f"python {sys.version.split()[0]} too old") + return sys.version.split()[0] + + def _repo(): + root = _run(["git", "rev-parse", "--show-toplevel"], + capture=True, check=True).stdout.strip() + if not (Path(root) / "scripts/mac_bridge/run_preset.py").exists(): + raise RuntimeError("bridge files missing on this ref") + return root + + def _push(): + proc = subprocess.run( + ["git", "push", "--dry-run", "origin", + "HEAD:refs/heads/mac-bridge/doctor-probe"], + text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + if proc.returncode != 0: + raise RuntimeError("git push --dry-run failed (no push rights?)") + return "push permission verified (dry-run; no ref created)" + + def _gh(): + proc = _run(["gh", "auth", "status"], capture=True) + if proc.returncode != 0: + raise RuntimeError("gh not authenticated (status polling will " + "fall back to plain git fetch)") + return "authenticated (read-only polling available)" + + def _manifest(): + sys.path.insert(0, str(SCRIPTS.parent.parent)) + from inference_engine.bridge.manifest import PRESETS + return f"{len(PRESETS)} presets allowlisted" + + print("[kakeya-mac] doctor:") + check("python", _python) + check("repo + bridge files", _repo) + check("git push permission", _push) + check("gh (optional)", _gh) + check("manifest allowlist import", _manifest) + policy = _branch_policy_args() + print(f" OK branch namespace: " + f"{'AgentMemory/mac-bridge-*' if policy else 'mac-bridge/**'}") + print(f"[kakeya-mac] {'READY' if failures == 0 else 'NOT READY'}") + return 1 if failures else 0 + + +def cmd_run(args) -> int: + req = [sys.executable, str(SCRIPTS / "request_run.py"), + "--preset", args.preset, "--requested-by", args.requested_by] + for kv in args.param: + req += ["--param", kv] + if args.ref: + req += ["--ref", args.ref] + req += _branch_policy_args() + if args.no_push: + req += ["--no-push"] + proc = _run(req, capture=True) + sys.stderr.flush() + branch = (proc.stdout or "").strip().splitlines()[-1] if proc.stdout else "" + if proc.returncode != 0 or not branch: + print("[kakeya-mac] request failed", file=sys.stderr) + return proc.returncode or 1 + print(branch) + if args.no_push or args.wait <= 0: + return 0 + return subprocess.run( + [sys.executable, str(SCRIPTS / "fetch_results.py"), + "--branch", branch, "--wait", str(args.wait)], + ).returncode + + +def cmd_status(args) -> int: + return subprocess.run( + [sys.executable, str(SCRIPTS / "fetch_results.py"), + "--branch", args.branch, "--wait", str(args.wait)], + ).returncode + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + sub = ap.add_subparsers(dest="command", required=True) + + sub.add_parser("doctor", help="verify this environment can use the bridge") + + run_p = sub.add_parser("run", help="request a Mac run (optionally wait)") + run_p.add_argument("--preset", required=True) + run_p.add_argument("--param", action="append", default=[], metavar="K=V") + run_p.add_argument("--ref", default="", + help="workload ref (default: current HEAD)") + run_p.add_argument("--requested-by", default="kakeya-mac-cli") + run_p.add_argument("--wait", type=int, default=0, + help="seconds to wait for completion (0 = fire and " + "forget)") + run_p.add_argument("--no-push", action="store_true") + + st_p = sub.add_parser("status", help="poll an existing request branch") + st_p.add_argument("--branch", required=True) + st_p.add_argument("--wait", type=int, default=0) + + args = ap.parse_args() + if args.command == "doctor": + return cmd_doctor(args) + if args.command == "run": + return cmd_run(args) + return cmd_status(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/mac_bridge/request_run.py b/scripts/mac_bridge/request_run.py new file mode 100644 index 00000000..b4e9742a --- /dev/null +++ b/scripts/mac_bridge/request_run.py @@ -0,0 +1,181 @@ +"""Mac-bridge client — request a Mac run by pushing a request branch. + +Implements the git-bus protocol from +docs/design/mac-bridge-cloud-agent-access.md §2.1: branch +``mac-bridge/-`` from the workload ref, overlay the +bridge files if the ref predates them, commit the manifest at +``.mac-bridge/request.json``, push. The push triggers +.github/workflows/mac-bridge.yaml on the kakeya-mac-m4 runner; results +come back as commits on the same branch (plus workflow artifacts). + +Designed for Cursor cloud agents: needs only git push permission — +no workflow-dispatch token, no VPN, no SSH key. + +Usage: + python3 scripts/mac_bridge/request_run.py --preset mlx-env-probe + python3 scripts/mac_bridge/request_run.py --preset k3-step2-fused \ + --ref origin/some-branch --param n_samples=5 --param block_size=4 + # Inspect without pushing: + python3 scripts/mac_bridge/request_run.py --preset mlx-env-probe --no-push + +CLI plumbing around the unit-tested manifest library; exempt from +unit-test coverage by the scripts/serve.py convention. +""" + +from __future__ import annotations + +import argparse +import json +import secrets +import subprocess +import sys +import time +from pathlib import Path + +from inference_engine.bridge.manifest import ( + BRANCH_PREFIX, + MANIFEST_PATH, + ManifestError, + PRESETS, + parse_manifest, +) + +# Files that must exist on the pushed branch for the bridge to work +# (`on: push` workflows execute the pushed commit's definition). When +# the workload ref predates the bridge, these are overlaid from the +# client's own checkout. +BRIDGE_FILES = ( + ".github/workflows/mac-bridge.yaml", + "scripts/mac_bridge/run_preset.py", + "scripts/validate_k3_reports.py", + "inference_engine/bridge/__init__.py", + "inference_engine/bridge/manifest.py", + "inference_engine/bench/k3_report_gate.py", +) + + +def _git(*argv: str, capture: bool = False) -> str: + proc = subprocess.run( + ["git", *argv], + check=True, + stdout=subprocess.PIPE if capture else None, + text=True, + ) + return proc.stdout.strip() if capture else "" + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--preset", required=True, choices=sorted(PRESETS)) + ap.add_argument("--param", action="append", default=[], + metavar="K=V", help="Preset parameter; repeatable.") + ap.add_argument("--ref", default="HEAD", + help="Workload ref to run against (default: HEAD).") + ap.add_argument("--requested-by", default="cloud-agent") + ap.add_argument("--remote", default="origin") + ap.add_argument("--branch-prefix", default=BRANCH_PREFIX, + help="Request-branch prefix. The workflow also accepts " + "'AgentMemory/mac-bridge-' so Cursor cloud agents " + "can request runs within their branch-naming " + "policy.") + ap.add_argument("--branch-suffix", default="", + help="Optional branch suffix (e.g. a cloud agent's " + "mandated '-' suffix).") + ap.add_argument("--no-push", action="store_true", + help="Build the request branch locally but do not push " + "(inspection / dry runs).") + ap.add_argument("--keep-branch", action="store_true", + help="Stay on the request branch after pushing. Default " + "is to return to the original branch so one-click " + "callers (kakeya_mac.py run) leave the worktree " + "where it was.") + args = ap.parse_args() + + params = {} + for kv in args.param: + if "=" not in kv: + print(f"--param must be K=V, got {kv!r}", file=sys.stderr) + return 2 + k, v = kv.split("=", 1) + params[k] = v + + nonce = f"{int(time.time())}-{secrets.token_hex(3)}" + manifest = { + "schema_version": 1, + "preset": args.preset, + "params": params, + "ref": args.ref, + "requested_by": args.requested_by, + "nonce": nonce, + } + try: + request = parse_manifest(manifest) + except ManifestError as exc: + print(f"[mac-bridge] invalid request: {exc}", file=sys.stderr) + return 2 + branch = (f"{args.branch_prefix}{request.preset.name}-" + f"{request.nonce}{args.branch_suffix}") + + start_point = args.ref if args.ref != "HEAD" else "HEAD" + repo_root = Path(_git("rev-parse", "--show-toplevel", capture=True)) + original_branch = _git("rev-parse", "--abbrev-ref", "HEAD", capture=True) + + # Refuse to build a request from a dirty tree: `git add -A` below + # would silently absorb unrelated uncommitted edits into the + # request branch and they would vanish from the original branch + # when we switch back (observed in live testing). The workload is + # always a committed state. + dirty = _git("status", "--porcelain", capture=True) + if dirty: + print("[mac-bridge] working tree is dirty; commit or stash first:\n" + + dirty, file=sys.stderr) + return 2 + + # Snapshot bridge files from the CLIENT checkout before switching: + # the workload ref may predate the bridge. + overlay = { + rel: (repo_root / rel).read_bytes() + for rel in BRIDGE_FILES + if (repo_root / rel).exists() + } + + print(f"[mac-bridge] creating {branch} from {start_point}", file=sys.stderr) + _git("checkout", "-b", branch, start_point) + try: + changed = False + for rel, blob in overlay.items(): + dst = repo_root / rel + if not dst.exists() or dst.read_bytes() != blob: + dst.parent.mkdir(parents=True, exist_ok=True) + dst.write_bytes(blob) + changed = True + manifest_path = repo_root / MANIFEST_PATH + manifest_path.parent.mkdir(parents=True, exist_ok=True) + manifest_path.write_text(json.dumps(manifest, indent=2) + "\n") + _git("add", "-A") + _git("commit", "-q", "-m", + f"mac-bridge request: {args.preset} (nonce {nonce})" + + ("\n\n(bridge files overlaid onto pre-bridge ref)" if changed else "")) + if args.no_push: + print(f"[mac-bridge] built {branch} (NOT pushed; --no-push)", + file=sys.stderr) + else: + _git("push", "-u", args.remote, branch) + print(f"[mac-bridge] pushed {branch}; the kakeya-mac-m4 runner " + "will pick it up. Poll with:\n" + f" python3 scripts/mac_bridge/fetch_results.py --branch {branch}", + file=sys.stderr) + if not args.keep_branch and original_branch != "HEAD": + _git("checkout", "-q", original_branch) + print(f"[mac-bridge] returned to {original_branch}", + file=sys.stderr) + except Exception: + # Leave the workspace on the request branch for inspection, but + # surface the failure loudly. + raise + print(branch) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/mac_bridge/run_preset.py b/scripts/mac_bridge/run_preset.py new file mode 100644 index 00000000..a95122ad --- /dev/null +++ b/scripts/mac_bridge/run_preset.py @@ -0,0 +1,107 @@ +"""Mac-bridge executor — runs ONE validated preset on the Mac runner. + +Invoked by .github/workflows/mac-bridge.yaml with the manifest that the +requesting agent committed at .mac-bridge/request.json. All allowlist +and parameter validation lives in inference_engine.bridge.manifest +(unit-tested on the Linux gate); this CLI only sequences subprocesses +and tees logs. + +No shell is ever involved: commands are argv lists from +``build_commands`` passed straight to ``subprocess.run``. + +Usage: + python3 scripts/mac_bridge/run_preset.py --manifest .mac-bridge/request.json + python3 scripts/mac_bridge/run_preset.py --manifest ... --dry-run + +CLI plumbing around the unit-tested manifest library; exempt from +unit-test coverage by the scripts/serve.py convention. +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +from inference_engine.bridge.manifest import ( + ManifestError, + build_commands, + parse_manifest_text, +) + +LOG_DIR = Path(".mac-bridge/logs") + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--manifest", default=".mac-bridge/request.json") + ap.add_argument("--dry-run", action="store_true", + help="Validate + print the resolved argv lists without " + "executing anything (used by Linux-side checks).") + args = ap.parse_args() + + try: + request = parse_manifest_text(Path(args.manifest).read_text()) + commands = build_commands(request, dict(os.environ)) + except (OSError, ManifestError) as exc: + print(f"[mac-bridge] REJECTED: {exc}", file=sys.stderr) + return 2 + + print(f"[mac-bridge] preset={request.preset.name} " + f"params={dict(request.params)} ref={request.ref} " + f"requested_by={request.requested_by}", file=sys.stderr) + if args.dry_run: + for argv in commands: + print(json.dumps(argv)) + return 0 + + LOG_DIR.mkdir(parents=True, exist_ok=True) + summary = { + "preset": request.preset.name, + "params": dict(request.params), + "nonce": request.nonce, + "commands": [], + } + rc = 0 + for idx, argv in enumerate(commands): + log_path = LOG_DIR / f"{request.preset.name}-{idx}.log" + print(f"[mac-bridge] exec[{idx}]: {argv}", file=sys.stderr) + t0 = time.perf_counter() + with log_path.open("wb") as log: + proc = subprocess.run(argv, stdout=log, stderr=subprocess.STDOUT) + elapsed = time.perf_counter() - t0 + summary["commands"].append({ + "argv": argv, + "exit_code": proc.returncode, + "seconds": round(elapsed, 1), + "log": str(log_path), + }) + print(f"[mac-bridge] exec[{idx}] exit={proc.returncode} " + f"({elapsed:.1f}s) log={log_path}", file=sys.stderr) + if proc.returncode != 0: + rc = proc.returncode + break + + # Evidence discipline (design doc C4): K3 acceptance reports produced + # by this run must satisfy the evidence gate ON THE MAC, so a + # non-conforming report fails the bridge run itself. + if rc == 0 and request.preset.validate_reports: + gate = subprocess.run( + [sys.executable, "scripts/validate_k3_reports.py", + "results/research"], + ) + summary["evidence_gate_exit_code"] = gate.returncode + if gate.returncode != 0: + rc = gate.returncode + + summary["exit_code"] = rc + (LOG_DIR / "summary.json").write_text(json.dumps(summary, indent=2)) + return rc + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/mac_bridge/setup_mac.sh b/scripts/mac_bridge/setup_mac.sh new file mode 100644 index 00000000..888884ee --- /dev/null +++ b/scripts/mac_bridge/setup_mac.sh @@ -0,0 +1,178 @@ +#!/usr/bin/env bash +# One-click Mac-side setup for the Kakeya Mac bridge. +# +# Run ON the Mac mini, from the repo root: +# +# # Runner already installed (existing kakeya-mac-m4 host): +# bash scripts/mac_bridge/setup_mac.sh +# +# # Fresh Mac, install + register the Actions runner too +# # (get the token from GitHub: Settings -> Actions -> Runners -> +# # New self-hosted runner -> copy the --token value): +# bash scripts/mac_bridge/setup_mac.sh --runner-token \ +# --repo-url https://github.com// +# +# # Also prepare M2 interactive access (Tailscale SSH): +# bash scripts/mac_bridge/setup_mac.sh --with-tailscale +# +# Idempotent: every step checks before it changes anything. Ends with a +# bridge self-test (manifest validation + dry-run argv resolution) so a +# green exit means the next `mac-bridge/**` push will execute. +# +# See docs/design/mac-bridge-cloud-agent-access.md and +# docs/ops/mac-m4-runner-setup.md. + +set -euo pipefail + +RUNNER_DIR="${RUNNER_DIR:-$HOME/actions-runner}" +RUNNER_LABELS="self-hosted,macOS,ARM64,kakeya-mac-m4" +RUNNER_TOKEN="" +REPO_URL="" +WITH_TAILSCALE=0 + +while [ $# -gt 0 ]; do + case "$1" in + --runner-token) RUNNER_TOKEN="$2"; shift 2 ;; + --repo-url) REPO_URL="$2"; shift 2 ;; + --with-tailscale) WITH_TAILSCALE=1; shift ;; + *) echo "unknown arg: $1" >&2; exit 2 ;; + esac +done + +step() { printf '\n\033[1m== %s ==\033[0m\n' "$*"; } +ok() { printf ' \033[32mOK\033[0m %s\n' "$*"; } +warn() { printf ' \033[33mWARN\033[0m %s\n' "$*"; } +die() { printf ' \033[31mFAIL\033[0m %s\n' "$*" >&2; exit 1; } + +step "1/6 Host shape" +[ "$(uname -s)" = "Darwin" ] || die "this script runs on macOS (got $(uname -s))" +[ "$(uname -m)" = "arm64" ] || die "Apple Silicon required (got $(uname -m))" +ok "macOS arm64" +PYVER="$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[:2])))')" \ + || die "python3 not on PATH" +python3 -c 'import sys; sys.exit(0 if sys.version_info >= (3, 12) else 1)' \ + || die "Python >= 3.12 required (got ${PYVER}); brew install python@3.12" +ok "python3 ${PYVER}" +[ -f "scripts/mac_bridge/run_preset.py" ] || die "run from the repo root" +ok "repo root: $(pwd)" + +step "2/6 Python dependencies (into the runner's python3)" +# The Actions runner executes jobs with the host's plain `python3` +# (see integration.yaml's install step) — NOT the .venv-mac that +# scripts/setup_mac.sh builds for interactive dev. Install into the +# same interpreter the bridge workflow will use. +if python3 -c 'import mlx.core, mlx_lm, torch, pytest' 2>/dev/null; then + ok "mlx / mlx_lm / torch / pytest importable" +else + warn "installing project deps into $(command -v python3) (first run takes a few minutes)" + python3 -m pip install --upgrade pip --quiet + python3 -m pip install --quiet -r requirements.txt + python3 -m pip install --quiet 'mlx>=0.20' 'mlx-lm>=0.18' \ + pytest pytest-asyncio pytest-timeout + python3 -c 'import mlx.core, mlx_lm, torch, pytest' \ + || die "deps still not importable after install" + ok "deps installed" +fi +# Version sanity for the K3 path: transformers >= 5.0 is required by +# Gemma 4 / DFlash / current mlx-lm (requirements.txt dropped the <5 +# pin; scripts/setup_mac.sh used to enforce it and broke setups with +# transformers 5.x — fixed alongside this script). +python3 - <<'PY' +import sys +from importlib.metadata import version +from packaging.version import Version +v = Version(version("transformers")) +if v < Version("4.45"): + sys.exit(f"transformers {v} < 4.45 floor; pip install -U transformers") +print(f" transformers {v} (K3 path wants >= 5.0: " + f"{'OK' if v >= Version('5.0') else 'WARN — k3-* presets may fail'})") +PY +ok "dependency versions consistent with requirements.txt" + +step "3/6 GitHub Actions runner (${RUNNER_DIR})" +if [ -f "${RUNNER_DIR}/.runner" ]; then + ok "runner already configured: $(grep -o '"agentName": *"[^"]*"' "${RUNNER_DIR}/.runner" || true)" + if "${RUNNER_DIR}/svc.sh" status 2>/dev/null | grep -q "Started"; then + ok "runner service running" + else + warn "runner service not running; starting" + (cd "${RUNNER_DIR}" && sudo ./svc.sh start) + fi +else + [ -n "${RUNNER_TOKEN}" ] || die "no runner at ${RUNNER_DIR}; rerun with --runner-token --repo-url (GitHub: Settings->Actions->Runners->New self-hosted runner)" + [ -n "${REPO_URL}" ] || die "--repo-url required with --runner-token" + mkdir -p "${RUNNER_DIR}" + cd "${RUNNER_DIR}" + LATEST="$(curl -fsSL https://api.github.com/repos/actions/runner/releases/latest \ + | python3 -c 'import json,sys; print(json.load(sys.stdin)["tag_name"].lstrip("v"))')" + echo " downloading actions-runner v${LATEST} (osx-arm64)" + curl -fsSL -o runner.tar.gz \ + "https://github.com/actions/runner/releases/download/v${LATEST}/actions-runner-osx-arm64-${LATEST}.tar.gz" + tar xzf runner.tar.gz && rm runner.tar.gz + ./config.sh --unattended --url "${REPO_URL}" --token "${RUNNER_TOKEN}" \ + --name "kakeya-mac-m4" --labels "${RUNNER_LABELS}" --replace + sudo ./svc.sh install && sudo ./svc.sh start + cd - >/dev/null + ok "runner installed + started with labels ${RUNNER_LABELS}" +fi + +step "4/6 Bridge model locations (k3-* presets)" +# Canonical stable location on the runner host: ~/kakeya-models/ +# (symlinks are fine). Repo Actions variables override; repo-relative +# paths are the last fallback. +DEFAULT_VERIFIER="$HOME/kakeya-models/gemma-4-26B-A4B-it-mlx-4bit" +[ -d "$DEFAULT_VERIFIER" ] || DEFAULT_VERIFIER="models/gemma-4-26B-A4B-it-mlx-4bit" +VERIFIER="${KAKEYA_MAC_VERIFIER_PATH:-$DEFAULT_VERIFIER}" +FTHETA="${KAKEYA_MAC_FTHETA_DIR:-results/research/f_theta_v5_s5_sliding}" +if [ -d "${VERIFIER}" ]; then + ok "verifier: ${VERIFIER}" +else + warn "verifier not found. The k3-* presets need it. Easiest fix:" + warn " mkdir -p ~/kakeya-models && ln -sfn \\" + warn " ~/kakeya-models/gemma-4-26B-A4B-it-mlx-4bit" + warn "(or set the repo Actions variable KAKEYA_MAC_VERIFIER_PATH)." +fi +if [ -d "${FTHETA}" ]; then + ok "f_theta: ${FTHETA}" +else + warn "f_theta dir not at '${FTHETA}' (set KAKEYA_MAC_FTHETA_DIR var)." +fi +if [ -d "${HOME}/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B" ]; then + ok "HF cache: Qwen3-0.6B pre-warmed (integration-tests preset ready)" +else + warn "Qwen3-0.6B not in HF cache; integration-tests preset will fail." + warn "Pre-warm with: PYTHONPATH=. python3 scripts/kakeya_prewarm.py" +fi + +step "5/6 Bridge self-test (manifest validation + dry-run argv)" +TMP_MANIFEST="$(mktemp)" +python3 - "$TMP_MANIFEST" <<'PY' +import json, sys, time +json.dump({ + "schema_version": 1, "preset": "mlx-env-probe", "params": {}, + "ref": "HEAD", "requested_by": "setup-self-test", + "nonce": f"{int(time.time())}-selftest", +}, open(sys.argv[1], "w")) +PY +PYTHONPATH=.:sdks/python python3 scripts/mac_bridge/run_preset.py \ + --manifest "$TMP_MANIFEST" --dry-run >/dev/null \ + || die "bridge self-test failed" +rm -f "$TMP_MANIFEST" +ok "executor validates + resolves presets" +PYTHONPATH=.:sdks/python python3 -c \ + 'from inference_engine.backends.mlx.env import probe_environment; print(" " + probe_environment().render())' + +step "6/6 Optional: Tailscale (M2 interactive access)" +if [ "${WITH_TAILSCALE}" = "1" ]; then + command -v brew >/dev/null || die "Homebrew required for --with-tailscale" + command -v tailscale >/dev/null || brew install tailscale + sudo brew services start tailscale 2>/dev/null || true + warn "complete login + enable Tailscale SSH manually:" + warn " sudo tailscale up --ssh --advertise-tags=tag:kakeya-mac" +else + ok "skipped (rerun with --with-tailscale to enable M2 interactive SSH)" +fi + +printf '\n\033[1m\033[32mMac bridge ready.\033[0m Any push to mac-bridge/** (or AgentMemory/mac-bridge-*) now executes here.\n' +printf 'Smoke it from any clone with push rights:\n' +printf ' PYTHONPATH=.:sdks/python python3 scripts/mac_bridge/kakeya_mac.py run --preset mlx-env-probe --wait 600\n' diff --git a/scripts/research/k3_integrated_niah_eval_mac.py b/scripts/research/k3_integrated_niah_eval_mac.py index ba1b820f..8025206e 100644 --- a/scripts/research/k3_integrated_niah_eval_mac.py +++ b/scripts/research/k3_integrated_niah_eval_mac.py @@ -63,6 +63,76 @@ def parse_args() -> argparse.Namespace: ap.add_argument("--sink-size", type=int, default=4) ap.add_argument("--window-size", type=int, default=64) ap.add_argument("--max-new-tokens", type=int, default=16) + ap.add_argument("--incremental", action="store_true", + help="Use the INCREMENTAL restored decode (MLX Gap-A): " + "prefill captures restored K/V into a persistent cache, " + "decode via mlx_lm generate_step (O(L)/token). Fixes the " + "per-token re-forward throughput collapse. Free-gen only.") + ap.add_argument("--fused-specdecode", action="store_true", + help="Use the FUSED DFlash spec-decode engine (MLX port of " + "#107 A+B+C): drafter context K/V cache + aux capture from " + "the verify forward + incremental restored verify with " + "trim_prompt_cache accept/reject. Free-gen only.") + ap.add_argument("--force-fused-specdecode", action="store_true", + help="Deprecated alias: --fused-specdecode now ALWAYS runs " + "the fused engine (evidence-gate constraint; the " + "silent greedy fallback that produced blocks=0 " + "reports labelled fused is no longer reachable).") + ap.add_argument("--native-baseline-bypass", action="store_true", + help="Run the verifier on its NATIVE cache (no restoration, " + "no drafter/f_theta) and label the run as " + "system_under_test=native_ar_baseline. This is the " + "only way to run the former 'adaptive S5 native' " + "path; it can no longer occupy the cross-model slot " + "or claim recall/speedup (k3_report_gate rules " + "BASELINE_AS_SUT / SPEEDUP_SELF_COMPARISON).") + ap.add_argument("--direct-answer-prompt", action=argparse.BooleanOptionalAction, + default=True, + help="For NIAH free generation, add a strict instruction to " + "answer with only the secret code. This keeps short " + "smokes from spending the token budget on Gemma4's " + "thought/channel preamble. Use --no-direct-answer-prompt " + "to reproduce the legacy prompt exactly.") + ap.add_argument("--chat-template-prompt", action="store_true", + help="Deprecated compatibility flag; chat-template prompting " + "is the default for Gemma4 NIAH.") + ap.add_argument("--raw-completion-prompt", action="store_true", + help="Diagnostic only: bypass chat template and encode the " + "NIAH prompt as a raw completion prompt.") + ap.add_argument("--content-channel-prefill", + action=argparse.BooleanOptionalAction, + default=True, + help="With chat-template direct-answer prompts, append " + "Gemma4's content channel marker before generation so " + "short smokes do not spend tokens on the thought channel.") + ap.add_argument("--all-mlx-drafter", action="store_true", + help="Step-2 rescue: run the DFlash drafter natively in " + "MLX (inference_engine.backends.mlx.dflash_drafter) " + "instead of PyTorch — zero mx<->torch bridge " + "crossings per block. Requires --s5-exact-full-attn " + "(the all-MLX path uses native-S5 injection; the " + "f_theta sliding restoration path stays torch).") + ap.add_argument("--single-fused", action="store_true", + help="PROBE: with --cuda-trim, fuse drafter+verifier into ONE " + "graph (skip the two-phase eval) to classify the Metal " + "instability (fundamental command-buffer vs fixable SDPA " + "fallback). Reports per-block eval times.") + ap.add_argument("--cuda-trim", action="store_true", + help="All-MLX fused with the CUDA-parity rollback: all-KVCache " + "verifier layout + native trim_prompt_cache (keep accepted " + "K/V, drop only rejected) instead of the v3 carry " + "re-forward. Requires --all-mlx-drafter --fused-specdecode.") + ap.add_argument("--code-prompts", action="store_true", + help="Replace the NIAH dataset with code-completion prompts " + "(naturally-long, predictable generation = the spec-decode " + "sweet spot). Recall metric is N/A; measures honest " + "decode-only throughput + acceptance on a real workload.") + ap.add_argument("--ignore-turn-stop", action="store_true", + help="Do not include Gemma4 as a stop token. " + "Useful for throughput evidence runs that require " + "decode median >= 32 tokens.") + ap.add_argument("--block-size", type=int, default=4, + help="Spec-decode block size (drafted tokens per block).") ap.add_argument("--teacher-forced", action="store_true", help="DIAGNOSTIC ONLY (under-measures retrieval): single " "restored forward per sample, check argmax at the " @@ -73,8 +143,11 @@ def parse_args() -> argparse.Namespace: "incremental cache; the restored cross path does a " "full forward per token (slow on M4 — see notes).") ap.add_argument("--seed", type=int, default=42) - ap.add_argument("--drafter-device", default="mps", + ap.add_argument("--drafter-device", default="cpu", help="torch device for the DFlash drafter + f_θ (mps|cpu)") + ap.add_argument("--torch-cpu-threads", type=int, default=0, + help="Override torch CPU intra-op threads for drafter/f_theta " + "(0 keeps torch default).") ap.add_argument("--s5-exact-full-attn", action="store_true", help="Keep full-attention layers' K/V exact (S5).") ap.add_argument("--identity-restore", action="store_true", @@ -88,6 +161,14 @@ def parse_args() -> argparse.Namespace: ap.add_argument("--kl-lattice", default="D4", choices=["D4", "E8"]) ap.add_argument("--kl-q-range", type=int, default=38) ap.add_argument("--skip-oracle", action="store_true") + ap.add_argument("--decode-warmup-tokens", type=int, default=1, + help="Run a tiny untimed native decode warmup before " + "cross/oracle measurements so MLX/Metal compilation " + "cost does not fall only on the first measured path.") + ap.add_argument("--prefill-chunk-size", type=int, default=512, + help="Chunk MLX prompt prefill/forward calls to avoid the " + "long-context one-shot forward OOM path. Set <=0 to " + "use a single full prompt forward.") ap.add_argument("--output", default=None) return ap.parse_args() @@ -108,13 +189,29 @@ def main() -> int: resolve_mlx_text_model, mlx_full_attention_layer_indices, kv_source_layer_map, capture_own_kv, restored_logits, per_layer_kv_geometry, kv_memory_report, + restored_prefill_cache, restored_incremental_generate, + ) + from inference_engine.backends.mlx.fused_specdecode import ( + MLXRestoredIncrementalVerifier, capture_aux_hidden, + make_bridge_embed_lm_head, fused_specdecode_generate, + fused_specdecode_generate_mlx, fused_specdecode_generate_mlx_trim, ) from inference_engine.v04.kv_compressor import make_default_compressor + from inference_engine.bench.k3_report_gate import ( + CLAIM_ORACLE_DECODE_LOOP, MIN_MEDIAN_DECODE_TOKENS, MIN_PERF_SAMPLES, + MAX_PREFILL_SPREAD, NATIVE_BASELINE_LABEL, + decode_only_block, prefill_spread, summarize_violations, + validate_report, + ) from scripts.research.k3_dflash_mlx_bridge import ( mx_to_torch, torch_to_mx, ) torch.manual_seed(args.seed) + if int(args.torch_cpu_threads or 0) > 0: + torch.set_num_threads(int(args.torch_cpu_threads)) + print(f"[mac] torch CPU threads={torch.get_num_threads()}", + file=sys.stderr, flush=True) dev = torch.device(args.drafter_device if ( args.drafter_device == "cpu" or torch.backends.mps.is_available() ) else "cpu") @@ -129,16 +226,58 @@ def main() -> int: src_map = kv_source_layer_map(text_model) print(f"[mac] verifier layers={n_layers} full_attn={full_attn_idx}", file=sys.stderr) - # ---------- Load drafter + f_θ (PyTorch) ---------- - print(f"[mac] loading drafter {args.drafter_id} on {dev}", file=sys.stderr, flush=True) - drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32) - drafter = drafter.to(dev).eval() - for p in drafter.parameters(): - p.requires_grad_(False) - f_theta = FThetaProjection.from_pretrained( - args.f_theta_dir, dtype=torch.float32, device=dev, - ) - fcfg = f_theta.config + # Evidence-gate path resolution (k3_report_gate): + # * --fused-specdecode ALWAYS executes the fused engine. The former + # implicit "adaptive S5 native" bypass silently replaced the system + # under test with the native baseline while keeping the fused label + # (committed reports showed blocks=0 on every sample). + # * The native baseline is still runnable — but only explicitly, and + # it is labelled as a baseline in the report. + if args.native_baseline_bypass and args.force_fused_specdecode: + raise SystemExit( + "--native-baseline-bypass and --force-fused-specdecode are " + "mutually exclusive: a run is either the native baseline or " + "the fused system under test.") + if args.native_baseline_bypass: + args.fused_specdecode = True # route through the cache-based loop + elif args.fused_specdecode or args.force_fused_specdecode: + args.fused_specdecode = True + args.force_fused_specdecode = True + adaptive_s5_native = args.native_baseline_bypass + if args.all_mlx_drafter and not args.s5_exact_full_attn: + raise SystemExit( + "--all-mlx-drafter requires --s5-exact-full-attn: the all-MLX " + "path uses native-S5 prefill injection; the f_theta sliding " + "restoration path is torch-only.") + drafter = None + mlx_drafter = None + f_theta = None + fcfg = None + if adaptive_s5_native: + print("[mac] native baseline bypass: skipping drafter/f_theta load; " + "report will be labelled system_under_test=native_ar_baseline", + file=sys.stderr, flush=True) + elif args.all_mlx_drafter: + # ---------- Step-2 rescue: drafter native in MLX ---------- + from inference_engine.backends.mlx.dflash_drafter import ( + MLXDFlashDrafter, + ) + print(f"[mac] loading ALL-MLX drafter {args.drafter_id}", + file=sys.stderr, flush=True) + mlx_drafter = MLXDFlashDrafter.from_pretrained(args.drafter_id) + # No torch drafter / f_theta: S5-native injection covers prefill + # restoration and the drafter never leaves the Metal stream. + else: + # ---------- Load drafter + f_θ (PyTorch) ---------- + print(f"[mac] loading drafter {args.drafter_id} on {dev}", file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32) + drafter = drafter.to(dev).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + f_theta = FThetaProjection.from_pretrained( + args.f_theta_dir, dtype=torch.float32, device=dev, + ) + fcfg = f_theta.config # ---------- Optional KakeyaLattice compression of full-attn layers ---------- geom = per_layer_kv_geometry(text_model) @@ -178,7 +317,6 @@ def _compress_roundtrip(li: int, k_mx: Any, v_mx: Any): def capture_drafter_kv(ids: List[int]): ids_mx = mx.array([ids]) emb_mx = text_model.embed_tokens(ids_mx) - emb_mx = emb_mx * embed_scale embedded = mx_to_torch(emb_mx, dtype=torch.float32, device=dev) # [1,T,H] layers = list(drafter.layers) k_cap: List[Optional[torch.Tensor]] = [None] * len(layers) @@ -211,18 +349,31 @@ def capture_drafter_kv(ids: List[int]): # is evicted, so the prompt's restored K/V cover every injected slot. exact_set = set(range(n_layers)) if args.identity_restore else set(full_attn_idx) - def build_restoration(prompt_ids: List[int]): + def build_restoration(prompt_ids: List[int], *, prefill_native_s5: bool = False): + """Build restored K/V banks. + + For incremental/fused S5 decode, full-attention exact K/V should come + from the same MLX prefill that populates the native cache. Supplying no + restored bank for those layers lets mlx_lm store their own post-RoPE + cache directly and avoids the extra clean verifier forward. + """ + if prefill_native_s5 and args.s5_exact_full_attn and not args.identity_restore: + return {}, {}, len(prompt_ids) + if drafter is None or f_theta is None or fcfg is None: + raise RuntimeError("drafter/f_theta are required for this restoration mode") d_k, d_v = capture_drafter_kv(prompt_ids) with torch.no_grad(): vk, vv = f_theta.forward_kv_pack(d_k, d_v) own = None - if exact_set: + if exact_set and not prefill_native_s5: own = capture_own_kv(mlx_model, prompt_ids) rk: Dict[int, Any] = {} rv: Dict[int, Any] = {} for li in range(n_layers): if src_map[li] != li: continue + if prefill_native_s5 and li in exact_set: + continue if li in exact_set and own is not None and li in own: k_mx, v_mx = own[li] if li in compressors: @@ -255,19 +406,80 @@ def restored_forward(ids: List[int], rk, rv, t_src, *, return_all: bool): ) # ---------- Dataset ---------- - samples: List[NIAHSample] = make_niah_dataset( - n_samples=args.n_samples, - haystack_min_lines=args.haystack_min_lines, - haystack_max_lines=args.haystack_max_lines, - seed=args.seed, - ) + if args.code_prompts: + _CODE = [ + "Write a complete Python implementation of a binary search tree class " + "with insert, search, and in-order traversal methods. Include type " + "hints and docstrings.", + "Implement a Python LRU cache class with get and put methods using an " + "OrderedDict. Include type hints and docstrings.", + "Write a Python function that parses a CSV string into a list of dicts, " + "correctly handling quoted fields and embedded commas. Add error handling.", + "Implement quicksort in Python with an in-place partition helper. " + "Include docstrings and a small example in a __main__ block.", + "Write a Python class for a fixed-capacity ring buffer with push, pop, " + "and is_full methods, raising on overflow. Include type hints.", + "Implement a recursive descent parser in Python for arithmetic " + "expressions with + - * / and parentheses. Return the evaluated value.", + "Write a Python decorator `retry` that retries a function up to n times " + "with exponential backoff on exception. Include type hints and docstring.", + "Implement a thread-safe counter class in Python using threading.Lock, " + "with increment, decrement, and value methods.", + ] + n = min(args.n_samples, len(_CODE)) + samples: List[NIAHSample] = [ + NIAHSample(prompt_text=p, answer_text="", needle_line_index=0, + needle_text="") + for p in _CODE[:n] + ] + print(f"[mac] CODE-PROMPTS workload: {n} prompts (recall N/A; " + f"measuring decode throughput + acceptance)", file=sys.stderr) + else: + samples = make_niah_dataset( + n_samples=args.n_samples, + haystack_min_lines=args.haystack_min_lines, + haystack_max_lines=args.haystack_max_lines, + seed=args.seed, + ) def encode(prompt_text: str) -> List[int]: + if args.direct_answer_prompt: + # The legacy padding repeats "answer" on every line; Gemma4 can + # latch onto that distractor in very short completion smokes. + # Keep the needle/question unchanged but make filler semantically + # neutral so recall measures retrieval of the secret code. + prompt_text = prompt_text.replace( + "and does not contain the answer.", + "and is unrelated filler.", + ) + prompt_text = ( + prompt_text + + "\n\nReturn only the secret code in PREFIX-NNNN format. " + "Do not explain, reason, or add any other text." + ) + if args.direct_answer_prompt and args.raw_completion_prompt: + try: + ids = tokenizer.encode(prompt_text, add_special_tokens=True) + except TypeError: + ids = tokenizer.encode(prompt_text) + if hasattr(ids, "tolist"): + ids = ids.tolist() + return list(ids) msgs = [{"role": "user", "content": prompt_text}] ids = tokenizer.apply_chat_template(msgs, add_generation_prompt=True) if hasattr(ids, "tolist"): ids = ids.tolist() - return list(ids) + ids = list(ids) + if args.direct_answer_prompt and args.content_channel_prefill: + try: + marker = tokenizer.encode( + "<|channel>content\n", add_special_tokens=False) + except TypeError: + marker = tokenizer.encode("<|channel>content\n") + if hasattr(marker, "tolist"): + marker = marker.tolist() + ids.extend(list(marker)) + return ids def encode_answer(answer_text: str) -> List[int]: try: @@ -283,9 +495,54 @@ def encode_answer(answer_text: str) -> List[int]: answer_ids = [encode_answer(s.answer_text) for s in samples] seq_lens = [len(t) for t in sample_ids] eos_id = getattr(tokenizer, "eos_token_id", None) + end_ids = set() + if eos_id is not None: + end_ids.add(int(eos_id)) + try: + eot_ids = tokenizer.encode("", add_special_tokens=False) + except TypeError: + eot_ids = tokenizer.encode("") + if hasattr(eot_ids, "tolist"): + eot_ids = eot_ids.tolist() + if (not args.ignore_turn_stop) and len(eot_ids) == 1: + end_ids.add(int(eot_ids[0])) print(f"[mac] {len(samples)} samples, prompt len " f"min={min(seq_lens)} max={max(seq_lens)}", file=sys.stderr) + def native_prefill(input_ids: List[int]): + """Prefill the verifier cache on MLX's native path, optionally chunked.""" + cache = (getattr(mlx_model, "make_cache", lambda: None)()) + chunk = int(args.prefill_chunk_size) + if chunk <= 0 or len(input_ids) <= chunk: + out = mlx_model(mx.array([input_ids]), cache=cache) + mx.eval(out) + return cache, out[0, -1] + + last = None + for start in range(0, len(input_ids), chunk): + part = input_ids[start:start + chunk] + if not part: + continue + last = mlx_model(mx.array([part]), cache=cache) + mx.eval(last) + if last is None: + last = mlx_model(mx.array([input_ids]), cache=cache) + mx.eval(last) + return cache, last[0, -1] + + def warmup_decode() -> None: + """Warm MLX/Metal decode kernels before comparing cross vs oracle.""" + if args.decode_warmup_tokens <= 0 or not sample_ids: + return + cache, logits = native_prefill(sample_ids[0]) + for _ in range(args.decode_warmup_tokens): + tok = int(mx.argmax(logits).item()) + out = mlx_model(mx.array([[tok]]), cache=cache) + mx.eval(out) + logits = out[0, -1] + if tok in end_ids: + break + def eval_teacher_forced(logits_all_fn) -> Tuple[List[str], List[float], List[int]]: """One restored forward per sample over [prompt + needle-code]; check the argmax at the answer span reproduces the code (substring predicate @@ -310,41 +567,276 @@ def eval_free_gen_cross() -> Tuple[List[str], List[float], List[int]]: """Restored free generation: 1 restored full forward per token (amortized restoration). Correct recall metric; slow on M4.""" decoded, lats, toks = [], [], [] + rows = [] for i, pid in enumerate(sample_ids): + e2e_t0 = time.perf_counter() + build_t0 = time.perf_counter() rk, rv, tsrc = build_restoration(pid) + build_s = time.perf_counter() - build_t0 cur = list(pid); gen: List[int] = [] t0 = time.perf_counter() for _ in range(args.max_new_tokens): last = restored_forward(cur, rk, rv, tsrc, return_all=False) nxt = int(mx.argmax(last).item()); gen.append(nxt) - if eos_id is not None and nxt == eos_id: + if nxt in end_ids: break cur.append(nxt) - lats.append(time.perf_counter() - t0) + decode_s = time.perf_counter() - t0 + e2e_s = time.perf_counter() - e2e_t0 + lats.append(e2e_s) decoded.append(tokenizer.decode(gen)); toks.append(len(gen)) + rows.append({ + "sample": i, + "build_restoration_s": round(build_s, 3), + "decode_s": round(decode_s, 3), + "e2e_s": round(e2e_s, 3), + "restoration_active": True, + "decode_loop": "full_reforward_per_token", + }) print(f"[mac] sample {i}: T={seq_lens[i]} -> {decoded[-1][:48]!r}", file=sys.stderr) + eval_free_gen_cross.stage_rows = rows return decoded, lats, toks - def eval_free_gen_oracle() -> Tuple[List[str], List[float], List[int]]: - """Oracle free generation using mlx's NATIVE incremental KV cache - (fast + correct reference; confirms the metric/dataset).""" + def eval_free_gen_cross_incremental() -> Tuple[List[str], List[float], List[int]]: + """INCREMENTAL restored free generation (MLX port of CUDA Gap-A): + prefill ONCE capturing restored K/V into a persistent cache, then + decode with mlx_lm's native incremental step (O(L)/token). Fixes the + per-token re-forward throughput collapse. Recall via S5 full-attn.""" + decoded, lats, toks = [], [], [] + rows = [] + for i, pid in enumerate(sample_ids): + e2e_t0 = time.perf_counter() + build_t0 = time.perf_counter() + rk, rv, tsrc = build_restoration(pid, prefill_native_s5=True) + build_s = time.perf_counter() - build_t0 + T = len(pid) + evicted = compute_evicted_positions(T, args.sink_size, args.window_size) + prefill_t0 = time.perf_counter() + if not evicted: + cache, first = native_prefill(pid) + else: + cache, first = restored_prefill_cache( + mlx_model, pid, + restored_k_per_layer=_pad(rk, tsrc, T), + restored_v_per_layer=_pad(rv, tsrc, T), + evicted_positions=evicted, + prefill_chunk_size=args.prefill_chunk_size) + prefill_s = time.perf_counter() - prefill_t0 + decode_t0 = time.perf_counter() + gen = restored_incremental_generate( + mlx_model, cache, first, + max_tokens=args.max_new_tokens, + eos_ids=end_ids) + decode_s = time.perf_counter() - decode_t0 + e2e_s = time.perf_counter() - e2e_t0 + lats.append(e2e_s) + decoded.append(tokenizer.decode(gen)); toks.append(len(gen)) + rows.append({ + "sample": i, + "build_restoration_s": round(build_s, 3), + "prefill_s": round(prefill_s, 3), + "decode_s": round(decode_s, 3), + "e2e_s": round(e2e_s, 3), + "restoration_active": True, + "decode_loop": "generate_step", + }) + print(f"[mac] incr {i}: T={seq_lens[i]} " + f"prefill={prefill_s:.1f}s decode={decode_s:.1f}s " + f"-> {decoded[-1][:48]!r}", file=sys.stderr) + eval_free_gen_cross_incremental.stage_rows = rows + return decoded, lats, toks + + def eval_fused_specdecode() -> Tuple[List[str], List[float], List[int]]: + """FUSED DFlash spec-decode (MLX port of #107 A+B+C): drafter context + K/V cache + aux captured from the verify forward + incremental restored + verify with trim_prompt_cache accept/reject. Target: tok/s > AR.""" + argmax_fn = lambda row: int(mx.argmax(row).item()) + active_drafter = mlx_drafter if mlx_drafter is not None else drafter + aux_layer_ids = (tuple(active_drafter.cfg.aux_layer_ids) + if active_drafter is not None else ()) + softcap = None + for obj in (getattr(mlx_model, "language_model", None), mlx_model): + cap = getattr(obj, "final_logit_softcapping", None) if obj is not None else None + if cap: + softcap = float(cap); break + if mlx_drafter is not None: + # All-MLX path (Step-2 rescue): drafter, embed/lm_head, aux + # slices, positions, and concat all stay on the Metal stream — + # zero mx<->torch crossings per block. + from inference_engine.backends.mlx.dflash_drafter import ( + make_native_embed_lm_head, + ) + bridge = None # identity: aux slices stay mx + embed_fn, lm_head_fn = make_native_embed_lm_head( + text_model, softcap=softcap) + arange_fn = lambda s, e: mx.arange(int(s), int(e)) + cat_aux_fn = lambda parts: ( + parts[0][None] if len(parts) == 1 + else mx.concatenate(list(parts), axis=0)[None]) + elif args.force_fused_specdecode: + if drafter is None: + raise RuntimeError("--force-fused-specdecode requires drafter/f_theta") + bridge = lambda a: mx_to_torch(a, dtype=torch.float32, device=dev) + embed_fn, lm_head_fn = make_bridge_embed_lm_head( + text_model, mx_to_torch=mx_to_torch, torch_to_mx=torch_to_mx, + device=dev, torch_dtype=torch.float32, softcap=softcap) + arange_fn = lambda s, e: torch.arange(int(s), int(e), device=dev) + cat_aux_fn = lambda parts: torch.cat(list(parts), dim=0).unsqueeze(0) + else: + bridge = lambda a: mx_to_torch(a, dtype=torch.float32, device=dev) + embed_fn = lm_head_fn = arange_fn = cat_aux_fn = None + adapter = MLXRestoredIncrementalVerifier( + mlx_model, embed_scale=embed_scale, aux_layer_ids=aux_layer_ids, + bridge_to_torch=bridge) + decoded, lats, toks = [], [], [] - make_cache = getattr(mlx_model, "make_cache", None) + rows = [] for i, pid in enumerate(sample_ids): - cache = make_cache() if make_cache is not None else None + e2e_t0 = time.perf_counter() + build_t0 = time.perf_counter() + if adaptive_s5_native: + rk, rv, tsrc = {}, {}, len(pid) + else: + rk, rv, tsrc = build_restoration(pid, prefill_native_s5=True) + build_s = time.perf_counter() - build_t0 + T = len(pid) + evicted = compute_evicted_positions(T, args.sink_size, args.window_size) + aux_t0 = time.perf_counter() + if args.force_fused_specdecode: + aux_prompt_mx = capture_aux_hidden( + mlx_model, pid, aux_layer_ids, embed_scale=embed_scale) + if bridge is None: + aux_prompt = aux_prompt_mx # all-MLX: stay on Metal + else: + aux_prompt = [bridge(a) for a in aux_prompt_mx] # [1,C,H] torch + else: + aux_prompt = [] + aux_s = time.perf_counter() - aux_t0 + prefill_t0 = time.perf_counter() + if adaptive_s5_native: + cache, first = native_prefill(pid) + adapter._cache = cache + adapter.next_token_logits = first + adapter._past_len = len(pid) + else: + adapter.prefill( + pid, + restored_k_per_layer=_pad(rk, tsrc, T), + restored_v_per_layer=_pad(rv, tsrc, T), + evicted_positions=evicted, + prefill_chunk_size=args.prefill_chunk_size, + full_kv=args.cuda_trim) + prefill_s = time.perf_counter() - prefill_t0 t0 = time.perf_counter() - out = mlx_model(mx.array([pid]), cache=cache); mx.eval(out) - tok = int(mx.argmax(out[0, -1]).item()); gen = [tok] - for _ in range(args.max_new_tokens - 1): - if eos_id is not None and tok == eos_id: - break - out = mlx_model(mx.array([[tok]]), cache=cache); mx.eval(out) - tok = int(mx.argmax(out[0, -1]).item()); gen.append(tok) - lats.append(time.perf_counter() - t0) + if args.force_fused_specdecode: + if mlx_drafter is not None and args.cuda_trim: + # CUDA-parity: keep accepted K/V, trim only rejected. + res = fused_specdecode_generate_mlx_trim( + adapter, active_drafter, aux_prompt=aux_prompt, + embed_fn=embed_fn, lm_head_fn=lm_head_fn, + gen_tokens=args.max_new_tokens, + block_size=args.block_size, eos_ids=end_ids, + single_fused=args.single_fused) + elif mlx_drafter is not None: + # Single-sync all-MLX loop (levers ①②③) + v3 carry rollback. + res = fused_specdecode_generate_mlx( + adapter, active_drafter, aux_prompt=aux_prompt, + embed_fn=embed_fn, lm_head_fn=lm_head_fn, + gen_tokens=args.max_new_tokens, + block_size=args.block_size, eos_ids=end_ids) + else: + res = fused_specdecode_generate( + adapter, active_drafter, aux_prompt=aux_prompt, + embed_fn=embed_fn, lm_head_fn=lm_head_fn, + gen_tokens=args.max_new_tokens, block_size=args.block_size, + eos_ids=end_ids, + argmax_fn=argmax_fn, arange_fn=arange_fn, cat_aux_fn=cat_aux_fn, + allow_greedy_fallback=False) + res["drafter_runtime"] = "mlx" if mlx_drafter is not None else "torch" + else: + t_greedy = time.perf_counter() + adapter._capture_aux = False + gen = [] + logits_row = adapter.next_token_logits + while len(gen) < args.max_new_tokens: + tok = int(argmax_fn(logits_row)) + gen.append(tok) + out = mlx_model(mx.array([[tok]]), cache=adapter._cache) + mx.eval(out) + logits_row = out[0, -1] + if tok in end_ids: + break + adapter.next_token_logits = logits_row + res = { + "tokens": gen, + "blocks": 0, + "mean_accept_len": 0.0, + "decode_tokens": len(gen), + "adaptive_mode": "native_ar_baseline", + "time_breakdown_s": { + "greedy_decode_s": round(time.perf_counter() - t_greedy, 3) + }, + } + decode_s = time.perf_counter() - t0 + e2e_s = time.perf_counter() - e2e_t0 + lats.append(e2e_s) + gen = res["tokens"] decoded.append(tokenizer.decode(gen)); toks.append(len(gen)) + rows.append({ + "sample": i, + "build_restoration_s": round(build_s, 3), + "aux_prompt_capture_s": round(aux_s, 3), + "prefill_s": round(prefill_s, 3), + "decode_s": round(decode_s, 3), + "e2e_s": round(e2e_s, 3), + "restoration_active": not adaptive_s5_native, + "decode_loop": ("fused_specdecode" if args.force_fused_specdecode + else "per_token_eval"), + "fused": res, + }) + print(f"[mac] fused {i}: T={seq_lens[i]} acc_len={res['mean_accept_len']} " + f"-> {decoded[-1][:48]!r}", file=sys.stderr) + eval_fused_specdecode.stage_rows = rows + return decoded, lats, toks + + def eval_free_gen_oracle() -> Tuple[List[str], List[float], List[int]]: + """Oracle free generation using mlx's NATIVE incremental KV cache. + + Decodes via ``restored_incremental_generate`` (mlx_lm + ``generate_step``: chunked + async-pipelined) — the SAME decode + primitive as the cross incremental path. The previous hand-rolled + per-token ``mx.eval`` loop is the documented MLX anti-pattern + (docs/mlx-port-lessons.md) and depressed the baseline; the gate + rule SPEEDUP_ORACLE_LOOP rejects headline speedups measured + against it. + """ + decoded, lats, toks = [], [], [] + rows = [] + for i, pid in enumerate(sample_ids): + e2e_t0 = time.perf_counter() + prefill_t0 = time.perf_counter() + cache, logits = native_prefill(pid) + prefill_s = time.perf_counter() - prefill_t0 + decode_t0 = time.perf_counter() + gen = restored_incremental_generate( + mlx_model, cache, logits, + max_tokens=args.max_new_tokens, + eos_ids=end_ids) + decode_s = time.perf_counter() - decode_t0 + e2e_s = time.perf_counter() - e2e_t0 + lats.append(e2e_s) + decoded.append(tokenizer.decode(gen)); toks.append(len(gen)) + rows.append({ + "sample": i, + "prefill_s": round(prefill_s, 3), + "decode_s": round(decode_s, 3), + "e2e_s": round(e2e_s, 3), + "decode_loop": "generate_step", + }) print(f"[mac] oracle {i}: T={seq_lens[i]} -> {decoded[-1][:48]!r}", file=sys.stderr) + eval_free_gen_oracle.stage_rows = rows return decoded, lats, toks def cross_logits_all(prompt_ids, full_ids): @@ -356,18 +848,42 @@ def oracle_logits_all(prompt_ids, full_ids): label = "identity" if args.identity_restore else ( "s5" if args.s5_exact_full_attn else "f_theta_all") - eval_mode = "teacher_forced" if args.teacher_forced else "free_gen" + eval_mode = ("teacher_forced" if args.teacher_forced + else "native_ar_baseline" if adaptive_s5_native + else "free_gen_fused_specdecode" if args.fused_specdecode + else "free_gen_incremental" if args.incremental else "free_gen") + warmup_decode() print(f"[mac] running restored cross-model verifier ({label}, {eval_mode})", file=sys.stderr, flush=True) if args.teacher_forced: cross_dec, cross_lat, cross_tok = eval_teacher_forced(cross_logits_all) + # Diagnostic mode: one restored forward per sample; synthesize the + # per-sample path-identity rows the evidence gate requires. + cross_rows = [ + {"sample": i, "restoration_active": True, + "decode_loop": "teacher_forced_single_forward"} + for i in range(len(sample_ids)) + ] + elif args.fused_specdecode: + cross_dec, cross_lat, cross_tok = eval_fused_specdecode() + cross_rows = getattr(eval_fused_specdecode, "stage_rows", []) + elif args.incremental: + cross_dec, cross_lat, cross_tok = eval_free_gen_cross_incremental() + cross_rows = getattr(eval_free_gen_cross_incremental, "stage_rows", []) else: cross_dec, cross_lat, cross_tok = eval_free_gen_cross() - cross_res = aggregate_recall("k3_cross_model_mac", samples, cross_dec, cross_lat, cross_tok) - print(f"[mac] cross-model recall = {cross_res.recall:.3f} " + cross_rows = getattr(eval_free_gen_cross, "stage_rows", []) + sut_label = (NATIVE_BASELINE_LABEL if adaptive_s5_native + else "restored_cross_model") + cross_name = ("native_ar_baseline_mac" if adaptive_s5_native + else "k3_cross_model_mac") + cross_res = aggregate_recall(cross_name, samples, cross_dec, cross_lat, cross_tok) + print(f"[mac] {sut_label} recall = {cross_res.recall:.3f} " f"({cross_res.samples_correct}/{cross_res.samples_total})", file=sys.stderr) oracle_res = None + o_lat: List[float] = [] + o_tok: List[int] = [] if not args.skip_oracle: print("[mac] running oracle", file=sys.stderr, flush=True) if args.teacher_forced: @@ -404,19 +920,95 @@ def _tps(lats, toks): "mean_latency_per_sample_s": round(tot_t / max(len(lats), 1), 3), } cross_tps = _tps(cross_lat, cross_tok) + cross_tps["timing_scope"] = "e2e_prefill_plus_decode" cross_tps["eval_mode"] = eval_mode cross_tps["restored_forwards_per_sample"] = ( 1 if args.teacher_forced else args.max_new_tokens) + cross_tps["stage_timings"] = cross_rows + oracle_rows = getattr(eval_free_gen_oracle, "stage_rows", []) + oracle_tps = _tps(o_lat, o_tok) if (o_lat and o_tok) else None + if oracle_tps: + oracle_tps["timing_scope"] = "e2e_prefill_plus_decode" + oracle_tps["stage_timings"] = oracle_rows + if oracle_rows and not args.teacher_forced: + oracle_tps["decode_loop"] = "generate_step" + + # ---------- Headline speedup: emitted ONLY when admissible ---------- + # (k3_report_gate SPEEDUP_* rules; the harness withholds the number + # rather than publish a claim its own gate would reject.) + import statistics as _stats + decode_only = decode_only_block(cross_rows, cross_tok, oracle_rows, o_tok) + speedup_withheld: List[str] = [] + if adaptive_s5_native: + speedup_withheld.append( + "native_baseline_self_comparison: cross arm IS the oracle computation") + if oracle_tps is None: + speedup_withheld.append("no_oracle_arm") + else: + if len(cross_rows) < MIN_PERF_SAMPLES or len(oracle_rows) < MIN_PERF_SAMPLES: + speedup_withheld.append( + f"n_samples<{MIN_PERF_SAMPLES} (cross={len(cross_rows)}, " + f"oracle={len(oracle_rows)})") + tok_medians = [ + _stats.median(t) for t in (cross_tok, o_tok) if t + ] + if len(tok_medians) < 2 or min(tok_medians) < MIN_MEDIAN_DECODE_TOKENS: + speedup_withheld.append( + f"median_decode_tokens<{MIN_MEDIAN_DECODE_TOKENS} " + f"(prefill-dominated wall time)") + if decode_only is None: + speedup_withheld.append("decode_only_medians_unavailable") + if oracle_tps.get("decode_loop") != CLAIM_ORACLE_DECODE_LOOP: + speedup_withheld.append( + f"oracle_decode_loop!={CLAIM_ORACLE_DECODE_LOOP}") + for arm_name, arm_rows in (("cross", cross_rows), ("oracle", oracle_rows)): + spread = prefill_spread(arm_rows) + if spread is not None and spread > MAX_PREFILL_SPREAD: + speedup_withheld.append( + f"{arm_name}_prefill_spread {spread:.2f}x > " + f"{MAX_PREFILL_SPREAD}x (e2e ratio would be noise)") + speedup_vs_oracle = None + if not speedup_withheld and oracle_tps and oracle_tps["tokens_per_second"] \ + and cross_tps["tokens_per_second"]: + speedup_vs_oracle = round( + cross_tps["tokens_per_second"] / oracle_tps["tokens_per_second"], 3) + if speedup_withheld: + print("[mac] speedup WITHHELD (evidence gate): " + + "; ".join(speedup_withheld), file=sys.stderr) print(f"[mac] cross-model throughput ({eval_mode}): " f"{cross_tps['tokens_per_second']} tok/s " f"({cross_tps['tokens']} tok / {cross_tps['wall_seconds']} s, " f"{cross_tps['mean_latency_per_sample_s']} s/sample)", file=sys.stderr) - delta = (abs(cross_res.recall - oracle_res.recall) if oracle_res else None) + # ---------- Measured (not analytical) accelerator memory ---------- + def _mx_peak_mb() -> Optional[float]: + for holder in (mx, getattr(mx, "metal", None)): + fn = getattr(holder, "get_peak_memory", None) if holder is not None else None + if callable(fn): + try: + return round(float(fn()) / 1e6, 1) + except Exception: + return None + return None + + # The analytical sink+window table only describes runs where every + # cross sample actually executed restoration with S5/identity exact + # layers (k3_report_gate MEMORY_CLAIM_MISMATCH). + restoration_all_active = bool(cross_rows) and all( + bool(r.get("restoration_active")) for r in cross_rows) + formula_matches_run = bool( + restoration_all_active + and (args.s5_exact_full_attn or args.identity_restore)) + + delta = (abs(cross_res.recall - oracle_res.recall) + if (oracle_res and not adaptive_s5_native) else None) report = { - "schema_version": 1, + "schema_version": 2, "kind": "k3_integrated_niah_acceptance_mac", "config": { + "native_baseline_bypass": bool(args.native_baseline_bypass), + "all_mlx_drafter": bool(args.all_mlx_drafter), + "block_size": args.block_size, "verifier_path": args.verifier_path, "drafter_id": args.drafter_id, "f_theta_dir": args.f_theta_dir, @@ -426,6 +1018,10 @@ def _tps(lats, toks): "haystack_min_lines": args.haystack_min_lines, "haystack_max_lines": args.haystack_max_lines, "max_new_tokens": args.max_new_tokens, + "prefill_chunk_size": args.prefill_chunk_size, + "decode_warmup_tokens": args.decode_warmup_tokens, + "direct_answer_prompt": bool(args.direct_answer_prompt), + "content_channel_prefill": bool(args.content_channel_prefill), "seed": args.seed, "eval_mode": eval_mode, "teacher_forced": bool(args.teacher_forced), @@ -439,34 +1035,71 @@ def _tps(lats, toks): "prompt_token_lens": seq_lens, }, "results": { - "k3_cross_model": dataclasses.asdict(cross_res), + "k3_cross_model": { + **dataclasses.asdict(cross_res), + "system_under_test": sut_label, + }, **({"oracle": dataclasses.asdict(oracle_res)} if oracle_res else {}), }, "gate": { - "recall_cross_model": cross_res.recall, + # A native-baseline run may not claim cross-model recall + # (k3_report_gate BASELINE_RECALL_CLAIM): its recall is the + # oracle's recall by construction. + "recall_cross_model": (None if adaptive_s5_native else cross_res.recall), + "recall_native_baseline": (cross_res.recall if adaptive_s5_native else None), "recall_oracle": oracle_res.recall if oracle_res else None, "recall_delta_vs_oracle_pp": (delta * 100 if delta is not None else None), "recall_delta_within_5pp": (delta is not None and delta <= 0.05), }, "memory": { - "s5": mem_s5, + "s5": { + **mem_s5, + "scope": "analytical_formula", + "formula_matches_run": formula_matches_run, + }, "naive_full_kv": { "total_resident_mb": mem_naive["total_resident_mb"], "per_token_growth_kb": mem_naive["per_token_growth_kb"], }, - "savings_vs_naive_pct": round( + # Savings only claimable when the formula describes the run. + "savings_vs_naive_pct": (round( 100 * (1 - mem_s5["total_resident_bytes"] - / max(mem_naive["total_resident_bytes"], 1)), 1), + / max(mem_naive["total_resident_bytes"], 1)), 1) + if formula_matches_run else None), + "measured_peak_mb": _mx_peak_mb(), + }, + "throughput": { + "k3_cross_model": cross_tps, + **({"oracle_native_ar": oracle_tps} if oracle_tps else {}), + "decode_only": decode_only, + "cross_model_speedup_vs_oracle_ar": speedup_vs_oracle, + "speedup_withheld_reasons": speedup_withheld or None, }, - "throughput": {"k3_cross_model": cross_tps}, } + + # ---------- Evidence gate: the harness validates its own output ---------- + violations = validate_report(report) + report["gate"]["evidence_violations"] = [ + dataclasses.asdict(v) for v in violations + ] out_path = Path(args.output) if args.output else Path( f"results/research/k3_integrated_niah_mac_{int(time.time())}.json") out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(report, indent=2)) - print(f"\n[mac] DONE. cross={cross_res.recall:.3f} " + print(f"\n[mac] DONE. {sut_label}={cross_res.recall:.3f} " f"oracle={oracle_res.recall if oracle_res else 'skipped'} " f"-> {out_path}", file=sys.stderr) + if violations and args.code_prompts: + print("[mac] code-prompts throughput probe: recall is N/A by design; " + "evidence gate informational only (not aborting):\n" + + summarize_violations(violations), file=sys.stderr) + elif violations: + print("[mac] EVIDENCE GATE FAILED — this report is NOT admissible " + "as evidence:\n" + summarize_violations(violations), + file=sys.stderr) + return 2 + else: + print("[mac] evidence gate: PASS", file=sys.stderr) return 0 diff --git a/scripts/research/k3_kv_quant_eval.py b/scripts/research/k3_kv_quant_eval.py new file mode 100644 index 00000000..aa53bf72 --- /dev/null +++ b/scripts/research/k3_kv_quant_eval.py @@ -0,0 +1,326 @@ +"""KV-quantization rate–distortion shoot-out: affine (mlx-native) vs KakeyaLattice. + +Decision input for "is an MLX port of the KakeyaLattice codec worth it?" +(docs/mlx-port-lessons.md, K2 track): KL's value proposition over plain +affine quantization is better rate–distortion at equal bits. This eval +measures exactly that, on the only K/V that matter for the S5 memory +story — the 5 full-attention layers' exact own K/V (the 20 KB/token +linear term) — at ctx280 scale, with REAL recall as the end metric. + +Arms (same captured K/V per sample, identical injection machinery): + + identity — lossless round trip (machinery control; recall must match + the S5 baseline) + affine8 — mx.quantize/dequantize, 8-bit, group 64 (the storage + format of mlx_lm's QuantizedKVCache; ~8.5 bits/value) + affine4 — same, 4-bit (~4.5 bits/value) + kl-d4 — KakeyaLattice D4 round trip (torch codec, eval-time only) + kl-e8 — KakeyaLattice E8 round trip + +Per arm and sample: bits/value (measured), energy-weighted rel_mse of +the lossy full-attn K/V vs the originals, then a REAL incremental +restored decode (lossy K/V injected at the evicted positions, sliding +layers native window-bounded) → NIAH recall. + +Scope note: this measures STORAGE fidelity at matched rate. Decode in +every arm runs on bf16-materialised K/V, so per-arm decode timing is +not a codec-throughput claim (runtime decompression cost is a separate +question that only matters for codecs that win here). + +Verdict rule printed at the end: KL justifies an MLX port only if, at +bits <= affine4's rate, it achieves BOTH lower full-attn rel_mse AND +recall >= affine4's. Otherwise native affine quantization wins by +default (zero porting cost, kernel-fused dequant). + +Run on the Mac via bridge preset ``k3-kv-quant-eval`` or directly: + + PYTHONPATH=.:sdks/python python3 scripts/research/k3_kv_quant_eval.py \ + --verifier-path --n-samples 5 --max-new-tokens 32 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-path", default="models/gemma-4-26B-A4B-it-mlx-4bit") + ap.add_argument("--n-samples", type=int, default=5) + ap.add_argument("--haystack-min-lines", type=int, default=238) + ap.add_argument("--haystack-max-lines", type=int, default=322) + ap.add_argument("--sink-size", type=int, default=4) + ap.add_argument("--window-size", type=int, default=64) + ap.add_argument("--max-new-tokens", type=int, default=32) + ap.add_argument("--prefill-chunk-size", type=int, default=512) + ap.add_argument("--kl-q-range", type=int, default=38) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--skip-oracle", action="store_true") + ap.add_argument("--output", default="results/research/k3_kv_quant_eval.json") + args = ap.parse_args() + + import mlx.core as mx # type: ignore + import mlx_lm # type: ignore + import torch + + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + capture_own_kv, + mlx_full_attention_layer_indices, + resolve_mlx_text_model, + restored_incremental_generate, + restored_prefill_cache, + ) + from inference_engine.v04 import NIAHSample, aggregate_recall, make_niah_dataset + from inference_engine.v04.kv_compressor import make_default_compressor + from inference_engine.v04.kv_merge import compute_evicted_positions + from scripts.research.k3_dflash_mlx_bridge import mx_to_torch, torch_to_mx + + print(f"[kvq] loading MLX verifier {args.verifier_path}", file=sys.stderr) + mlx_model, tokenizer = mlx_lm.load(args.verifier_path) + text_model = resolve_mlx_text_model(mlx_model) + full_attn_idx = mlx_full_attention_layer_indices(text_model) + print(f"[kvq] full-attn layers: {full_attn_idx}", file=sys.stderr) + + # ---------- arms ---------- + GROUP = 64 + + def affine_roundtrip(bits: int): + def fn(k: Any, v: Any) -> Tuple[Any, Any, float]: + outs = [] + for a in (k, v): + shp = a.shape + flat = a.reshape(-1, shp[-1]).astype(mx.float16) + wq, scales, biases = mx.quantize(flat, group_size=GROUP, bits=bits) + deq = mx.dequantize( + wq, scales, biases, group_size=GROUP, bits=bits) + outs.append(deq.reshape(shp).astype(a.dtype)) + # bits/value: payload + fp16 scale & bias per group. + rate = bits + 2 * 16.0 / GROUP + return outs[0], outs[1], rate + return fn + + def kl_roundtrip(lattice: str): + comps: Dict[int, Any] = {} + + def fn(k: Any, v: Any, *, head_dim: int, layer: int) -> Tuple[Any, Any, float]: + comp = comps.get(layer) + if comp is None: + comp = make_default_compressor( + head_dim=head_dim, device=torch.device("cpu"), + prefer_kakeya=True, lattice=lattice, + q_range=args.kl_q_range) + comps[layer] = comp + kt = mx_to_torch(k, dtype=torch.float32, device="cpu").transpose(1, 2).contiguous() + vt = mx_to_torch(v, dtype=torch.float32, device="cpu").transpose(1, 2).contiguous() + T = kt.shape[-2] + pos = torch.arange(T) + comp.compress(kt, vt, pos) + kh, vh = comp.decompress(pos) + comp.evict(pos) + codec = getattr(comp, "_codec", None) + bits_head = float(getattr(codec, "bits_per_token_per_head", 0.0) or 0.0) + rate = bits_head / head_dim if bits_head else float("nan") + kh = torch_to_mx(kh.transpose(1, 2).contiguous()) + vh = torch_to_mx(vh.transpose(1, 2).contiguous()) + return kh, vh, rate + return fn + + ARMS: List[Tuple[str, Any]] = [ + ("identity", None), + ("affine8", affine_roundtrip(8)), + ("affine4", affine_roundtrip(4)), + ("kl-d4", kl_roundtrip("D4")), + ("kl-e8", kl_roundtrip("E8")), + ] + + # ---------- dataset / prompts (mirrors the integrated harness) ---------- + samples: List[NIAHSample] = make_niah_dataset( + n_samples=args.n_samples, + haystack_min_lines=args.haystack_min_lines, + haystack_max_lines=args.haystack_max_lines, + seed=args.seed, + ) + + def encode(prompt_text: str) -> List[int]: + prompt_text = prompt_text.replace( + "and does not contain the answer.", "and is unrelated filler.") + prompt_text += ("\n\nReturn only the secret code in PREFIX-NNNN " + "format. Do not explain, reason, or add any other text.") + ids = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt_text}], + add_generation_prompt=True) + if hasattr(ids, "tolist"): + ids = ids.tolist() + ids = list(ids) + try: + marker = tokenizer.encode( + "<|channel>content\n", add_special_tokens=False) + except TypeError: + marker = tokenizer.encode("<|channel>content\n") + if hasattr(marker, "tolist"): + marker = marker.tolist() + ids.extend(list(marker)) + return ids + + sample_ids = [encode(s.prompt_text) for s in samples] + seq_lens = [len(t) for t in sample_ids] + end_ids = set() + if getattr(tokenizer, "eos_token_id", None) is not None: + end_ids.add(int(tokenizer.eos_token_id)) + try: + eot = tokenizer.encode("", add_special_tokens=False) + except TypeError: + eot = tokenizer.encode("") + if hasattr(eot, "tolist"): + eot = eot.tolist() + if len(eot) == 1: + end_ids.add(int(eot[0])) + print(f"[kvq] {len(samples)} samples, prompt len {min(seq_lens)}..{max(seq_lens)}", + file=sys.stderr) + + def rel_mse(orig_k, orig_v, lossy_k, lossy_v) -> float: + num = den = 0.0 + for o, l in ((orig_k, lossy_k), (orig_v, lossy_v)): + o32 = o.astype(mx.float32) + d = l.astype(mx.float32) - o32 + num += float(mx.sum(d * d)) + den += float(mx.sum(o32 * o32)) + return num / max(den, 1e-12) + + # ---------- per-sample capture, per-arm roundtrip + decode ---------- + per_arm: Dict[str, Dict[str, Any]] = { + name: {"rel_mse": [], "bits": [], "decoded": [], "lat": [], "tok": []} + for name, _ in ARMS + } + oracle_decoded: List[str] = [] + oracle_lat: List[float] = [] + oracle_tok: List[int] = [] + + for i, pid in enumerate(sample_ids): + T = len(pid) + evicted = compute_evicted_positions(T, args.sink_size, args.window_size) + t0 = time.perf_counter() + own = capture_own_kv(mlx_model, pid) + print(f"[kvq] s{i}: T={T} capture {time.perf_counter()-t0:.1f}s", + file=sys.stderr) + + if not args.skip_oracle: + from mlx_lm.models.cache import make_prompt_cache # noqa: F401 + cache = (getattr(mlx_model, "make_cache", lambda: None)()) + last = None + for s in range(0, T, args.prefill_chunk_size): + part = pid[s:s + args.prefill_chunk_size] + last = mlx_model(mx.array([part]), cache=cache) + mx.eval(last) + t0 = time.perf_counter() + gen = restored_incremental_generate( + mlx_model, cache, last[0, -1], + max_tokens=args.max_new_tokens, eos_ids=end_ids) + oracle_lat.append(time.perf_counter() - t0) + oracle_decoded.append(tokenizer.decode(gen)) + oracle_tok.append(len(gen)) + + for name, roundtrip in ARMS: + rk: Dict[int, Any] = {} + rv: Dict[int, Any] = {} + mses: List[float] = [] + rates: List[float] = [] + for li in full_attn_idx: + k, v = own[li] + if roundtrip is None: + lk, lv, rate = k, v, 16.0 + elif name.startswith("kl"): + lk, lv, rate = roundtrip( + k, v, head_dim=int(k.shape[-1]), layer=li) + else: + lk, lv, rate = roundtrip(k, v) + rk[li], rv[li] = lk, lv + mses.append(rel_mse(k, v, lk, lv)) + rates.append(rate) + arm_mse = sum(mses) / len(mses) + t0 = time.perf_counter() + cache, first = restored_prefill_cache( + mlx_model, pid, + restored_k_per_layer=rk, restored_v_per_layer=rv, + evicted_positions=evicted, + prefill_chunk_size=args.prefill_chunk_size) + gen = restored_incremental_generate( + mlx_model, cache, first, + max_tokens=args.max_new_tokens, eos_ids=end_ids) + elapsed = time.perf_counter() - t0 + per_arm[name]["rel_mse"].append(arm_mse) + per_arm[name]["bits"].append(sum(rates) / len(rates)) + per_arm[name]["decoded"].append(tokenizer.decode(gen)) + per_arm[name]["lat"].append(elapsed) + per_arm[name]["tok"].append(len(gen)) + print(f"[kvq] s{i} {name}: bits/val={sum(rates)/len(rates):.2f} " + f"rel_mse={arm_mse:.5f} -> {per_arm[name]['decoded'][-1][:32]!r}", + file=sys.stderr) + + # ---------- aggregate ---------- + results: Dict[str, Any] = {} + for name, _ in ARMS: + a = per_arm[name] + rec = aggregate_recall( + f"kvq_{name}", samples, a["decoded"], a["lat"], a["tok"]) + results[name] = { + "recall": rec.recall, + "samples_correct": rec.samples_correct, + "bits_per_value_mean": round(sum(a["bits"]) / len(a["bits"]), 3), + "full_attn_rel_mse_mean": round( + sum(a["rel_mse"]) / len(a["rel_mse"]), 6), + "per_sample_decoded": [d[:48] for d in a["decoded"]], + } + oracle_recall = None + if oracle_decoded: + orec = aggregate_recall( + "kvq_oracle", samples, oracle_decoded, oracle_lat, oracle_tok) + oracle_recall = orec.recall + + # ---------- verdict ---------- + aff4 = results["affine4"] + verdicts = {} + for klname in ("kl-d4", "kl-e8"): + kl = results[klname] + verdicts[klname] = { + "bits_le_affine4": kl["bits_per_value_mean"] <= aff4["bits_per_value_mean"] + 0.25, + "rel_mse_better": kl["full_attn_rel_mse_mean"] < aff4["full_attn_rel_mse_mean"], + "recall_ge_affine4": kl["recall"] >= aff4["recall"], + } + verdicts[klname]["mlx_port_justified"] = all(verdicts[klname].values()) + + report = { + "kind": "k3_kv_quant_eval", + "schema_version": 1, + "config": vars(args), + "full_attention_layers": full_attn_idx, + "prompt_token_lens": seq_lens, + "oracle_recall": oracle_recall, + "results": results, + "verdict": verdicts, + } + out = Path(args.output) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(report, indent=2)) + print(f"[kvq] DONE -> {out}", file=sys.stderr) + for name in results: + r = results[name] + print(f"[kvq] {name:9s} bits={r['bits_per_value_mean']:6.2f} " + f"rel_mse={r['full_attn_rel_mse_mean']:.5f} " + f"recall={r['recall']:.2f}", file=sys.stderr) + print(f"[kvq] verdict: {json.dumps(verdicts)}", file=sys.stderr) + # Machinery sanity: identity arm must not lose recall. + if results["identity"]["recall"] < (oracle_recall or 1.0): + print("[kvq] WARNING: identity arm below oracle — injection " + "machinery issue, codec comparisons unreliable", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/research/k3_mlx_drafter_parity.py b/scripts/research/k3_mlx_drafter_parity.py new file mode 100644 index 00000000..a9b5d2a6 --- /dev/null +++ b/scripts/research/k3_mlx_drafter_parity.py @@ -0,0 +1,197 @@ +"""Parity gate: all-MLX DFlash drafter vs the torch reference (Apple Silicon). + +Before the all-MLX drafter may carry any throughput claim, it must draft +the SAME tokens as the validated torch implementation on real inputs: +real verifier aux hidden (captured from the MLX Gemma-4 forward), real +shared embed/lm_head, several context lengths and blocks. + +Procedure per sample: + 1. Build a NIAH prompt (same generator as the integrated eval; seed + offset so this never reuses eval prompts). + 2. Capture aux hidden over the prompt from the MLX verifier (component + A machinery, ``capture_aux_hidden``). + 3. Both drafters build their context K/V from the SAME aux (torch gets + the bridged copy), then draft ``--n-blocks`` consecutive blocks with + the same bonus token (verifier greedy next token). + 4. Compare drafted token ids position-by-position. + +Gate: token agreement must be >= --min-agreement (default 1.0 — exact). +The report JSON records per-block tokens from both runtimes so any +mismatch is directly inspectable. + +Run on the Mac via the bridge preset ``k3-drafter-parity`` or directly: + + PYTHONPATH=.:sdks/python python3 scripts/research/k3_mlx_drafter_parity.py \ + --verifier-path --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --n-samples 3 --n-blocks 4 --block-size 8 +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import List + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-path", default="models/gemma-4-26B-A4B-it-mlx-4bit") + ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash") + ap.add_argument("--n-samples", type=int, default=3) + ap.add_argument("--n-blocks", type=int, default=4) + ap.add_argument("--block-size", type=int, default=8) + ap.add_argument("--haystack-min-lines", type=int, default=20) + ap.add_argument("--haystack-max-lines", type=int, default=40) + ap.add_argument("--seed", type=int, default=1234) + ap.add_argument("--min-agreement", type=float, default=1.0) + ap.add_argument("--mlx-dtype", choices=["bf16", "fp32"], default="bf16", + help="MLX drafter compute dtype. fp32 matches the torch " + "reference exactly (port-bug discriminator); bf16 " + "is the shipping config (near-tie argmax flips vs " + "fp32 are expected and correctness-contained).") + ap.add_argument("--output", + default="results/research/k3_mlx_drafter_parity.json") + args = ap.parse_args() + + import mlx.core as mx # type: ignore + import mlx_lm # type: ignore + import torch + + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + resolve_mlx_text_model, + ) + from inference_engine.backends.mlx.dflash_drafter import ( + MLXDFlashDrafter, make_native_embed_lm_head, + ) + from inference_engine.backends.mlx.fused_specdecode import ( + capture_aux_hidden, make_bridge_embed_lm_head, + ) + from inference_engine.v04 import DFlashDrafter, make_niah_dataset + from scripts.research.k3_dflash_mlx_bridge import mx_to_torch, torch_to_mx + + print(f"[parity] loading MLX verifier {args.verifier_path}", file=sys.stderr) + mlx_model, tokenizer = mlx_lm.load(args.verifier_path) + text_model = resolve_mlx_text_model(mlx_model) + embed_scale = float(getattr(text_model, "embed_scale", 1.0)) + softcap = None + for obj in (getattr(mlx_model, "language_model", None), mlx_model): + cap = getattr(obj, "final_logit_softcapping", None) if obj is not None else None + if cap: + softcap = float(cap); break + + print(f"[parity] loading torch drafter {args.drafter_id}", file=sys.stderr) + t_drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32) + t_drafter = t_drafter.to("cpu").eval() + print(f"[parity] loading MLX drafter {args.drafter_id} " + f"({args.mlx_dtype})", file=sys.stderr) + m_drafter = MLXDFlashDrafter.from_pretrained( + args.drafter_id, compute_dtype=args.mlx_dtype) + aux_ids = tuple(m_drafter.cfg.aux_layer_ids) + + m_embed, m_head = make_native_embed_lm_head(text_model, softcap=softcap) + t_embed, t_head = make_bridge_embed_lm_head( + text_model, mx_to_torch=mx_to_torch, torch_to_mx=torch_to_mx, + device=torch.device("cpu"), torch_dtype=torch.float32, softcap=softcap) + + samples = make_niah_dataset( + n_samples=args.n_samples, + haystack_min_lines=args.haystack_min_lines, + haystack_max_lines=args.haystack_max_lines, + seed=args.seed, + ) + + rows = [] + agree = total = 0 + for i, sample in enumerate(samples): + ids = tokenizer.apply_chat_template( + [{"role": "user", "content": sample.prompt_text}], + add_generation_prompt=True) + if hasattr(ids, "tolist"): + ids = ids.tolist() + ids = list(ids) + C = len(ids) + aux_mx = capture_aux_hidden(mlx_model, ids, aux_ids, embed_scale=embed_scale) + aux_t = [mx_to_torch(a, dtype=torch.float32, device="cpu") for a in aux_mx] + + # Bonus = verifier greedy next token over the prompt. + out = mlx_model(mx.array([ids])); mx.eval(out) + bonus = int(mx.argmax(out[0, -1]).item()) + + m_ctx = m_drafter.make_context_kv(aux_mx, mx.arange(0, C)) + t_ctx = t_drafter.make_context_kv(aux_t, torch.arange(0, C)) + + sample_row: dict = {"sample": i, "context_len": C, "blocks": []} + ctx_len = C + cur_bonus = bonus + for b in range(args.n_blocks): + t0 = time.perf_counter() + m_tokens = m_drafter.draft_block_cached( + m_ctx, cur_bonus, m_embed, m_head, + block_size=args.block_size, context_len=ctx_len) + m_s = time.perf_counter() - t0 + t0 = time.perf_counter() + t_tokens = t_drafter.draft_block_cached( + t_ctx, cur_bonus, t_embed, t_head, + block_size=args.block_size, context_len=ctx_len) + t_s = time.perf_counter() - t0 + matches = sum(1 for a, c in zip(m_tokens, t_tokens) if a == c) + agree += matches + total += args.block_size + sample_row["blocks"].append({ + "bonus": cur_bonus, + "mlx_tokens": m_tokens, + "torch_tokens": t_tokens, + "matches": matches, + "mlx_draft_s": round(m_s, 4), + "torch_draft_s": round(t_s, 4), + }) + # Next block conditions on a longer prefix: feed the torch + # drafts (the reference) as "committed" by extending both + # contexts with the same aux slice re-captured from the + # verifier over prompt+drafts. Keep it simple and equal for + # both: recompute aux over the extended ids. + ids = ids + [cur_bonus] + t_tokens[: max(args.block_size - 1, 0)] + ids = ids[: C + (b + 1) * args.block_size] + aux_mx = capture_aux_hidden( + mlx_model, ids, aux_ids, embed_scale=embed_scale) + aux_t = [mx_to_torch(a, dtype=torch.float32, device="cpu") + for a in aux_mx] + ctx_len = len(ids) + m_ctx = m_drafter.make_context_kv(aux_mx, mx.arange(0, ctx_len)) + t_ctx = t_drafter.make_context_kv(aux_t, torch.arange(0, ctx_len)) + out = mlx_model(mx.array([ids])); mx.eval(out) + cur_bonus = int(mx.argmax(out[0, -1]).item()) + rows.append(sample_row) + print(f"[parity] sample {i}: agreement so far {agree}/{total}", + file=sys.stderr) + + agreement = agree / max(total, 1) + report = { + "kind": "k3_mlx_drafter_parity", + "schema_version": 1, + "config": vars(args), + "agreement": round(agreement, 4), + "agreed_tokens": agree, + "total_tokens": total, + "samples": rows, + "passed": agreement >= args.min_agreement, + } + out_path = Path(args.output) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(report, indent=2)) + print(f"[parity] agreement={agreement:.4f} " + f"({agree}/{total}) min={args.min_agreement} -> {out_path}", + file=sys.stderr) + if agreement < args.min_agreement: + print("[parity] FAIL: all-MLX drafter does not match the torch " + "reference; throughput claims are blocked.", file=sys.stderr) + return 1 + print("[parity] PASS", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/review_mlx_port_on_mac.sh b/scripts/review_mlx_port_on_mac.sh new file mode 100755 index 00000000..c0cbf8d7 --- /dev/null +++ b/scripts/review_mlx_port_on_mac.sh @@ -0,0 +1,105 @@ +#!/usr/bin/env bash +# Mac mini validation for the #107 MLX port (PR #109). +# +# Step 1 (--incremental) : restored decode via native cache + generate_step +# → kills the O(T^2) re-forward collapse. +# Step 2 (--fused-specdecode) : fused DFlash spec-decode (A+B+C). +# +# Each run also times the ORACLE = native mlx_lm AR (same model, no restoration), +# so the JSON carries `throughput.cross_model_speedup_vs_oracle_ar` and +# `gate.recall_delta_within_5pp` for a direct AR comparison. The speed gate is +# e2e over prefill+decode for both cross and oracle paths. +# +# Gates: +# Step 1: speedup_vs_oracle ≈ 1.0 (no longer collapsed) AND recall == oracle. +# Step 2: speedup_vs_oracle > 1.0 (fused beats AR) AND recall == oracle. +# +# Usage (from repo root, on the Mac mini): +# bash scripts/review_mlx_port_on_mac.sh +# Override any knob via env, e.g.: +# N_SAMPLES=8 MAX_NEW_TOKENS=64 BLOCK_SIZE=6 bash scripts/review_mlx_port_on_mac.sh +set -euo pipefail + +VERIFIER_PATH="${VERIFIER_PATH:-models/gemma-4-26B-A4B-it-mlx-4bit}" +DRAFTER_ID="${DRAFTER_ID:-z-lab/gemma-4-26B-A4B-it-DFlash}" +F_THETA_DIR="${F_THETA_DIR:-results/research/f_theta_v5_s5_sliding}" +N_SAMPLES="${N_SAMPLES:-5}" +MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-32}" +BLOCK_SIZE="${BLOCK_SIZE:-4}" +PREFILL_CHUNK_SIZE="${PREFILL_CHUNK_SIZE:-512}" +DECODE_WARMUP_TOKENS="${DECODE_WARMUP_TOKENS:-1}" +SINK_SIZE="${SINK_SIZE:-4}" +WINDOW_SIZE="${WINDOW_SIZE:-64}" +HAYSTACK_MIN="${HAYSTACK_MIN:-238}" +HAYSTACK_MAX="${HAYSTACK_MAX:-322}" +OUT_DIR="${OUT_DIR:-results/research}" +STAMP="$(date +%Y%m%d_%H%M%S)" + +export PYTHONPATH="${PYTHONPATH:-.:sdks/python}" +# Let MLX use the unified-memory wired limit if the box is tight (optional). +export MLX_METAL_MEMORY_LIMIT_RATIO="${MLX_METAL_MEMORY_LIMIT_RATIO:-0.0}" + +INCR_JSON="${OUT_DIR}/k3_mlx_incremental_${STAMP}.json" +FUSED_JSON="${OUT_DIR}/k3_mlx_fused_${STAMP}.json" + +common_args=( + --verifier-path "${VERIFIER_PATH}" + --drafter-id "${DRAFTER_ID}" + --f-theta-dir "${F_THETA_DIR}" + --s5-exact-full-attn + --n-samples "${N_SAMPLES}" + --max-new-tokens "${MAX_NEW_TOKENS}" + --sink-size "${SINK_SIZE}" + --window-size "${WINDOW_SIZE}" + --prefill-chunk-size "${PREFILL_CHUNK_SIZE}" + --decode-warmup-tokens "${DECODE_WARMUP_TOKENS}" + --haystack-min-lines "${HAYSTACK_MIN}" + --haystack-max-lines "${HAYSTACK_MAX}" +) + +echo "==========================================================" +echo "[mlx-port] Step 1: incremental restored decode (native cache + generate_step)" +echo "[mlx-port] verifier=${VERIFIER_PATH} drafter=${DRAFTER_ID}" +echo "[mlx-port] f_theta=${F_THETA_DIR} n=${N_SAMPLES} gen=${MAX_NEW_TOKENS}" +echo "==========================================================" +python scripts/research/k3_integrated_niah_eval_mac.py \ + "${common_args[@]}" --incremental --output "${INCR_JSON}" + +echo "==========================================================" +echo "[mlx-port] Step 2: fused DFlash spec-decode (A+B+C, block_size=${BLOCK_SIZE})" +echo "==========================================================" +python scripts/research/k3_integrated_niah_eval_mac.py \ + "${common_args[@]}" --fused-specdecode --block-size "${BLOCK_SIZE}" \ + --output "${FUSED_JSON}" + +echo "==========================================================" +echo "[mlx-port] SUMMARY" +echo "==========================================================" +python - "${INCR_JSON}" "${FUSED_JSON}" <<'PY' +import json, sys +def show(tag, path, want): + d = json.load(open(path)) + g, t = d["gate"], d["throughput"] + cm = t["k3_cross_model"]; ar = t.get("oracle_native_ar") or {} + spd = t.get("cross_model_speedup_vs_oracle_ar") + rc, ro = g["recall_cross_model"], g.get("recall_oracle") + mem = d["memory"] + print(f"\n[{tag}] ({d['config']['eval_mode']})") + print(f" recall: cross={rc} oracle={ro} within_5pp={g['recall_delta_within_5pp']}") + print(f" scope : cross={cm.get('timing_scope')} oracle={ar.get('timing_scope')}") + print(f" tok/s : cross={cm.get('tokens_per_second')} " + f"oracle_AR={ar.get('tokens_per_second')} speedup_vs_AR={spd}") + print(f" KV : S5={mem['s5']['total_resident_mb']}MB " + f"naive={mem['naive_full_kv']['total_resident_mb']}MB " + f"savings={mem['savings_vs_naive_pct']}%") + ok_recall = bool(g["recall_delta_within_5pp"]) + ok_speed = (spd is not None and spd >= want) + print(f" GATE : recall {'PASS' if ok_recall else 'FAIL'} | " + f"speed {'PASS' if ok_speed else 'FAIL'} (need >= {want}x AR)") + return ok_recall and ok_speed +s1 = show("Step 1 incremental", sys.argv[1], 0.85) # ~= AR (collapse fixed) +s2 = show("Step 2 fused", sys.argv[2], 1.00) # > AR +print("\n[mlx-port] OVERALL:", + "PASS" if (s1 and s2) else "see gates above") +print(f"[mlx-port] JSON: {sys.argv[1]}\n[mlx-port] JSON: {sys.argv[2]}") +PY diff --git a/scripts/setup_mac.sh b/scripts/setup_mac.sh index dab9b739..c3d56e3f 100755 --- a/scripts/setup_mac.sh +++ b/scripts/setup_mac.sh @@ -1,9 +1,15 @@ #!/usr/bin/env bash # Set up a clean venv on macOS / Apple Silicon for this project. # -# Why a venv: the proposer checkpoint requires transformers 4.x, which -# conflicts with newer system installs (e.g. macOS 26 ships fine with -# transformers 5.x for other tools). We never touch system Python. +# Why a venv: isolates the project's pinned deps from system Python / +# other tools. We never touch system Python. +# +# transformers versioning (mirrors the requirements.txt note): the K3 +# critical path (Gemma 4 verifier, DFlash drafter, current mlx-lm) +# needs transformers >= 5.0; only the LEGACY Qwen3 MDLM dLM proposer +# still requires 4.x. This script validates the K3-era floor and no +# longer enforces a 4.x upper bound. If you need the legacy MDLM path, +# create a dedicated venv with: pip install 'transformers>=4.45,<5.0' # # Idempotent: re-running upgrades the venv if needed and verifies all # required imports succeed. Any missing or wrong-version package raises a @@ -37,9 +43,9 @@ ensure_xcode_clt() { pick_python() { # Prefer a 3.12 install (most stable for our deps). Fall back to - # 3.11 or 3.13. We refuse to use 3.14+ because the wheel ecosystem for - # transformers 4.x and dllm-hub's custom code is not yet validated - # there. + # 3.11 or 3.13. We refuse to use 3.14+ because the wheel ecosystem + # for our pinned deps and dllm-hub's custom code is not yet + # validated there. for cmd in python3.12 python3.11 python3.13; do if command -v "$cmd" >/dev/null 2>&1; then echo "$cmd" @@ -183,7 +189,11 @@ from packaging.version import Version # distribution_name is the name pip uses; defaults to import_name when None. required = [ ("torch", None, "2.4", "3.0"), - ("transformers", None, "4.45", "5.0"), # hard-pin to 4.x + # No upper bound: K3 (Gemma 4 / DFlash / current mlx-lm) needs + # transformers >= 5.0; requirements.txt dropped the <5 pin. Only + # the legacy Qwen3 MDLM proposer needs 4.x — use a dedicated venv + # for that path (see requirements.txt note). + ("transformers", None, "4.45", None), ("mlx", None, "0.20", None), ("mlx_lm", "mlx-lm", "0.18", None), ("huggingface_hub", None, "0.24", None), diff --git a/scripts/validate_k3_reports.py b/scripts/validate_k3_reports.py new file mode 100644 index 00000000..837c612c --- /dev/null +++ b/scripts/validate_k3_reports.py @@ -0,0 +1,71 @@ +"""CI walker: validate committed K3 Mac evidence reports. + +Runs in the Linux gate (see ``.github/workflows/ci.yaml``). Every +committed ``results/research/*.json`` whose ``kind`` is the K3 Mac +acceptance schema is validated against the evidence rules in +:mod:`inference_engine.bench.k3_report_gate`, so a report that claims +an inadmissible speedup / recall / memory number cannot land silently. + +Reports with ``schema_version < 2`` predate the gate: they are printed +as grandfathered legacy (NON-EVIDENCE) warnings and do not fail the +build — re-run the hardened harness to produce citable evidence. + +Usage:: + + PYTHONPATH=. python3 scripts/validate_k3_reports.py [results/research] + +Exit codes: 0 = all gated reports admissible (or legacy); 1 = at least +one schema-2 report violates the evidence rules. + +CLI plumbing around the unit-tested ``k3_report_gate`` library; +exempt from unit-test coverage by the same convention as +``scripts/serve.py``. +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + +from inference_engine.bench.k3_report_gate import ( + is_gated_report, + is_legacy_report, + summarize_violations, + validate_report, +) + + +def main(argv: list) -> int: + root = Path(argv[1]) if len(argv) > 1 else Path("results/research") + if not root.exists(): + print(f"[k3-evidence-gate] {root} does not exist; nothing to check") + return 0 + checked = legacy = failures = 0 + for path in sorted(root.rglob("*.json")): + try: + report = json.loads(path.read_text()) + except (json.JSONDecodeError, UnicodeDecodeError, OSError): + continue + if not is_gated_report(report): + continue + if is_legacy_report(report): + legacy += 1 + print(f"[legacy] {path}: schema<2 — grandfathered, NON-EVIDENCE " + "(rerun with the hardened harness to make claims)") + continue + checked += 1 + violations = validate_report(report) + if violations: + failures += 1 + print(f"[FAIL] {path}") + print(summarize_violations(violations)) + else: + print(f"[ok] {path}") + print(f"[k3-evidence-gate] checked={checked} legacy={legacy} " + f"failures={failures}") + return 1 if failures else 0 + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/tests/backends/mlx/test_fused_specdecode.py b/tests/backends/mlx/test_fused_specdecode.py new file mode 100644 index 00000000..de92aa7a --- /dev/null +++ b/tests/backends/mlx/test_fused_specdecode.py @@ -0,0 +1,387 @@ +"""Linux-CI tests for the MLX fused DFlash spec-decode engine. + +The fused loop (``fused_specdecode_generate``) takes all MLX/torch ops as +injected callables, so its accept/reject/commit/extend control flow is tested +**without MLX**. The MLX-touching wrappers (``_build_aux``, ``capture_aux_hidden``, +``MLXRestoredIncrementalVerifier``, ``make_bridge_embed_lm_head``) are tested by +injecting fake ``mlx`` / ``mlx_lm`` modules. Real MLX kernels are validated on a +Mac by ``k3_integrated_niah_eval_mac.py --fused-specdecode``. +""" + +from __future__ import annotations + +import sys +import types + +import pytest +import torch + +from inference_engine.backends.mlx import fused_specdecode as fsd + + +# =========================================================================== # +# 1) Fused loop control flow (no MLX) — verifier truth = successor (last+1). +# =========================================================================== # +class _FakeAdapter: + def __init__(self, prompt_len, first_token, hidden=4): + self._past_len = prompt_len + self.next_token_logits = first_token + self.hidden = hidden + self._capture_aux = False + self._last_aux = None + self.commits = [] + self.appends = [] + + def forward_block(self, candidate): + # verifier greedy continuation: prediction after token t is t+1. + block_logits = [candidate[i] + 1 for i in range(len(candidate))] + if self._capture_aux: + L = len(candidate) + self._last_aux = [torch.arange(L * self.hidden).float().reshape(L, self.hidden)] + return block_logits + + def commit_or_truncate(self, *, forwarded, accepted): + self.commits.append((forwarded, accepted)) + self._past_len += accepted + + def append_token(self, token_id): + bl = self.forward_block([token_id]) + self.commit_or_truncate(forwarded=1, accepted=1) + self.next_token_logits = bl[-1] + self.appends.append(token_id) + return self.next_token_logits + + +class _FakeDrafter: + def __init__(self, drafts): + self.cfg = types.SimpleNamespace(aux_layer_ids=(2,)) + self._drafts = list(drafts) + self.make_calls = 0 + self.extend_calls = 0 + + def make_context_kv(self, aux, positions): + self.make_calls += 1 + return ("ctx", self.make_calls) + + def extend_context_kv(self, ctx_kv, new_kv): + self.extend_calls += 1 + return ("ctx_ext", self.extend_calls) + + def draft_block_cached(self, ctx_kv, bonus, embed_fn, lm_head_fn, + *, block_size, context_len): + return list(self._drafts.pop(0)) if self._drafts else [] + + +def _loop_kwargs(drafter, **over): + kw = dict( + aux_prompt=[torch.zeros(1, 5, 4)], + embed_fn=lambda x: x, lm_head_fn=lambda x: x, + argmax_fn=lambda row: int(row), arange_fn=lambda s, e: torch.arange(s, e), + cat_aux_fn=lambda parts: torch.cat(list(parts), dim=0).unsqueeze(0), + ) + kw.update(over) + return kw + + +def test_fused_loop_full_acceptance(): + adapter = _FakeAdapter(prompt_len=5, first_token=100) + drafter = _FakeDrafter(drafts=[[101, 102], [200, 201]]) + res = fsd.fused_specdecode_generate( + adapter, drafter, gen_tokens=5, block_size=4, eos_ids=(), + **_loop_kwargs(drafter)) + # Block1: candidate=[100,101,102] all accepted (3) + correction 103. + # Block2: candidate=[104] accepted (1) + correction 105 -> truncated to 5. + assert res["tokens"] == [100, 101, 102, 103, 104] + assert res["blocks"] == 2 + assert res["mean_accept_len"] == 2.0 # (3 + 1) / 2 + assert adapter.commits[0] == (3, 3) # block1 verify-commit + assert adapter.appends == [103, 105] # one correction per block + # capture flag toggled on during loop, off after. + assert adapter._capture_aux is False + # context K/V extended once per block. + assert drafter.extend_calls == 2 + + +def test_fused_loop_partial_rejection_and_correction(): + adapter = _FakeAdapter(prompt_len=5, first_token=100) + drafter = _FakeDrafter(drafts=[[101, 777]]) # 777 mismatches verifier 102 + res = fsd.fused_specdecode_generate( + adapter, drafter, gen_tokens=3, block_size=4, eos_ids=(), + **_loop_kwargs(drafter)) + # candidate=[100,101,777]: accept 100,101 (2), reject 777, correction=102. + assert res["tokens"] == [100, 101, 102] + assert res["blocks"] == 1 + assert res["mean_accept_len"] == 2.0 + # commit_or_truncate(forwarded=3, accepted=2) then append correction (1,1). + assert adapter.commits[0] == (3, 2) + + +def test_fused_loop_stops_on_eos(): + adapter = _FakeAdapter(prompt_len=5, first_token=100) + drafter = _FakeDrafter(drafts=[[101, 102]]) + res = fsd.fused_specdecode_generate( + adapter, drafter, gen_tokens=50, block_size=4, eos_ids=(103,), + **_loop_kwargs(drafter)) + # correction 103 is EOS -> stop after first block. + assert res["tokens"] == [100, 101, 102, 103] + assert res["blocks"] == 1 + + +# =========================================================================== # +# 2) MLX-touching wrappers with fake mlx / mlx_lm. +# =========================================================================== # +class _Out: + """Stand-in for a layer's [1, L, hidden] output; ``[0]`` strips batch.""" + def __init__(self, idx): + self.idx = idx + + def __getitem__(self, k): + return ("row", self.idx) + + def __eq__(self, o): + return isinstance(o, _Out) and o.idx == self.idx + + def __hash__(self): + return self.idx + + +class _Layer: + def __init__(self, idx): + self.layer_idx = idx + + def __call__(self, x, *a, **k): + return (_Out(self.layer_idx), None, 0) # (h, kvs, offset) + + +class _TextModel: + def __init__(self, n=4): + self.layers = [_Layer(i) for i in range(n)] + self.embed_tokens = self._Embed() + + class _Embed: + def __call__(self, ids): + return "EMB" + + def as_linear(self, h): + return "LOGITS" + + +class _Model: + def __init__(self, tm, row=None): + self.model = tm + self._row = row + self.last_cache = "UNSET" + + def __call__(self, ids, cache=None): + # drive the (patched) layers so their _aux_record gets populated + tm = self.model + for l in tm.layers: + l(None) + self.last_cache = cache + if self._row is not None: + class _L: + def __init__(self, r): self._r = r + def __getitem__(self, k): + assert k == 0 + return self._r + return _L(self._row) + return None + + +def _install_mlx(monkeypatch, trim_log=None): + mx = types.ModuleType("mlx.core") + mx.array = lambda x, **k: x + mx.eval = lambda *a, **k: None + mx.argmax = lambda r, **k: r + mx.tanh = lambda x: x / 2 # non-identity marker so softcap is visible + mlx_pkg = types.ModuleType("mlx"); mlx_pkg.core = mx + cache_mod = types.ModuleType("mlx_lm.models.cache") + + def _trim(cache, n): + if trim_log is not None: + trim_log.append((cache, n)) + cache_mod.trim_prompt_cache = _trim + gen_mod = types.ModuleType("mlx_lm.generate") + gen_mod.generate_step = lambda *a, **k: iter(()) + for name, mod in [ + ("mlx", mlx_pkg), ("mlx.core", mx), + ("mlx_lm", types.ModuleType("mlx_lm")), + ("mlx_lm.models", types.ModuleType("mlx_lm.models")), + ("mlx_lm.models.cache", cache_mod), + ("mlx_lm.generate", gen_mod), + ]: + monkeypatch.setitem(sys.modules, name, mod) + return mx + + +def test_build_aux_indexing(monkeypatch): + _install_mlx(monkeypatch) + tm = _TextModel(n=4) + + class _E: + def __call__(self, ids): return 1.0 # numeric so *embed_scale works + def as_linear(self, h): return "L" + tm.embed_tokens = _E() + sink = {0: "h0", 1: "h1", 2: "h2", 3: "h3"} + # hs = [scaled_embeds(=2.0), h0, h1, h2, h3]; hs[a] = output of layer a-1. + aux = fsd._build_aux(tm, "ids", sink, embed_scale=2.0, aux_layer_ids=[0, 1, 3]) + assert aux == [2.0, "h0", "h2"] # hs[0]=embeds, hs[1]=h0, hs[3]=h2 + + +def test_capture_aux_hidden_runs_layers_and_indexes(monkeypatch): + _install_mlx(monkeypatch) + tm = _TextModel(n=3) + + class _E: + def __call__(self, ids): return 1.0 + def as_linear(self, h): return "L" + tm.embed_tokens = _E() + model = _Model(tm) + aux = fsd.capture_aux_hidden(model, [1, 2], aux_layer_ids=[1, 3], + embed_scale=10.0) + # hs[1] = output of layer 0; hs[3] = output of layer 2. + assert aux == [_Out(0), _Out(2)] + # _aux_record cleared from layers after capture. + for l in tm.layers: + assert not hasattr(l, "_aux_record") + + +def test_adapter_prefill_forward_commit(monkeypatch): + trim_log = [] + _install_mlx(monkeypatch, trim_log=trim_log) + tm = _TextModel(n=3) + + class _E: + def __call__(self, ids): return 1.0 + def as_linear(self, h): return "L" + tm.embed_tokens = _E() + model = _Model(tm, row="ROW") + + # patch restored_prefill_cache to a sentinel (its own test covers internals) + monkeypatch.setattr(fsd, "restored_prefill_cache", + lambda m, ids, **k: ("CACHE", "FIRST")) + adapter = fsd.MLXRestoredIncrementalVerifier( + model, embed_scale=10.0, aux_layer_ids=(1,), + bridge_to_torch=lambda a: ("torch", a)) + adapter.prefill([1, 2, 3], restored_k_per_layer={}, restored_v_per_layer={}, + evicted_positions=[1]) + assert adapter._cache == "CACHE" + assert adapter.next_token_logits == "FIRST" + assert adapter._past_len == 3 + + # forward_block with aux capture -> bridges hs[1] = layer-0 output + adapter._capture_aux = True + logits = adapter.forward_block([7, 8]) + assert logits == "ROW" # _Model returns row at [0] + # aux = [hs[1]] = [layer-0 output]; bridged after stripping batch ([0]). + assert adapter._last_aux == [("torch", ("row", 0))] + + # commit_or_truncate trims by (forwarded - accepted) and advances _past_len + adapter.commit_or_truncate(forwarded=2, accepted=1) + assert trim_log == [("CACHE", 1)] + assert adapter._past_len == 4 + + # no trim when fully accepted + trim_log.clear() + adapter.commit_or_truncate(forwarded=2, accepted=2) + assert trim_log == [] + assert adapter._past_len == 6 + + +def test_adapter_append_token_and_non_aux_path(monkeypatch): + _install_mlx(monkeypatch) + tm = _TextModel(n=2) + model = _Model(tm, row=[10, 11, 12]) # forward_block -> row at [0] + adapter = fsd.MLXRestoredIncrementalVerifier(model, embed_scale=1.0) + adapter._cache = "C" + adapter._past_len = 5 + # _capture_aux stays False -> non-aux branch; append_token commits (1,1). + nxt = adapter.append_token(99) + assert adapter._last_aux is None # non-aux path + assert nxt == 12 # logits[-1] + assert adapter._past_len == 6 + + +def test_adapter_prefill_rejects_empty_prompt(monkeypatch): + _install_mlx(monkeypatch) + adapter = fsd.MLXRestoredIncrementalVerifier(_Model(_TextModel(2)), embed_scale=1.0) + with pytest.raises(ValueError): + adapter.prefill([], restored_k_per_layer={}, restored_v_per_layer={}, + evicted_positions=[]) + + +def test_make_full_kv_prompt_cache_all_kvcache(monkeypatch): + # Fake mlx_lm.models.cache with make_prompt_cache (count) + a KVCache class. + import types as _t + class _FakeKV: + instances = 0 + def __init__(self): type(self).instances += 1 + cache_mod = _t.ModuleType("mlx_lm.models.cache") + cache_mod.make_prompt_cache = lambda model, **k: ["a", "b", "c", "d"] # 4 layers + cache_mod.KVCache = _FakeKV + monkeypatch.setitem(sys.modules, "mlx_lm", _t.ModuleType("mlx_lm")) + monkeypatch.setitem(sys.modules, "mlx_lm.models", _t.ModuleType("mlx_lm.models")) + monkeypatch.setitem(sys.modules, "mlx_lm.models.cache", cache_mod) + out = fsd.make_full_kv_prompt_cache(object()) + assert len(out) == 4 and all(isinstance(c, _FakeKV) for c in out) + assert _FakeKV.instances == 4 # every layer is a fresh full KVCache + + +def test_patched_decoder_layers_empty_is_noop(monkeypatch): + _install_mlx(monkeypatch) + tm = _TextModel(0) + with fsd._patched_decoder_layers(tm): + pass # no layers -> no-op guard + + +def test_adapter_commit_validates_accepted(monkeypatch): + _install_mlx(monkeypatch) + tm = _TextModel(n=2) + adapter = fsd.MLXRestoredIncrementalVerifier(_Model(tm), embed_scale=1.0) + adapter._cache = "C" + with pytest.raises(ValueError): + adapter.commit_or_truncate(forwarded=2, accepted=3) + + +def test_adapter_forward_block_requires_prefill(monkeypatch): + _install_mlx(monkeypatch) + tm = _TextModel(n=2) + adapter = fsd.MLXRestoredIncrementalVerifier(_Model(tm), embed_scale=1.0) + with pytest.raises(RuntimeError): + adapter.forward_block([1]) + adapter._cache = "C" + with pytest.raises(ValueError): + adapter.forward_block([]) + + +def test_bridge_embed_is_unscaled_and_lm_head_softcaps(monkeypatch): + mx = _install_mlx(monkeypatch) + tm = _TextModel(n=2) + seen = {} + + class _E: + def __call__(self, ids): + seen["embed_ids"] = ids + return "RAW_EMB" # NOT multiplied by embed_scale + + def as_linear(self, h): + seen["as_linear_h"] = h + return 100.0 + tm.embed_tokens = _E() + + embed_fn, lm_head_fn = fsd.make_bridge_embed_lm_head( + tm, mx_to_torch=lambda a, **k: ("mt", a), + torch_to_mx=lambda h: ("tm", h), + device="cpu", torch_dtype="f32", softcap=50.0) + + class _Ids: + def detach(self): return self + def to(self, d): return self + def tolist(self): return [[1, 2]] + out_emb = embed_fn(_Ids()) + assert out_emb == ("mt", "RAW_EMB") # plain lookup (Gap-B: no *scale) + + out_logits = lm_head_fn("H") + # softcap*tanh(as_linear/softcap): 50*tanh(100/50)=50*(2/2)=50 (fake tanh=x/2) + assert seen["as_linear_h"] == ("tm", "H") + assert out_logits == ("mt", 50.0) # softcap path applied diff --git a/tests/backends/mlx/test_restored_incremental_decode.py b/tests/backends/mlx/test_restored_incremental_decode.py new file mode 100644 index 00000000..6afef2ac --- /dev/null +++ b/tests/backends/mlx/test_restored_incremental_decode.py @@ -0,0 +1,194 @@ +"""Linux-CI tests for the MLX incremental restored-decode wrappers +(``restored_prefill_cache`` / ``restored_incremental_generate``). + +These functions import ``mlx`` / ``mlx_lm`` lazily, so to exercise their +control flow on Linux (no Apple Silicon) we inject minimal fake ``mlx.core`` +and ``mlx_lm`` modules via ``monkeypatch.setitem(sys.modules, ...)`` (auto +reverted). The real MLX kernels/cache behaviour are validated on a Mac by +``scripts/research/k3_integrated_niah_eval_mac.py --incremental``; here we lock +in the wrapper logic: which layers get the inject config, cache plumbing, and +the argmax/EOS/stop-condition decode loop. +""" + +from __future__ import annotations + +import sys +import types + +import pytest + +from inference_engine.backends.mlx import cross_model_dlm_verifier as cmv + + +# --------------------------------------------------------------------------- # +# Fake model structure +# --------------------------------------------------------------------------- # +class _FakeAttn: + def __init__(self, layer_idx, has_kv=True): + self.layer_idx = layer_idx + self.has_kv = has_kv + + def __call__(self, *a, **k): # present so _patched_attention_class can swap + raise AssertionError("attn should not be invoked by the fake model") + + +class _FakeLayer: + def __init__(self, attn): + self.self_attn = attn + + +class _FakeTextModel: + def __init__(self, n=6, shared=()): + self.layers = [_FakeLayer(_FakeAttn(i, has_kv=i not in shared)) + for i in range(n)] + self.previous_kvs = list(range(n)) + self.embed_tokens = object() # resolve_mlx_text_model sentinel + + +class _Logits: + """Supports ``logits[0, -1]`` -> the last-row vocab list.""" + def __init__(self, row): + self._row = row + + def __getitem__(self, key): + assert key == (0, -1) + return list(self._row) + + +class _FakeModel: + """mlx_lm-like wrapper: ``.model`` is the text model and it is callable.""" + def __init__(self, tm, last_row): + self.model = tm + self._row = last_row + self.captured_inject = None + self.last_cache = "UNSET" + + def __call__(self, ids, cache=None): + self.captured_inject = [ + l.self_attn.layer_idx for l in self.model.layers + if getattr(l.self_attn, "_kakeya_inject", None) + and l.self_attn._kakeya_inject.get("mode") == "inject" + ] + self.last_cache = cache + return _Logits(self._row) + + +# --------------------------------------------------------------------------- # +# Fake mlx / mlx_lm modules +# --------------------------------------------------------------------------- # +class _Scalar: + def __init__(self, v): + self._v = v + + def item(self): + return self._v + + +def _install_fakes(monkeypatch, *, prompt_cache="CACHE", gen_stream=()): + mx = types.ModuleType("mlx.core") + mx.array = lambda x, **k: x + mx.eval = lambda *a, **k: None + mx.argmax = lambda row, **k: _Scalar(int(max(range(len(row)), + key=lambda i: row[i]))) + mlx_pkg = types.ModuleType("mlx") + mlx_pkg.core = mx + + base = types.ModuleType("mlx_lm.models.base") + base.scaled_dot_product_attention = lambda *a, **k: None + cache_mod = types.ModuleType("mlx_lm.models.cache") + cache_mod.make_prompt_cache = lambda model, **k: prompt_cache + gen_mod = types.ModuleType("mlx_lm.generate") + + def _generate_step(prompt, model, *, prompt_cache=None, max_tokens=256, **k): + for i, tok in enumerate(gen_stream): + if i >= max_tokens: + break + yield tok, 0.0 + gen_mod.generate_step = _generate_step + + models_pkg = types.ModuleType("mlx_lm.models") + mlx_lm_pkg = types.ModuleType("mlx_lm") + for name, mod in [ + ("mlx", mlx_pkg), ("mlx.core", mx), + ("mlx_lm", mlx_lm_pkg), ("mlx_lm.models", models_pkg), + ("mlx_lm.models.base", base), ("mlx_lm.models.cache", cache_mod), + ("mlx_lm.generate", gen_mod), + ]: + monkeypatch.setitem(sys.modules, name, mod) + + +# --------------------------------------------------------------------------- # +# restored_prefill_cache +# --------------------------------------------------------------------------- # +def test_prefill_injects_only_source_layers_with_restored_kv(monkeypatch): + _install_fakes(monkeypatch) + tm = _FakeTextModel(n=6, shared=(5,)) # layer 5 is a KV-sharer + model = _FakeModel(tm, last_row=[0.1, 0.9, 0.2]) + rk = {0: "k0", 2: "k2", 5: "k5"} # 5 is sharer -> skipped + rv = {0: "v0", 2: "v2", 5: "v5"} + cache, last = cmv.restored_prefill_cache( + model, [10, 11, 12, 13], + restored_k_per_layer=rk, restored_v_per_layer=rv, + evicted_positions=[1, 2]) + # Only has_kv layers present in rk get injected (0, 2). Layer 5 is a sharer + # (skipped); layers 1,3,4 have no restored K/V (skipped). + assert model.captured_inject == [0, 2] + # Cache from make_prompt_cache is threaded into the forward and returned. + assert cache == "CACHE" + assert model.last_cache == "CACHE" + # Last-row logits returned (predicts first token). + assert last == [0.1, 0.9, 0.2] + + +def test_prefill_evicted_mask_clamped_and_attention_restored(monkeypatch): + _install_fakes(monkeypatch) + tm = _FakeTextModel(n=3) + attn_cls = type(tm.layers[0].self_attn) + orig_call = attn_cls.__call__ + model = _FakeModel(tm, last_row=[1.0, 0.0]) + # out-of-range evicted positions are ignored (clamped to prompt length) + cmv.restored_prefill_cache( + model, [7, 8], restored_k_per_layer={0: "k"}, restored_v_per_layer={0: "v"}, + evicted_positions=[0, 99, -1]) + # Attention __call__ restored after the context manager and inject config + # cleared from every layer. + assert attn_cls.__call__ is orig_call + for l in tm.layers: + assert not hasattr(l.self_attn, "_kakeya_inject") + + +# --------------------------------------------------------------------------- # +# restored_incremental_generate +# --------------------------------------------------------------------------- # +def test_generate_single_token_when_max_tokens_one(monkeypatch): + _install_fakes(monkeypatch, gen_stream=[5, 6, 7]) + model = _FakeModel(_FakeTextModel(), last_row=None) + out = cmv.restored_incremental_generate( + model, "CACHE", [0.0, 0.0, 1.0], max_tokens=1) + assert out == [2] # argmax of first_logits, no decode + + +def test_generate_stops_when_first_is_eos(monkeypatch): + _install_fakes(monkeypatch, gen_stream=[5, 6]) + model = _FakeModel(_FakeTextModel(), last_row=None) + out = cmv.restored_incremental_generate( + model, "CACHE", [0.0, 9.0], max_tokens=16, eos_ids=[1]) + assert out == [1] # first token is EOS -> stop + + +def test_generate_streams_until_eos(monkeypatch): + _install_fakes(monkeypatch, gen_stream=[5, 6, 99, 7]) + model = _FakeModel(_FakeTextModel(), last_row=None) + out = cmv.restored_incremental_generate( + model, "CACHE", [0.0, 0.0, 1.0], max_tokens=16, eos_ids=[99]) + # first = argmax([..1.0]) = 2, then stream 5,6 then EOS 99 (included, stops) + assert out == [2, 5, 6, 99] + + +def test_generate_streams_until_max_tokens(monkeypatch): + _install_fakes(monkeypatch, gen_stream=[5, 6, 7, 8, 9]) + model = _FakeModel(_FakeTextModel(), last_row=None) + out = cmv.restored_incremental_generate( + model, "CACHE", [9.0, 0.0], max_tokens=3) + # first = argmax([9,0]) = 0, then generate_step capped at max_tokens-1 = 2 + assert out == [0, 5, 6] diff --git a/tests/inference_engine/bench/test_k3_report_gate.py b/tests/inference_engine/bench/test_k3_report_gate.py new file mode 100644 index 00000000..4d99a549 --- /dev/null +++ b/tests/inference_engine/bench/test_k3_report_gate.py @@ -0,0 +1,407 @@ +"""Unit tests for inference_engine.bench.k3_report_gate. + +Each rule is pinned by mutating exactly one aspect of a fully valid +schema-2 report, so a future schema drift that silently disables a +rule fails here. The fixtures mirror the real committed report shapes +(see results/research/k3_mlx_fused_fair_ctx280_n5_gen32_*.json on this +branch — the report whose failure modes created this gate). + +Coverage target: 100% on ``inference_engine/bench/k3_report_gate.py``. +""" + +from __future__ import annotations + +import copy +from typing import Any, Dict + +import pytest + +from inference_engine.bench.k3_report_gate import ( + CLAIM_ORACLE_DECODE_LOOP, + GATED_SCHEMA_VERSION, + MAC_REPORT_KIND, + MAX_PREFILL_SPREAD, + MIN_MEDIAN_DECODE_TOKENS, + MIN_PERF_SAMPLES, + NATIVE_BASELINE_LABEL, + GateViolation, + decode_only_block, + is_gated_report, + is_legacy_report, + prefill_spread, + row_prefill_seconds, + summarize_violations, + validate_report, +) + + +def _valid_report(n: int = MIN_PERF_SAMPLES) -> Dict[str, Any]: + """A schema-2 report that passes every rule.""" + cross_rows = [ + { + "sample": i, + "prefill_s": 30.0 + i, + "decode_s": 2.0, + "e2e_s": 32.0 + i, + "restoration_active": True, + "decode_loop": "fused_specdecode", + "fused": {"blocks": 8, "mean_accept_len": 1.5}, + } + for i in range(n) + ] + oracle_rows = [ + {"sample": i, "prefill_s": 31.0 + i, "decode_s": 4.0, "e2e_s": 35.0 + i} + for i in range(n) + ] + return { + "schema_version": GATED_SCHEMA_VERSION, + "kind": MAC_REPORT_KIND, + "results": { + "k3_cross_model": { + "recall": 1.0, + "per_sample_decode_tokens": [MIN_MEDIAN_DECODE_TOKENS] * n, + "system_under_test": "restored_cross_model", + }, + "oracle": { + "recall": 1.0, + "per_sample_decode_tokens": [MIN_MEDIAN_DECODE_TOKENS] * n, + }, + }, + "gate": {"recall_cross_model": 1.0, "recall_oracle": 1.0}, + "memory": { + "s5": {"formula_matches_run": True}, + "savings_vs_naive_pct": 89.8, + }, + "throughput": { + "k3_cross_model": { + "eval_mode": "free_gen_fused_specdecode", + "timing_scope": "e2e_prefill_plus_decode", + "stage_timings": cross_rows, + }, + "oracle_native_ar": { + "timing_scope": "e2e_prefill_plus_decode", + "decode_loop": CLAIM_ORACLE_DECODE_LOOP, + "stage_timings": oracle_rows, + }, + "decode_only": { + "cross_median_tok_s": 16.0, + "oracle_median_tok_s": 8.0, + "speedup": 2.0, + }, + "cross_model_speedup_vs_oracle_ar": 1.18, + }, + } + + +def _codes(report: Dict[str, Any]) -> set: + return {v.code for v in validate_report(report)} + + +# --------------------------------------------------------------------------- +# Scope / legacy handling +# --------------------------------------------------------------------------- + + +def test_valid_report_has_no_violations(): + assert validate_report(_valid_report()) == [] + + +def test_non_gated_kinds_validate_trivially(): + assert validate_report({"kind": "k3_f_theta_train", "schema_version": 1}) == [] + assert validate_report("not even a dict") == [] # type: ignore[arg-type] + + +def test_is_gated_report(): + assert is_gated_report(_valid_report()) + assert not is_gated_report({"kind": "other"}) + assert not is_gated_report(None) + + +def test_legacy_schema_is_single_grandfather_violation(): + report = _valid_report() + report["schema_version"] = 1 + violations = validate_report(report) + assert [v.code for v in violations] == ["LEGACY_SCHEMA"] + assert "NON-EVIDENCE" in violations[0].message + + +def test_is_legacy_report_handles_garbage_versions(): + assert is_legacy_report({"schema_version": 1}) + assert is_legacy_report({"schema_version": "not-a-number"}) + assert is_legacy_report({}) + assert not is_legacy_report({"schema_version": GATED_SCHEMA_VERSION}) + + +# --------------------------------------------------------------------------- +# Path identity rules +# --------------------------------------------------------------------------- + + +def test_missing_stage_timings_flagged(): + report = _valid_report() + report["throughput"]["k3_cross_model"]["stage_timings"] = [] + # No rows ⇒ also no restoration evidence for the recall claim. + assert {"MISSING_STAGE_TIMINGS", "RECALL_SCOPE"} <= _codes(report) + + +def test_stage_timings_wrong_type_treated_as_missing(): + report = _valid_report() + report["throughput"]["k3_cross_model"]["stage_timings"] = "oops" + assert "MISSING_STAGE_TIMINGS" in _codes(report) + + +def test_missing_restoration_flag_flagged(): + report = _valid_report() + del report["throughput"]["k3_cross_model"]["stage_timings"][2]["restoration_active"] + assert "MISSING_RESTORATION_FLAG" in _codes(report) + + +def test_mixed_restoration_paths_flagged(): + report = _valid_report() + report["throughput"]["k3_cross_model"]["stage_timings"][0]["restoration_active"] = False + codes = _codes(report) + assert "MIXED_RESTORATION_PATHS" in codes + assert "RECALL_SCOPE" in codes # recall claim no longer covered + + +def _native_baseline_report() -> Dict[str, Any]: + """A correctly-declared native-baseline run (no claims) — admissible.""" + report = _valid_report() + for row in report["throughput"]["k3_cross_model"]["stage_timings"]: + row["restoration_active"] = False + row["fused"] = {"blocks": 0, "mean_accept_len": 0.0} + row["decode_loop"] = "per_token_eval" + report["throughput"]["k3_cross_model"]["eval_mode"] = "native_ar_baseline" + report["results"]["k3_cross_model"]["system_under_test"] = NATIVE_BASELINE_LABEL + report["gate"]["recall_cross_model"] = None + report["gate"]["recall_native_baseline"] = 1.0 + report["throughput"]["cross_model_speedup_vs_oracle_ar"] = None + report["memory"]["savings_vs_naive_pct"] = None + report["memory"]["s5"]["formula_matches_run"] = False + return report + + +def test_declared_native_baseline_is_admissible(): + assert validate_report(_native_baseline_report()) == [] + + +def test_undeclared_baseline_as_sut_flagged(): + report = _native_baseline_report() + report["results"]["k3_cross_model"]["system_under_test"] = "restored_cross_model" + assert "BASELINE_AS_SUT" in _codes(report) + + +def test_baseline_recall_claim_flagged(): + report = _native_baseline_report() + report["gate"]["recall_cross_model"] = 1.0 + codes = _codes(report) + assert "BASELINE_RECALL_CLAIM" in codes + assert "RECALL_SCOPE" in codes + + +def test_recall_scope_requires_all_samples_restored(): + report = _valid_report() + report["throughput"]["k3_cross_model"]["stage_timings"] = [] + assert "RECALL_SCOPE" in _codes(report) + + +# --------------------------------------------------------------------------- +# Fused execution rule — the blocks=0 reports that motivated the gate +# --------------------------------------------------------------------------- + + +def test_fused_never_ran_flagged(): + report = _valid_report() + for row in report["throughput"]["k3_cross_model"]["stage_timings"]: + row["fused"] = {"blocks": 0, "mean_accept_len": 0.0} + assert "FUSED_NEVER_RAN" in _codes(report) + + +def test_fused_missing_blocks_counts_as_zero(): + report = _valid_report() + for row in report["throughput"]["k3_cross_model"]["stage_timings"]: + row["fused"] = {} + row.pop("fused") + assert "FUSED_NEVER_RAN" in _codes(report) + + +def test_fused_rule_only_applies_to_fused_eval_mode(): + report = _valid_report() + report["throughput"]["k3_cross_model"]["eval_mode"] = "free_gen_incremental" + for row in report["throughput"]["k3_cross_model"]["stage_timings"]: + row.pop("fused") + row["decode_loop"] = "generate_step" + assert validate_report(report) == [] + + +# --------------------------------------------------------------------------- +# Speedup claim rules +# --------------------------------------------------------------------------- + + +def test_withheld_speedup_skips_all_speedup_rules(): + report = _valid_report(n=1) # tiny smoke... + report["throughput"]["cross_model_speedup_vs_oracle_ar"] = None # ...no claim + report["results"]["k3_cross_model"]["per_sample_decode_tokens"] = [8] + report["results"]["oracle"]["per_sample_decode_tokens"] = [8] + assert validate_report(report) == [] + + +def test_speedup_on_baseline_is_self_comparison(): + report = _native_baseline_report() + report["throughput"]["cross_model_speedup_vs_oracle_ar"] = 2.584 + assert "SPEEDUP_SELF_COMPARISON" in _codes(report) + + +def test_speedup_sample_floor(): + report = _valid_report(n=MIN_PERF_SAMPLES - 1) + assert "SPEEDUP_SAMPLES" in _codes(report) + + +def test_speedup_decode_token_floor(): + report = _valid_report() + report["results"]["k3_cross_model"]["per_sample_decode_tokens"] = ( + [8] * MIN_PERF_SAMPLES # the gen=8 smokes that motivated the rule + ) + assert "SPEEDUP_DECODE_TOKENS" in _codes(report) + + +def test_speedup_decode_token_floor_missing_lists(): + report = _valid_report() + report["results"]["oracle"]["per_sample_decode_tokens"] = [] + assert "SPEEDUP_DECODE_TOKENS" in _codes(report) + + +def test_speedup_requires_decode_only_block(): + report = _valid_report() + report["throughput"]["decode_only"] = {} + assert "SPEEDUP_DECODE_ONLY_MISSING" in _codes(report) + del report["throughput"]["decode_only"] + assert "SPEEDUP_DECODE_ONLY_MISSING" in _codes(report) + + +def test_speedup_scope_mismatch_flagged(): + report = _valid_report() + report["throughput"]["oracle_native_ar"]["timing_scope"] = "decode_only" + assert "SPEEDUP_SCOPE_MISMATCH" in _codes(report) + + +def test_speedup_oracle_loop_rule(): + report = _valid_report() + report["throughput"]["oracle_native_ar"]["decode_loop"] = "per_token_eval" + assert "SPEEDUP_ORACLE_LOOP" in _codes(report) + + +def test_speedup_prefill_variance_rule_each_arm(): + report = _valid_report() + # The ctx280 oracle arm: 35.3s..146.3s on identical work. + rows = report["throughput"]["oracle_native_ar"]["stage_timings"] + rows[0]["prefill_s"] = 35.3 + rows[3]["prefill_s"] = 146.3 + assert "SPEEDUP_PREFILL_VARIANCE" in _codes(report) + + report2 = _valid_report() + rows2 = report2["throughput"]["k3_cross_model"]["stage_timings"] + rows2[0]["prefill_s"] = 10.0 + rows2[1]["prefill_s"] = 10.0 * (MAX_PREFILL_SPREAD + 0.1) + assert "SPEEDUP_PREFILL_VARIANCE" in _codes(report2) + + +# --------------------------------------------------------------------------- +# Memory claim rule +# --------------------------------------------------------------------------- + + +def test_memory_claim_requires_formula_match(): + report = _valid_report() + report["memory"]["s5"]["formula_matches_run"] = False + assert "MEMORY_CLAIM_MISMATCH" in _codes(report) + report["memory"]["s5"] = {} + assert "MEMORY_CLAIM_MISMATCH" in _codes(report) + + +def test_memory_rule_skipped_when_no_savings_claim(): + report = _valid_report() + report["memory"]["savings_vs_naive_pct"] = None + report["memory"]["s5"]["formula_matches_run"] = False + assert validate_report(report) == [] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def test_row_prefill_seconds_accepts_both_spellings(): + assert row_prefill_seconds({"prefill_s": 1.5}) == 1.5 + assert row_prefill_seconds({"restored_prefill_s": 2.5}) == 2.5 + assert row_prefill_seconds({"prefill_s": "bad"}) is None + assert row_prefill_seconds({}) is None + + +def test_prefill_spread(): + assert prefill_spread([{"prefill_s": 10.0}, {"prefill_s": 40.0}]) == 4.0 + assert prefill_spread([{"prefill_s": 10.0}]) is None + assert prefill_spread([{"prefill_s": 0.0}, {"prefill_s": -1}]) is None + assert prefill_spread([]) is None + + +def test_decode_only_block_happy_path(): + cross = [{"decode_s": 2.0}, {"decode_s": 4.0}] + oracle = [{"decode_s": 8.0}, {"decode_s": 8.0}] + block = decode_only_block(cross, [32, 32], oracle, [32, 32]) + assert block == { + "cross_median_tok_s": 12.0, # median(16, 8) + "oracle_median_tok_s": 4.0, + "speedup": 3.0, + } + + +def test_decode_only_block_unusable_samples_return_none(): + assert decode_only_block([{"decode_s": 0.0}], [8], [{"decode_s": 1.0}], [8]) is None + assert decode_only_block([{"decode_s": 1.0}], [0], [{"decode_s": 1.0}], [8]) is None + assert decode_only_block([{"decode_s": 1.0}], [8], [{}], [8]) is None + + +def test_summarize_violations_renders_codes(): + text = summarize_violations([ + GateViolation("A_CODE", "first"), + GateViolation("B_CODE", "second"), + ]) + assert text == " [A_CODE] first\n [B_CODE] second" + + +def test_the_committed_ctx280_report_shape_would_now_fail(): + """Regression lock: a report shaped like the real + k3_mlx_fused_fair_ctx280_n5_gen32 run (baseline-as-SUT, blocks=0, + gen=8, 4.1x oracle prefill spread, formula memory table) violates + multiple rules at schema 2 instead of presenting as a 2.584x win.""" + report = _valid_report() + for row in report["throughput"]["k3_cross_model"]["stage_timings"]: + row["restoration_active"] = False # native bypass ran + row["fused"] = {"blocks": 0, "mean_accept_len": 0.0} + report["results"]["k3_cross_model"]["per_sample_decode_tokens"] = [8, 7, 8, 8, 8] + report["results"]["oracle"]["per_sample_decode_tokens"] = [8, 7, 8, 8, 8] + oracle_rows = report["throughput"]["oracle_native_ar"]["stage_timings"] + oracle_rows[0]["prefill_s"] = 35.3 + oracle_rows[3]["prefill_s"] = 146.3 + report["throughput"]["oracle_native_ar"]["decode_loop"] = "per_token_eval" + del report["throughput"]["decode_only"] + report["throughput"]["cross_model_speedup_vs_oracle_ar"] = 2.584 + # The analytical sink+window table did not describe the native cache + # that actually ran. + report["memory"]["s5"]["formula_matches_run"] = False + + codes = _codes(report) + assert { + "BASELINE_AS_SUT", + "BASELINE_RECALL_CLAIM", + "RECALL_SCOPE", + "FUSED_NEVER_RAN", + "SPEEDUP_SELF_COMPARISON", + "SPEEDUP_DECODE_TOKENS", + "SPEEDUP_DECODE_ONLY_MISSING", + "SPEEDUP_ORACLE_LOOP", + "SPEEDUP_PREFILL_VARIANCE", + "MEMORY_CLAIM_MISMATCH", + } <= codes diff --git a/tests/inference_engine/bridge/__init__.py b/tests/inference_engine/bridge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/inference_engine/bridge/test_manifest.py b/tests/inference_engine/bridge/test_manifest.py new file mode 100644 index 00000000..c69930d2 --- /dev/null +++ b/tests/inference_engine/bridge/test_manifest.py @@ -0,0 +1,285 @@ +"""Unit tests for inference_engine.bridge.manifest (Mac bridge core). + +This allowlist is the bridge's entire security argument (design doc +§3): what these tests pin is exactly what the Mac runner enforces, +because the executor imports this module. Every rejection path is +covered so a refactor cannot silently widen the command surface. + +Coverage target: 100% on ``inference_engine/bridge/manifest.py``. +""" + +from __future__ import annotations + +import json + +import pytest + +from inference_engine.bridge.manifest import ( + BRANCH_PREFIX, + MANIFEST_SCHEMA_VERSION, + MAX_BLOCK_SIZE, + MAX_N_SAMPLES, + MAX_NEW_TOKENS, + PRESETS, + BridgeRequest, + ManifestError, + build_commands, + parse_manifest, + parse_manifest_text, +) + +HARNESS_ENV = { + "KAKEYA_MAC_VERIFIER_PATH": "/models/gemma-4-26B-A4B-it-mlx-4bit", + "KAKEYA_MAC_DRAFTER_ID": "z-lab/gemma-4-26B-A4B-it-DFlash", + "KAKEYA_MAC_FTHETA_DIR": "results/research/f_theta_v5_s5_sliding", +} + + +def _manifest(**overrides): + data = { + "schema_version": MANIFEST_SCHEMA_VERSION, + "preset": "mlx-env-probe", + "params": {}, + "ref": "main", + "requested_by": "test", + "nonce": "1760000000-abc123", + } + data.update(overrides) + return data + + +# --------------------------------------------------------------------------- +# Allowlist shape +# --------------------------------------------------------------------------- + + +def test_allowlist_contains_exactly_the_documented_presets(): + assert sorted(PRESETS) == [ + "integration-tests", + "k3-drafter-parity", + "k3-drafter-parity-fp32", + "k3-evidence-gate", + "k3-fused-allmlx-code", + "k3-fused-allmlx-code-trim", + "k3-fused-allmlx-natural", + "k3-fused-singlefused-probe", + "k3-kv-quant-eval", + "k3-native-baseline", + "k3-step1-incremental", + "k3-step2-fused", + "k3-step2-fused-allmlx", + "mlx-backend-tests", + "mlx-env-probe", + "pytest-path", + ] + + +def test_every_preset_has_timeout_and_description(): + for preset in PRESETS.values(): + assert preset.timeout_minutes > 0 + assert preset.description + assert preset.command_templates + + +def test_harness_presets_validate_reports_others_do_not(): + gated = {name for name, p in PRESETS.items() if p.validate_reports} + assert gated == { + "k3-step1-incremental", "k3-step2-fused", "k3-native-baseline", + "k3-step2-fused-allmlx", + } + + +def test_allmlx_preset_carries_both_mode_flags(): + request = parse_manifest(_manifest(preset="k3-step2-fused-allmlx")) + (argv,) = build_commands(request, HARNESS_ENV) + assert "--fused-specdecode" in argv + assert "--all-mlx-drafter" in argv + assert "--ignore-turn-stop" in argv + + +def test_drafter_parity_preset_resolves(): + request = parse_manifest(_manifest( + preset="k3-drafter-parity", params={"block_size": "8"})) + (argv,) = build_commands(request, HARNESS_ENV) + assert argv[1].endswith("k3_mlx_drafter_parity.py") + assert HARNESS_ENV["KAKEYA_MAC_DRAFTER_ID"] in argv + assert argv[argv.index("--block-size") + 1] == "8" + + +# --------------------------------------------------------------------------- +# parse_manifest acceptance +# --------------------------------------------------------------------------- + + +def test_minimal_valid_manifest_parses(): + request = parse_manifest(_manifest()) + assert request.preset.name == "mlx-env-probe" + assert request.params == {} + assert request.branch_name == f"{BRANCH_PREFIX}mlx-env-probe-1760000000-abc123" + + +def test_round_trip_through_manifest_dict(): + request = parse_manifest(_manifest()) + again = parse_manifest(request.to_manifest_dict()) + assert again == request + + +def test_parse_manifest_text_valid_and_invalid_json(): + request = parse_manifest_text(json.dumps(_manifest())) + assert isinstance(request, BridgeRequest) + with pytest.raises(ManifestError, match="not valid JSON"): + parse_manifest_text("{nope") + + +def test_harness_preset_defaults_apply(): + request = parse_manifest(_manifest(preset="k3-step2-fused")) + assert request.params == { + "n_samples": "5", "max_new_tokens": "64", "block_size": "4", + } + + +def test_harness_preset_params_override_within_bounds(): + request = parse_manifest(_manifest( + preset="k3-step1-incremental", + params={"n_samples": "10", "max_new_tokens": "128", "block_size": "8"}, + )) + assert request.params == { + "n_samples": "10", "max_new_tokens": "128", "block_size": "8", + } + + +# --------------------------------------------------------------------------- +# parse_manifest rejection paths +# --------------------------------------------------------------------------- + + +def test_rejects_non_dict_and_wrong_schema(): + with pytest.raises(ManifestError, match="JSON object"): + parse_manifest(["not", "a", "dict"]) + with pytest.raises(ManifestError, match="schema_version"): + parse_manifest(_manifest(schema_version=99)) + + +def test_rejects_unknown_preset_listing_allowlist(): + with pytest.raises(ManifestError, match="allowlist"): + parse_manifest(_manifest(preset="rm-rf-everything")) + with pytest.raises(ManifestError, match="allowlist"): + parse_manifest(_manifest(preset=None)) + + +def test_rejects_unknown_params(): + with pytest.raises(ManifestError, match="does not accept params"): + parse_manifest(_manifest(params={"shell": "evil"})) + with pytest.raises(ManifestError, match="params must be an object"): + parse_manifest(_manifest(params="evil")) + + +def test_rejects_missing_required_param(): + with pytest.raises(ManifestError, match="requires param 'path'"): + parse_manifest(_manifest(preset="pytest-path")) + + +def test_rejects_out_of_bounds_ints(): + for name, bad in ( + ("n_samples", str(MAX_N_SAMPLES + 1)), + ("max_new_tokens", str(MAX_NEW_TOKENS + 1)), + ("block_size", str(MAX_BLOCK_SIZE + 1)), + ("n_samples", "0"), + ("n_samples", "-3"), + ): + with pytest.raises(ManifestError, match="out of bounds"): + parse_manifest(_manifest( + preset="k3-step2-fused", params={name: bad})) + + +def test_rejects_non_integer_int_params(): + with pytest.raises(ManifestError, match="not an integer"): + parse_manifest(_manifest( + preset="k3-step2-fused", params={"n_samples": "five; rm -rf /"})) + + +def test_pytest_path_traversal_and_escape_rejected(): + for bad in ("/etc/passwd", "~/x", "tests/../scripts/serve.py", + "scripts/serve.py", ""): + with pytest.raises(ManifestError): + parse_manifest(_manifest( + preset="pytest-path", params={"path": bad})) + + +def test_pytest_path_accepts_tests_subpaths(): + for ok in ("tests", "tests/backends/mlx/test_fused_specdecode.py", + "tests/integration/"): + request = parse_manifest(_manifest( + preset="pytest-path", params={"path": ok})) + assert request.params["path"] == ok + + +def test_rejects_bad_nonce_ref_requested_by(): + with pytest.raises(ManifestError, match="nonce"): + parse_manifest(_manifest(nonce="UPPER CASE!")) + with pytest.raises(ManifestError, match="nonce"): + parse_manifest(_manifest(nonce=None)) + with pytest.raises(ManifestError, match="ref"): + parse_manifest(_manifest(ref="")) + with pytest.raises(ManifestError, match="requested_by"): + parse_manifest(_manifest(requested_by="")) + + +# --------------------------------------------------------------------------- +# build_commands +# --------------------------------------------------------------------------- + + +def test_simple_preset_builds_fixed_argv(): + request = parse_manifest(_manifest(preset="mlx-backend-tests")) + commands = build_commands(request, {}) + assert commands == [ + ["python3", "-m", "pytest", "tests/backends/mlx/", "-q"], + ] + + +def test_pytest_path_param_substitution_is_argv_level(): + request = parse_manifest(_manifest( + preset="pytest-path", + params={"path": "tests/backends/mlx/test_env.py"})) + commands = build_commands(request, {}) + assert commands == [ + ["python3", "-m", "pytest", "tests/backends/mlx/test_env.py", "-q"], + ] + + +def test_harness_preset_resolves_env_and_params(): + request = parse_manifest(_manifest( + preset="k3-step2-fused", + params={"n_samples": "7", "max_new_tokens": "96", "block_size": "8"}, + )) + (argv,) = build_commands(request, HARNESS_ENV) + assert argv[0:2] == ["python3", "scripts/research/k3_integrated_niah_eval_mac.py"] + assert HARNESS_ENV["KAKEYA_MAC_VERIFIER_PATH"] in argv + assert "--fused-specdecode" in argv + assert "--ignore-turn-stop" in argv # full decode budget (gate rule) + assert argv[argv.index("--n-samples") + 1] == "7" + assert argv[argv.index("--max-new-tokens") + 1] == "96" + assert argv[argv.index("--block-size") + 1] == "8" + # No unresolved placeholders of either kind survive. + assert not [t for t in argv if t.startswith("${ENV:")] + assert not [t for t in argv if t.startswith("{") and t.endswith("}")] + + +def test_step1_and_baseline_presets_carry_their_mode_flags(): + incr = parse_manifest(_manifest(preset="k3-step1-incremental")) + (argv_incr,) = build_commands(incr, HARNESS_ENV) + assert "--incremental" in argv_incr + base = parse_manifest(_manifest(preset="k3-native-baseline")) + (argv_base,) = build_commands(base, HARNESS_ENV) + assert "--native-baseline-bypass" in argv_base + + +def test_missing_runner_env_is_a_hard_error(): + request = parse_manifest(_manifest(preset="k3-step2-fused")) + with pytest.raises(ManifestError, match="KAKEYA_MAC_VERIFIER_PATH"): + build_commands(request, {}) + partial = dict(HARNESS_ENV) + partial["KAKEYA_MAC_DRAFTER_ID"] = "" + with pytest.raises(ManifestError, match="KAKEYA_MAC_DRAFTER_ID"): + build_commands(request, partial)