diff --git a/.gitattributes b/.gitattributes index 52856e3f..9bee89b0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,6 @@ models/dflash-kakeya-baseline/*.safetensors filter=lfs diff=lfs merge=lfs -text +results/research/f_theta_v1/f_theta_weights.pt filter=lfs diff=lfs merge=lfs -text +results/research/f_theta_v3_attn_distill/f_theta_weights.pt filter=lfs diff=lfs merge=lfs -text +results/research/f_theta_v4a_warmstart_hybrid/f_theta_weights.pt filter=lfs diff=lfs merge=lfs -text +results/research/f_theta_v4b_fresh_hybrid/f_theta_weights.pt filter=lfs diff=lfs merge=lfs -text +results/research/f_theta_v5_s5_sliding/f_theta_weights.pt filter=lfs diff=lfs merge=lfs -text diff --git a/docs/k3-gpu-beta.md b/docs/k3-gpu-beta.md new file mode 100644 index 00000000..70998b2c --- /dev/null +++ b/docs/k3-gpu-beta.md @@ -0,0 +1,109 @@ +# K3 GPU beta — Kakeya inference (f_θ + S5 K/V-Restoration) + +Status: beta, GPU-validated on NVIDIA H200 with `google/gemma-4-26B-A4B-it` +(verifier) + `z-lab/gemma-4-26B-A4B-it-DFlash` (drafter) + the trained f_θ v5 +checkpoint (`results/research/f_theta_v5_s5_sliding/`). Recall 1.0 throughout. + +## What it is + +The verifier keeps only a **sink+window** local KV cache; at every *evicted* +position its attention reads **reconstructed** K/V, so it attends over the full +context while holding `O(sink+window)` resident KV (ADR 0008 §11). + + verifier (Gemma 4 26B-A4B): sink+window resident KV + ├─ sliding layers → evicted K/V restored via f_θ(drafter K/V) + └─ full-attn layers (S5: [5,11,17,23,29]) → verifier's OWN exact K/V + (recall-critical; f_θ cannot reconstruct these — + proven by the α-sweep, eval rel_mse floor ~1.4) + + drafter (DFlash 0.4B): no KV cache; constant-memory K/V reconstruction + source (its K/V are projected into verifier space by f_θ). + +## Components (this branch) + +| piece | file | +|---|---| +| DFlash drafter (block diffusion, faithful to z-lab `qwen3_dflash`) | `inference_engine/v04/dflash_drafter.py` | +| f_θ projection (drafter K/V → verifier K/V) | `inference_engine/v04/f_theta.py` | +| Cross-model restored verifier (CUDA) + S5 | `inference_engine/v04/cross_model_dlm_verifier.py` | +| Cross-model restored verifier (MLX / Apple Silicon) | `inference_engine/backends/mlx/cross_model_dlm_verifier.py` | +| Incremental restored verifier (`SinkWindowVerifier` API) | `inference_engine/v04/restored_sink_window_verifier.py` | +| Served-path factories + gRPC `--backend restored` | `inference_engine/v04/build_restored.py`, `scripts/start_grpc_runtime_server.py` | + +## Three engines (decode modes) + +* **Re-forward** (`incremental=False`) — memory-optimal, eval-grade; recomputes + restoration each step (O(T)/step). Bit-equivalent reference for the gate. +* **Gap-A incremental** (`incremental=True`) — capture restored K/V into a + `DynamicCache` at prefill, decode natively (O(L)/block). **= AR decode speed**, + KV 16.9×–43.9× smaller, recall 1.0. +* **Fused spec-decode** (`restored_specdecode_fused`) — DFlash block draft + + incremental verify, with three prefill-built, incrementally-extended caches: + (A) verifier aux hidden captured from the verify forward, (B) drafter context + K/V cache, (C) Gap-A restored KV. Per-block O(L). **> AR** (see below). + +## Validated results (H200, ctx 1238, gemma-4-26B-A4B) + +| path | decode tok/s | vs AR | recall | +|---|---|---|---| +| standalone AR | 21.1 | 1.0× | 1.0 | +| Gap-A incremental restored | 21.7 | 1.03× | 1.0 | +| fused DFlash spec-decode (aggregate) | 26.8 | **1.27×** | 1.0 | + +KV memory: restored resident KV constant **16.71 MB** vs AR 282 MB @1238 tok → +733 MB @3238 tok (**16.9× → 43.9×**, grows with context). DFlash acceptance on +HumanEval ≈ official gemma-4-26B parity (length ~3.9 ≈ official 3.3× speedup). + +## Run + +```bash +# Incremental restored decode vs AR (memory + tok/s + recall) +PYTHONPATH=.:sdks/python python scripts/research/k3_e2e_gpu_bench.py \ + --verifier-id google/gemma-4-26B-A4B-it \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --f-theta-dir results/research/f_theta_v5_s5_sliding \ + --incremental --haystack-lines 60,160 + +# Fused DFlash spec-decode vs AR +PYTHONPATH=.:sdks/python python scripts/research/k3_specdecode_gpu_bench.py \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash --skip-unfused + +# gRPC server with the restored backend +PYTHONPATH=.:sdks/python python scripts/start_grpc_runtime_server.py \ + --backend restored --device cuda \ + --verifier-id google/gemma-4-26B-A4B-it \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --f-theta-dir results/research/f_theta_v5_s5_sliding --sink 4 --window 64 +``` + +## Canonical proposer + +The proposer/drafter is **`z-lab/gemma-4-26B-A4B-it-DFlash`** (the official +checkpoint, with the Gap-B embed-scale fix) — used uniformly for both drafting +and as the f_θ restoration K/V source across all entry points. The earlier +`models/dflash-kakeya-baseline` was alignment-trained against a buggy +(`×sqrt(hidden)`-scaled) embed pipeline and is not the beta drafter. + +f_θ v5 was trained against the kakeya-baseline drafter, so its **sliding-layer** +restoration is technically off for z-lab K/V — but this is **harmless for +recall**: recall is carried by the S5 exact full-attention layers, and the +sliding-layer restored K/V are window-masked during decode. Both incremental +decode and fused spec-decode measure **recall 1.0** with z-lab. (If pure +sliding-layer restoration is ever needed, retrain f_θ on z-lab K/V.) + +All **inference/eval** entry points default to z-lab (`k3_e2e_gpu_bench`, +`k3_specdecode_gpu_bench`, `k3_integrated_niah_eval`(+`_mac`), +`k3_dflash_specdecode_eval`(+`_mac`); the gRPC server takes an explicit +`--drafter-id`). The **f_θ training** script (`k3_f_theta_train.py`) and its +orchestration `.sh` keep `models/dflash-kakeya-baseline` because that is how the +shipped v5 checkpoint was historically trained. + +## Notes / scope + +* Drafting conditions on the restored verifier hidden for committed decode tokens + (clean aux for the prompt) — resolves the bounded-KV vs clean-aux tension + natively; no SGLang/vLLM dependency. +* Stable decode requires loading the verifier without `device_map` (no accelerate + per-forward hooks; the 26B-A4B fits on one H200) + a full-length warmup. +* f_θ v5 restores the sliding layers; recall is carried by the S5 exact + full-attention layers, so f_θ fidelity is not the recall bottleneck. diff --git a/docs/mlx-port-lessons.md b/docs/mlx-port-lessons.md new file mode 100644 index 00000000..5a977125 --- /dev/null +++ b/docs/mlx-port-lessons.md @@ -0,0 +1,82 @@ +# Porting the K3 GPU beta (#107) to MLX — lessons & plan + +Audience: whoever ports the validated CUDA restored-verifier engine +(`inference_engine/v04/…`, PR #107) to the Apple-Silicon MLX backend +(`inference_engine/backends/mlx/…`). The current MLX blocker is **decode +token-throughput collapse**. This doc distills *why* #107 is fast and exactly +which mechanisms must be reproduced in MLX. + +## TL;DR — the throughput collapse is the O(T²) re-forward + +On MLX today, `restored_logits` (`backends/mlx/cross_model_dlm_verifier.py`) does +a **full-position forward over the whole sequence every step**, and the Mac +harness calls it per generated token → **O(T²)** → collapse (the same harness +also shows the *oracle* is fast because it uses mlx_lm's **native incremental KV +cache**). The fix is the #107 **Gap-A** trick, ported verbatim: + +> **Capture the restored K/V into a persistent (sink+window) cache at prefill, +> then decode with mlx_lm's native incremental step (O(L)/block) — never +> re-forward the whole sequence per token.** + +This alone takes the restored path from "collapsed" to **= native AR decode +speed** (on CUDA: 1.3–2.8 tok/s re-forward → ~21 tok/s incremental = AR). + +## What makes #107 fast — and the MLX analog of each + +| # | #107 (CUDA) mechanism | MLX analog / gotcha | +|---|---|---| +| 1 | **Gap-A incremental decode**: capture restored K/V (per layer, post-norm/RoPE) into a `transformers.DynamicCache` at prefill; decode L new tokens against it. | Capture into `inference_engine/backends/mlx/cache.SinkWindowKVCache` (already exists) and decode via **`mlx_lm.generate.generate_step`** with `prompt_cache=` — its **chunked prefill + `mx.async_eval` pipelined decode** is the throughput-critical part. A hand-rolled per-token loop with `mx.eval` each step is itself a collapse cause. | +| 2 | **S5 carries recall** via the 5 full-attention layers' **exact own K/V**; f_θ restores only the sliding layers (masked at decode). | Same: store the 5 full-attn evicted own K/V (KakeyaLattice-compressible); **do not** invest in f_θ sliding fidelity for recall. The needle reaches output through the full-attn layers only. | +| 3 | **Eliminate the extra `capture_own_kv` forward**: in #107 the full-attn own K/V are captured once at prefill (not recomputed per step). PR #108 showed removing it via *f_θ full-attn* breaks recall — wrong fix. | The Mac harness's 12.4s `build_restoration` is this extra forward. Right fix: capture own K/V from the **prefill** forward / store as positions evict — **not** f_θ-restore the full-attn layers. | +| 4 | **Fused spec-decode (>AR)** = three prefill-built, incrementally-extended caches: (A) verifier aux hidden from the verify forward, (B) drafter context K/V cache, (C) Gap-A restored KV. Per-block O(L). | Port `draft_block_cached` + `make/extend_context_kv` semantics to the MLX drafter path; capture aux from the MLX verify forward. Only after #1 works. | +| 5 | **Stabilization**: load verifier **without `device_map`** (no accelerate per-forward hooks) + **full-length warmup** (pre-size the allocator) → removed per-block variance. | MLX analog of the variance source is **graph (re)compilation + lazy eval**: warm up the *exact* shapes (prefill chunk size + 1-token decode) before timing; avoid shape churn; force `mx.eval` only where measuring. | +| 6 | **Gap-B drafter fidelity**: drafter query embedding is a **plain lookup — no Gemma `×sqrt(hidden)`** (port bug; fixed). | Same fix on the MLX drafting path: do not scale the shared embedding fed to the drafter. (z-lab acceptance 0.05→reference parity.) | + +## MLX-specific gotchas already learned + +- **MPS/MLX SDPA materializes scores** (no flash kernel for some shapes) → OOM at + long context. Use **bounded attention** (decode only attends sink+window+restored + evicted, not a transient full O(T) matrix) and/or **query-chunked SDPA** + (`KAKEYA_DFLASH_ATTN_QCHUNK`). Bounded decode (Gap-A) avoids the transient full + cache that OOM'd the ctx280 runs. +- **Lazy eval**: MLX is lazy; throughput depends on `mx.async_eval` pipelining + (mlx_lm's `generate_step` does this). Per-token `mx.eval().item()` serializes → + collapse. Mirror the native loop. +- **`make_sink_window_cache(model, *, sink_size, window_size)`** is keyword-only + (a past bug was positional args). The cache is a drop-in `_BaseCache`. +- **Cross-runtime bridge**: verifier in MLX, drafter+f_θ in PyTorch (MPS/CPU) is + workable, but the per-step tensor bridging must not re-forward; bridge only at + the K/V-injection boundary, once per block. + +## MLX port plan (ordered; each gates the next) + +1. **Incremental decode (kills the collapse).** Add an MLX analog of + `CrossModelRestoredSinkWindowVerifier(incremental=True)`: prefill → capture + restored K/V into `SinkWindowKVCache` (full-attn = own/exact; sliding = f_θ or + window-masked) → decode via `generate_step(prompt_cache=…)`. **Gate: decode + tok/s ≈ native mlx_lm AR; recall 1.0** (carried by S5). +2. **Drop the extra build forward.** Capture full-attn own K/V at prefill; do not + re-run a clean verifier forward per request beyond prefill. **Gate: + `build_restoration` from ~12s → ~prefill cost.** +3. **Gap-B drafter embed fix** (no `×sqrt`) on the MLX/Bridge drafting path. + **Gate: acceptance toward reference on code prompts.** +4. **Fused spec-decode** (A+B+C incremental caches). **Gate: tok/s > AR.** + +## Validation gates (match #107 evidence) + +- Recall **1.0** vs oracle (S5). +- Bounded resident KV (sink+window), reported via `kv_memory_report`. +- Decode tok/s: incremental **≥ native AR**; fused **> AR**. +- Reference: #107 on H200 — incremental = 1.0× AR (KV 16.9–43.9× smaller), + fused 1.27× AR, recall 1.0. (`docs/k3-gpu-beta.md`, + `results/research/k3_e2e_gpu_bench_incremental.json`, + `k3_specdecode_fused_stable.json`.) + +## Do-not-repeat (anti-patterns) + +- ❌ Re-forwarding the full sequence per generated token (the current collapse). +- ❌ A custom decode loop with per-token `mx.eval` (no async pipelining). +- ❌ f_θ-restoring the **full-attention** layers (PR #108: breaks recall; those + K/V are not reconstructable from the shallow drafter — α-sweep proven). Keep S5. +- ❌ Scaling the drafter's shared embedding by `×sqrt(hidden)` (Gap-B port bug). +- ❌ Materializing a transient full-T attention score matrix on MPS (OOM). diff --git a/inference_engine/backends/mlx/cross_model_dlm_verifier.py b/inference_engine/backends/mlx/cross_model_dlm_verifier.py new file mode 100644 index 00000000..29a27b27 --- /dev/null +++ b/inference_engine/backends/mlx/cross_model_dlm_verifier.py @@ -0,0 +1,362 @@ +"""MLX cross-model DLM-restored verifier (K3 Mac path). + +Apple-Silicon (MLX) counterpart of +:class:`inference_engine.v04.cross_model_dlm_verifier.CrossModelDLMRestoredVerifier` +(the validated CUDA/transformers implementation). Same architecture: + + * verifier = Gemma 4 26B-A4B (MLX 4-bit, ``mlx_lm``) + * proposer/drafter = DFlash 0.4B (PyTorch ``DFlashDrafter``, on MPS/CPU) + * f_θ = trained K/V projection (PyTorch ``FThetaProjection``) + +The verifier holds only a sink+window local cache; at *evicted* positions +its attention reads **restored** K/V: + + * **sliding-attention layers** → f_θ projection of the drafter's K/V + * **full-attention (global) layers** → the verifier's OWN true K/V + (**S5**). These are the recall-critical layers f_θ cannot reconstruct + (CUDA eval rel_mse floor ~1.4). For long context the needle is outside + the sliding window, so it reaches the output only through these layers + — exact K/V there gives oracle-parity recall (CUDA ctx280: 10/10). + +Cross-runtime design (mirrors ``scripts/research/k3_dflash_mlx_bridge.py``): +verifier in MLX; drafter + f_θ in PyTorch; tensors bridged at the +K/V-injection boundary. f_θ weights stay PyTorch (no MLX port for v1). + +Mechanism: MLX ``nn.Module`` resolves ``__call__`` on the *class*, so we +temporarily patch ``Attention.__call__`` (verified against +mlx_lm.models.gemma4_text 0.31.3) with a dispatcher that, for layers +carrying a per-instance ``_kakeya_inject`` config, replaces evicted-position +K/V (via ``mx.where``) with restored pre-norm K/V before k_norm + RoPE; +other layers fall through to the original. This is the MLX analogue of the +CUDA ``_make_patched_forward`` (which patches ``layer.self_attn.forward``). +A clean capture pass supplies the S5 exact K/V for the full-attention +layers — mirroring CUDA ``capture_verifier_own_kv``. + +KV sharing: Gemma 4 shares K/V across same-type layers for the last +``num_kv_shared_layers``. Sharing layers (``self_attn.has_kv == False``) +receive K/V via ``shared_kv`` from a source layer; injection happens only at +source (``has_kv``) layers and propagates to the sharers. + +**Validation status**: the MLX path needs Apple Silicon and is validated by +``scripts/research/k3_integrated_niah_eval_mac.py`` on a Mac. The non-MLX +helpers here are importable + unit-tested on Linux (``mlx`` is imported +lazily inside the MLX-touching functions). +""" + +from __future__ import annotations + +import contextlib +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +from inference_engine.v04.kv_merge import compute_evicted_positions + + +# --------------------------------------------------------------------------- +# Model-structure helpers (no mlx import needed; unit-tested on Linux) +# --------------------------------------------------------------------------- + + +def resolve_mlx_text_model(mlx_model: Any) -> Any: + """Return the ``Gemma4TextModel`` (exposes ``.layers`` / ``.embed_tokens``). + + Handles the multimodal wrapper (``model.language_model.model``) and the + text-only wrapper (``model.model``), matching the bridge resolver. + """ + logits_model = getattr(mlx_model, "language_model", mlx_model) + text_model = getattr(logits_model, "model", None) + if text_model is None and hasattr(logits_model, "embed_tokens"): + text_model = logits_model + if text_model is None or not hasattr(text_model, "embed_tokens"): + raise AttributeError( + "Could not locate MLX Gemma text model " + "(expected model.language_model.model or model.model)" + ) + return text_model + + +def mlx_full_attention_layer_indices(text_model: Any) -> List[int]: + """Indices of the full-attention (global) layers — the S5 exact layers. + + Detected via each attention module's ``head_dim`` (full layers use + ``global_head_dim`` > sliding ``head_dim``); falls back to the + ``layer_type == 'full_attention'`` label. Returns [] if uniform. + """ + layers = text_model.layers + head_dims: List[int] = [] + types: List[str] = [] + for layer in layers: + attn = layer.self_attn + head_dims.append(int(getattr(attn, "head_dim", 0))) + types.append(str(getattr(attn, "layer_type", ""))) + if len(set(head_dims)) > 1: + max_hd = max(head_dims) + return [i for i, hd in enumerate(head_dims) if hd == max_hd] + return [i for i, t in enumerate(types) if t == "full_attention"] + + +def per_layer_kv_geometry(text_model: Any) -> List[Tuple[int, int, str]]: + """Return ``[(n_kv_heads, head_dim, layer_type)]`` per layer.""" + out: List[Tuple[int, int, str]] = [] + for layer in text_model.layers: + a = layer.self_attn + out.append(( + int(getattr(a, "n_kv_heads", 0)), + int(getattr(a, "head_dim", 0)), + str(getattr(a, "layer_type", "")), + )) + return out + + +def kv_memory_report( + text_model: Any, + *, + sink_size: int, + window_size: int, + seq_len: int, + kv_dtype_bytes: int = 2, + exact_layer_indices: Optional[Sequence[int]] = None, + compress_full_bits_per_token_per_head: Optional[float] = None, +) -> Dict[str, Any]: + """Analytical resident-KV-cache accounting for the Kakeya S5 engine. + + Models a *bounded* production engine (not the eval's full re-forward): + + * sliding layers → resident = sink + window positions + * exact full layers→ resident = ``seq_len`` positions (S5 keeps them + exact). If ``compress_full_bits_per_token_per_head`` is given + (KakeyaLattice), the per-token byte cost uses the compressed + bits/head instead of ``head_dim * kv_dtype_bytes``. + + Returns per-layer-type bytes, total, and the per-token growth slope + (the asymptotic linear term, dominated by the exact full layers). + All quantities are pure arithmetic — unit-tested on Linux. + """ + geom = per_layer_kv_geometry(text_model) + exact = set(exact_layer_indices or []) + resident_window = sink_size + window_size + + def kv_bytes_per_token(n_kv: int, hd: int, compressed: bool) -> int: + # K + V (two tensors). attention_k_eq_v sharing is ignored here + # (conservative: count both K and V). + if compressed and compress_full_bits_per_token_per_head is not None: + per_head_bytes = compress_full_bits_per_token_per_head / 8.0 + return int(round(2 * n_kv * per_head_bytes)) + return 2 * n_kv * hd * kv_dtype_bytes + + sliding_total = 0 + full_total = 0 + full_slope = 0 # bytes/token contributed by O(T) exact layers + sliding_slope = 0 # bytes/token for sliding (0 once bounded) + per_layer = [] + for i, (n_kv, hd, lt) in enumerate(geom): + is_exact = i in exact + bpt = kv_bytes_per_token(n_kv, hd, compressed=is_exact) + if is_exact: + positions = seq_len + full_total += positions * bpt + full_slope += bpt + else: + positions = min(resident_window, seq_len) + sliding_total += positions * bpt + sliding_slope += bpt if seq_len <= resident_window else 0 + per_layer.append({ + "layer": i, "layer_type": lt, "n_kv_heads": n_kv, + "head_dim": hd, "exact": is_exact, + "resident_positions": positions, + "bytes_per_token": bpt, + "resident_bytes": positions * bpt, + }) + + total = sliding_total + full_total + return { + "seq_len": seq_len, + "kv_dtype_bytes": kv_dtype_bytes, + "sink_window": resident_window, + "exact_layer_indices": sorted(exact), + "compress_full_bits_per_token_per_head": compress_full_bits_per_token_per_head, + "sliding_resident_bytes": sliding_total, + "full_resident_bytes": full_total, + "total_resident_bytes": total, + "total_resident_mb": round(total / 1e6, 2), + "per_token_growth_bytes": full_slope + sliding_slope, + "per_token_growth_kb": round((full_slope + sliding_slope) / 1024, 2), + "per_layer": per_layer, + } + + +def kv_source_layer_map(text_model: Any) -> List[int]: + """Map layer index → the layer index that actually computes its K/V. + + For KV-shared layers (``has_kv == False``) the source is the earlier + same-type layer in ``text_model.previous_kvs``; otherwise self. + """ + n = len(text_model.layers) + prev = list(getattr(text_model, "previous_kvs", list(range(n)))) + src: List[int] = [] + for i in range(n): + attn = text_model.layers[i].self_attn + has_kv = bool(getattr(attn, "has_kv", True)) + src.append(i if has_kv else int(prev[i])) + return src + + +# --------------------------------------------------------------------------- +# MLX attention dispatcher (class-level patch with per-instance config) +# --------------------------------------------------------------------------- + + +def _build_dispatch(orig_call: Callable) -> Callable: + """Build a replacement ``Attention.__call__`` (mlx_lm 0.31.3) that honors + a per-instance ``self._kakeya_inject`` config when present, else delegates + to ``orig_call``. + + Config dict keys: ``mode`` ("capture"|"inject"), ``restored_k``, + ``restored_v`` (mx [B,T,n_kv,hd] pre-norm), ``evicted_mask`` (mx bool [T]), + ``sink`` (dict for capture mode). + """ + import mlx.core as mx # type: ignore + from mlx_lm.models.base import scaled_dot_product_attention as _sdpa # type: ignore + + def dispatch(self, x, mask=None, cache=None, shared_kv=None, offset=None): + cfg = getattr(self, "_kakeya_inject", None) + if cfg is None: + return orig_call(self, x, mask, cache, shared_kv, offset) + + B, L, _ = x.shape + queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim) + queries = self.q_norm(queries) + + if shared_kv is not None: + keys, values = shared_kv + elif not getattr(self, "has_kv", True): + raise ValueError( + f"Layer {self.layer_idx} is KV-shared but received no shared_kv" + ) + else: + keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + values = keys + if not self.use_k_eq_v: + values = self.v_proj(x).reshape( + B, L, self.n_kv_heads, self.head_dim + ) + + mode = cfg.get("mode") + if mode == "capture": + cfg["sink"][self.layer_idx] = (keys, values) + elif mode == "inject": + em = cfg.get("evicted_mask") + if em is not None: + m = em.reshape(1, L, 1, 1) + rk = cfg.get("restored_k") + if rk is not None: + keys = mx.where(m, rk.astype(keys.dtype), keys) + if self.use_k_eq_v: + values = keys + else: + rv = cfg.get("restored_v") + if rv is not None: + values = mx.where(m, rv.astype(values.dtype), values) + + offset = mx.array(cache.offset) if cache is not None else 0 + keys = self.k_norm(keys) + keys = keys.transpose(0, 2, 1, 3) + keys = self.rope(keys, offset=offset) + values = self.v_norm(values) + values = values.transpose(0, 2, 1, 3) + + queries = queries.transpose(0, 2, 1, 3) + queries = self.rope(queries, offset=offset) + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + output = _sdpa(queries, keys, values, cache=cache, scale=self.scale, mask=mask) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values), offset + + return dispatch + + +@contextlib.contextmanager +def _patched_attention_class(text_model: Any): + """Temporarily replace the Attention class ``__call__`` with the + Kakeya dispatcher; restore on exit. Clears any ``_kakeya_inject`` configs. + """ + if not text_model.layers: + yield + return + attn_cls = type(text_model.layers[0].self_attn) + orig_call = attn_cls.__call__ + attn_cls.__call__ = _build_dispatch(orig_call) # type: ignore[assignment] + try: + yield + finally: + attn_cls.__call__ = orig_call # type: ignore[assignment] + for layer in text_model.layers: + if hasattr(layer.self_attn, "_kakeya_inject"): + delattr(layer.self_attn, "_kakeya_inject") + + +def capture_own_kv(mlx_model: Any, input_ids: Sequence[int]) -> Dict[int, Tuple[Any, Any]]: + """Clean forward recording each source layer's pre-norm K/V (mx arrays). + + Mirrors CUDA ``capture_verifier_own_kv``: ``{layer_idx: (k, v)}`` for + ``has_kv`` layers, each ``[B, T, n_kv, head_dim]`` pre-norm. Supplies the + S5 exact K/V for the full-attention layers. + """ + import mlx.core as mx # type: ignore + + text_model = resolve_mlx_text_model(mlx_model) + sink: Dict[int, Tuple[Any, Any]] = {} + with _patched_attention_class(text_model): + for layer in text_model.layers: + layer.self_attn._kakeya_inject = {"mode": "capture", "sink": sink} + ids = mx.array([list(input_ids)]) + _ = text_model(ids, cache=None) + mx.eval([t for kv in sink.values() for t in kv]) + return sink + + +def restored_logits( + mlx_model: Any, + input_ids: Sequence[int], + *, + restored_k_per_layer: Dict[int, Any], # source_layer_idx -> mx [B,T,n_kv,hd] pre-norm + restored_v_per_layer: Dict[int, Any], + evicted_positions: Sequence[int], + return_all: bool = False, +) -> Any: + """Run the verifier with evicted-position K/V restoration. + + Returns the last-row logits (mx.array ``[V]``) by default, or all-position + logits (``[T, V]``) when ``return_all=True`` (used by the teacher-forced + single-forward recall eval). Injects only at ``has_kv`` source layers + (sharers inherit via ``shared_kv``). + """ + import mlx.core as mx # type: ignore + + text_model = resolve_mlx_text_model(mlx_model) + T = len(list(input_ids)) + mask_bool = [False] * T + for p in evicted_positions: + if 0 <= p < T: + mask_bool[p] = True + evicted_mask = mx.array(mask_bool) + + with _patched_attention_class(text_model): + for idx, layer in enumerate(text_model.layers): + attn = layer.self_attn + if not bool(getattr(attn, "has_kv", True)): + continue # sharers inherit injected K/V via shared_kv + rk = restored_k_per_layer.get(idx) + rv = restored_v_per_layer.get(idx) + if rk is None: + continue + attn._kakeya_inject = { + "mode": "inject", + "evicted_mask": evicted_mask, + "restored_k": rk, + "restored_v": rv, + } + ids = mx.array([list(input_ids)]) + logits = mlx_model(ids) # full Model.__call__ → tied embed + softcap + mx.eval(logits) + return logits[0] if return_all else logits[0, -1] diff --git a/inference_engine/v04/__init__.py b/inference_engine/v04/__init__.py index f63efca6..0d2ed86f 100644 --- a/inference_engine/v04/__init__.py +++ b/inference_engine/v04/__init__.py @@ -49,6 +49,18 @@ DFlashDrafter, DFlashProposer, ) +from inference_engine.v04.f_theta import FThetaConfig, FThetaProjection +from inference_engine.v04.cross_model_dlm_verifier import ( + CrossModelDLMRestoredVerifier, + CrossModelLayerMapping, +) +from inference_engine.v04.restored_sink_window_verifier import ( + CrossModelRestoredSinkWindowVerifier, +) +from inference_engine.v04.build_restored import ( + build_restored_speculative_decoder, + load_restored_verifier, +) from inference_engine.v04.kv_compressor import ( IdentityCompressor, KakeyaLatticeCompressor, @@ -122,4 +134,15 @@ "DFlashConfig", "DFlashDrafter", "DFlashProposer", + # K3 Block C — f_θ K/V projection + "FThetaConfig", + "FThetaProjection", + # K3 Block B — cross-model DLMRestoredVerifier with f_θ-mediated + # K/V Restoration (the integrated Kakeya inference architecture) + "CrossModelDLMRestoredVerifier", + "CrossModelLayerMapping", + # Gap 1 + Gap 2 — incremental restored verifier + served-path factories + "CrossModelRestoredSinkWindowVerifier", + "build_restored_speculative_decoder", + "load_restored_verifier", ] diff --git a/inference_engine/v04/build_restored.py b/inference_engine/v04/build_restored.py new file mode 100644 index 00000000..8ca06dab --- /dev/null +++ b/inference_engine/v04/build_restored.py @@ -0,0 +1,125 @@ +"""Gap 2 — factories that wire K/V Restoration into the served paths. + +Two entry points: + +* :func:`build_restored_speculative_decoder` — wrap a proposer + a + restored verifier in a :class:`kv_cache_proposer.speculative.SpeculativeDecoder`. + Pure plumbing over already-tested library code; the restored verifier + (:class:`~inference_engine.v04.restored_sink_window_verifier.CrossModelRestoredSinkWindowVerifier`) + implements the ``SinkWindowVerifier`` contract, so the decoder needs no + changes. + +* :func:`load_restored_verifier` — load the Gemma 4 verifier + DFlash + drafter + trained f_θ from disk and build the restored adapter, ready + to hand to the gRPC server (``GenerationCoordinator`` AR path) or to + :func:`build_restored_speculative_decoder`. This is a heavy model + loader (multi-GB ``from_pretrained``); per the repo convention for + model loaders (e.g. ``scripts/start_grpc_runtime_server.py``) its body + is exempt from unit-test coverage and validated by GPU integration runs. +""" + +from __future__ import annotations + +from typing import Any, Optional + +from inference_engine.v04.restored_sink_window_verifier import ( + CrossModelRestoredSinkWindowVerifier, +) + + +def build_restored_speculative_decoder( + proposer: Any, + verifier: CrossModelRestoredSinkWindowVerifier, + *, + block_size: int = 16, + num_diffusion_steps: int = 16, +): + """Return a :class:`SpeculativeDecoder` over ``proposer`` + restored + ``verifier`` (the f_θ + S5 K/V-Restoration verifier). + + The restored verifier exposes the full ``SinkWindowVerifier`` API, so + the speculative accept/reject loop runs unchanged — the only + difference from the vanilla path is that every verifier forward + reconstructs the evicted-position K/V (bounded resident cache). + """ + from kv_cache_proposer.speculative import SpeculativeDecoder + + return SpeculativeDecoder( + proposer=proposer, + verifier=verifier, + block_size=block_size, + num_diffusion_steps=num_diffusion_steps, + ) + + +def load_restored_verifier( + *, + verifier_id: str, + drafter_id: str, + f_theta_dir: str, + sink_size: int = 4, + window_size: int = 64, + s5_exact_full_attn: bool = True, + device: str = "cpu", + dtype: Optional[Any] = None, + incremental: bool = True, +) -> CrossModelRestoredSinkWindowVerifier: # pragma: no cover - heavy model loader + """Load Gemma 4 verifier + DFlash drafter + f_θ and build the restored + sink+window verifier adapter. + + Coverage-exempt (model-loading plumbing): validated by GPU integration + runs, mirroring ``scripts/research/k3_integrated_niah_eval.py``. + """ + import torch + from transformers import AutoModelForCausalLM + from transformers.models.gemma4.modeling_gemma4 import ( # type: ignore + ALL_ATTENTION_FUNCTIONS, + apply_rotary_pos_emb, + eager_attention_forward, + ) + + from inference_engine.v04 import DFlashDrafter, FThetaProjection + from inference_engine.v04.cross_model_dlm_verifier import ( + CrossModelDLMRestoredVerifier, + full_attention_layer_indices, + ) + + dev = torch.device(device) + if dtype is None: + dtype = torch.bfloat16 if dev.type == "cuda" else torch.float32 + + verifier = AutoModelForCausalLM.from_pretrained( + verifier_id, + dtype=dtype, + attn_implementation="eager", + device_map="auto" if dev.type == "cuda" else None, + ).eval() + for p in verifier.parameters(): + p.requires_grad_(False) + + drafter = DFlashDrafter.from_pretrained(drafter_id, dtype=dtype).to(dev).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + + f_theta = FThetaProjection.from_pretrained( + f_theta_dir, dtype=torch.float32, device=dev, + ) + + exact_layers = full_attention_layer_indices(verifier) if s5_exact_full_attn else None + + restored = CrossModelDLMRestoredVerifier( + verifier_model=verifier, + drafter=drafter, + f_theta=f_theta, + sink_size=sink_size, + window_size=window_size, + exact_layer_indices=exact_layers, + ) + return CrossModelRestoredSinkWindowVerifier( + restored, + apply_rotary_pos_emb=apply_rotary_pos_emb, + eager_attention_forward=eager_attention_forward, + all_attention_functions=ALL_ATTENTION_FUNCTIONS, + device=device, + incremental=incremental, + ) diff --git a/inference_engine/v04/cross_model_dlm_verifier.py b/inference_engine/v04/cross_model_dlm_verifier.py new file mode 100644 index 00000000..664c491f --- /dev/null +++ b/inference_engine/v04/cross_model_dlm_verifier.py @@ -0,0 +1,683 @@ +"""K3 Block B — Cross-model `DLMRestoredVerifier`. + +The integrated Kakeya inference architecture per ADR 0008 §11.3: + + verifier (Gemma 4 26B-A4B): + ├─ holds only sink+window local KV cache + └─ at evicted positions, takes K/V supplied by the proposer + (via `f_θ` projection) — verifier attends over full context + despite holding O(sink+window) memory + + drafter (DFlash 0.4B, K3 alignment-trained baseline): + ├─ runs full forward over committed prefix per step (no cache) + ├─ K/V at every layer at every position captured via + │ `inference_engine.v04.capture_proposer_kv` + └─ K/V projected through `f_θ` into verifier K/V space, injected + at evicted positions + +This module implements the cross-model integration. The same-checkpoint +`DLMRestoredVerifier` (`inference_engine.v04.dlm_restored_verifier. +DLMRestoredVerifier`) covers the K1 / K2.A path; this module covers K3. + +Differences from same-checkpoint DLMRestoredVerifier +----------------------------------------------------- + +1. **Drafter ≠ verifier**: drafter is a separate model object + (a `DFlashDrafter` or any `nn.Module` whose attention layers + expose `k_proj` / `v_proj`). Same-checkpoint version assumes + drafter is the verifier itself. + +2. **`f_θ` projection mediates**: drafter K/V dim ≠ verifier K/V dim + in cross-model setup. `f_θ` projects drafter K/V into verifier + K/V space at every (layer, position) before injection. + +3. **Layer-count mismatch handled**: drafter typically has fewer + layers than verifier (DFlash 5 vs Gemma 4 26B-A4B 30). `f_θ` + handles the projection from `drafter_num_layers`-concat input + to `verifier_num_layers` outputs. + +4. **K/V are pre-norm pre-RoPE on capture**: same as same-checkpoint + path. The verifier's attention forward is patched to call + `prepare_restored_attention_kv` which applies `k_norm` + RoPE + to the projected K/V at evicted positions — matching the + standard verifier's own K/V transformation pipeline. + +What this module does NOT do (deliberately, scope-out) +------------------------------------------------------ + +* **MLX verifier path**: this module patches HF transformers + attention modules. Mac MLX integration requires a separate + approach (instrument mlx_lm Gemma 4 model directly). Tracked as + follow-up PR after CUDA evidence. + +* **Speculative decoding accept/reject loop**: that's a higher-level + inference engine concern. This module produces a verifier with + K/V Restoration; the spec decode loop wraps it. PR #93's + `DFlashProposer` + `mlx_verify_block` is the spec decode side; + combining with this module's K/V Restoration is a separate + integration. +""" + +from __future__ import annotations + +import dataclasses +import math +from typing import Any, Callable, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from inference_engine.v04.f_theta import FThetaProjection +from inference_engine.v04.kv_capture import KVCapture, capture_proposer_kv +from inference_engine.v04.kv_merge import compute_evicted_positions +from inference_engine.v04.restored_attention import prepare_restored_attention_kv + + +def resolve_text_config(config: Any) -> Any: + """Return the text-decoder sub-config for multimodal HF configs. + + Gemma 4 (``Gemma4Config``) is a multimodal composite whose decoder + dimensions live under ``config.text_config`` (with sibling + ``vision_config`` / ``audio_config``). Flat text-only configs + (e.g. Gemma 3) expose the attributes directly. This helper returns + ``config.text_config`` when present, else ``config`` itself, so + callers can read ``num_hidden_layers`` / ``num_key_value_heads`` / + ``head_dim`` uniformly. + """ + return getattr(config, "text_config", None) or config + + +def get_verifier_decoder(model: Any) -> Any: + """Locate the decoder module exposing ``.layers`` / ``.embed_tokens``. + + Handles two HF layouts: + * flat: ``model.model`` (Gemma 3, Llama, ...) + * multimodal: ``model.model.language_model`` (Gemma 4 conditional + generation — text decoder nested beside vision/audio + towers) + """ + base = getattr(model, "model", model) + lm = getattr(base, "language_model", None) + if lm is not None and hasattr(lm, "layers"): + return lm + if hasattr(base, "layers"): + return base + for attr in ("language_model", "text_model", "decoder"): + sub = getattr(base, attr, None) + if sub is not None and hasattr(sub, "layers"): + return sub + raise AttributeError( + "could not locate decoder layers on verifier model " + f"(type={type(model).__name__})" + ) + + +def full_attention_layer_indices(model: Any) -> List[int]: + """Return indices of the verifier's full-attention (global) layers. + + Gemma 4 interleaves sliding-attention layers (head_dim 256) with a few + full-attention layers (head_dim = global_head_dim 512). The full-attention + layers are the only ones that attend globally, so they are the + recall-critical layers for long-context retrieval. Detected as the layers + whose ``self_attn.head_dim`` is the maximum across layers; if all layers + share one head_dim (no interleaving), returns an empty list. + """ + layers = get_verifier_decoder(model).layers + head_dims = [int(l.self_attn.head_dim) for l in layers] + max_hd = max(head_dims) + if len(set(head_dims)) == 1: + return [] + return [i for i, hd in enumerate(head_dims) if hd == max_hd] + + +@dataclasses.dataclass +class CrossModelLayerMapping: + """How drafter K/V layers project to verifier K/V layers under f_θ. + + f_θ takes ALL drafter layers' K/V (concat) per position and + outputs ALL verifier layers' K/V per position. So the layer + mapping is fixed by the f_θ architecture; this dataclass is + informational only — it records which drafter / verifier layer + counts the f_θ was trained against, so we can validate at + construction time. + """ + drafter_num_layers: int + verifier_num_layers: int + + +class CrossModelDLMRestoredVerifier: + """K3 cross-model verifier wrapper with f_θ-mediated K/V Restoration. + + Construction + ------------ + + >>> verifier = CrossModelDLMRestoredVerifier( + ... verifier_model=hf_gemma_4, # transformers Gemma4ForCausalLM + ... drafter=dflash_drafter, # DFlashDrafter from PR #93 + ... f_theta=FThetaProjection(...), # trained f_θ + ... sink_size=4, + ... window_size=64, + ... ) + + Forward + ------- + + >>> output = verifier.forward( + ... input_ids=..., + ... apply_rotary_pos_emb=..., # transformers Gemma 4 RoPE helper + ... eager_attention_forward=..., # transformers Gemma 4 eager attn + ... ) + + Each forward: + 1. Drafter runs full forward over input_ids → KVCapture (per + drafter layer, per position, pre-norm pre-RoPE). + 2. f_θ projects drafter K/V to verifier K/V at every (verifier + layer, position). + 3. Verifier attention modules patched: at every layer, at every + evicted position, the attention takes injected K/V from f_θ + output (via prepare_restored_attention_kv to apply k_norm + + RoPE) instead of computing K/V from the verifier's local + hidden state. + 4. Verifier sink+window cache holds only resident K/V; evicted + K/V come from f_θ each forward (transient, no memory cost). + """ + + def __init__( + self, + *, + verifier_model: nn.Module, + drafter: Any, # DFlashDrafter or any nn.Module with .model + f_theta: FThetaProjection, + sink_size: int = 4, + window_size: int = 64, + exact_layer_indices: Optional[Sequence[int]] = None, + ) -> None: + if sink_size < 0 or window_size < 0: + raise ValueError("sink_size and window_size must be non-negative") + self.verifier_model = verifier_model + self.drafter = drafter + self.f_theta = f_theta + self.sink_size = sink_size + self.window_size = window_size + # S5: layers whose evicted-position K/V are kept EXACT (the verifier's + # own true K/V) instead of f_θ-restored. Used for the full-attention + # layers — the recall-critical ones that f_θ cannot reconstruct. + # In a bounded-memory engine these layers' K/V would be stored + # (optionally KakeyaLattice-compressed); here they simply remain + # unpatched so the verifier uses its own K/V at all positions. + self.exact_layer_indices = set(exact_layer_indices or ()) + self._validate_dimensions() + + # ----------------------------------------------------------------- + # Dimension validation at construction time + # ----------------------------------------------------------------- + + def _validate_dimensions(self) -> None: + cfg = self.f_theta.config + # Verifier dimensions (resolve multimodal text sub-config, e.g. + # Gemma 4's config.text_config) + v_cfg = resolve_text_config(self.verifier_model.config) + v_layers = getattr(v_cfg, "num_hidden_layers", None) + v_kv_heads = getattr(v_cfg, "num_key_value_heads", None) + v_head_dim = getattr(v_cfg, "head_dim", None) + if v_head_dim is None: + hidden = getattr(v_cfg, "hidden_size", None) + num_q_heads = getattr(v_cfg, "num_attention_heads", None) + if hidden is not None and num_q_heads: + v_head_dim = hidden // num_q_heads + + if v_layers is not None and v_layers != cfg.verifier_num_layers: + raise ValueError( + f"f_θ trained for verifier_num_layers={cfg.verifier_num_layers} " + f"but verifier has {v_layers} layers" + ) + if v_kv_heads is not None and v_kv_heads != cfg.verifier_num_kv_heads: + raise ValueError( + f"f_θ trained for verifier_num_kv_heads={cfg.verifier_num_kv_heads} " + f"but verifier has {v_kv_heads}" + ) + if v_head_dim is not None and v_head_dim != cfg.verifier_head_dim: + raise ValueError( + f"f_θ trained for verifier_head_dim={cfg.verifier_head_dim} " + f"but verifier has {v_head_dim}" + ) + + # Drafter dimensions + drafter_cfg = getattr(self.drafter, "cfg", None) or getattr(self.drafter, "config", None) + if drafter_cfg is None: + return # cannot validate; trust the caller + d_layers = getattr(drafter_cfg, "num_hidden_layers", None) + d_kv_heads = getattr(drafter_cfg, "num_key_value_heads", None) + d_head_dim = getattr(drafter_cfg, "head_dim", None) + + if d_layers is not None and d_layers != cfg.drafter_num_layers: + raise ValueError( + f"f_θ trained for drafter_num_layers={cfg.drafter_num_layers} " + f"but drafter has {d_layers}" + ) + if d_kv_heads is not None and d_kv_heads != cfg.drafter_num_kv_heads: + raise ValueError( + f"f_θ trained for drafter_num_kv_heads={cfg.drafter_num_kv_heads} " + f"but drafter has {d_kv_heads}" + ) + if d_head_dim is not None and d_head_dim != cfg.drafter_head_dim: + raise ValueError( + f"f_θ trained for drafter_head_dim={cfg.drafter_head_dim} " + f"but drafter has {d_head_dim}" + ) + + # ----------------------------------------------------------------- + # Drafter capture + f_θ projection + # ----------------------------------------------------------------- + + @torch.no_grad() + def project_drafter_kv( + self, input_ids: torch.Tensor, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run the drafter forward over input_ids, project K/V through f_θ. + + Returns + ------- + (verifier_k, verifier_v): per-layer lists of length + ``verifier_num_layers``, element ``i`` shaped + ``[B, T, layer_kv_heads[i], verifier_head_dim]`` on the f_θ + device. Per-layer KV-head counts can differ (Gemma 4). + + These are the per-position-per-verifier-layer K/V that the + cross-model verifier injects at evicted positions during its + attention forward. + """ + capture = _capture_drafter_kv( + verifier_model=self.verifier_model, + drafter=self.drafter, + input_ids=input_ids, + ) + # capture.keys[i] shape: [B, T, num_d_kv_heads, head_dim] + verifier_k, verifier_v = self.f_theta.forward_kv_pack( + capture.keys, capture.values, + ) + return verifier_k, verifier_v + + # ----------------------------------------------------------------- + # Forward (with K/V Restoration) + # ----------------------------------------------------------------- + + def forward( + self, + input_ids: torch.Tensor, + *, + apply_rotary_pos_emb: Callable, + eager_attention_forward: Callable, + all_attention_functions: Optional[Any] = None, + capture_kv: Optional[list] = None, + ): + """Run a verifier forward with f_θ-mediated K/V Restoration. + + Steps: + 1. Compute evicted positions from sink+window per ADR §11.7. + 2. Drafter forward + f_θ projection → verifier K/V at every + evicted position at every verifier layer. + 3. Patch verifier attention: at evicted positions, K/V come + from the f_θ output (via prepare_restored_attention_kv); + at resident positions, K/V come from the verifier's own + k_proj / v_proj on its hidden state. + 4. Run verifier forward; collect logits. + 5. Restore original attention forwards. + + Returns the verifier's output (typically with .logits). + """ + T = int(input_ids.size(1)) + evicted_positions = compute_evicted_positions( + T, self.sink_size, self.window_size, + ) + + # If nothing is evicted (T <= sink+window), no K/V Restoration + # needed — run the verifier directly. This is the trivial case + # for short prompts, e.g. T=8 with sink=4 + window=64. + if not evicted_positions: + return self.verifier_model(input_ids=input_ids, use_cache=False) + + # f_θ projection → per-layer lists (layers can have different + # KV-head counts on Gemma 4). + verifier_k_layers, verifier_v_layers = self.project_drafter_kv(input_ids) + + # S5: for exact layers, replace the f_θ-restored K/V with the + # verifier's OWN true K/V (the recall-critical full-attention + # layers f_θ cannot reconstruct). All layers stay bounded + # (sink+window local cache + restored evicted K/V); only the + # SOURCE of the evicted K/V differs (true vs f_θ). In a real + # bounded engine those layers' K/V would be stored + # (optionally KakeyaLattice-compressed) rather than recomputed. + if self.exact_layer_indices: + true_k_layers, true_v_layers = capture_verifier_own_kv( + self.verifier_model, input_ids, + ) + for li in self.exact_layer_indices: + verifier_k_layers[li] = true_k_layers[li] + verifier_v_layers[li] = true_v_layers[li] + + # Patch verifier attention forwards to inject K/V at evicted + # positions. Restore originals after the forward. + layers = get_verifier_decoder(self.verifier_model).layers + originals: List[Callable] = [] + try: + for layer_idx, layer in enumerate(layers): + attn = layer.self_attn + originals.append(attn.forward) + attn.forward = self._make_patched_forward( + attn, + layer_idx=layer_idx, + evicted_positions=evicted_positions, + verifier_k_at_layer=verifier_k_layers[layer_idx], + verifier_v_at_layer=verifier_v_layers[layer_idx], + apply_rotary_pos_emb=apply_rotary_pos_emb, + eager_attention_forward=eager_attention_forward, + all_attention_functions=all_attention_functions, + capture_kv=capture_kv, + ) + return self.verifier_model(input_ids=input_ids, use_cache=False) + finally: + for layer_idx, layer in enumerate(layers): + layer.self_attn.forward = originals[layer_idx] + + def _make_patched_forward( + self, attn_module: nn.Module, *, + layer_idx: int, + evicted_positions: List[int], + verifier_k_at_layer: torch.Tensor, + verifier_v_at_layer: torch.Tensor, + apply_rotary_pos_emb: Callable, + eager_attention_forward: Callable, + all_attention_functions: Optional[Any] = None, + capture_kv: Optional[list] = None, + ) -> Callable: + """Build a patched attention forward that injects K/V at evicted + positions from `verifier_k_at_layer` / `verifier_v_at_layer` + instead of using the verifier's own k_proj / v_proj at those + positions. + + Mirrors Gemma 4's ``Gemma4TextAttention.forward`` exactly (RoPE + applied per-tensor on the ``[B, T, H, D]`` layout with + ``unsqueeze_dim=2``; ``q_norm`` / ``k_norm`` / ``v_norm``; + ``v_proj`` is ``None`` on full-attention layers where V = raw K) + with one change: at evicted positions K/V come from the + f_θ-projected values (after k_norm + RoPE for K, after v_norm + for V), so the verifier attends over full context despite + holding only sink+window in its local cache. + """ + def _patched_forward( + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + shared_kv_states: Any = None, + past_key_values=None, + cache_position=None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] # [B, T] + head_dim = attn_module.head_dim + hidden_shape = (*input_shape, -1, head_dim) + cos, sin = position_embeddings + + # Query (Gemma 4: norm → RoPE on [B,T,Hq,D] → transpose) + query_states = attn_module.q_proj(hidden_states).view(hidden_shape) + query_states = attn_module.q_norm(query_states) + query_states = apply_rotary_pos_emb( + query_states, cos, sin, unsqueeze_dim=2, + ) + query_states = query_states.transpose(1, 2) # [B, Hq, T, D] + + # Key / Value. Full-attention layers have v_proj=None ⇒ the + # value is the raw k_proj output (pre k_norm), per Gemma 4. + k_lin = attn_module.k_proj(hidden_states).view(hidden_shape) + if getattr(attn_module, "v_proj", None) is not None: + v_lin = attn_module.v_proj(hidden_states).view(hidden_shape) + else: + v_lin = k_lin + key_states = attn_module.k_norm(k_lin) + key_states = apply_rotary_pos_emb( + key_states, cos, sin, unsqueeze_dim=2, + ) + key_states = key_states.transpose(1, 2) # [B, Hkv, T, D] + value_states = attn_module.v_norm(v_lin).transpose(1, 2) # [B, Hkv, T, D] + + # Inject f_θ K/V at evicted positions. + # verifier_k_at_layer shape: [B, T, num_kv_heads_v, head_dim_v] + # (pre-norm pre-RoPE). prepare_restored_attention_kv applies + # k_norm + RoPE to K; V must be v_norm'd here to match the + # local branch (Gemma 4 runs V through v_norm). + if evicted_positions: + idx = torch.tensor( + evicted_positions, device=key_states.device, dtype=torch.long, + ) + cap_k_pre = verifier_k_at_layer.index_select(1, idx).to( + device=key_states.device, dtype=key_states.dtype, + ) + cap_v_pre = verifier_v_at_layer.index_select(1, idx).to( + device=value_states.device, dtype=value_states.dtype, + ) + cap_v_norm = attn_module.v_norm(cap_v_pre) + key_states, value_states = prepare_restored_attention_kv( + K_local=key_states, + V_local=value_states, + captured_K_pre_norm=cap_k_pre, + captured_V=cap_v_norm, + evicted_positions=evicted_positions, + k_norm=attn_module.k_norm, + position_embeddings=(cos, sin), + ) + + # Capture the post-norm/RoPE/injection K/V (exactly what an HF + # KV cache holds) so an incremental decoder can reuse them + # instead of re-running this O(T) restored forward each step. + if capture_kv is not None: + capture_kv[layer_idx] = ( + key_states.detach(), value_states.detach(), + ) + + # Standard attention path + attn_impl = getattr( + attn_module.config, "_attn_implementation", "eager", + ) + if attn_impl == "eager" or all_attention_functions is None: + attention_interface = eager_attention_forward + else: + attention_interface = all_attention_functions[attn_impl] + + attn_output, attn_weights = attention_interface( + attn_module, + query_states, + key_states, + value_states, + attention_mask, + dropout=getattr(attn_module, "attention_dropout", 0.0), + scaling=getattr(attn_module, "scaling", attn_module.head_dim ** -0.5), + sliding_window=getattr(attn_module, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_module.o_proj(attn_output) + return attn_output, attn_weights + + return _patched_forward + + +# --------------------------------------------------------------------------- +# Drafter K/V capture (DFlashDrafter-aware variant of capture_proposer_kv) +# --------------------------------------------------------------------------- + + +@torch.no_grad() +def capture_verifier_own_kv( + verifier_model: Any, input_ids: torch.Tensor, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Capture the verifier's OWN pre-norm per-layer K/V via k_proj / + v_proj forward hooks (identity-restoration diagnostic). + + Returns ``(k_layers, v_layers)``: per-layer lists where element + ``i`` is ``[B, T, kv_heads_i, head_dim_i]`` (heterogeneous per + layer, matching the f_θ output layout). Layers whose ``v_proj`` is + ``None`` (Gemma 4 full-attention K==V) take V from the k_proj + output, mirroring the model's own behaviour. + + Injecting these at evicted positions reproduces exactly what the + verifier would have computed under full attention — so cross-model + recall under identity restoration should match the oracle. This + isolates "is the K/V Restoration machinery correct?" (this helper) + from "is f_θ accurate enough?" (the trained projection). + """ + layers = get_verifier_decoder(verifier_model).layers + n = len(layers) + k_cap: List[Optional[torch.Tensor]] = [None] * n + v_cap: List[Optional[torch.Tensor]] = [None] * n + v_shared: List[int] = [] + handles = [] + for i, layer in enumerate(layers): + attn = layer.self_attn + + def _kh(_m, _inp, out, idx=i): + k_cap[idx] = out.detach() + + def _vh(_m, _inp, out, idx=i): + v_cap[idx] = out.detach() + + handles.append(attn.k_proj.register_forward_hook(_kh)) + if getattr(attn, "v_proj", None) is not None: + handles.append(attn.v_proj.register_forward_hook(_vh)) + else: + v_shared.append(i) + try: + verifier_model(input_ids=input_ids, use_cache=False) + finally: + for h in handles: + h.remove() + for i in v_shared: + v_cap[i] = k_cap[i] + if any(k is None for k in k_cap) or any(v is None for v in v_cap): + raise RuntimeError("verifier own-K/V capture missing some layers") + k_layers: List[torch.Tensor] = [] + v_layers: List[torch.Tensor] = [] + for i, layer in enumerate(layers): + hd = layer.self_attn.head_dim + b, t, kvdim = k_cap[i].shape + k_layers.append(k_cap[i].view(b, t, kvdim // hd, hd)) + v_layers.append(v_cap[i].view(b, t, kvdim // hd, hd)) + return k_layers, v_layers + + +def _capture_drafter_kv( + *, verifier_model: Any, drafter: Any, input_ids: torch.Tensor, +) -> KVCapture: + """Capture pre-norm pre-RoPE K/V from the DFlash drafter at every + drafter layer at every position. + + DFlashDrafter (PR #93) has a non-standard structure: it doesn't + follow the embed → layers → norm → lm_head pattern. It's a flat + ``nn.Module`` with ``.layers`` directly + an architectural choice + that **embed_tokens are shared with the verifier** (DFlash design: + no own embeddings or lm_head). + + Capture strategy: + + 1. ``verifier_model.get_input_embeddings()(input_ids) * scale`` + → real embedded hiddens (Gemma scaling: ``× sqrt(hidden_size)``). + 2. Pass these embedded hiddens through ``drafter.layers`` with + ``ctx_k = ctx_v = None`` per layer (no aux conditioning). + 3. Forward hooks on each layer's ``self_attn.k_proj`` / + ``self_attn.v_proj`` capture pre-norm pre-RoPE K, V values + per layer per position. + + This produces K/V values from drafter layers operating on REAL + embedded hiddens (not synthetic zero) but WITHOUT aux conditioning + on verifier mid-layer hiddens. For f_θ first-iteration training, + this is the correct level: f_θ learns to project drafter K/V + (computed without aux) into verifier K/V space. Adding aux + conditioning is a follow-up that can be plumbed into both training + and inference paths once first-iteration f_θ validates the + architecture. + + Required because :func:`capture_proposer_kv` doesn't support the + DFlashDrafter shape — it looks for ``model.model.layers`` or + ``model.transformer.h`` and DFlashDrafter has neither. + """ + # Capture K, V via forward hooks on each drafter layer's k_proj / v_proj. + layers = list(drafter.layers) + num_layers = len(layers) + k_capture: List[Optional[torch.Tensor]] = [None] * num_layers + v_capture: List[Optional[torch.Tensor]] = [None] * num_layers + handles = [] + + for i, layer in enumerate(layers): + attn = layer.self_attn + + def _make_k_hook(idx): + def hook(_mod, _inp, output): + k_capture[idx] = output.detach() + return hook + + def _make_v_hook(idx): + def hook(_mod, _inp, output): + v_capture[idx] = output.detach() + return hook + + handles.append(attn.k_proj.register_forward_hook(_make_k_hook(i))) + handles.append(attn.v_proj.register_forward_hook(_make_v_hook(i))) + + try: + # Embed input_ids using the verifier's embed_tokens (DFlash + # design: shares verifier embeddings, no own lookup table). + # Apply Gemma's × sqrt(hidden) scaling per the alignment + # training pipeline convention. + cfg = drafter.cfg + verifier_embed = verifier_model.get_input_embeddings() + embed_scale = math.sqrt(cfg.hidden_size) + + drafter_param = next(drafter.parameters()) + drafter_dtype = drafter_param.dtype + drafter_device = drafter_param.device + + with torch.no_grad(): + input_ids_for_embed = input_ids.to(verifier_embed.weight.device) + embedded = verifier_embed(input_ids_for_embed) * embed_scale + embedded = embedded.to(device=drafter_device, dtype=drafter_dtype) + T = embedded.size(1) + query_positions = torch.arange(T, device=drafter_device) + # Run each drafter layer with NO aux conditioning (ctx_k = + # ctx_v = None). The k_proj / v_proj hooks fire on the + # query hidden states' projection, capturing pre-norm + # pre-RoPE K/V at every layer at every position. + h = embedded + for layer in layers: + h = layer(h, query_positions, ctx_k=None, ctx_v=None) + finally: + for h in handles: + h.remove() + + if any(k is None for k in k_capture): + raise RuntimeError("drafter K capture missing some layers") + if any(v is None for v in v_capture): + raise RuntimeError("drafter V capture missing some layers") + + keys = [] + values = [] + for k_raw, v_raw in zip(k_capture, v_capture): + # k_raw shape: [B, T, num_d_kv_heads * head_dim] (k_proj output) + b, t, last = k_raw.shape + if last != cfg.num_key_value_heads * cfg.head_dim: + raise RuntimeError( + f"drafter k_proj output last-dim {last} != " + f"num_kv_heads * head_dim " + f"({cfg.num_key_value_heads * cfg.head_dim})" + ) + keys.append(k_raw.view(b, t, cfg.num_key_value_heads, cfg.head_dim)) + values.append(v_raw.view(b, t, cfg.num_key_value_heads, cfg.head_dim)) + + return KVCapture( + keys=keys, + values=values, + num_layers=len(keys), + seq_len=keys[0].shape[1] if keys else 0, + num_kv_heads=cfg.num_key_value_heads, + head_dim=cfg.head_dim, + ) diff --git a/inference_engine/v04/dflash_drafter.py b/inference_engine/v04/dflash_drafter.py index 8b5efe8e..3a8ca468 100644 --- a/inference_engine/v04/dflash_drafter.py +++ b/inference_engine/v04/dflash_drafter.py @@ -202,6 +202,39 @@ def _apply_rope( return (x * cos) + (_rotate_half(x) * sin) +# Query-chunk size for the drafter's non-causal attention. Bounds peak +# attention memory to O(q_chunk × (C+T)); tune down on tight-memory hosts +# (e.g. 24 GB Mac at long context). 0 ⇒ no chunking (single SDPA call). +# Override at runtime with KAKEYA_DFLASH_ATTN_QCHUNK (e.g. 256 on a 24 GB Mac). +import os as _os +_ATTN_Q_CHUNK = int(_os.environ.get("KAKEYA_DFLASH_ATTN_QCHUNK", "1024") or "1024") + + +def _chunked_sdpa( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *, scale: float, q_chunk: Optional[int] = None, +) -> torch.Tensor: + """Non-causal SDPA computed in query-dimension chunks. + + ``q`` is ``[B, nh, T, hd]``; ``k``/``v`` ``[B, nh, C+T, hd]``. Returns + ``[B, nh, T, hd]``. Chunking the query dim keeps the (possibly + materialised) score tensor at ``[B, nh, q_chunk, C+T]`` so long-context + attention does not OOM on hosts/kernels without a flash path (MPS). + """ + T = q.shape[-2] + if not q_chunk or q_chunk <= 0 or T <= q_chunk: + return F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=False, scale=scale, + ) + outs = [] + for start in range(0, T, q_chunk): + qc = q[:, :, start:start + q_chunk, :] + outs.append(F.scaled_dot_product_attention( + qc, k, v, attn_mask=None, is_causal=False, scale=scale, + )) + return torch.cat(outs, dim=2) + + class _DFlashAttention(nn.Module): """DFlash draft attention (faithful to vLLM ``DFlashQwen3Attention``). @@ -271,10 +304,13 @@ def forward( rep = self.nh // self.nkv k = k.repeat_interleave(rep, dim=1) v = v.repeat_interleave(rep, dim=1) - scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B,nh,T,C+T] - # Non-causal: queries see all context + all query positions. - attn = torch.softmax(scores.float(), dim=-1).to(q.dtype) - out = torch.matmul(attn, v) # [B, nh, T, hd] + # Non-causal (queries see all context + all query positions), no mask. + # Use SDPA, and **chunk over the query dimension** so peak attention + # memory stays O(chunk × (C+T)) instead of O(T × (C+T)). The full + # materialisation OOMs at long context (≈5 GB at T≈6k, nh=32) — and + # MPS's SDPA has no flash kernel for this shape, so it materialises + # too; query-chunking bounds it on every device/kernel. + out = _chunked_sdpa(q, k, v, scale=self.scale, q_chunk=_ATTN_Q_CHUNK) out = out.transpose(1, 2).contiguous().view(B, T, self.nh * self.hd) return self.o_proj(out) @@ -468,6 +504,67 @@ def draft_block( logits[..., self.cfg.mask_token_id] = float("-inf") # never draft the sentinel return torch.argmax(logits[0], dim=-1).tolist() + # -- fused-engine fast path: draft from a PRECOMPUTED context K/V cache -- + def make_context_kv( + self, aux_hidden_context: Sequence[torch.Tensor], positions: torch.Tensor, + ): + """Per-layer context K/V for ``positions`` from their aux hidden. + + ``combine_aux`` (fc) + ``precompute_context_kv`` (hidden_norm + per-layer + k/v_proj + k_norm + RoPE). Returns a per-layer list of ``(k, v)``, each + ``[B, nkv, len(positions), hd]``. Use once at prefill, then + :meth:`extend_context_kv` incrementally for newly-committed tokens — so + the drafter never re-scans the whole committed prefix (O(L)/block, not + O(C)). This is component B of the fused spec-decode engine. + """ + ctx_states = self.combine_aux(aux_hidden_context) + return self.precompute_context_kv(ctx_states, positions) + + @staticmethod + def extend_context_kv(ctx_kv, new_kv): + """Append per-layer ``new_kv`` (from :meth:`make_context_kv`) to the + running ``ctx_kv`` cache along the sequence axis.""" + out = [] + for (ck, cv), (nk, nv) in zip(ctx_kv, new_kv): + out.append(( + torch.cat([ck, nk.to(ck.dtype)], dim=2), + torch.cat([cv, nv.to(cv.dtype)], dim=2), + )) + return out + + @torch.no_grad() + def draft_block_cached( + self, + ctx_kv, + bonus_token_id: int, + embed_fn: Callable[[torch.Tensor], torch.Tensor], + lm_head_fn: Callable[[torch.Tensor], torch.Tensor], + *, + block_size: int, + context_len: int, + ) -> List[int]: + """Draft ``block_size`` tokens using a PRECOMPUTED per-layer context + K/V cache (``ctx_kv`` covering positions ``0..context_len-1``). + + Same single non-causal pass as :meth:`draft_block`, but skips the + O(C) context K/V recompute — the caller maintains ``ctx_kv`` + incrementally. Cost is O(block_size) on the drafter. + """ + cfg = self.cfg + device = ctx_kv[0][0].device + query_ids = torch.tensor( + [[int(bonus_token_id)] + [cfg.mask_token_id] * block_size], + dtype=torch.long, device=device, + ) + query_positions = torch.arange( + context_len, context_len + 1 + block_size, device=device, + ) + h = embed_fn(query_ids).to(self.fc.weight.dtype) + h = self._run_layers(h, query_positions, ctx_kv) + logits = lm_head_fn(h).clone() + logits[..., cfg.mask_token_id] = float("-inf") + return torch.argmax(logits[0, 1:1 + block_size], dim=-1).tolist() + def draft_logits( self, aux_hidden_context: Sequence[torch.Tensor], diff --git a/inference_engine/v04/f_theta.py b/inference_engine/v04/f_theta.py new file mode 100644 index 00000000..e37a8769 --- /dev/null +++ b/inference_engine/v04/f_theta.py @@ -0,0 +1,393 @@ +"""K3 Block C — `f_θ` K/V projection: drafter K/V → verifier K/V space. + +Per ADR 0008 §11.5 (v0.4 GA dLM K/V Restoration architecture), the +verifier maintains only a sink+window local KV cache and accepts +**reconstructed** K/V at every evicted position from the proposer's +transient K/V. In the K3 cross-model setup (drafter = DFlash 0.4B, +verifier = Gemma 4 26B-A4B), the drafter's K/V live in a different +space than the verifier's: + + drafter K, V shape (per layer, per position): + [num_kv_heads_drafter * head_dim_drafter] (e.g. 2 * 128 = 256) + + verifier K, V shape (per layer, per position): + [num_kv_heads_verifier * head_dim_verifier] (e.g. 8 * 256 = 2048) + +`f_θ` is the trainable projection that bridges these spaces. Its +contract: for every position p, take the drafter's K/V at p across +ALL drafter layers (concatenated along the feature dim) and produce +the verifier's K/V at p at EVERY verifier layer. + +Architecture (chosen 2026-06-09 for K3 first-iteration training) +---------------------------------------------------------------- + +Shared encoder + per-verifier-layer decoder, low-rank factorisation: + + drafter_kv_input [B, T, drafter_layers * drafter_kv_dim] + ↓ + shared encoder Linear(drafter_layers*drafter_kv_dim, rank) + ↓ + rep [B, T, rank] + ↓ + per-verifier-layer decoder K: Linear(rank, verifier_kv_dim) × num_verifier_layers + per-verifier-layer decoder V: Linear(rank, verifier_kv_dim) × num_verifier_layers + ↓ + output [B, T, num_verifier_layers, num_kv_heads_v, head_dim_v] + for K, and same shape for V + +Total params (default rank=256): + encoder: drafter_layers * drafter_kv_dim × rank = 5 * 256 × 256 ≈ 327k + decoders: 2 (K+V) × num_verifier_layers × rank × verifier_kv_dim + = 2 × 30 × 256 × 2048 ≈ 31.5M + Total: ~31.8M params (vs drafter 430M, verifier 26B → small) + +Why this architecture +--------------------- + +1. **Per-verifier-layer decoders**: each verifier layer has its own + K/V distribution; one shared output projection is too lossy. 30 + separate decoders give per-layer capacity. + +2. **Shared encoder**: forces the drafter K/V representation to + capture position-level features that generalise across verifier + layers. Reduces parameter count vs full per-(drafter,verifier)-pair + matrices (which would be 30 × 5 × 2 × full_dim². + +3. **Low-rank**: rank=256 is a tunable. Smaller rank = fewer params + + faster training but less capacity; larger rank approaches the + shared encoder being identity. 256 was chosen as the smallest + rank that keeps verifier_kv_dim/rank ratio reasonable (2048/256=8) + without crushing capacity at the encoder bottleneck. + +4. **Separate K and V decoders**: K and V have different roles + downstream (Q·K dot product vs attention-weighted sum of V); their + per-layer distributions differ. Separate decoders capture this. + +Training contract (per :mod:`scripts.research.k3_f_theta_train`) +---------------------------------------------------------------- + +* Inputs: paired (drafter_kv, verifier_kv) over a long-context corpus + collected by running both models on the same input sequences and + recording K/V at every layer at every position. + +* Loss: MSE between f_θ(drafter_kv) and verifier_kv, averaged over + layers and positions. Weighted equally across layers; weighting + schemes are a hyperparameter. + +* Optimiser: AdamW with lr=1e-3, weight_decay=0.01. + +Loadable checkpoint +------------------- + +The trained `f_θ` is saved as a state_dict. The +:class:`FThetaProjection.from_pretrained` classmethod loads from +either a local file or HF hub id. The cross-model DLMRestoredVerifier +(:mod:`inference_engine.v04.cross_model_dlm_verifier`) consumes this +state_dict at construction time. + +This module is engine API surface (not research scaffolding), so +imports are minimal and tests cover the shape contract + load/save ++ device dispatch. +""" + +from __future__ import annotations + +import dataclasses +import json +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + + +@dataclasses.dataclass(frozen=True) +class FThetaConfig: + """Configuration for :class:`FThetaProjection`. + + Stored alongside the trained state_dict as ``f_theta_config.json`` + so the cross-model verifier can reconstruct the projection at load + time without inferring shapes from the state_dict alone. + + Heterogeneous verifier KV heads + ------------------------------- + Production verifiers do not always use a uniform KV-head count + across layers. Gemma 4 26B-A4B, for example, uses 8 KV heads on its + sliding-attention layers and 4 KV heads on its full-attention + layers (head_dim is uniform at 256). ``verifier_layer_kv_heads`` + captures the per-layer count; when ``None`` every layer uses + ``verifier_num_kv_heads`` (the legacy uniform behaviour, kept for + backward compatibility and the same-head-count case). + """ + + drafter_num_layers: int # e.g. DFlash drafter 5 layers + drafter_num_kv_heads: int # e.g. DFlash 2 kv heads + drafter_head_dim: int # e.g. DFlash 128 head dim + verifier_num_layers: int # e.g. Gemma 4 26B-A4B 30 layers + verifier_num_kv_heads: int # representative / uniform KV head count + verifier_head_dim: int # e.g. Gemma 4 256 head dim + rank: int = 256 # encoder bottleneck + # Per-layer KV head counts (len == verifier_num_layers). None ⇒ + # uniform verifier_num_kv_heads for every layer. + verifier_layer_kv_heads: Optional[Tuple[int, ...]] = None + # Per-layer head dims (len == verifier_num_layers). None ⇒ uniform + # verifier_head_dim. Gemma 4 uses 256 on sliding layers and 512 + # (global_head_dim) on its full-attention layers. + verifier_layer_head_dims: Optional[Tuple[int, ...]] = None + + @property + def drafter_kv_dim(self) -> int: + return self.drafter_num_kv_heads * self.drafter_head_dim + + @property + def verifier_kv_dim(self) -> int: + return self.verifier_num_kv_heads * self.verifier_head_dim + + @property + def layer_kv_heads(self) -> Tuple[int, ...]: + """Per-layer KV head counts (always length ``verifier_num_layers``).""" + if self.verifier_layer_kv_heads is None: + return tuple( + self.verifier_num_kv_heads + for _ in range(self.verifier_num_layers) + ) + return tuple(int(h) for h in self.verifier_layer_kv_heads) + + @property + def layer_head_dims(self) -> Tuple[int, ...]: + """Per-layer head dims (always length ``verifier_num_layers``).""" + if self.verifier_layer_head_dims is None: + return tuple( + self.verifier_head_dim + for _ in range(self.verifier_num_layers) + ) + return tuple(int(d) for d in self.verifier_layer_head_dims) + + @property + def layer_kv_dims(self) -> Tuple[int, ...]: + """Per-layer K (or V) feature dim = kv_heads[i] * head_dim[i].""" + return tuple( + h * d for h, d in zip(self.layer_kv_heads, self.layer_head_dims) + ) + + @property + def encoder_in_features(self) -> int: + """Concat dim across all drafter layers' K (or V) per position.""" + return self.drafter_num_layers * self.drafter_kv_dim + + def to_json_dict(self) -> dict: + d = dataclasses.asdict(self) + if self.verifier_layer_kv_heads is not None: + d["verifier_layer_kv_heads"] = list(self.verifier_layer_kv_heads) + if self.verifier_layer_head_dims is not None: + d["verifier_layer_head_dims"] = list(self.verifier_layer_head_dims) + return d + + @classmethod + def from_json_dict(cls, d: dict) -> "FThetaConfig": + list_fields = {"verifier_layer_kv_heads", "verifier_layer_head_dims"} + kwargs: dict = {} + for k, v in d.items(): + if k in list_fields: + kwargs[k] = None if v is None else tuple(int(x) for x in v) + else: + kwargs[k] = int(v) + return cls(**kwargs) + + +class FThetaProjection(nn.Module): + """`f_θ`: projects drafter K/V into verifier K/V space. + + Forward contract: + + forward_k(drafter_k_concat: torch.Tensor) + Input shape: [B, T, drafter_num_layers * drafter_kv_dim] + Output shape: [B, T, verifier_num_layers, verifier_num_kv_heads, verifier_head_dim] + + forward_v(drafter_v_concat: torch.Tensor) + Same shapes as forward_k but separate weights (K and V have + different downstream roles → separate projections). + + Helper :meth:`forward_kv_pack` accepts the unpacked drafter + KVCapture format (list of 5 [B, T, num_kv_heads_d, head_dim_d] + tensors) and runs the concat + project + reshape pipeline in one + call — what the cross-model verifier uses. + """ + + def __init__(self, config: FThetaConfig) -> None: + super().__init__() + self.config = config + + # Shared encoder: drafter K/V (concat across drafter layers) → rank-d rep + self.encoder_k = nn.Linear( + config.encoder_in_features, config.rank, bias=False, + ) + self.encoder_v = nn.Linear( + config.encoder_in_features, config.rank, bias=False, + ) + + # Per-verifier-layer decoders, each sized to its own layer's KV + # feature dim (heterogeneous KV-head counts are supported). + self.decoders_k = nn.ModuleList([ + nn.Linear(config.rank, kv_dim, bias=False) + for kv_dim in config.layer_kv_dims + ]) + self.decoders_v = nn.ModuleList([ + nn.Linear(config.rank, kv_dim, bias=False) + for kv_dim in config.layer_kv_dims + ]) + + # ----------------------------------------------------------------- + # Forward primitives + # ----------------------------------------------------------------- + + def _project( + self, + drafter_concat: torch.Tensor, + encoder: nn.Module, + decoders: nn.ModuleList, + ) -> List[torch.Tensor]: + if drafter_concat.dim() != 3: + raise ValueError( + f"expected [B, T, encoder_in_features]; got shape " + f"{tuple(drafter_concat.shape)}" + ) + if drafter_concat.size(-1) != self.config.encoder_in_features: + raise ValueError( + f"last dim {drafter_concat.size(-1)} != " + f"encoder_in_features {self.config.encoder_in_features}" + ) + # f_θ weights may be a different dtype than the captured drafter + # K/V (e.g. f_θ in fp32, drafter in bf16). Cast the input to the + # encoder's weight dtype so matmul dtypes agree. + drafter_concat = drafter_concat.to(encoder.weight.dtype) + rep = encoder(drafter_concat) # [B, T, rank] + kv_heads = self.config.layer_kv_heads + head_dims = self.config.layer_head_dims + outs: List[torch.Tensor] = [] + for li, dec in enumerate(decoders): + o = dec(rep) # [B, T, kv_heads[li] * head_dims[li]] + B, T, _ = o.shape + outs.append(o.view(B, T, kv_heads[li], head_dims[li])) + return outs + + def forward_k(self, drafter_k_concat: torch.Tensor) -> List[torch.Tensor]: + """Project drafter K (concat across drafter layers) to per-verifier-layer K. + + Parameters + ---------- + drafter_k_concat + [B, T, drafter_num_layers * drafter_kv_dim] + + Returns + ------- + List of ``verifier_num_layers`` tensors, each shape + ``[B, T, layer_kv_heads[i], verifier_head_dim]`` (per-layer KV + head counts can differ). + """ + return self._project(drafter_k_concat, self.encoder_k, self.decoders_k) + + def forward_v(self, drafter_v_concat: torch.Tensor) -> List[torch.Tensor]: + """V counterpart of :meth:`forward_k`.""" + return self._project(drafter_v_concat, self.encoder_v, self.decoders_v) + + # ----------------------------------------------------------------- + # KVCapture-aware helper + # ----------------------------------------------------------------- + + def forward_kv_pack( + self, + drafter_k_per_layer: Sequence[torch.Tensor], + drafter_v_per_layer: Sequence[torch.Tensor], + ) -> tuple: + """Take unpacked KVCapture tensors and project to verifier K/V. + + Parameters + ---------- + drafter_k_per_layer + List of ``drafter_num_layers`` tensors, each shape + ``[B, T, drafter_num_kv_heads, drafter_head_dim]`` (the + natural KVCapture layout from + :class:`inference_engine.v04.KVCapture`). + + drafter_v_per_layer + Same as ``drafter_k_per_layer`` but for V tensors. + + Returns + ------- + (verifier_k, verifier_v) where each is a list of + ``verifier_num_layers`` tensors, element ``i`` shaped + ``[B, T, layer_kv_heads[i], verifier_head_dim]``. + """ + if len(drafter_k_per_layer) != self.config.drafter_num_layers: + raise ValueError( + f"expected {self.config.drafter_num_layers} drafter layers, " + f"got {len(drafter_k_per_layer)}" + ) + if len(drafter_v_per_layer) != self.config.drafter_num_layers: + raise ValueError( + f"expected {self.config.drafter_num_layers} drafter layers " + f"for V, got {len(drafter_v_per_layer)}" + ) + # Concat along the kv-feature dim to get [B, T, drafter_layers * kv_dim] + # Each layer tensor is [B, T, num_kv_heads, head_dim] → flatten last two + # dims → [B, T, kv_dim], then concat across layers. + k_flat = [k.flatten(-2, -1) for k in drafter_k_per_layer] + v_flat = [v.flatten(-2, -1) for v in drafter_v_per_layer] + k_concat = torch.cat(k_flat, dim=-1) # [B, T, drafter_layers * kv_dim] + v_concat = torch.cat(v_flat, dim=-1) + return self.forward_k(k_concat), self.forward_v(v_concat) + + # ----------------------------------------------------------------- + # Persistence + # ----------------------------------------------------------------- + + def save_pretrained(self, output_dir: str | Path) -> None: + """Save config + state_dict to ``output_dir``. + + Layout:: + + output_dir/ + f_theta_config.json # FThetaConfig.to_json_dict() + f_theta_weights.pt # torch state_dict (bf16 by default) + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + (output_dir / "f_theta_config.json").write_text( + json.dumps(self.config.to_json_dict(), indent=2), + ) + torch.save(self.state_dict(), output_dir / "f_theta_weights.pt") + + @classmethod + def from_pretrained( + cls, source: str | Path, *, dtype: Any = None, device: Any = None, + ) -> "FThetaProjection": + """Load f_θ from a directory containing config + weights. + + ``source`` is a local directory. HF Hub support deferred until + a public f_θ checkpoint is hosted (training is internal first). + """ + source = Path(source) + if not source.is_dir(): + raise FileNotFoundError( + f"f_θ source must be a directory; got {source}" + ) + config_path = source / "f_theta_config.json" + weights_path = source / "f_theta_weights.pt" + if not config_path.is_file(): + raise FileNotFoundError(f"missing {config_path}") + if not weights_path.is_file(): + raise FileNotFoundError(f"missing {weights_path}") + + config = FThetaConfig.from_json_dict( + json.loads(config_path.read_text()), + ) + model = cls(config) + state = torch.load(weights_path, map_location="cpu") + model.load_state_dict(state, strict=True) + if dtype is not None: + model = model.to(dtype) + if device is not None: + model = model.to(device) + model.eval() + return model diff --git a/inference_engine/v04/restored_sink_window_verifier.py b/inference_engine/v04/restored_sink_window_verifier.py new file mode 100644 index 00000000..7c43a05b --- /dev/null +++ b/inference_engine/v04/restored_sink_window_verifier.py @@ -0,0 +1,375 @@ +"""Gap 1 — incremental, stateful verifier adapter for K/V Restoration. + +This module bridges the *validated* (but full-forward / eval-only) +:class:`inference_engine.v04.cross_model_dlm_verifier.CrossModelDLMRestoredVerifier` +to the **stateful, incremental** verifier contract that the speculative +decoder (:class:`kv_cache_proposer.speculative.SpeculativeDecoder`) and the +gRPC session coordinators expect — i.e. the public surface of +:class:`kv_cache_proposer.verifier.SinkWindowVerifier`: + + * ``prefill(prompt_ids)`` + * ``forward_block(tokens) -> [L, V]`` + * ``commit_or_truncate(forwarded, accepted)`` + * ``append_token(token_id) -> next_token_logits`` + * ``next_token_logits`` / ``next_global_position`` / ``cached_token_sequence`` + * ``cache_logical_size`` / ``cache`` + * ``k_seq_length(session)`` / ``kv_live_bytes(session)`` / ``live_kv_bytes()`` + * ``stats`` (:class:`kv_cache_proposer.verifier.VerifierStats`) + * ``model`` (the verifier ``nn.Module``, for KV-dim resolution) + +Once an instance of this adapter is constructed, it is a drop-in +replacement for ``SinkWindowVerifier`` everywhere those callers use it — +that is *both* Gap 1 (the speculative accept/reject loop) and Gap 2 (the +server: ``SessionStore`` / ``AppendTokensCoordinator`` / +``GenerationCoordinator`` only depend on this contract). + +Beta semantics (honest) +----------------------- + +Each ``prefill`` / ``forward_block`` / ``append_token`` re-runs the +restored full-forward over the committed prefix (+ the block being +verified). This is **bit-equivalent to the validated gate forward** — it +*is* that forward — and realizes the headline Kakeya property: the +verifier holds only a sink+window resident cache (``cache_logical_size`` +is bounded by ``sink+window``), and the evicted-position K/V are +reconstructed each step from the cache-free drafter (ADR 0008 §11.3: the +proposer is a constant-memory K/V reconstruction source) plus the S5 +exact full-attention layers. + +The compute is O(T)/step (O(T^2) per generation), same as the eval +harness. The per-step O(1) persistent-cache optimization (reusing the +verifier's resident sink+window K/V across steps and amortizing the +drafter forward with the proposer's) is the K2.A.2 follow-up — it does +not change *outputs*, only speed. Keeping this adapter a thin, +provably-equivalent wrapper is deliberate: it lets the recall gate that +passed on the full-forward path carry over to the served path unchanged. +""" + +from __future__ import annotations + +from typing import Any, Callable, List, Optional + +import torch + +from kv_cache_proposer.verifier import VerifierStats + +from inference_engine.v04.cross_model_dlm_verifier import ( + CrossModelDLMRestoredVerifier, + get_verifier_decoder, + resolve_text_config, +) + + +class CrossModelRestoredSinkWindowVerifier: + """Stateful sink+window verifier backed by f_θ + S5 K/V Restoration. + + Wraps a constructed :class:`CrossModelDLMRestoredVerifier` and exposes + the :class:`~kv_cache_proposer.verifier.SinkWindowVerifier` public API. + """ + + def __init__( + self, + restored: CrossModelDLMRestoredVerifier, + *, + apply_rotary_pos_emb: Callable, + eager_attention_forward: Callable, + all_attention_functions: Optional[Any] = None, + device: str = "cpu", + incremental: bool = False, + ) -> None: + self._restored = restored + self._apply_rotary_pos_emb = apply_rotary_pos_emb + self._eager_attention_forward = eager_attention_forward + self._all_attention_functions = all_attention_functions + self._device = torch.device(device) + # Incremental decode mode (Gap-A throughput optimization): capture the + # restored K/V into a persistent KV cache at prefill, then decode the + # new tokens with the verifier's NATIVE incremental forward (O(L)/block) + # instead of re-running the O(T) restored forward each step. Recall is + # carried by the full-attention (S5) layers whose captured K/V are the + # verifier's own at every position (== native AR for those layers). + self._incremental = bool(incremental) + self._past = None # transformers Cache holding restored K/V + self._past_len = 0 # number of positions in the cache + self._num_layers_cache = None # resolved lazily (incremental path only) + # Fused-engine (component A): optionally capture the verifier's aux-layer + # hidden states DURING the incremental verify forward, so the DFlash + # drafter's context can be extended incrementally instead of via a + # separate O(C) clean-aux forward each block. Gated off by default so + # the plain Gap-A decode path pays no overhead. + drafter_cfg = getattr(getattr(restored, "drafter", None), "cfg", None) + self._aux_layer_ids = tuple(getattr(drafter_cfg, "aux_layer_ids", ()) or ()) + self._capture_aux = False + self._last_aux = None # list[Tensor [L, hidden]] from the last verify + + self.sink_size = restored.sink_size + self.window_size = restored.window_size + + # No persistent DynamicCache: the bounded sink+window resident K/V + # are conceptual here (re-derived each forward). ``cache is None`` + # makes SpeculativeDecoder._kv_bytes return 0; the bounded-KV story + # is carried by stats.peak_kv_bytes / kv_live_bytes instead. + self.cache = None + self.cache_logical_size: int = 0 + self.next_global_position: int = 0 + self.next_token_logits: Optional[torch.Tensor] = None + self.cached_token_sequence: List[int] = [] + + # Full committed prefix (prompt + accepted/correction tokens). This + # is what drives restoration; it is NOT bounded (it is the logical + # sequence), while ``cached_token_sequence`` is the bounded resident + # mirror used by the CacheInspector accessors. + self._committed: List[int] = [] + # Tokens passed to the most recent forward_block, pending a + # commit_or_truncate decision. + self._pending: List[int] = [] + + self.stats = VerifierStats(weight_bytes=self._compute_weight_bytes()) + self._bytes_per_kv_token = self._compute_bytes_per_kv_token() + + # ------------------------------------------------------------------ # + # Introspection used by the server (scripts/start_grpc_runtime_server) + # ------------------------------------------------------------------ # + @property + def model(self): + """The verifier ``nn.Module`` (exposes ``.config`` for KV dims).""" + return self._restored.verifier_model + + # ------------------------------------------------------------------ # + # Construction-time accounting + # ------------------------------------------------------------------ # + def _compute_weight_bytes(self) -> int: + total = 0 + for module in ( + getattr(self._restored, "verifier_model", None), + getattr(self._restored, "drafter", None), + getattr(self._restored, "f_theta", None), + ): + params = getattr(module, "parameters", None) + if params is None: + continue + for p in params(): + total += p.numel() * p.element_size() + return total + + def _compute_bytes_per_kv_token(self) -> int: + cfg = resolve_text_config(self._restored.verifier_model.config) + num_layers = int(getattr(cfg, "num_hidden_layers", 0) or 0) + num_kv_heads = int( + getattr(cfg, "num_key_value_heads", None) + or getattr(cfg, "num_attention_heads", 0) + or 0 + ) + head_dim = getattr(cfg, "head_dim", None) + if head_dim is None: + hidden = getattr(cfg, "hidden_size", 0) or 0 + num_q = getattr(cfg, "num_attention_heads", 0) or 0 + head_dim = (hidden // num_q) if num_q else 0 + head_dim = int(head_dim) + # itemsize from the verifier's own parameters (fp32 on CPU / bf16 GPU) + itemsize = 4 + for p in self._restored.verifier_model.parameters(): + itemsize = p.element_size() + break + # ``× 2`` = K + V + return num_layers * num_kv_heads * head_dim * itemsize * 2 + + # ------------------------------------------------------------------ # + # Core: run the restored forward over a sequence → per-position logits + # ------------------------------------------------------------------ # + @torch.no_grad() + def _restored_logits(self, seq_ids: List[int]) -> torch.Tensor: + """Return ``[T, V]`` logits for ``seq_ids`` via the restored forward.""" + input_ids = torch.tensor( + [seq_ids], dtype=torch.long, device=self._device + ) + out = self._restored.forward( + input_ids, + apply_rotary_pos_emb=self._apply_rotary_pos_emb, + eager_attention_forward=self._eager_attention_forward, + all_attention_functions=self._all_attention_functions, + ) + logits = out.logits if hasattr(out, "logits") else out + return logits[0] # [T, V] + + # ------------------------------------------------------------------ # + # SinkWindowVerifier public API + # ------------------------------------------------------------------ # + def reset(self) -> None: + self._committed = [] + self._pending = [] + self.cached_token_sequence = [] + self.cache_logical_size = 0 + self.next_global_position = 0 + self.next_token_logits = None + self._past = None + self._past_len = 0 + + # ------------------------------------------------------------------ # + # Incremental-decode helpers (Gap-A throughput path) + # ------------------------------------------------------------------ # + @torch.no_grad() + def _build_restored_cache(self, prompt_ids): + """Run the restored forward over the prompt ONCE, capturing the + per-layer post-norm/RoPE/injection K/V into a transformers + ``DynamicCache``. Returns (cache, last_logits).""" + from transformers.cache_utils import DynamicCache + if self._num_layers_cache is None: + self._num_layers_cache = len( + get_verifier_decoder(self._restored.verifier_model).layers) + input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=self._device) + capture: list = [None] * self._num_layers_cache + out = self._restored.forward( + input_ids, + apply_rotary_pos_emb=self._apply_rotary_pos_emb, + eager_attention_forward=self._eager_attention_forward, + all_attention_functions=self._all_attention_functions, + capture_kv=capture, + ) + logits = (out.logits if hasattr(out, "logits") else out)[0] + if any(c is None for c in capture): + raise RuntimeError( + "Incremental prefill requires an evicted-position restored " + "forward (prompt must exceed sink+window); some layers were " + "not captured. Use a longer prompt or incremental=False." + ) + cache = DynamicCache() + for li, (k, v) in enumerate(capture): + cache.update(k, v, li) + return cache, logits + + @torch.no_grad() + def _native_forward(self, tokens): + """Native incremental verifier forward over ``tokens`` against the + persistent restored cache. Appends tokens' K/V to the cache. + Returns ``[len(tokens), V]`` logits.""" + L = len(tokens) + ids = torch.tensor([tokens], dtype=torch.long, device=self._device) + pos = torch.arange(self._past_len, self._past_len + L, device=self._device) + want_aux = self._capture_aux and bool(self._aux_layer_ids) + out = self._restored.verifier_model( + input_ids=ids, + position_ids=pos.unsqueeze(0), + cache_position=pos, + past_key_values=self._past, + use_cache=True, + output_hidden_states=want_aux, + ) + self._past = out.past_key_values + if want_aux: + hs = out.hidden_states # tuple; hs[a] = [B, L, hidden] + self._last_aux = [hs[a][0].detach() for a in self._aux_layer_ids] + return out.logits[0] + + @torch.no_grad() + def prefill(self, prompt_ids: List[int]) -> None: + if not prompt_ids: + raise ValueError("prompt_ids must be non-empty") + self.reset() + self._committed = list(prompt_ids) + if self._incremental: + self._past, logits = self._build_restored_cache(self._committed) + self._past_len = len(self._committed) + else: + logits = self._restored_logits(self._committed) # [L, V] + self.next_token_logits = logits[-1].clone() + self._sync_bounded_state() + self.stats.forward_calls += 1 + self.stats.tokens_consumed += len(prompt_ids) + self._record_peak_activation(logits) + self._record_peak_kv() + + @torch.no_grad() + def forward_block(self, tokens: List[int]) -> torch.Tensor: + if not self._committed: + raise RuntimeError("Verifier not prefilled.") + if not tokens: + raise ValueError("tokens must be non-empty") + self._pending = list(tokens) + if self._incremental: + block = self._native_forward(self._pending).clone() # [L, V] + else: + seq = self._committed + self._pending + logits = self._restored_logits(seq) # [len(seq), V] + start = len(self._committed) + block = logits[start : start + len(tokens)].clone() # [L, V] + # Provisional resident size mirrors SinkWindowVerifier (un-trimmed + # until commit_or_truncate); _sync_bounded_state re-bounds on commit. + self.cache_logical_size = len(self._committed) + len(tokens) + self.stats.forward_calls += 1 + self.stats.tokens_consumed += len(tokens) + self._record_peak_activation(block) + return block + + def commit_or_truncate(self, forwarded: int, accepted: int) -> None: + if accepted < 0 or accepted > forwarded: + raise ValueError("accepted must satisfy 0 <= accepted <= forwarded") + if self._incremental and self._past is not None: + # forward_block appended `forwarded` tokens' K/V to the cache; + # drop the rejected tail so the cache reflects only committed. + drop = forwarded - accepted + if drop > 0: + keep = self._past_len + forwarded - drop # == _past_len + accepted + for layer in self._past.layers: + if getattr(layer, "keys", None) is not None: + layer.keys = layer.keys[:, :, :keep, :].contiguous() + layer.values = layer.values[:, :, :keep, :].contiguous() + self._past_len += accepted + if accepted: + self._committed.extend(self._pending[:accepted]) + self._pending = [] + self._sync_bounded_state() + self._record_peak_kv() + + @torch.no_grad() + def append_token(self, token_id: int) -> torch.Tensor: + logits = self.forward_block([token_id]) + self.commit_or_truncate(forwarded=1, accepted=1) + self.next_token_logits = logits[-1].clone() + return self.next_token_logits + + # ------------------------------------------------------------------ # + # CacheInspector protocol (used by SessionStore / coordinators) + # ------------------------------------------------------------------ # + def k_seq_length(self, session: object) -> int: + del session # single-tenant: one verifier per bound session + return len(self.cached_token_sequence) + + def kv_live_bytes(self, session: object) -> int: + del session + return len(self.cached_token_sequence) * self._bytes_per_kv_token + + def live_kv_bytes(self) -> int: + return len(self.cached_token_sequence) * self._bytes_per_kv_token + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + def _budget(self) -> int: + return self.sink_size + self.window_size + + def _sync_bounded_state(self) -> None: + """Recompute the bounded sink+window resident mirror + counters.""" + budget = self._budget() + seq = self._committed + if len(seq) <= budget: + self.cached_token_sequence = list(seq) + else: + keep_window = budget - self.sink_size + self.cached_token_sequence = ( + seq[: self.sink_size] + seq[-keep_window:] + if keep_window > 0 + else seq[: self.sink_size] + ) + self.cache_logical_size = len(self.cached_token_sequence) + self.next_global_position = len(self._committed) + + def _record_peak_activation(self, logits: torch.Tensor) -> None: + n = int(logits.numel() * logits.element_size()) + if n > self.stats.peak_activation_bytes: + self.stats.peak_activation_bytes = n + + def _record_peak_kv(self) -> None: + self.stats.peak_kv_bytes = max( + self.stats.peak_kv_bytes, self.live_kv_bytes() + ) diff --git a/results/platform-tests/bench_gemma4_26b_mac.json b/results/platform-tests/bench_gemma4_26b_mac.json new file mode 100644 index 00000000..e2005bbf --- /dev/null +++ b/results/platform-tests/bench_gemma4_26b_mac.json @@ -0,0 +1,61 @@ +{ + "kind": "mlx_kakeya_deployment_benchmark", + "config": { + "model_id": "models/gemma-4-26B-A4B-it-mlx-4bit", + "context_lengths": [ + 512, + 2048, + 8192 + ], + "gen_tokens": 64, + "sink_size": 4, + "window_size": 64 + }, + "env": { + "mlx_version": "0.31.2" + }, + "results": [ + { + "context_length": 512, + "kakeya": { + "error": "TypeError: make_sink_window_cache() takes 1 positional argument but 3 were given" + }, + "vanilla": { + "prefill_s": 9.1962, + "decode_s": 4.4363, + "decode_tokens": 63, + "decode_tokens_per_s": 14.201, + "kv_bytes": 129536000, + "peak_memory_bytes": 16024951996 + } + }, + { + "context_length": 2048, + "kakeya": { + "error": "TypeError: make_sink_window_cache() takes 1 positional argument but 3 were given" + }, + "vanilla": { + "prefill_s": 7.6443, + "decode_s": 5.9676, + "decode_tokens": 63, + "decode_tokens_per_s": 10.557, + "kv_bytes": 475566080, + "peak_memory_bytes": 17317280424 + } + }, + { + "context_length": 8192, + "kakeya": { + "error": "TypeError: make_sink_window_cache() takes 1 positional argument but 3 were given" + }, + "vanilla": { + "prefill_s": 45.3409, + "decode_s": 20.705, + "decode_tokens": 63, + "decode_tokens_per_s": 3.043, + "kv_bytes": 1859686400, + "peak_memory_bytes": 22526543070 + } + } + ] +} \ No newline at end of file diff --git a/results/platform-tests/bench_gemma4_26b_mac.log b/results/platform-tests/bench_gemma4_26b_mac.log new file mode 100644 index 00000000..7261ffba --- /dev/null +++ b/results/platform-tests/bench_gemma4_26b_mac.log @@ -0,0 +1,26 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/bench_mlx_kakeya_deployment.py --model-id models/gemma-4-26B-A4B-it-mlx-4bit --context-lengths 512,2048,8192 --gen-tokens 64 --sink-size 4 --window-size 64 --output results/platform-tests/bench_gemma4_26b_mac.json + +Commit under test: +2b6851c Mac deployment bench: default to gemma-4-26B-A4B-it-mlx-4bit; measure REAL native incremental-decode tok/s. + +Started: 2026-06-11T07:01:22.167Z +Ended: 2026-06-11T07:03:04.279Z +Elapsed: 102.111s +Exit code: 0 + +[bench] loading MLX model models/gemma-4-26B-A4B-it-mlx-4bit +[bench] L=512: Kakeya sink+window ... +[bench] L=512: kakeya path failed: make_sink_window_cache() takes 1 positional argument but 3 were given +[bench] L=512: vanilla full-KV ... +[bench] L=512: vanilla 14.201 tok/s (prefill 9.1962s, KV 129.54 MB, peak 16.02 GB) +[bench] L=2048: Kakeya sink+window ... +[bench] L=2048: kakeya path failed: make_sink_window_cache() takes 1 positional argument but 3 were given +[bench] L=2048: vanilla full-KV ... +[bench] L=2048: vanilla 10.557 tok/s (prefill 7.6443s, KV 475.57 MB, peak 17.32 GB) +[bench] L=8192: Kakeya sink+window ... +[bench] L=8192: kakeya path failed: make_sink_window_cache() takes 1 positional argument but 3 were given +[bench] L=8192: vanilla full-KV ... +[bench] L=8192: vanilla 3.043 tok/s (prefill 45.3409s, KV 1859.69 MB, peak 22.53 GB) + +[bench] wrote results/platform-tests/bench_gemma4_26b_mac.json diff --git a/results/platform-tests/bench_gemma4_26b_mac_kakeya.json b/results/platform-tests/bench_gemma4_26b_mac_kakeya.json new file mode 100644 index 00000000..e8e891af --- /dev/null +++ b/results/platform-tests/bench_gemma4_26b_mac_kakeya.json @@ -0,0 +1,88 @@ +{ + "kind": "mlx_kakeya_deployment_benchmark", + "config": { + "model_id": "models/gemma-4-26B-A4B-it-mlx-4bit", + "context_lengths": [ + 512, + 2048, + 8192 + ], + "gen_tokens": 64, + "sink_size": 4, + "window_size": 64 + }, + "env": { + "mlx_version": "0.31.2" + }, + "results": [ + { + "context_length": 512, + "kakeya": { + "prefill_s": 9.6883, + "decode_s": 3.5036, + "decode_tokens": 63, + "decode_tokens_per_s": 17.981, + "kv_bytes": 15319040, + "peak_memory_bytes": 16024947900 + }, + "vanilla": { + "prefill_s": 1.498, + "decode_s": 2.5221, + "decode_tokens": 63, + "decode_tokens_per_s": 24.98, + "kv_bytes": 129536000, + "peak_memory_bytes": 16041187510 + }, + "kakeya_vs_vanilla": { + "decode_speedup_x": 0.72, + "kv_bytes_ratio_x": 8.5 + } + }, + { + "context_length": 2048, + "kakeya": { + "prefill_s": 7.5367, + "decode_s": 7.1653, + "decode_tokens": 63, + "decode_tokens_per_s": 8.792, + "kv_bytes": 15319040, + "peak_memory_bytes": 17490295530 + }, + "vanilla": { + "prefill_s": 6.3451, + "decode_s": 9.5656, + "decode_tokens": 63, + "decode_tokens_per_s": 6.586, + "kv_bytes": 252948480, + "peak_memory_bytes": 17467718378 + }, + "kakeya_vs_vanilla": { + "decode_speedup_x": 1.335, + "kv_bytes_ratio_x": 16.5 + } + }, + { + "context_length": 8192, + "kakeya": { + "prefill_s": 43.6301, + "decode_s": 22.1921, + "decode_tokens": 63, + "decode_tokens_per_s": 2.839, + "kv_bytes": 15319040, + "peak_memory_bytes": 22783444190 + }, + "vanilla": { + "prefill_s": 53.3461, + "decode_s": 22.9514, + "decode_tokens": 63, + "decode_tokens_per_s": 2.745, + "kv_bytes": 378777600, + "peak_memory_bytes": 22542763230 + }, + "kakeya_vs_vanilla": { + "decode_speedup_x": 1.034, + "kv_bytes_ratio_x": 24.7 + } + } + ] +} \ No newline at end of file diff --git a/results/platform-tests/bench_gemma4_26b_mac_kakeya.log b/results/platform-tests/bench_gemma4_26b_mac_kakeya.log new file mode 100644 index 00000000..2cb20077 --- /dev/null +++ b/results/platform-tests/bench_gemma4_26b_mac_kakeya.log @@ -0,0 +1,26 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/bench_mlx_kakeya_deployment.py --model-id models/gemma-4-26B-A4B-it-mlx-4bit --context-lengths 512,2048,8192 --gen-tokens 64 --sink-size 4 --window-size 64 --output results/platform-tests/bench_gemma4_26b_mac_kakeya.json + +Commit under test: +85b9c5a Fix Kakeya path in Mac deployment bench: make_sink_window_cache() takes keyword-only sink_size/window_size. + +Started: 2026-06-11T07:21:53.206Z +Ended: 2026-06-11T07:25:12.257Z +Elapsed: 199.051s +Exit code: 0 + +[bench] loading MLX model models/gemma-4-26B-A4B-it-mlx-4bit +[bench] L=512: Kakeya sink+window ... +[bench] L=512: vanilla full-KV ... +[bench] L=512: kakeya 17.981 tok/s (prefill 9.6883s, KV 15.32 MB, peak 16.02 GB) +[bench] L=512: vanilla 24.98 tok/s (prefill 1.498s, KV 129.54 MB, peak 16.04 GB) +[bench] L=2048: Kakeya sink+window ... +[bench] L=2048: vanilla full-KV ... +[bench] L=2048: kakeya 8.792 tok/s (prefill 7.5367s, KV 15.32 MB, peak 17.49 GB) +[bench] L=2048: vanilla 6.586 tok/s (prefill 6.3451s, KV 252.95 MB, peak 17.47 GB) +[bench] L=8192: Kakeya sink+window ... +[bench] L=8192: vanilla full-KV ... +[bench] L=8192: kakeya 2.839 tok/s (prefill 43.6301s, KV 15.32 MB, peak 22.78 GB) +[bench] L=8192: vanilla 2.745 tok/s (prefill 53.3461s, KV 378.78 MB, peak 22.54 GB) + +[bench] wrote results/platform-tests/bench_gemma4_26b_mac_kakeya.json diff --git a/results/research/f_theta_v5_s5_sliding.json b/results/research/f_theta_v5_s5_sliding.json new file mode 100644 index 00000000..6d65b876 --- /dev/null +++ b/results/research/f_theta_v5_s5_sliding.json @@ -0,0 +1,124 @@ +{ + "kind": "k3_f_theta_train", + "schema_version": 2, + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "steps": 10000, + "lr": 0.001, + "lr_schedule": "cosine", + "warmup_steps": 500, + "weight_decay": 0.01, + "n_prompts": 62, + "n_niah_prompts": 96, + "no_niah_prompts": false, + "niah_min_lines": 30, + "niah_max_lines": 140, + "gen_len": 256, + "sample_positions": 0, + "loss_type": "attn_distill_hybrid", + "lambda_k_dir": 1.0, + "lambda_v_dir": 1.0, + "lambda_k_mag": 0.1, + "lambda_v_mag": 0.1, + "s5_exact_full_attn": true, + "init_from": null, + "rank": 768, + "save": "results/research/f_theta_v5_s5_sliding", + "seed": 0, + "log_every": 50, + "eval_every": 500 + }, + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_params": 94371840, + "n_sequences": 158, + "n_general_prompts": 62, + "n_niah_prompts": 96, + "collect_seconds": 2245.6606731540523, + "train_seconds": 2807.3437344370177, + "initial_loss": 4.634022235870361, + "final_loss": 0.5480128973722458, + "loss_reduction_factor": 8.45604593996012, + "final_diagnostic": { + "mse_O_mean": 0.0797144242748618, + "abs_O_target_mean": 0.6430309748649597, + "k_dir_mean": 0.17511656086891889, + "v_dir_mean": 0.19396256901323794, + "k_mag_mean": 0.07836286470293999, + "v_mag_mean": 0.2563087701052427 + }, + "loss_type": "attn_distill_hybrid", + "lr_schedule": "cosine" +} \ No newline at end of file diff --git a/results/research/f_theta_v5_s5_sliding/f_theta_config.json b/results/research/f_theta_v5_s5_sliding/f_theta_config.json new file mode 100644 index 00000000..a7e565ee --- /dev/null +++ b/results/research/f_theta_v5_s5_sliding/f_theta_config.json @@ -0,0 +1,73 @@ +{ + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] +} \ No newline at end of file diff --git a/results/research/f_theta_v5_s5_sliding/f_theta_weights.pt b/results/research/f_theta_v5_s5_sliding/f_theta_weights.pt new file mode 100644 index 00000000..0030d5d9 --- /dev/null +++ b/results/research/f_theta_v5_s5_sliding/f_theta_weights.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1af1b0cdd25bd087f263be4c763db2ea2906ec6d4c421f080f28d39d669e2be +size 377510345 diff --git a/results/research/k3_alpha_sweep_attn_distill.json b/results/research/k3_alpha_sweep_attn_distill.json new file mode 100644 index 00000000..02ae5159 --- /dev/null +++ b/results/research/k3_alpha_sweep_attn_distill.json @@ -0,0 +1,87 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "results/research/f_theta_v3_attn_distill", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 24, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 1331.9405048759936, + "full_attn": 18254.26540890811 + }, + "sweep": [ + { + "alpha": 0.0, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 1331.9405048759936, + "eff_rel_mse_full_attn": 18254.26540890811 + }, + { + "alpha": 0.1, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 1078.8718089495549, + "eff_rel_mse_full_attn": 14785.954981215571 + }, + { + "alpha": 0.25, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 749.2165339927465, + "eff_rel_mse_full_attn": 10268.024292510812 + }, + { + "alpha": 0.5, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 332.9851262189984, + "eff_rel_mse_full_attn": 4563.566352227028 + }, + { + "alpha": 0.75, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 83.2462815547496, + "eff_rel_mse_full_attn": 1140.891588056757 + }, + { + "alpha": 0.9, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 13.31940504875993, + "eff_rel_mse_full_attn": 182.54265408908103 + }, + { + "alpha": 1.0, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.0, + "eff_rel_mse_full_attn": 0.0 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_alpha_sweep_relmse.json b/results/research/k3_alpha_sweep_relmse.json new file mode 100644 index 00000000..46949ca5 --- /dev/null +++ b/results/research/k3_alpha_sweep_relmse.json @@ -0,0 +1,95 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "/tmp/f_theta_v3_relmse", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 24, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 0.2313521382961503, + "full_attn": 1.4451023524463593 + }, + "sweep": [ + { + "alpha": 0.0, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.2313521382961503, + "eff_rel_mse_full_attn": 1.4451023524463593 + }, + { + "alpha": 0.5, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.05783803457403758, + "eff_rel_mse_full_attn": 0.36127558811158983 + }, + { + "alpha": 0.75, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.014459508643509394, + "eff_rel_mse_full_attn": 0.09031889702789746 + }, + { + "alpha": 0.9, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.002313521382961502, + "eff_rel_mse_full_attn": 0.014451023524463586 + }, + { + "alpha": 0.95, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.0005783803457403768, + "eff_rel_mse_full_attn": 0.0036127558811159047 + }, + { + "alpha": 0.98, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 9.25408553184603e-05, + "eff_rel_mse_full_attn": 0.0005780409409785448 + }, + { + "alpha": 0.99, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 2.3135213829615074e-05, + "eff_rel_mse_full_attn": 0.0001445102352446362 + }, + { + "alpha": 1.0, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.0, + "eff_rel_mse_full_attn": 0.0 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_alpha_sweep_relmse_knee.json b/results/research/k3_alpha_sweep_relmse_knee.json new file mode 100644 index 00000000..a0bb85c0 --- /dev/null +++ b/results/research/k3_alpha_sweep_relmse_knee.json @@ -0,0 +1,63 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "/tmp/f_theta_v3_relmse", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 24, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 0.2313521382961503, + "full_attn": 1.4451023524463593 + }, + "sweep": [ + { + "alpha": 0.1, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.18739523201988176, + "eff_rel_mse_full_attn": 1.1705329054815512 + }, + { + "alpha": 0.2, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.14806536850953622, + "eff_rel_mse_full_attn": 0.9248655055656702 + }, + { + "alpha": 0.3, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.11336254776511363, + "eff_rel_mse_full_attn": 0.708100152698716 + }, + { + "alpha": 0.4, + "recall": 0.6, + "samples_correct": 6, + "samples_total": 10, + "eff_rel_mse_overall": 0.0832867697866141, + "eff_rel_mse_full_attn": 0.5202368468806894 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_alpha_sweep_v4a.json b/results/research/k3_alpha_sweep_v4a.json new file mode 100644 index 00000000..27e1cbb8 --- /dev/null +++ b/results/research/k3_alpha_sweep_v4a.json @@ -0,0 +1,71 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "results/research/f_theta_v4a_warmstart_hybrid", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 24, + "haystack_min_lines": 60, + "haystack_max_lines": 80, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 0.19859662496240946, + "full_attn": 1.4201765954772976 + }, + "sweep": [ + { + "alpha": 0.0, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.19859662496240946, + "eff_rel_mse_full_attn": 1.4201765954772976 + }, + { + "alpha": 0.25, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.11171060154135531, + "eff_rel_mse_full_attn": 0.7988493349559799 + }, + { + "alpha": 0.5, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.049649156240602364, + "eff_rel_mse_full_attn": 0.3550441488693244 + }, + { + "alpha": 0.75, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.012412289060150591, + "eff_rel_mse_full_attn": 0.0887610372173311 + }, + { + "alpha": 1.0, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.0, + "eff_rel_mse_full_attn": 0.0 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_alpha_sweep_v4b.json b/results/research/k3_alpha_sweep_v4b.json new file mode 100644 index 00000000..4b69945e --- /dev/null +++ b/results/research/k3_alpha_sweep_v4b.json @@ -0,0 +1,71 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "results/research/f_theta_v4b_fresh_hybrid", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 24, + "haystack_min_lines": 60, + "haystack_max_lines": 80, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 0.20886406627984377, + "full_attn": 1.5155949128956676 + }, + "sweep": [ + { + "alpha": 0.0, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.20886406627984377, + "eff_rel_mse_full_attn": 1.5155949128956676 + }, + { + "alpha": 0.25, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.11748603728241212, + "eff_rel_mse_full_attn": 0.8525221385038131 + }, + { + "alpha": 0.5, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.052216016569960944, + "eff_rel_mse_full_attn": 0.3788987282239169 + }, + { + "alpha": 0.75, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.013054004142490236, + "eff_rel_mse_full_attn": 0.09472468205597923 + }, + { + "alpha": 1.0, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.0, + "eff_rel_mse_full_attn": 0.0 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_dflash_accept_b15.json b/results/research/k3_dflash_accept_b15.json new file mode 100644 index 00000000..a37825e3 --- /dev/null +++ b/results/research/k3_dflash_accept_b15.json @@ -0,0 +1,132 @@ +{ + "schema_version": 1, + "kind": "k3_dflash_specdecode_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "block_size": 15, + "num_steps": 8, + "max_new_tokens": 48, + "n_prompts": 4, + "aux_layer_ids": [ + 2, + 7, + 12, + 18, + 23, + 28 + ] + }, + "aggregate": { + "acceptance_rate": 0.16160220994475138, + "acceptance_length": 3.1666666666666665, + "total_accepted": 117, + "total_drafted": 724, + "total_blocks": 54, + "lossless_vs_ar": true, + "reference_humaneval": { + "acceptance_length": 7.7, + "acceptance_rate": 0.447 + } + }, + "per_prompt": [ + { + "prompt": "Write a Python function that returns the n-th Fibonacci number.", + "blocks": 12, + "block_accepts": [ + 4, + 2, + 2, + 0, + 1, + 4, + 5, + 7, + 0, + 4, + 4, + 4 + ], + "mean_accepted_per_block": 3.0833333333333335, + "tokens_generated": 48, + "verifier_forwards_spec": 12, + "lossless_vs_ar": true, + "decoded": "There are several ways to implement this depending on whether you prioritize readability, memory, or speed. Below are the three most common approaches.\n\n### 1. The Efficient Approach (Iterative)\nThis " + }, + { + "prompt": "Explain in two sentences why the sky is blue.", + "blocks": 22, + "block_accepts": [ + 1, + 2, + 1, + 6, + 2, + 0, + 2, + 4, + 0, + 0, + 1, + 1, + 0, + 2, + 0, + 0, + 2, + 1, + 0, + 1, + 0, + 0 + ], + "mean_accepted_per_block": 1.1818181818181819, + "tokens_generated": 48, + "verifier_forwards_spec": 22, + "lossless_vs_ar": true, + "decoded": "The sky appears blue because of a phenomenon called Rayleigh scattering, where sunlight interacts with the gases and particles in Earth's atmosphere. As sunlight reaches the atmosphere, shorter blue w" + }, + { + "prompt": "List three prime numbers greater than 100.", + "blocks": 5, + "block_accepts": [ + 0, + 10, + 3, + 14, + 8 + ], + "mean_accepted_per_block": 7.0, + "tokens_generated": 40, + "verifier_forwards_spec": 5, + "lossless_vs_ar": true, + "decoded": "Here are three prime numbers greater than 100:\n\n1. **101**\n2. **103**\n3. **107**" + }, + { + "prompt": "Summarize the plot of Romeo and Juliet in one sentence.", + "blocks": 15, + "block_accepts": [ + 0, + 0, + 3, + 0, + 0, + 1, + 1, + 1, + 2, + 3, + 1, + 2, + 1, + 2, + 2 + ], + "mean_accepted_per_block": 1.2666666666666666, + "tokens_generated": 34, + "verifier_forwards_spec": 15, + "lossless_vs_ar": true, + "decoded": "Two star-crossed lovers from feuding noble families take their own lives in a tragic misunderstanding, ultimately transforming their deaths into a catalyst for peace between their households." + } + ] +} \ No newline at end of file diff --git a/results/research/k3_dflash_accept_baseline.json b/results/research/k3_dflash_accept_baseline.json new file mode 100644 index 00000000..34543a71 --- /dev/null +++ b/results/research/k3_dflash_accept_baseline.json @@ -0,0 +1,143 @@ +{ + "schema_version": 1, + "kind": "k3_dflash_specdecode_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "block_size": 16, + "num_steps": 8, + "max_new_tokens": 48, + "n_prompts": 4, + "aux_layer_ids": [ + 2, + 7, + 12, + 18, + 23, + 28 + ] + }, + "aggregate": { + "acceptance_rate": 0.11169652265542676, + "acceptance_length": 2.6307692307692307, + "total_accepted": 106, + "total_drafted": 949, + "total_blocks": 65, + "lossless_vs_ar": true, + "reference_humaneval": { + "acceptance_length": 7.7, + "acceptance_rate": 0.447 + } + }, + "per_prompt": [ + { + "prompt": "Write a Python function that returns the n-th Fibonacci number.", + "blocks": 7, + "block_accepts": [ + 7, + 1, + 11, + 8, + 1, + 9, + 5 + ], + "mean_accepted_per_block": 6.0, + "tokens_generated": 48, + "verifier_forwards_spec": 7, + "lossless_vs_ar": true, + "decoded": "There are several ways to implement this depending on whether you prioritize readability, memory, or speed. Below are the three most common approaches.\n\n### 1. The Efficient Approach (Iterative)\nThis " + }, + { + "prompt": "Explain in two sentences why the sky is blue.", + "blocks": 27, + "block_accepts": [ + 0, + 0, + 0, + 1, + 0, + 0, + 2, + 3, + 2, + 0, + 0, + 1, + 4, + 1, + 1, + 1, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 1, + 1 + ], + "mean_accepted_per_block": 0.7777777777777778, + "tokens_generated": 48, + "verifier_forwards_spec": 27, + "lossless_vs_ar": true, + "decoded": "The sky appears blue because of a phenomenon called Rayleigh scattering, where sunlight interacts with the gases and particles in Earth's atmosphere. As sunlight reaches the atmosphere, shorter blue w" + }, + { + "prompt": "List three prime numbers greater than 100.", + "blocks": 10, + "block_accepts": [ + 5, + 10, + 1, + 2, + 1, + 3, + 2, + 3, + 3, + 0 + ], + "mean_accepted_per_block": 3.0, + "tokens_generated": 40, + "verifier_forwards_spec": 10, + "lossless_vs_ar": true, + "decoded": "Here are three prime numbers greater than 100:\n\n1. **101**\n2. **103**\n3. **107**" + }, + { + "prompt": "Summarize the plot of Romeo and Juliet in one sentence.", + "blocks": 21, + "block_accepts": [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 2, + 0, + 1, + 1, + 1, + 0, + 1, + 2, + 1, + 2 + ], + "mean_accepted_per_block": 0.6190476190476191, + "tokens_generated": 34, + "verifier_forwards_spec": 21, + "lossless_vs_ar": true, + "decoded": "Two star-crossed lovers from feuding noble families take their own lives in a tragic misunderstanding, ultimately transforming their deaths into a catalyst for peace between their households." + } + ] +} \ No newline at end of file diff --git a/results/research/k3_dflash_accept_code.json b/results/research/k3_dflash_accept_code.json new file mode 100644 index 00000000..08373301 --- /dev/null +++ b/results/research/k3_dflash_accept_code.json @@ -0,0 +1,193 @@ +{ + "schema_version": 1, + "kind": "k3_dflash_specdecode_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "block_size": 15, + "num_steps": 8, + "max_new_tokens": 64, + "n_prompts": 6, + "aux_layer_ids": [ + 2, + 7, + 12, + 18, + 23, + 28 + ] + }, + "aggregate": { + "acceptance_rate": 0.22741194486983154, + "acceptance_length": 4.193548387096774, + "total_accepted": 297, + "total_drafted": 1306, + "total_blocks": 93, + "lossless_vs_ar": true, + "reference_humaneval": { + "acceptance_length": 7.7, + "acceptance_rate": 0.447 + } + }, + "per_prompt": [ + { + "prompt": "Complete this Python function:\n\ndef has_close_elements(numbers: list[float], threshold: float) -> bool:\n \"\"\"Return True if any two numbers are closer than threshold.\"\"\"\n", + "blocks": 19, + "block_accepts": [ + 0, + 1, + 4, + 0, + 1, + 1, + 3, + 1, + 0, + 4, + 3, + 1, + 1, + 1, + 1, + 8, + 7, + 4, + 5 + ], + "mean_accepted_per_block": 2.4210526315789473, + "tokens_generated": 64, + "verifier_forwards_spec": 19, + "lossless_vs_ar": true, + "decoded": "To complete this function, the most efficient way is to sort the list first. Once sorted, the two closest numbers must be adjacent to each other, allowing you to find the result in $O(n \\log n)$ time " + }, + { + "prompt": "Complete this Python function:\n\ndef is_palindrome(s: str) -> bool:\n \"\"\"Return True if s reads the same forwards and backwards, ignoring case and non-alphanumeric chars.\"\"\"\n", + "blocks": 15, + "block_accepts": [ + 0, + 3, + 2, + 1, + 3, + 6, + 0, + 1, + 0, + 3, + 0, + 7, + 3, + 13, + 8 + ], + "mean_accepted_per_block": 3.3333333333333335, + "tokens_generated": 64, + "verifier_forwards_spec": 15, + "lossless_vs_ar": true, + "decoded": "To complete this function, you can use a generator expression to filter out non-alphanumeric characters and convert the remaining ones to lowercase, then compare the resulting string with its reverse." + }, + { + "prompt": "Complete this Python function:\n\ndef merge_sort(arr: list[int]) -> list[int]:\n \"\"\"Return a new list with the elements of arr sorted ascending using merge sort.\"\"\"\n", + "blocks": 11, + "block_accepts": [ + 2, + 8, + 1, + 2, + 0, + 6, + 4, + 0, + 3, + 13, + 15 + ], + "mean_accepted_per_block": 4.909090909090909, + "tokens_generated": 64, + "verifier_forwards_spec": 11, + "lossless_vs_ar": true, + "decoded": "Here is the complete implementation of the `merge_sort` algorithm. This implementation follows the divide-and-conquer paradigm and returns a new list to avoid mutating the original input.\n\n```python\nd" + }, + { + "prompt": "Complete this Python function:\n\ndef gcd(a: int, b: int) -> int:\n \"\"\"Return the greatest common divisor of a and b using the Euclidean algorithm.\"\"\"\n", + "blocks": 6, + "block_accepts": [ + 3, + 10, + 15, + 15, + 11, + 5 + ], + "mean_accepted_per_block": 9.833333333333334, + "tokens_generated": 64, + "verifier_forwards_spec": 6, + "lossless_vs_ar": true, + "decoded": "Here is the completed function using the Euclidean algorithm:\n\n```python\ndef gcd(a: int, b: int) -> int:\n \"\"\"Return the greatest common divisor of a and b using the Euclidean algorithm.\"\"\"\n whil" + }, + { + "prompt": "Complete this Python function:\n\ndef flatten(nested: list) -> list:\n \"\"\"Flatten an arbitrarily nested list of integers into a single flat list.\"\"\"\n", + "blocks": 14, + "block_accepts": [ + 3, + 4, + 4, + 2, + 0, + 3, + 4, + 4, + 3, + 1, + 4, + 5, + 3, + 11 + ], + "mean_accepted_per_block": 3.642857142857143, + "tokens_generated": 64, + "verifier_forwards_spec": 14, + "lossless_vs_ar": true, + "decoded": "To flatten an arbitrarily nested list, the most effective approach is to use **recursion**. We iterate through each element: if the element is a list, we call the function again; if it is an integer, " + }, + { + "prompt": "Complete this Python function:\n\ndef count_words(text: str) -> dict[str, int]:\n \"\"\"Return a dict mapping each lowercased word in text to its frequency.\"\"\"\n", + "blocks": 28, + "block_accepts": [ + 0, + 3, + 1, + 2, + 5, + 2, + 4, + 2, + 0, + 1, + 1, + 1, + 0, + 0, + 1, + 1, + 2, + 2, + 0, + 1, + 2, + 1, + 0, + 1, + 0, + 0, + 1, + 3 + ], + "mean_accepted_per_block": 1.3214285714285714, + "tokens_generated": 64, + "verifier_forwards_spec": 28, + "lossless_vs_ar": true, + "decoded": "To complete this function, you should normalize the text by converting it to lowercase, splitting it into individual words, and then counting the occurrences. \n\nUsing `re.findall` is the most robust a" + } + ] +} \ No newline at end of file diff --git a/results/research/k3_dflash_accept_humaneval.json b/results/research/k3_dflash_accept_humaneval.json new file mode 100644 index 00000000..ecd93a16 --- /dev/null +++ b/results/research/k3_dflash_accept_humaneval.json @@ -0,0 +1,383 @@ +{ + "schema_version": 1, + "kind": "k3_dflash_specdecode_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "block_size": 15, + "num_steps": 8, + "max_new_tokens": 96, + "n_prompts": 10, + "aux_layer_ids": [ + 2, + 7, + 12, + 18, + 23, + 28 + ] + }, + "aggregate": { + "acceptance_rate": 0.199245939675174, + "acceptance_length": 3.8744769874476988, + "total_accepted": 687, + "total_drafted": 3448, + "total_blocks": 239, + "lossless_vs_ar": false, + "reference_humaneval": { + "acceptance_length": 7.7, + "acceptance_rate": 0.447 + } + }, + "per_prompt": [ + { + "prompt": "from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n given threshold.\n >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n False\n >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n True\n \"\"\"\n", + "blocks": 16, + "block_accepts": [ + 0, + 6, + 0, + 1, + 6, + 5, + 0, + 1, + 3, + 1, + 6, + 5, + 2, + 2, + 0, + 0 + ], + "mean_accepted_per_block": 2.375, + "tokens_generated": 54, + "verifier_forwards_spec": 16, + "lossless_vs_ar": true, + "decoded": " for i in range(len(numbers)):\n for j in range(i + 1, len(numbers)):\n if abs(numbers[i] - numbers[j]) < threshold:\n return True\n return False\n```" + }, + { + "prompt": "from typing import List\n\n\ndef separate_paren_groups(paren_string: str) -> List[str]:\n \"\"\" Input to this function is a string containing multiple groups of nested parentheses. Your goal is to\n separate those group into separate strings and return the list of those.\n Separate groups are balanced (each open brace is properly closed) and not nested within each other\n Ignore any spaces in the input string.\n >>> separate_paren_groups('( ) (( )) (( )( ))')\n ['()', '(())', '(()())']\n \"\"\"\n", + "blocks": 26, + "block_accepts": [ + 1, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 3, + 0, + 2, + 0, + 5, + 3, + 4, + 0, + 6, + 7, + 11, + 10, + 11, + 5 + ], + "mean_accepted_per_block": 2.730769230769231, + "tokens_generated": 96, + "verifier_forwards_spec": 26, + "lossless_vs_ar": false, + "decoded": " # This is a bit of a function to a part of the list of those.\n # This is a part of the list of those.\n # This is a part of thes.\n # This is a part of thes.\n # This is a part of thes.\n " + }, + { + "prompt": "\n\ndef truncate_number(number: float) -> float:\n \"\"\" Given a positive floating point number, it can be decomposed into\n and integer part (largest integer smaller than given number) and decimals\n (leftover part always smaller than 1).\n\n Return the decimal part of the number.\n >>> truncate_number(3.5)\n 0.5\n \"\"\"\n", + "blocks": 38, + "block_accepts": [ + 0, + 0, + 3, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 1, + 1, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 1, + 3, + 3, + 3, + 3, + 1, + 3, + 1 + ], + "mean_accepted_per_block": 1.5263157894736843, + "tokens_generated": 96, + "verifier_forwards_spec": 38, + "lossless_vs_ar": true, + "decoded": " return 0.5\n```\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* * *\n\n* *" + }, + { + "prompt": "from typing import List\n\n\ndef below_zero(operations: List[int]) -> bool:\n \"\"\" You're given a list of deposit and withdrawal operations on a bank account that starts with\n zero balance. Your task is to detect if at any point the balance of account fallls below zero, and\n at that point function should return True. Otherwise it should return False.\n >>> below_zero([1, 2, 3])\n False\n >>> below_zero([1, 2, -4, 5])\n True\n \"\"\"\n", + "blocks": 29, + "block_accepts": [ + 0, + 0, + 5, + 1, + 4, + 2, + 0, + 6, + 7, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 2, + 3, + 9, + 11, + 15 + ], + "mean_accepted_per_block": 2.3448275862068964, + "tokens_generated": 96, + "verifier_forwards_spec": 29, + "lossless_vs_ar": true, + "decoded": " balance = 0\n for amount in amount:\n balance += amount\n if balance < 0:\n return True\n return False\n```\n\n***\n\n\n\n#### Final Result:\n\nThe code has several errors in the " + }, + { + "prompt": "from typing import List\n\n\ndef mean_absolute_deviation(numbers: List[float]) -> float:\n \"\"\" For a given list of input numbers, calculate Mean Absolute Deviation\n around the mean of this dataset.\n Mean Absolute Deviation is the average absolute difference between each\n element and a centerpoint (mean in this case):\n MAD = average | x - x_mean |\n >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])\n 1.0\n \"\"\"\n", + "blocks": 29, + "block_accepts": [ + 0, + 0, + 1, + 2, + 2, + 0, + 0, + 0, + 3, + 3, + 0, + 1, + 0, + 4, + 6, + 6, + 0, + 5, + 7, + 7, + 2, + 8, + 0, + 1, + 1, + 1, + 3, + 2, + 3 + ], + "mean_accepted_per_block": 2.3448275862068964, + "tokens_generated": 96, + "verifier_forwards_spec": 29, + "lossless_vs_ar": true, + "decoded": " # Calculate the mean of the list\n # (In this case, we'll use the sum of the actual list)\n # (In this case, we'll use the sum of the actual list)\n # (In this case, we'll use the sum of the" + }, + { + "prompt": "from typing import List\n\n\ndef intersperse(numbers: List[int], delimeter: int) -> List[int]:\n \"\"\" Insert a number 'delimeter' between every two consecutive elements of input list `numbers'\n >>> intersperse([], 4)\n []\n >>> intersperse([1, 2, 3], 4)\n [1, 4, 2, 4, 3]\n \"\"\"\n", + "blocks": 17, + "block_accepts": [ + 0, + 0, + 0, + 7, + 2, + 1, + 2, + 3, + 5, + 8, + 5, + 9, + 4, + 10, + 9, + 10, + 5 + ], + "mean_accepted_per_block": 4.705882352941177, + "tokens_generated": 96, + "verifier_forwards_spec": 17, + "lossless_vs_ar": true, + "decoded": " # Insert a number 'delimeter' between every two consecutive elements of input list `numbers'\n # Insert a number 'delimeter' between every two consecutive elements of input list `numbers'\n # " + }, + { + "prompt": "from typing import List\n\n\ndef parse_nested_parens(paren_string: str) -> List[int]:\n \"\"\" Input to this function is a string represented multiple groups for nested parentheses separated by spaces.\n For each of the group, output the deepest level of nesting of parentheses.\n E.g. (()()) has maximum two levels of nesting while ((())) has three.\n\n >>> parse_nested_parens('(()()) ((())) () ((())()())')\n [2, 3, 1, 3]\n \"\"\"\n", + "blocks": 32, + "block_accepts": [ + 1, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 2, + 4, + 0, + 1, + 1, + 2, + 4, + 2, + 7, + 6, + 10, + 4, + 5, + 0, + 0, + 0, + 1, + 0, + 0, + 3, + 3, + 4, + 4 + ], + "mean_accepted_per_block": 2.03125, + "tokens_generated": 96, + "verifier_forwards_spec": 32, + "lossless_vs_ar": true, + "decoded": " # This function is a bit of a complex for nested parentheses.\n # We'll use a bit of a complex for nested parentheses.\n # We'll use a bit of a complex for nested parentheses.\n # We'll use " + }, + { + "prompt": "from typing import List\n\n\ndef filter_by_substring(strings: List[str], substring: str) -> List[str]:\n \"\"\" Filter an input list of strings only for ones that contain given substring\n >>> filter_by_substring([], 'a')\n []\n >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')\n ['abc', 'bacd', 'array']\n \"\"\"\n", + "blocks": 15, + "block_accepts": [ + 1, + 0, + 3, + 3, + 4, + 5, + 9, + 7, + 7, + 6, + 11, + 6, + 8, + 6, + 6 + ], + "mean_accepted_per_block": 5.466666666666667, + "tokens_generated": 96, + "verifier_forwards_spec": 15, + "lossless_vs_ar": true, + "decoded": " return [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s in [s for s " + }, + { + "prompt": "from typing import List, Tuple\n\n\ndef sum_product(numbers: List[int]) -> Tuple[int, int]:\n \"\"\" For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list.\n Empty sum should be equal to 0 and empty product should be equal to 1.\n >>> sum_product([])\n (0, 1)\n >>> sum_product([1, 2, 3, 4])\n (10, 24)\n \"\"\"\n", + "blocks": 16, + "block_accepts": [ + 1, + 0, + 1, + 0, + 2, + 1, + 2, + 4, + 9, + 2, + 9, + 12, + 12, + 12, + 12, + 2 + ], + "mean_accepted_per_block": 5.0625, + "tokens_generated": 96, + "verifier_forwards_spec": 16, + "lossless_vs_ar": true, + "decoded": " # The sum of all the integers in a list.\n # The product of all the integers in a list.\n # The sum of all the integers in a list.\n # The sum of all the integers in a list.\n # The sum of" + }, + { + "prompt": "from typing import List, Tuple\n\n\ndef rolling_max(numbers: List[int]) -> List[int]:\n \"\"\" From a given list of integers, generate a list of rolling maximum element found until given moment\n in the sequence.\n >>> rolling_max([1, 2, 3, 2, 3, 4, 2])\n [1, 2, 3, 3, 3, 4, 4]\n \"\"\"\n", + "blocks": 21, + "block_accepts": [ + 1, + 1, + 0, + 0, + 0, + 1, + 1, + 0, + 4, + 4, + 0, + 3, + 6, + 7, + 9, + 9, + 4, + 9, + 4, + 9, + 4 + ], + "mean_accepted_per_block": 3.619047619047619, + "tokens_generated": 96, + "verifier_forwards_spec": 21, + "lossless_vs_ar": true, + "decoded": " # This is not a given moment, not a given list, not a given sequence, not a given moment, not a given list, not a given sequence, not a given moment, not a given list, not a given sequence, not a " + } + ] +} \ No newline at end of file diff --git a/results/research/k3_dflash_accept_noscale.json b/results/research/k3_dflash_accept_noscale.json new file mode 100644 index 00000000..07e68125 --- /dev/null +++ b/results/research/k3_dflash_accept_noscale.json @@ -0,0 +1,131 @@ +{ + "schema_version": 1, + "kind": "k3_dflash_specdecode_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "block_size": 16, + "num_steps": 8, + "max_new_tokens": 48, + "n_prompts": 4, + "aux_layer_ids": [ + 2, + 7, + 12, + 18, + 23, + 28 + ] + }, + "aggregate": { + "acceptance_rate": 0.157543391188251, + "acceptance_length": 3.2264150943396226, + "total_accepted": 118, + "total_drafted": 749, + "total_blocks": 53, + "lossless_vs_ar": true, + "reference_humaneval": { + "acceptance_length": 7.7, + "acceptance_rate": 0.447 + } + }, + "per_prompt": [ + { + "prompt": "Write a Python function that returns the n-th Fibonacci number.", + "blocks": 12, + "block_accepts": [ + 4, + 2, + 2, + 0, + 1, + 4, + 5, + 7, + 0, + 4, + 4, + 4 + ], + "mean_accepted_per_block": 3.0833333333333335, + "tokens_generated": 48, + "verifier_forwards_spec": 12, + "lossless_vs_ar": true, + "decoded": "There are several ways to implement this depending on whether you prioritize readability, memory, or speed. Below are the three most common approaches.\n\n### 1. The Efficient Approach (Iterative)\nThis " + }, + { + "prompt": "Explain in two sentences why the sky is blue.", + "blocks": 22, + "block_accepts": [ + 1, + 2, + 0, + 7, + 2, + 0, + 2, + 4, + 0, + 0, + 1, + 1, + 0, + 2, + 0, + 0, + 2, + 1, + 0, + 1, + 0, + 0 + ], + "mean_accepted_per_block": 1.1818181818181819, + "tokens_generated": 48, + "verifier_forwards_spec": 22, + "lossless_vs_ar": true, + "decoded": "The sky appears blue because of a phenomenon called Rayleigh scattering, where sunlight interacts with the gases and particles in Earth's atmosphere. As sunlight reaches the atmosphere, shorter blue w" + }, + { + "prompt": "List three prime numbers greater than 100.", + "blocks": 4, + "block_accepts": [ + 11, + 3, + 14, + 8 + ], + "mean_accepted_per_block": 9.0, + "tokens_generated": 40, + "verifier_forwards_spec": 4, + "lossless_vs_ar": true, + "decoded": "Here are three prime numbers greater than 100:\n\n1. **101**\n2. **103**\n3. **107**" + }, + { + "prompt": "Summarize the plot of Romeo and Juliet in one sentence.", + "blocks": 15, + "block_accepts": [ + 0, + 0, + 3, + 0, + 0, + 1, + 1, + 1, + 2, + 3, + 1, + 2, + 1, + 2, + 2 + ], + "mean_accepted_per_block": 1.2666666666666666, + "tokens_generated": 34, + "verifier_forwards_spec": 15, + "lossless_vs_ar": true, + "decoded": "Two star-crossed lovers from feuding noble families take their own lives in a tragic misunderstanding, ultimately transforming their deaths into a catalyst for peace between their households." + } + ] +} \ No newline at end of file diff --git a/results/research/k3_e2e_gpu_bench.json b/results/research/k3_e2e_gpu_bench.json new file mode 100644 index 00000000..7810a6ea --- /dev/null +++ b/results/research/k3_e2e_gpu_bench.json @@ -0,0 +1,92 @@ +{ + "kind": "k3_e2e_gpu_bench", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "sink_size": 4, + "window_size": 64, + "gen_tokens": 16, + "n_samples": 3, + "haystack_lines": [ + 60, + 160 + ] + }, + "verifier_dims": { + "num_hidden_layers": 30, + "num_key_value_heads": 8, + "head_dim": 256, + "sliding_window": 1024 + }, + "env": { + "gpu": "NVIDIA H200", + "torch": "2.11.0+cu128" + }, + "results": [ + { + "haystack_lines": 60, + "prompt_tokens": { + "min": 1238, + "max": 1238 + }, + "ar": { + "decode_tokens_per_s": 21.514, + "prefill_s_mean": 0.1802, + "kv_bytes_final": 282501120, + "peak_mem_bytes": 54761089024, + "recall": 1.0, + "decode_tokens": 48 + }, + "restored": { + "decode_tokens_per_s": 2.259, + "prefill_s_mean": 0.4505, + "resident_kv_bytes": 16711680, + "resident_window_tokens": 68, + "effective_context_tokens": 1254, + "peak_mem_bytes": 55049542656, + "recall": 1.0, + "decode_tokens": 48 + }, + "comparison": { + "kv_memory_saving_x": 16.9, + "ar_kv_mb": 282.5, + "restored_resident_kv_mb": 16.71, + "context_compression_x": 18.4, + "throughput_ratio_restored_over_ar": 0.105 + } + }, + { + "haystack_lines": 160, + "prompt_tokens": { + "min": 3238, + "max": 3238 + }, + "ar": { + "decode_tokens_per_s": 21.92, + "prefill_s_mean": 0.2378, + "kv_bytes_final": 733061120, + "peak_mem_bytes": 57768487424, + "recall": 1.0, + "decode_tokens": 48 + }, + "restored": { + "decode_tokens_per_s": 1.273, + "prefill_s_mean": 0.7683, + "resident_kv_bytes": 16711680, + "resident_window_tokens": 68, + "effective_context_tokens": 3254, + "peak_mem_bytes": 58487434240, + "recall": 1.0, + "decode_tokens": 48 + }, + "comparison": { + "kv_memory_saving_x": 43.9, + "ar_kv_mb": 733.06, + "restored_resident_kv_mb": 16.71, + "context_compression_x": 47.9, + "throughput_ratio_restored_over_ar": 0.058 + } + } + ] +} \ No newline at end of file diff --git a/results/research/k3_e2e_gpu_bench_incremental.json b/results/research/k3_e2e_gpu_bench_incremental.json new file mode 100644 index 00000000..e4d4a4e2 --- /dev/null +++ b/results/research/k3_e2e_gpu_bench_incremental.json @@ -0,0 +1,92 @@ +{ + "kind": "k3_e2e_gpu_bench", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "sink_size": 4, + "window_size": 64, + "gen_tokens": 16, + "n_samples": 3, + "haystack_lines": [ + 60, + 160 + ] + }, + "verifier_dims": { + "num_hidden_layers": 30, + "num_key_value_heads": 8, + "head_dim": 256, + "sliding_window": 1024 + }, + "env": { + "gpu": "NVIDIA H200", + "torch": "2.11.0+cu128" + }, + "results": [ + { + "haystack_lines": 60, + "prompt_tokens": { + "min": 1238, + "max": 1238 + }, + "ar": { + "decode_tokens_per_s": 17.454, + "prefill_s_mean": 0.2179, + "kv_bytes_final": 282501120, + "peak_mem_bytes": 54761089024, + "recall": 1.0, + "decode_tokens": 48 + }, + "restored": { + "decode_tokens_per_s": 21.705, + "prefill_s_mean": 0.4646, + "resident_kv_bytes": 16711680, + "resident_window_tokens": 68, + "effective_context_tokens": 1254, + "peak_mem_bytes": 55252261888, + "recall": 1.0, + "decode_tokens": 48 + }, + "comparison": { + "kv_memory_saving_x": 16.9, + "ar_kv_mb": 282.5, + "restored_resident_kv_mb": 16.71, + "context_compression_x": 18.4, + "throughput_ratio_restored_over_ar": 1.244 + } + }, + { + "haystack_lines": 160, + "prompt_tokens": { + "min": 3238, + "max": 3238 + }, + "ar": { + "decode_tokens_per_s": 21.206, + "prefill_s_mean": 0.238, + "kv_bytes_final": 733061120, + "peak_mem_bytes": 58049087488, + "recall": 1.0, + "decode_tokens": 48 + }, + "restored": { + "decode_tokens_per_s": 21.63, + "prefill_s_mean": 0.7713, + "resident_kv_bytes": 16711680, + "resident_window_tokens": 68, + "effective_context_tokens": 3254, + "peak_mem_bytes": 59053939712, + "recall": 1.0, + "decode_tokens": 48 + }, + "comparison": { + "kv_memory_saving_x": 43.9, + "ar_kv_mb": 733.06, + "restored_resident_kv_mb": 16.71, + "context_compression_x": 47.9, + "throughput_ratio_restored_over_ar": 1.02 + } + } + ] +} \ No newline at end of file diff --git a/results/research/k3_fidelity_f_theta_v4a_warmstart_hybrid.json b/results/research/k3_fidelity_f_theta_v4a_warmstart_hybrid.json new file mode 100644 index 00000000..d1ab3ec8 --- /dev/null +++ b/results/research/k3_fidelity_f_theta_v4a_warmstart_hybrid.json @@ -0,0 +1,47 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "results/research/f_theta_v4a_warmstart_hybrid", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 16, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 0.19859662496240946, + "full_attn": 1.4201765954772976 + }, + "sweep": [ + { + "alpha": 0.0, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.19859662496240946, + "eff_rel_mse_full_attn": 1.4201765954772976 + }, + { + "alpha": 0.5, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.049649156240602364, + "eff_rel_mse_full_attn": 0.3550441488693244 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_fidelity_f_theta_v4b_fresh_hybrid.json b/results/research/k3_fidelity_f_theta_v4b_fresh_hybrid.json new file mode 100644 index 00000000..fb3e8ed7 --- /dev/null +++ b/results/research/k3_fidelity_f_theta_v4b_fresh_hybrid.json @@ -0,0 +1,47 @@ +{ + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": "results/research/f_theta_v4b_fresh_hybrid", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "max_new_tokens": 16, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "seed": 42, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "f_theta_baseline_rel_mse": { + "overall": 0.20886406627984377, + "full_attn": 1.5155949128956676 + }, + "sweep": [ + { + "alpha": 0.0, + "recall": 0.0, + "samples_correct": 0, + "samples_total": 10, + "eff_rel_mse_overall": 0.20886406627984377, + "eff_rel_mse_full_attn": 1.5155949128956676 + }, + { + "alpha": 0.5, + "recall": 1.0, + "samples_correct": 10, + "samples_total": 10, + "eff_rel_mse_overall": 0.052216016569960944, + "eff_rel_mse_full_attn": 0.3788987282239169 + } + ] +} \ No newline at end of file diff --git a/results/research/k3_identity_restore_ctx70.json b/results/research/k3_identity_restore_ctx70.json new file mode 100644 index 00000000..6922555e --- /dev/null +++ b/results/research/k3_identity_restore_ctx70.json @@ -0,0 +1,272 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v1", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 256, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": true, + "identity_restore": true, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 5.108704237313941, + "median_latency_s": 5.052127701987047, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 3.8454941553293023, + 4.737514101160626, + 4.94522928603496, + 4.179844858322699, + 4.995243891981496, + 4.171802989047664, + 4.451344750865639, + 5.142973584540285, + 4.53123556298915, + 4.763504484631587 + ], + "mean_throughput_tokens_per_sec": 4.576418766490341, + "median_throughput_tokens_per_sec": 4.634374832074888, + "min_throughput_tokens_per_sec": 3.8454941553293023, + "max_throughput_tokens_per_sec": 5.142973584540285 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56555995136, + "peak_allocated_bytes": 55607588864, + "peak_reserved_bytes": 56555995136 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 1.0, + "recall_oracle": null, + "recall_delta_vs_oracle_pp": null, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx280_1781062484.json b/results/research/k3_integrated_niah_ctx280_1781062484.json new file mode 100644 index 00000000..c8844c1b --- /dev/null +++ b/results/research/k3_integrated_niah_ctx280_1781062484.json @@ -0,0 +1,340 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v1", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 256, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "prompt_token_lens": [ + 6417, + 5418, + 6177, + 4857, + 6078, + 5298, + 5358, + 5198, + 5497, + 5057 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 15.52308544210391, + "median_latency_s": 14.698599348543212, + "per_sample_decoded": [ + "The prompt asks for a single sentence, but the text provided is a distraction. Based on the instructions provided in the prompt", + "The prompt asks for a single sentence, but the instructions require a single sentence. Therefore, I will provide the answer in", + "The prompt asks for a single sentence, but the text provided is a distraction. Based on the text provided, there is", + "The prompt asks for a single sentence, but the text provided is a distraction. There is no \"secret code\" mentioned", + "The prompt asks for a single sentence, but the instructions are contradictory. However, based on the text provided, there is", + "The prompt asks for a single sentence, but the text provided is a distraction. Based on the instructions, here is the", + "The prompt asks for a single sentence, but the text provided is a distraction. Based on the instructions, here is the", + "The prompt asks for a single sentence, but the text provided is a distraction. Based on the text provided, there is", + "The prompt asks for a single sentence, but the instructions are contradictory. However, based on the text provided, there is", + "The prompt asks for a single sentence, but the text provided is a distraction. There is no \"secret code\" mentioned" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.1779341009672453, + 1.6165093588852733, + 1.246971030602152, + 1.93168938465748, + 1.387724273255223, + 1.6722198430620756, + 1.6494399590304862, + 1.7330882555455362, + 1.587746651599798, + 1.8212931246640527 + ], + "mean_throughput_tokens_per_sec": 1.582461598226932, + "median_throughput_tokens_per_sec": 1.6329746589578797, + "min_throughput_tokens_per_sec": 1.1779341009672453, + "max_throughput_tokens_per_sec": 1.93168938465748 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 11.51858025569236, + "median_latency_s": 11.570857462531421, + "per_sample_decoded": [ + "BETA-1409\n\nthought\nBETA-1409\n", + "The secret code is DELTA-3286.thought\nThe secret code is DELTA-", + "BETA-7912\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-4582\n\nThe secret code is BETA-4582.", + "KAPPA-1434\n\n\n\n", + "PINE-9928\n\n\n", + "THETA-5557\n\nThe secret code is THETA-5557", + "PINE-7924\n\nthought\nPINE-7924\n", + "GAMMA-4527\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-7224\n\nThe secret code is BETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 15, + 13, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4687567015431287, + 1.985795171268025, + 1.5669987421788882, + 2.358122932225276, + 1.7097474439227078, + 2.0548270151626356, + 2.0239335454596032, + 2.1269773436426123, + 1.9501700529181205, + 2.2392242031387943 + ], + "mean_throughput_tokens_per_sec": 1.9484553151459791, + "median_throughput_tokens_per_sec": 2.004864358363814, + "min_throughput_tokens_per_sec": 1.4687567015431287, + "max_throughput_tokens_per_sec": 2.358122932225276 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 5535.5, + "effective_keys_at_last_query_min": 4857, + "effective_keys_at_last_query_max": 6417, + "effective_keys_at_last_query_median": 5388.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 6417, + "effective_keys_at_last_query": 6417, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5418, + "effective_keys_at_last_query": 5418, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6177, + "effective_keys_at_last_query": 6177, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 4857, + "effective_keys_at_last_query": 4857, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6078, + "effective_keys_at_last_query": 6078, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5298, + "effective_keys_at_last_query": 5298, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5358, + "effective_keys_at_last_query": 5358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5198, + "effective_keys_at_last_query": 5198, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5497, + "effective_keys_at_last_query": 5497, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5057, + "effective_keys_at_last_query": 5057, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52631157248, + "current_reserved_bytes": 74801217536, + "peak_allocated_bytes": 69680971264, + "peak_reserved_bytes": 74801217536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52631157248, + "current_reserved_bytes": 66244837376, + "peak_allocated_bytes": 63333219840, + "peak_reserved_bytes": 66244837376 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx280_1781076342.json b/results/research/k3_integrated_niah_ctx280_1781076342.json new file mode 100644 index 00000000..6e3c6d5d --- /dev/null +++ b/results/research/k3_integrated_niah_ctx280_1781076342.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v3_attn_distill", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 6417, + 5418, + 6177, + 4857, + 6078, + 5298, + 5358, + 5198, + 5497, + 5057 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 14.563482611696235, + "median_latency_s": 14.795165188028477, + "per_sample_decoded": [ + "There is no secret code.\nthought\nThere is no secret code.thought\nThere", + "There is no secret code in the text provided.\nthought\nThere is no secret code in the text", + "There is no secret code.\n[instruction]\nThe user wants me to extract a \"secret code\" from", + "There is no secret code.edits\n", + "The secret code is 42.\nthought\nThe secret code is 42.", + "There is no secret code provided in the text.\nthought\nThere is no secret code provided in the", + "There is no secret code.\nthought\nThere is no secret code.\nthought\nThere", + "There is no secret code.\nthought\nThere is no secret code.\nthought\nThere", + "There is no secret code in the text provided.", + "The secret code is not provided in the text.\nthought\nThe secret code is not provided in" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 12, + 24, + 24, + 24, + 24, + 14, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.14942097630887, + 1.5802154299700242, + 1.252196973590312, + 1.8775123626557038, + 1.3580740592803073, + 1.6334039027748986, + 1.6110531501786365, + 1.6911834735050268, + 1.5477307357541588, + 1.7764574460473759 + ], + "mean_throughput_tokens_per_sec": 1.5477248510065313, + "median_throughput_tokens_per_sec": 1.5956342900743303, + "min_throughput_tokens_per_sec": 1.14942097630887, + "max_throughput_tokens_per_sec": 1.8775123626557038 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 11.526917663309723, + "median_latency_s": 11.566776791994926, + "per_sample_decoded": [ + "BETA-1409\n\nthought\nBETA-1409\n", + "The secret code is DELTA-3286.thought\nThe secret code is DELTA-", + "BETA-7912\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-4582\n\nThe secret code is BETA-4582.", + "KAPPA-1434\n\n\n\n", + "PINE-9928\n\n\n", + "THETA-5557\n\nThe secret code is THETA-5557", + "PINE-7924\n\nthought\nPINE-7924\n", + "GAMMA-4527\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-7224\n\nThe secret code is BETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 15, + 13, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.467854605766299, + 1.9837474386023999, + 1.563867600579857, + 2.354837759564395, + 1.707692367917557, + 2.0562460937815645, + 2.023706185169926, + 2.1287684033179115, + 1.9468546450809168, + 2.23985020775989 + ], + "mean_throughput_tokens_per_sec": 1.9473425307540715, + "median_throughput_tokens_per_sec": 2.003726811886163, + "min_throughput_tokens_per_sec": 1.467854605766299, + "max_throughput_tokens_per_sec": 2.354837759564395 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 5535.5, + "effective_keys_at_last_query_min": 4857, + "effective_keys_at_last_query_max": 6417, + "effective_keys_at_last_query_median": 5388.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 6417, + "effective_keys_at_last_query": 6417, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5418, + "effective_keys_at_last_query": 5418, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6177, + "effective_keys_at_last_query": 6177, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 4857, + "effective_keys_at_last_query": 4857, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6078, + "effective_keys_at_last_query": 6078, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5298, + "effective_keys_at_last_query": 5298, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5358, + "effective_keys_at_last_query": 5358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5198, + "effective_keys_at_last_query": 5198, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5497, + "effective_keys_at_last_query": 5497, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5057, + "effective_keys_at_last_query": 5057, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 75063361536, + "peak_allocated_bytes": 69932629504, + "peak_reserved_bytes": 75063361536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 66527952896, + "peak_allocated_bytes": 63584878080, + "peak_reserved_bytes": 66527952896 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx280_1781129939.json b/results/research/k3_integrated_niah_ctx280_1781129939.json new file mode 100644 index 00000000..86116807 --- /dev/null +++ b/results/research/k3_integrated_niah_ctx280_1781129939.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4a_warmstart_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 256, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 6417, + 5418, + 6177, + 4857, + 6078, + 5298, + 5358, + 5198, + 5497, + 5057 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 15.648320842999965, + "median_latency_s": 14.836793845519423, + "per_sample_decoded": [ + "The provided text does not contain a \"secret code.\" It appears to be a prompt designed to test the model's", + "The provided text does not contain a secret code. It appears to be a prompt designed to test the model's ability", + "The provided text does not contain a \"secret code.\" It appears to be a prompt designed to test the model's", + "I cannot provide the secret code as it was not provided in your prompt. However, based on the pattern of your message", + "The provided text does not contain a \"secret code.\" It appears to be a prompt designed to test the model's", + "The provided text does not contain a secret code. It appears to be a prompt designed to test the model's ability", + "The provided text does not contain a \"secret code.\" It appears to be a prompt designed to test the model's", + "The provided text does not contain a \"secret code.\" It appears to be a prompt designed to test the model's", + "I cannot provide the \"secret code\" as it was not provided in your prompt. However, based on the pattern of", + "I cannot answer that question because the text provided does not contain a \"secret code.\" It appears to be a prompt designed" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.1516221634360644, + 1.6030021372427365, + 1.2686161525220545, + 1.9106466027469282, + 1.373647136200167, + 1.6545386903360852, + 1.632466491731646, + 1.7178282937774716, + 1.576567784657136, + 1.8019848055293328 + ], + "mean_throughput_tokens_per_sec": 1.5690920258179621, + "median_throughput_tokens_per_sec": 1.6177343144871914, + "min_throughput_tokens_per_sec": 1.1516221634360644, + "max_throughput_tokens_per_sec": 1.9106466027469282 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 11.522616181732156, + "median_latency_s": 11.576587127055973, + "per_sample_decoded": [ + "BETA-1409\n\nthought\nBETA-1409\n", + "The secret code is DELTA-3286.thought\nThe secret code is DELTA-", + "BETA-7912\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-4582\n\nThe secret code is BETA-4582.", + "KAPPA-1434\n\n\n\n", + "PINE-9928\n\n\n", + "THETA-5557\n\nThe secret code is THETA-5557", + "PINE-7924\n\nthought\nPINE-7924\n", + "GAMMA-4527\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-7224\n\nThe secret code is BETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 15, + 13, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4682201908033004, + 1.985880099241351, + 1.567080677413877, + 2.3580853146306486, + 1.708136893685914, + 2.05324855215129, + 2.0220255590997196, + 2.126926447416288, + 1.9478445272701876, + 2.240006735978415 + ], + "mean_throughput_tokens_per_sec": 1.947745499769099, + "median_throughput_tokens_per_sec": 2.0039528291705353, + "min_throughput_tokens_per_sec": 1.4682201908033004, + "max_throughput_tokens_per_sec": 2.3580853146306486 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 5535.5, + "effective_keys_at_last_query_min": 4857, + "effective_keys_at_last_query_max": 6417, + "effective_keys_at_last_query_median": 5388.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 6417, + "effective_keys_at_last_query": 6417, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5418, + "effective_keys_at_last_query": 5418, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6177, + "effective_keys_at_last_query": 6177, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 4857, + "effective_keys_at_last_query": 4857, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6078, + "effective_keys_at_last_query": 6078, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5298, + "effective_keys_at_last_query": 5298, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5358, + "effective_keys_at_last_query": 5358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5198, + "effective_keys_at_last_query": 5198, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5497, + "effective_keys_at_last_query": 5497, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5057, + "effective_keys_at_last_query": 5057, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52631157248, + "current_reserved_bytes": 74801217536, + "peak_allocated_bytes": 69680971264, + "peak_reserved_bytes": 74801217536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52631157248, + "current_reserved_bytes": 66265808896, + "peak_allocated_bytes": 63333219840, + "peak_reserved_bytes": 66265808896 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx280_1781130321.json b/results/research/k3_integrated_niah_ctx280_1781130321.json new file mode 100644 index 00000000..4cf5f8df --- /dev/null +++ b/results/research/k3_integrated_niah_ctx280_1781130321.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4b_fresh_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 6417, + 5418, + 6177, + 4857, + 6078, + 5298, + 5358, + 5198, + 5497, + 5057 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 16.003178110602313, + "median_latency_s": 15.24593824299518, + "per_sample_decoded": [ + "The secret code is not explicitly provided in your prompt, as the text provided appears to be a template or a placeholder.", + "The secret code is not explicitly provided in your prompt, but based on the structure of your question, it appears to be", + "The secret code is not explicitly provided in your prompt, but based on the structure of your question, it appears you are", + "The secret code is not provided in your prompt, as the text provided is a template/example.---", + "The secret code is 300.\nthought\nThe secret code is 300.", + "The secret code is not explicitly provided in your prompt, but based on the structure of your question, it appears to be", + "The secret code is not explicitly provided in your prompt, as the text provided appears to be a template or a test string", + "The secret code is not explicitly provided in your prompt, but based on the structure of your question, it appears to be", + "The secret code is not explicitly provided in your prompt, as the text provided appears to be a template or a test string", + "The secret code is not explicitly provided in your prompt, but based on the structure of your question, it appears to be" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.142546750830522, + 1.5627995068619924, + 1.2432586993957289, + 1.8577431847919563, + 1.345858454907196, + 1.6119901317596121, + 1.5857472829139039, + 1.6758164168077347, + 1.539434995747024, + 1.754544401809637 + ], + "mean_throughput_tokens_per_sec": 1.5319739825825307, + "median_throughput_tokens_per_sec": 1.5742733948879482, + "min_throughput_tokens_per_sec": 1.142546750830522, + "max_throughput_tokens_per_sec": 1.8577431847919563 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 11.527656969381496, + "median_latency_s": 11.573496360972058, + "per_sample_decoded": [ + "BETA-1409\n\nthought\nBETA-1409\n", + "The secret code is DELTA-3286.thought\nThe secret code is DELTA-", + "BETA-7912\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-4582\n\nThe secret code is BETA-4582.", + "KAPPA-1434\n\n\n\n", + "PINE-9928\n\n\n", + "THETA-5557\n\nThe secret code is THETA-5557", + "PINE-7924\n\nthought\nPINE-7924\n", + "GAMMA-4527\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-7224\n\nThe secret code is BETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 15, + 13, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4678708392282118, + 1.9847416727172429, + 1.5656035999751874, + 2.3576435084777563, + 1.707854165802729, + 2.0523756010690235, + 2.0230769675412787, + 2.126928892944776, + 1.9450597336744353, + 2.2384141119547327 + ], + "mean_throughput_tokens_per_sec": 1.9469569093385375, + "median_throughput_tokens_per_sec": 2.003909320129261, + "min_throughput_tokens_per_sec": 1.4678708392282118, + "max_throughput_tokens_per_sec": 2.3576435084777563 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 5535.5, + "effective_keys_at_last_query_min": 4857, + "effective_keys_at_last_query_max": 6417, + "effective_keys_at_last_query_median": 5388.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 6417, + "effective_keys_at_last_query": 6417, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5418, + "effective_keys_at_last_query": 5418, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6177, + "effective_keys_at_last_query": 6177, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 4857, + "effective_keys_at_last_query": 4857, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6078, + "effective_keys_at_last_query": 6078, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5298, + "effective_keys_at_last_query": 5298, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5358, + "effective_keys_at_last_query": 5358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5198, + "effective_keys_at_last_query": 5198, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5497, + "effective_keys_at_last_query": 5497, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5057, + "effective_keys_at_last_query": 5057, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 75063361536, + "peak_allocated_bytes": 69932629504, + "peak_reserved_bytes": 75063361536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 66486009856, + "peak_allocated_bytes": 63584878080, + "peak_reserved_bytes": 66486009856 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx70_1781062484.json b/results/research/k3_integrated_niah_ctx70_1781062484.json new file mode 100644 index 00000000..f015d905 --- /dev/null +++ b/results/research/k3_integrated_niah_ctx70_1781062484.json @@ -0,0 +1,340 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v1", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 256, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 3.1370791353052483, + "median_latency_s": 3.055915425531566, + "per_sample_decoded": [ + "The answer is not provided in the text, but here is the requested sentence:\n\n**The quick brown fox jumps over", + "The following is a single sentence that contains the answer to your question: The secret code is \"The answer is hidden in", + "The student's request for a single sentence is met by the following:\n\n**The student's request for a", + "The answer is not provided in the text, but here is the requested sentence:\n\n**The quick brown fox jumps over", + "The following is a single sentence that contains the answer to your question: The secret code is \"The answer is hidden in", + "The secret code is not provided in the text, but the requested sentence is:\n\n**The quick brown fox jumps over", + "The answer is not provided in the text, but if you are looking for a summary of the text provided, it is", + "The following is a single sentence that contains the answer to your question: the secret code is \"the answer.\"", + "The answer is not provided in the text, but if you are looking for a summary of the text provided, it is", + "The student's request for a single sentence was met with a long, nonsensical preamble, but the answer is:" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 6.336640326986822, + 7.958648931952716, + 8.317171521066571, + 7.048946141903309, + 8.439155600567629, + 7.1006956864268576, + 7.5198063027824205, + 8.665990006756195, + 7.751328013001251, + 8.010016953397578 + ], + "mean_throughput_tokens_per_sec": 7.714839948484135, + "median_throughput_tokens_per_sec": 7.854988472476983, + "min_throughput_tokens_per_sec": 6.336640326986822, + "max_throughput_tokens_per_sec": 8.665990006756195 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 2.3139534484944306, + "median_latency_s": 2.2834689265000634, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 8.9242195402743, + 10.455683388186033, + 10.906713299860224, + 9.114937460967408, + 10.989454909727511, + 9.047302167485174, + 9.738628508248345, + 11.361974195992733, + 9.777514311472999, + 10.56554024618647 + ], + "mean_throughput_tokens_per_sec": 10.08819680284012, + "median_throughput_tokens_per_sec": 10.116598849829515, + "min_throughput_tokens_per_sec": 8.9242195402743, + "max_throughput_tokens_per_sec": 11.361974195992733 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56975425536, + "peak_allocated_bytes": 55998552064, + "peak_reserved_bytes": 56975425536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56178507776, + "peak_allocated_bytes": 55250622464, + "peak_reserved_bytes": 56178507776 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx70_1781076342.json b/results/research/k3_integrated_niah_ctx70_1781076342.json new file mode 100644 index 00000000..79f51b22 --- /dev/null +++ b/results/research/k3_integrated_niah_ctx70_1781076342.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v3_attn_distill", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 2.517939622409176, + "median_latency_s": 2.6519388455199078, + "per_sample_decoded": [ + "There is no secret code provided in the text.", + "The provided text does not contain a secret code.", + "There is no secret code provided in the text.\nthought\nThere is no secret code provided in the", + "The provided text does not contain a secret code.", + "There is no secret code provided in the text.", + "The provided text does not contain a secret code.\n\n", + "There is no secret code provided in the text.", + "The provided text does not contain a secret code.\n\n[instruction]\nThe user wants me to find", + "The provided text does not contain a secret code.\nthought\nThe provided text does not contain", + "I do not know the secret code.\nthought\nI do not know the secret code." + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 14, + 14, + 24, + 14, + 14, + 20, + 14, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 5.736363681116686, + 7.716042599545884, + 7.9512812752174655, + 6.722051725289905, + 7.934755345152868, + 6.78769157640746, + 7.210029225580785, + 8.381915886677966, + 7.411872786128763, + 7.819242974306768 + ], + "mean_throughput_tokens_per_sec": 7.367124707542454, + "median_throughput_tokens_per_sec": 7.563957692837324, + "min_throughput_tokens_per_sec": 5.736363681116686, + "max_throughput_tokens_per_sec": 8.381915886677966 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 2.3246362071717157, + "median_latency_s": 2.303811051941011, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 8.844667197447956, + 10.288136253081014, + 10.902090637350318, + 9.067728829991385, + 10.895934033945267, + 9.046133853405836, + 9.73770945624712, + 11.333942908827085, + 9.760995855287959, + 10.550201496616841 + ], + "mean_throughput_tokens_per_sec": 10.04275405222008, + "median_throughput_tokens_per_sec": 10.024566054184486, + "min_throughput_tokens_per_sec": 8.844667197447956, + "max_throughput_tokens_per_sec": 11.333942908827085 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882488320, + "current_reserved_bytes": 57216598016, + "peak_allocated_bytes": 56229919744, + "peak_reserved_bytes": 57216598016 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882488320, + "current_reserved_bytes": 56440651776, + "peak_allocated_bytes": 55502280704, + "peak_reserved_bytes": 56440651776 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx70_1781129939.json b/results/research/k3_integrated_niah_ctx70_1781129939.json new file mode 100644 index 00000000..282cfdb5 --- /dev/null +++ b/results/research/k3_integrated_niah_ctx70_1781129939.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4a_warmstart_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 256, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 3.3142378951073623, + "median_latency_s": 3.2595019129803404, + "per_sample_decoded": [ + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructions, a pattern of numbers,", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing a series of instructions, a", + "The provided text does not contain a secret code; it appears to be a prompt containing various instructions and examples.", + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructional text and a placeholder question." + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 6.158678376134239, + 7.10090388473375, + 7.424078257868762, + 6.236516091894493, + 8.124545174588512, + 6.915023123806111, + 7.303092019892089, + 8.403624325722832, + 7.616933818115398, + 7.825795769741083 + ], + "mean_throughput_tokens_per_sec": 7.310919084249727, + "median_throughput_tokens_per_sec": 7.363585138880426, + "min_throughput_tokens_per_sec": 6.158678376134239, + "max_throughput_tokens_per_sec": 8.403624325722832 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 2.319360340514686, + "median_latency_s": 2.291858280019369, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 8.924847071158883, + 10.420643598378328, + 10.88049672145145, + 9.116012437745775, + 10.984900257247734, + 9.014853564479848, + 9.707780285239126, + 11.31113673466426, + 9.749458559937297, + 10.523565547932998 + ], + "mean_throughput_tokens_per_sec": 10.06336947782357, + "median_throughput_tokens_per_sec": 10.085051079157813, + "min_throughput_tokens_per_sec": 8.924847071158883, + "max_throughput_tokens_per_sec": 11.31113673466426 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56975425536, + "peak_allocated_bytes": 55998552064, + "peak_reserved_bytes": 56975425536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56178507776, + "peak_allocated_bytes": 55250622464, + "peak_reserved_bytes": 56178507776 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_ctx70_1781130321.json b/results/research/k3_integrated_niah_ctx70_1781130321.json new file mode 100644 index 00000000..0cbf28d7 --- /dev/null +++ b/results/research/k3_integrated_niah_ctx70_1781130321.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4b_fresh_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 3.3216539990971796, + "median_latency_s": 3.226994057011325, + "per_sample_decoded": [ + "The secret code is not explicitly provided in your text, but based on the context of your prompt, it appears you are", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing placeholder text and instructions for an", + "The provided text does not contain a \"secret code.\" It appears to be a template or a placeholder text used for instructional", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructions and a placeholder question.", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructional text and placeholder content.", + "The secret code is not explicitly provided in your prompt, but based on the text provided, the \"answer\" or \"", + "The secret code is not explicitly provided in your text, but based on the context of your prompt, it appears you are", + "The secret code is **42** (based on the context of the prompt's structure, though the prompt itself", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructional text and placeholder content.", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructions and a placeholder question." + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 5.967912739780349, + 7.542007157797493, + 7.774064418142494, + 6.6451698412390545, + 7.9202519990829945, + 6.711000809456378, + 7.095308937422936, + 8.183627189414267, + 7.3353862735672095, + 7.690864422644549 + ], + "mean_throughput_tokens_per_sec": 7.286559378854773, + "median_throughput_tokens_per_sec": 7.438696715682351, + "min_throughput_tokens_per_sec": 5.967912739780349, + "max_throughput_tokens_per_sec": 8.183627189414267 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 2.32026572660543, + "median_latency_s": 2.2918679150170647, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 8.913799988953627, + 10.414824133191882, + 10.894933512247459, + 9.097945921139678, + 11.003153865703663, + 9.014252302265898, + 9.699037559338668, + 11.289935392668216, + 9.745713526999928, + 10.529418185512096 + ], + "mean_throughput_tokens_per_sec": 10.06030143880211, + "median_throughput_tokens_per_sec": 10.080268830095905, + "min_throughput_tokens_per_sec": 8.913799988953627, + "max_throughput_tokens_per_sec": 11.289935392668216 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882488320, + "current_reserved_bytes": 57237569536, + "peak_allocated_bytes": 56250210304, + "peak_reserved_bytes": 57237569536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882488320, + "current_reserved_bytes": 56440651776, + "peak_allocated_bytes": 55502280704, + "peak_reserved_bytes": 56440651776 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_v4a.json b/results/research/k3_integrated_niah_v4a.json new file mode 100644 index 00000000..ff964fec --- /dev/null +++ b/results/research/k3_integrated_niah_v4a.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4a_warmstart_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 256, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 80, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 3.27170688129263, + "median_latency_s": 3.1429821790661663, + "per_sample_decoded": [ + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructions, a pattern of numbers,", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing a series of instructions, a", + "The provided text does not contain a secret code; it appears to be a prompt containing various instructions and examples.", + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code; it appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructional text and a placeholder question.", + "The provided text does not contain a secret code. It appears to be a prompt containing instructional text and a placeholder question." + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 5.534455487584129, + 7.657139320000613, + 8.03883354693255, + 6.856273232887896, + 8.164993140820844, + 6.914881927705856, + 7.312892354600498, + 8.426305691157888, + 7.615096104459873, + 7.821586035301249 + ], + "mean_throughput_tokens_per_sec": 7.434245684145139, + "median_throughput_tokens_per_sec": 7.636117712230243, + "min_throughput_tokens_per_sec": 5.534455487584129, + "max_throughput_tokens_per_sec": 8.426305691157888 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 2.3213504247833043, + "median_latency_s": 2.292157870484516, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 8.909067309201422, + 10.422277058980027, + 10.866733832828013, + 9.108979985715147, + 10.98298312150047, + 9.01360558246596, + 9.692341809489132, + 11.29396081271051, + 9.73879046002474, + 10.519136930745425 + ], + "mean_throughput_tokens_per_sec": 10.054787690366085, + "median_throughput_tokens_per_sec": 10.080533759502384, + "min_throughput_tokens_per_sec": 8.909067309201422, + "max_throughput_tokens_per_sec": 11.29396081271051 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56975425536, + "peak_allocated_bytes": 55998552064, + "peak_reserved_bytes": 56975425536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52630830080, + "current_reserved_bytes": 56178507776, + "peak_allocated_bytes": 55250622464, + "peak_reserved_bytes": 56178507776 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_integrated_niah_v4b.json b/results/research/k3_integrated_niah_v4b.json new file mode 100644 index 00000000..8fd75ce1 --- /dev/null +++ b/results/research/k3_integrated_niah_v4b.json @@ -0,0 +1,341 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4b_fresh_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 80, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "prompt_token_lens": [ + 1637, + 1378, + 1299, + 1597, + 1278, + 1617, + 1498, + 1238, + 1498, + 1358 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 3.3272163896705025, + "median_latency_s": 3.2052694669691846, + "per_sample_decoded": [ + "The secret code is not explicitly provided in your text, but based on the context of your prompt, it appears you are", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing placeholder text and instructions for an", + "The provided text does not contain a \"secret code.\" It appears to be a template or a placeholder text used for instructional", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructions and a placeholder question.", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructional text and placeholder content.", + "The secret code is not explicitly provided in your prompt, but based on the text provided, the \"answer\" or \"", + "The secret code is not explicitly provided in your text, but based on the context of your prompt, it appears you are", + "The secret code is **42** (based on the context of the prompt's structure, though the prompt itself", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructional text and placeholder content.", + "The provided text does not contain a \"secret code.\" It appears to be a prompt containing instructions and a placeholder question." + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 5.8353472978979894, + 7.63703747067071, + 7.831732652604615, + 6.6790883907465135, + 7.898515115029887, + 6.688526857201077, + 7.110387029376355, + 8.14582362380606, + 7.361297575366125, + 7.618457099251185 + ], + "mean_throughput_tokens_per_sec": 7.280621311195051, + "median_throughput_tokens_per_sec": 7.489877337308656, + "min_throughput_tokens_per_sec": 5.8353472978979894, + "max_throughput_tokens_per_sec": 8.14582362380606 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 2.3210248557967135, + "median_latency_s": 2.2930795179563574, + "per_sample_decoded": [ + "BETA-1409\n//thought\nBETA-1409", + "DELTA-3286\n\nthought\nDELTA-3286\n", + "ORCHID-9935\n10. **Final Answer:** ORCHID-99", + "ETA-1520\n\n\nthought\nthought\nETA-15", + "DELTA-4811\n\n\u66f8\u304d\u8fbc\u307f\u5185\u5bb9\u306b\u57fa\u3065\u304f\u3068\u3001\u79d8\u5bc6\u306e\u30b3\u30fc\u30c9\u306f **DELTA-4", + "ALPHA-4257\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n \n\n\n", + "DELTA-8359\nnormative_thought\nDELTA-8359", + "MAPLE-3615\n\u092c\u52a9\u3051\u304c\u5fc5\u8981\u306a\u5834\u5408\u306f\u3001\u3044\u3064\u3067\u3082\u304a\u77e5\u3089\u305b\u304f\u3060\u3055\u3044\u3002\n", + "ZETA-5552\n\n\n", + "MAPLE-6514\n\ninclude_thought\nMAPLE-6514" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 16, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 8.915089090853835, + 10.413994555253957, + 10.868327694836768, + 9.091441563797241, + 11.002728292634039, + 9.007467950750152, + 9.700031834310884, + 11.319432302889112, + 9.733968493712515, + 10.5190822780884 + ], + "mean_throughput_tokens_per_sec": 10.057156405712691, + "median_throughput_tokens_per_sec": 10.073981524483237, + "min_throughput_tokens_per_sec": 8.915089090853835, + "max_throughput_tokens_per_sec": 11.319432302889112 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 1439.8, + "effective_keys_at_last_query_min": 1238, + "effective_keys_at_last_query_max": 1637, + "effective_keys_at_last_query_median": 1438.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 1637, + "effective_keys_at_last_query": 1637, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1378, + "effective_keys_at_last_query": 1378, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1299, + "effective_keys_at_last_query": 1299, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1597, + "effective_keys_at_last_query": 1597, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1278, + "effective_keys_at_last_query": 1278, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1617, + "effective_keys_at_last_query": 1617, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1238, + "effective_keys_at_last_query": 1238, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1498, + "effective_keys_at_last_query": 1498, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 1358, + "effective_keys_at_last_query": 1358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882488320, + "current_reserved_bytes": 57237569536, + "peak_allocated_bytes": 56250210304, + "peak_reserved_bytes": 57237569536 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882488320, + "current_reserved_bytes": 56440651776, + "peak_allocated_bytes": 55502280704, + "peak_reserved_bytes": 56440651776 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_s5_kl_niah_ctx70_mac.json b/results/research/k3_s5_kl_niah_ctx70_mac.json new file mode 100644 index 00000000..2c103a26 --- /dev/null +++ b/results/research/k3_s5_kl_niah_ctx70_mac.json @@ -0,0 +1,509 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 12, + "seed": 42, + "eval_mode": "teacher_forced", + "free_generation": false, + "s5_exact_full_attn": true, + "identity_restore": false, + "compress_full_attn": true, + "kl_lattice": "D4", + "kl_q_range": 38, + "kl_bits_per_token_per_head": 3232.0, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1639, + 1380, + 1301, + 1599, + 1280, + 1619, + 1500, + 1240, + 1500, + 1360 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 70.89755283326376, + "median_latency_s": 68.05938027054071, + "per_sample_decoded": [ + "<|channel>ETA-1409", + "<|channel>-3286", + "<|channel>CHID-9935", + "<|channel>-1520", + "<|channel>-4811", + "<|channel>-4257", + "<|channel>-8359", + "<|channel>LE-3615", + "<|channel>ETA-5552", + "<|channel>LE-6514" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 7, + 6, + 8, + 6, + 6, + 6, + 6, + 7, + 7, + 7 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.08387508852115706, + 0.06706377930919796, + 0.12686183519569158, + 0.09299567815023022, + 0.11067840526101776, + 0.10049817440520767, + 0.14145675424915535, + 0.0871139344999729, + 0.06986899281441022, + 0.09776586396279 + ], + "mean_throughput_tokens_per_sec": 0.09781785063688307, + "median_throughput_tokens_per_sec": 0.0953807710565101, + "min_throughput_tokens_per_sec": 0.06706377930919796, + "max_throughput_tokens_per_sec": 0.14145675424915535 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 10, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 26.936620991583915, + "median_latency_s": 20.539821645943448, + "per_sample_decoded": [ + "<|channel>ETA-1409", + "<|channel>-3286", + "<|channel>CHID-9935", + "<|channel>-1520", + "<|channel>-4811", + "<|channel>-4257", + "<|channel>-8359", + "<|channel>LE-3615", + "<|channel>ETA-5552", + "<|channel>LE-6514" + ], + "per_sample_correct": [ + false, + false, + false, + false, + false, + false, + false, + false, + false, + false + ], + "per_sample_decode_tokens": [ + 7, + 6, + 8, + 6, + 6, + 6, + 6, + 7, + 7, + 7 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.39918706125231773, + 0.10825221220240881, + 0.13477552930321893, + 0.49229455940587685, + 0.16196820378112536, + 0.23043185194980292, + 0.26949747058883566, + 0.6550377528729731, + 0.6992835344410542, + 0.372024135759359 + ], + "mean_throughput_tokens_per_sec": 0.35227523115569725, + "median_throughput_tokens_per_sec": 0.32076080317409733, + "min_throughput_tokens_per_sec": 0.10825221220240881, + "max_throughput_tokens_per_sec": 0.6992835344410542 + } + }, + "gate": { + "recall_cross_model": 0.0, + "recall_oracle": 0.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + }, + "memory": { + "s5": { + "seq_len": 1639, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": 3232.0, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 13243120, + "total_resident_bytes": 27169520, + "total_resident_mb": 27.17, + "per_token_growth_bytes": 8080, + "per_token_growth_kb": 7.89, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 1616, + "resident_bytes": 2648624 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 1616, + "resident_bytes": 2648624 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 1616, + "resident_bytes": 2648624 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 1616, + "resident_bytes": 2648624 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1639, + "bytes_per_token": 1616, + "resident_bytes": 2648624 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 369.23, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 92.6 + }, + "throughput": { + "k3_cross_model": { + "tokens": 66, + "wall_seconds": 708.976, + "tokens_per_second": 0.0931, + "mean_latency_per_sample_s": 70.898, + "eval_mode": "teacher_forced", + "restored_forwards_per_sample": 1 + } + } +} \ No newline at end of file diff --git a/results/research/k3_s5_niah_ctx280_v4b.json b/results/research/k3_s5_niah_ctx280_v4b.json new file mode 100644 index 00000000..cceffee5 --- /dev/null +++ b/results/research/k3_s5_niah_ctx280_v4b.json @@ -0,0 +1,349 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v4b_fresh_hybrid", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "s5_exact_full_attn": true, + "s5_exact_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 6417, + 5418, + 6177, + 4857, + 6078, + 5298, + 5358, + 5198, + 5497, + 5057 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 27.283239149395378, + "median_latency_s": 26.6922971814638, + "per_sample_decoded": [ + "The secret code is BETA-1409.\nthought\nThe secret code is BETA-1", + "The secret code is **DELTA-3286**.\nthought\nThe secret code is **DELTA", + "The secret code is BETA-7912.", + "The secret code is **BETA-4582**.\nthought\nThe secret code is **", + "The secret code is KAPPA-1434.\nthought\nThe secret code is KAPPA", + "The secret code is PINE-9928.\nthought\nThe secret code is PINE", + "The secret code is THETA-5557.\nthought\nThe secret code is THETA", + "The secret code is **PINE-7924**.\nthought\nThe secret code is **", + "The secret code is **GAMMA-4527**.\nthought\nThe secret code is **GAMMA", + "BETA-7224\nthought\nBETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 15, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.6434404414499331, + 0.8747554736456215, + 0.686778007678562, + 1.0406105031926975, + 0.7505387230661207, + 0.9050305389795371, + 0.8933173540360834, + 0.9372392356620551, + 0.8617929431105361, + 0.984635767483979 + ], + "mean_throughput_tokens_per_sec": 0.8578138988305126, + "median_throughput_tokens_per_sec": 0.8840364138408524, + "min_throughput_tokens_per_sec": 0.6434404414499331, + "max_throughput_tokens_per_sec": 1.0406105031926975 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 11.518795565795154, + "median_latency_s": 11.5701638009632, + "per_sample_decoded": [ + "BETA-1409\n\nthought\nBETA-1409\n", + "The secret code is DELTA-3286.thought\nThe secret code is DELTA-", + "BETA-7912\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-4582\n\nThe secret code is BETA-4582.", + "KAPPA-1434\n\n\n\n", + "PINE-9928\n\n\n", + "THETA-5557\n\nThe secret code is THETA-5557", + "PINE-7924\n\nthought\nPINE-7924\n", + "GAMMA-4527\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-7224\n\nThe secret code is BETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 15, + 13, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4695448763525356, + 1.9862464457106155, + 1.5677651753169861, + 2.356162436075314, + 1.7102381277088805, + 2.0540986877745913, + 2.0225585336395846, + 2.1287599680475906, + 1.9472104969174977, + 2.240135803560332 + ], + "mean_throughput_tokens_per_sec": 1.9482720551103927, + "median_throughput_tokens_per_sec": 2.0044024896751003, + "min_throughput_tokens_per_sec": 1.4695448763525356, + "max_throughput_tokens_per_sec": 2.356162436075314 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 5535.5, + "effective_keys_at_last_query_min": 4857, + "effective_keys_at_last_query_max": 6417, + "effective_keys_at_last_query_median": 5388.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 6417, + "effective_keys_at_last_query": 6417, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5418, + "effective_keys_at_last_query": 5418, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6177, + "effective_keys_at_last_query": 6177, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 4857, + "effective_keys_at_last_query": 4857, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6078, + "effective_keys_at_last_query": 6078, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5298, + "effective_keys_at_last_query": 5298, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5358, + "effective_keys_at_last_query": 5358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5198, + "effective_keys_at_last_query": 5198, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5497, + "effective_keys_at_last_query": 5497, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5057, + "effective_keys_at_last_query": 5057, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 76279709696, + "peak_allocated_bytes": 69932629504, + "peak_reserved_bytes": 76279709696 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 66527952896, + "peak_allocated_bytes": 63584878080, + "peak_reserved_bytes": 66527952896 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 1.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + } +} \ No newline at end of file diff --git a/results/research/k3_s5_niah_ctx280_v5.json b/results/research/k3_s5_niah_ctx280_v5.json new file mode 100644 index 00000000..7158cd09 --- /dev/null +++ b/results/research/k3_s5_niah_ctx280_v5.json @@ -0,0 +1,349 @@ +{ + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "f_theta_config": { + "drafter_num_layers": 5, + "drafter_num_kv_heads": 8, + "drafter_head_dim": 128, + "verifier_num_layers": 30, + "verifier_num_kv_heads": 8, + "verifier_head_dim": 256, + "rank": 768, + "verifier_layer_kv_heads": [ + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2, + 8, + 8, + 8, + 8, + 8, + 2 + ], + "verifier_layer_head_dims": [ + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512, + 256, + 256, + 256, + 256, + 256, + 512 + ] + }, + "n_samples": 10, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 238, + "haystack_max_lines": 322, + "max_new_tokens": 24, + "seed": 42, + "skip_oracle": false, + "identity_restore": false, + "s5_exact_full_attn": true, + "s5_exact_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 6417, + 5418, + 6177, + 4857, + 6078, + 5298, + 5358, + 5198, + 5497, + 5057 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 28.56916334920097, + "median_latency_s": 27.102085192978848, + "per_sample_decoded": [ + "BETA-1409thought\nThe secret code is **BETA-1409", + "The secret code is **DELTA-3286**.thought\nThe secret code is **DELTA-", + "BETA-7912thought\nThe secret code is **BETA-7912", + "The secret code is **BETA-4582**.---\n**Note 0000", + "KAPPA-1434thought\nThe secret code is **KAPPA-1434", + "The secret code is **PINE-9928**.thought\nThe secret code is **P", + "THETA-5557thought\nThe secret code is **THETA-5557", + "The secret code is **PINE-7924**.---\n**Note 0000", + "The secret code is **GAMMA-4527**.thought\nThe secret code is **GAMMA-", + "The secret code is **BETA-7224**.---\n**Note 0000" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.6406166012672451, + 0.876105903111179, + 0.6942713547020577, + 1.0404998964149432, + 0.7511945238497238, + 0.9062251508588193, + 0.8951809541869689, + 0.9421073620708614, + 0.8504310807457415, + 0.9892533991573305 + ], + "mean_throughput_tokens_per_sec": 0.858588622636487, + "median_throughput_tokens_per_sec": 0.885643428649074, + "min_throughput_tokens_per_sec": 0.6406166012672451, + "max_throughput_tokens_per_sec": 1.0404998964149432 + }, + "oracle": { + "name": "oracle", + "samples_total": 10, + "samples_correct": 10, + "recall": 1.0, + "mean_latency_s": 11.529694065195509, + "median_latency_s": 11.582987871486694, + "per_sample_decoded": [ + "BETA-1409\n\nthought\nBETA-1409\n", + "The secret code is DELTA-3286.thought\nThe secret code is DELTA-", + "BETA-7912\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-4582\n\nThe secret code is BETA-4582.", + "KAPPA-1434\n\n\n\n", + "PINE-9928\n\n\n", + "THETA-5557\n\nThe secret code is THETA-5557", + "PINE-7924\n\nthought\nPINE-7924\n", + "GAMMA-4527\nDESCRIBE THE PROCESS OF CREATING A CUSTOMIZED LEARNING PATH FOR", + "BETA-7224\n\nThe secret code is BETA-7224" + ], + "per_sample_correct": [ + true, + true, + true, + true, + true, + true, + true, + true, + true, + true + ], + "per_sample_decode_tokens": [ + 24, + 24, + 24, + 24, + 15, + 13, + 24, + 24, + 24, + 24 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.4681406567048605, + 1.9834711048788969, + 1.5669723192672609, + 2.355862777298239, + 1.7074733703163254, + 2.053136081258618, + 2.020997185464885, + 2.1256526508211393, + 1.9447056674307321, + 2.238313235986172 + ], + "mean_throughput_tokens_per_sec": 1.946472504942713, + "median_throughput_tokens_per_sec": 2.002234145171891, + "min_throughput_tokens_per_sec": 1.4681406567048605, + "max_throughput_tokens_per_sec": 2.355862777298239 + } + }, + "attention_window": { + "per_config": { + "k3_cross_model": { + "config": "v04_dlm_restored", + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)", + "samples_total": 10, + "effective_keys_at_last_query_mean": 5535.5, + "effective_keys_at_last_query_min": 4857, + "effective_keys_at_last_query_max": 6417, + "effective_keys_at_last_query_median": 5388.0, + "effective_attention_fraction_mean": 1.0, + "effective_attention_fraction_min": 1.0, + "effective_attention_fraction_max": 1.0, + "effective_attention_fraction_median": 1.0, + "per_sample": [ + { + "config": "v04_dlm_restored", + "seq_len": 6417, + "effective_keys_at_last_query": 6417, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5418, + "effective_keys_at_last_query": 5418, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6177, + "effective_keys_at_last_query": 6177, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 4857, + "effective_keys_at_last_query": 4857, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 6078, + "effective_keys_at_last_query": 6078, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5298, + "effective_keys_at_last_query": 5298, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5358, + "effective_keys_at_last_query": 5358, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5198, + "effective_keys_at_last_query": 5198, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5497, + "effective_keys_at_last_query": 5497, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + }, + { + "config": "v04_dlm_restored", + "seq_len": 5057, + "effective_keys_at_last_query": 5057, + "effective_attention_fraction": 1.0, + "structural_constraint": "causal_with_dlm_reconstruction (local_cache=sink=4+window=64)" + } + ] + } + } + }, + "memory": { + "k3_cross_model": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 76300681216, + "peak_allocated_bytes": 69932629504, + "peak_reserved_bytes": 76300681216 + }, + "oracle": { + "device_kind": "cuda", + "device_name": "NVIDIA H200", + "device_total_bytes": 150109880320, + "current_allocated_bytes": 52882815488, + "current_reserved_bytes": 66486009856, + "peak_allocated_bytes": 63584878080, + "peak_reserved_bytes": 66486009856 + } + }, + "gate": { + "architectural_correctness": true, + "recall_cross_model": 1.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 0.0, + "recall_delta_within_5pp": true + } +} \ No newline at end of file diff --git a/results/research/k3_s5_niah_mac_step1_diag.json b/results/research/k3_s5_niah_mac_step1_diag.json new file mode 100644 index 00000000..97df29ee --- /dev/null +++ b/results/research/k3_s5_niah_mac_step1_diag.json @@ -0,0 +1,60 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 1, + "seed": 42, + "s5_exact_full_attn": true, + "identity_restore": false, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1639 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 111.98745520808734, + "median_latency_s": 111.98745520808734, + "per_sample_decoded": [ + "<|channel>" + ], + "per_sample_correct": [ + false + ], + "per_sample_decode_tokens": [ + 1 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.008929571603728911 + ], + "mean_throughput_tokens_per_sec": 0.008929571603728911, + "median_throughput_tokens_per_sec": 0.008929571603728911, + "min_throughput_tokens_per_sec": 0.008929571603728911, + "max_throughput_tokens_per_sec": 0.008929571603728911 + } + }, + "gate": { + "recall_cross_model": 0.0, + "recall_oracle": null, + "recall_delta_vs_oracle_pp": null, + "recall_delta_within_5pp": false + } +} \ No newline at end of file diff --git a/results/research/k3_specdecode_fused.json b/results/research/k3_specdecode_fused.json new file mode 100644 index 00000000..25e6ea18 --- /dev/null +++ b/results/research/k3_specdecode_fused.json @@ -0,0 +1,499 @@ +{ + "kind": "k3_specdecode_gpu_bench", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "haystack_lines": 60, + "n_samples": 5, + "max_new_tokens": 64, + "block_size": 15, + "sink": 4, + "window": 64, + "seed": 0, + "skip_unfused": true, + "output": "results/research/k3_specdecode_fused.json" + }, + "env": { + "gpu": "NVIDIA H200", + "torch": "2.11.0+cu128" + }, + "prompt_tokens": { + "min": 1238, + "max": 1238 + }, + "ar_incremental": { + "decode_tokens_per_s_mean": 21.155, + "recall": 1.0 + }, + "restored_pertoken": { + "decode_tokens_per_s_mean": 21.902, + "recall": 1.0 + }, + "restored_specdecode": { + "skipped": true, + "decode_tokens_per_s_mean": null, + "mean_accept_len": 0.0, + "recall": 0.0, + "per_sample": [ + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + } + ] + }, + "restored_specdecode_fused": { + "decode_tokens_per_s_mean": 16.559, + "mean_accept_len": 4.02, + "time_breakdown_s_mean": { + "drafter_cached": 2.924, + "incremental_verify": 1.254, + "ctx_kv_extend": 0.023 + }, + "recall": 1.0, + "per_sample": [ + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618 + ], + "decode_s": 5.6331069769803435, + "prefill_s": 0.281, + "decode_tokens_per_s": 11.361, + "time_breakdown_s": { + "drafter_cached": 4.424, + "incremental_verify": 1.183, + "ctx_kv_extend": 0.026 + }, + "blocks": 12, + "mean_accept_len": 4.42, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 45518, + 107, + 101, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 106, + 106, + 45518, + 107, + 101, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 106, + 106, + 100, + 45518, + 107, + 101 + ], + "decode_s": 4.643828391912393, + "prefill_s": 0.282, + "decode_tokens_per_s": 13.782, + "time_breakdown_s": { + "drafter_cached": 3.289, + "incremental_verify": 1.332, + "ctx_kv_extend": 0.023 + }, + "blocks": 14, + "mean_accept_len": 3.64, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213 + ], + "decode_s": 2.7147572399117053, + "prefill_s": 0.28, + "decode_tokens_per_s": 23.575, + "time_breakdown_s": { + "drafter_cached": 1.478, + "incremental_verify": 1.215, + "ctx_kv_extend": 0.022 + }, + "blocks": 13, + "mean_accept_len": 4.0, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770, + 84750, + 106, + 106, + 107, + 106, + 101, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770, + 84750, + 106, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828 + ], + "decode_s": 5.006628326955251, + "prefill_s": 0.28, + "decode_tokens_per_s": 12.783, + "time_breakdown_s": { + "drafter_cached": 3.565, + "incremental_verify": 1.417, + "ctx_kv_extend": 0.024 + }, + "blocks": 14, + "mean_accept_len": 3.64, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236778, + 236810, + 236810, + 236800, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236778, + 236810, + 236810, + 236800, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236778, + 236810, + 236810, + 236800, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487 + ], + "decode_s": 3.0056109990691766, + "prefill_s": 0.281, + "decode_tokens_per_s": 21.294, + "time_breakdown_s": { + "drafter_cached": 1.864, + "incremental_verify": 1.122, + "ctx_kv_extend": 0.019 + }, + "blocks": 12, + "mean_accept_len": 4.42, + "decode_tokens": 64 + } + ], + "speedup_over_ar_x": 0.78 + } +} \ No newline at end of file diff --git a/results/research/k3_specdecode_fused_stable.json b/results/research/k3_specdecode_fused_stable.json new file mode 100644 index 00000000..a732f248 --- /dev/null +++ b/results/research/k3_specdecode_fused_stable.json @@ -0,0 +1,499 @@ +{ + "kind": "k3_specdecode_gpu_bench", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "haystack_lines": 60, + "n_samples": 5, + "max_new_tokens": 64, + "block_size": 15, + "sink": 4, + "window": 64, + "seed": 0, + "skip_unfused": true, + "output": "results/research/k3_specdecode_fused_stable.json" + }, + "env": { + "gpu": "NVIDIA H200", + "torch": "2.11.0+cu128" + }, + "prompt_tokens": { + "min": 1238, + "max": 1238 + }, + "ar_incremental": { + "decode_tokens_per_s_mean": 21.138, + "recall": 1.0 + }, + "restored_pertoken": { + "decode_tokens_per_s_mean": 21.093, + "recall": 1.0 + }, + "restored_specdecode": { + "skipped": true, + "decode_tokens_per_s_mean": null, + "mean_accept_len": 0.0, + "recall": 0.0, + "per_sample": [ + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + }, + { + "decode_tokens_per_s": null, + "mean_accept_len": 0.0, + "time_breakdown_s": { + "aux_clean_forward": 0.0, + "drafter": 0.0, + "incremental_verify": 0.0 + }, + "tokens": [] + } + ] + }, + "restored_specdecode_fused": { + "decode_tokens_per_s_mean": 26.753, + "mean_accept_len": 4.27, + "time_breakdown_s_mean": { + "drafter_cached": 1.572, + "incremental_verify": 1.223, + "ctx_kv_extend": 0.022 + }, + "recall": 1.0, + "per_sample": [ + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618 + ], + "decode_s": 1.2425236969720572, + "prefill_s": 0.283, + "decode_tokens_per_s": 51.508, + "time_breakdown_s": { + "drafter_cached": 0.051, + "incremental_verify": 1.17, + "ctx_kv_extend": 0.021 + }, + "blocks": 12, + "mean_accept_len": 4.42, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 45518, + 107, + 101, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 106, + 106, + 45518, + 107, + 101, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 106, + 106, + 45518, + 107, + 101, + 236777 + ], + "decode_s": 4.286014890065417, + "prefill_s": 0.282, + "decode_tokens_per_s": 14.932, + "time_breakdown_s": { + "drafter_cached": 2.971, + "incremental_verify": 1.288, + "ctx_kv_extend": 0.026 + }, + "blocks": 13, + "mean_accept_len": 4.0, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213 + ], + "decode_s": 2.777758297044784, + "prefill_s": 0.282, + "decode_tokens_per_s": 23.04, + "time_breakdown_s": { + "drafter_cached": 1.478, + "incremental_verify": 1.278, + "ctx_kv_extend": 0.021 + }, + "blocks": 13, + "mean_accept_len": 4.0, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770, + 84750, + 106, + 106, + 107, + 106, + 101, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770, + 84750, + 106, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236800, + 236778, + 236828, + 236770 + ], + "decode_s": 2.809437029995024, + "prefill_s": 0.28, + "decode_tokens_per_s": 22.78, + "time_breakdown_s": { + "drafter_cached": 1.5, + "incremental_verify": 1.286, + "ctx_kv_extend": 0.023 + }, + "blocks": 13, + "mean_accept_len": 4.0, + "decode_tokens": 64 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236778, + 236810, + 236810, + 236800, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236778, + 236810, + 236810, + 236800, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236778, + 236810, + 236810, + 236800, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487 + ], + "decode_s": 2.9762936460319906, + "prefill_s": 0.282, + "decode_tokens_per_s": 21.503, + "time_breakdown_s": { + "drafter_cached": 1.861, + "incremental_verify": 1.095, + "ctx_kv_extend": 0.02 + }, + "blocks": 11, + "mean_accept_len": 4.91, + "decode_tokens": 64 + } + ], + "speedup_over_ar_x": 1.27 + } +} \ No newline at end of file diff --git a/results/research/k3_specdecode_gpu_bench.json b/results/research/k3_specdecode_gpu_bench.json new file mode 100644 index 00000000..0813d312 --- /dev/null +++ b/results/research/k3_specdecode_gpu_bench.json @@ -0,0 +1,191 @@ +{ + "kind": "k3_specdecode_gpu_bench", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "haystack_lines": 60, + "n_samples": 3, + "max_new_tokens": 48, + "block_size": 16, + "sink": 4, + "window": 64, + "seed": 0, + "output": "results/research/k3_specdecode_gpu_bench.json" + }, + "env": { + "gpu": "NVIDIA H200", + "torch": "2.11.0+cu128" + }, + "prompt_tokens": { + "min": 1238, + "max": 1238 + }, + "ar_incremental": { + "decode_tokens_per_s_mean": 17.286, + "recall": 1.0 + }, + "restored_pertoken": { + "decode_tokens_per_s_mean": 3.47, + "recall": 1.0 + }, + "restored_specdecode": { + "decode_tokens_per_s_mean": 6.778, + "mean_accept_len": 2.38, + "verifier_forwards_total": 68, + "recall": 1.0, + "per_sample": [ + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 7243, + 107, + 100, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618 + ], + "decode_s": 10.724500498035923, + "decode_tokens_per_s": 4.476, + "verifier_forwards": 28, + "drafter_forwards": 14, + "blocks": 14, + "mean_accept_len": 2.43, + "decode_tokens": 48 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 236797, + 236786, + 236776, + 106, + 106, + 1 + ], + "decode_s": 3.5449153430527076, + "decode_tokens_per_s": 5.924, + "verifier_forwards": 16, + "drafter_forwards": 8, + "blocks": 8, + "mean_accept_len": 1.62, + "decode_tokens": 21 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 106, + 107, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825 + ], + "decode_s": 4.832237196038477, + "decode_tokens_per_s": 9.933, + "verifier_forwards": 24, + "drafter_forwards": 12, + "blocks": 12, + "mean_accept_len": 3.08, + "decode_tokens": 48 + } + ], + "speedup_over_pertoken_x": 1.95 + } +} \ No newline at end of file diff --git a/results/research/k3_specdecode_integrated.json b/results/research/k3_specdecode_integrated.json new file mode 100644 index 00000000..2e175938 --- /dev/null +++ b/results/research/k3_specdecode_integrated.json @@ -0,0 +1,223 @@ +{ + "kind": "k3_specdecode_gpu_bench", + "config": { + "verifier_id": "google/gemma-4-26B-A4B-it", + "drafter_id": "z-lab/gemma-4-26B-A4B-it-DFlash", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "haystack_lines": 60, + "n_samples": 3, + "max_new_tokens": 48, + "block_size": 15, + "sink": 4, + "window": 64, + "seed": 0, + "output": "results/research/k3_specdecode_integrated.json" + }, + "env": { + "gpu": "NVIDIA H200", + "torch": "2.11.0+cu128" + }, + "prompt_tokens": { + "min": 1238, + "max": 1238 + }, + "ar_incremental": { + "decode_tokens_per_s_mean": 20.884, + "recall": 1.0 + }, + "restored_pertoken": { + "decode_tokens_per_s_mean": 20.928, + "recall": 1.0 + }, + "restored_specdecode": { + "decode_tokens_per_s_mean": 10.623, + "mean_accept_len": 3.33, + "time_breakdown_s_mean": { + "aux_clean_forward": 0.984, + "drafter": 2.346, + "incremental_verify": 1.086 + }, + "recall": 1.0, + "per_sample": [ + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819, + 236771, + 84750, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 28487, + 1618, + 236772, + 236832, + 236828, + 236819 + ], + "decode_s": 5.677021626965143, + "decode_tokens_per_s": 8.455, + "time_breakdown_s": { + "aux_clean_forward": 0.956, + "drafter": 3.667, + "incremental_verify": 1.054 + }, + "blocks": 10, + "mean_accept_len": 3.9, + "decode_tokens": 48 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 107, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772, + 236828, + 236819, + 236825, + 236770, + 84750, + 106, + 106, + 107, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 236777, + 59790, + 236772 + ], + "decode_s": 4.319566503050737, + "decode_tokens_per_s": 11.112, + "time_breakdown_s": { + "aux_clean_forward": 1.004, + "drafter": 2.215, + "incremental_verify": 1.1 + }, + "blocks": 11, + "mean_accept_len": 3.45, + "decode_tokens": 48 + }, + { + "tokens": [ + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 45518, + 107, + 101, + 818, + 6789, + 3393, + 563, + 5213, + 4989, + 26742, + 236772, + 236825, + 236828, + 236825, + 236825, + 84750, + 106, + 106, + 106, + 107, + 101, + 107, + 1, + 106, + 45518 + ], + "decode_s": 3.251323492033407, + "decode_tokens_per_s": 12.303, + "time_breakdown_s": { + "aux_clean_forward": 0.992, + "drafter": 1.155, + "incremental_verify": 1.105 + }, + "blocks": 11, + "mean_accept_len": 2.64, + "decode_tokens": 40 + } + ], + "speedup_over_pertoken_x": 0.51 + } +} \ No newline at end of file diff --git a/results/research/logs/k3_e2e_gpu_bench.log b/results/research/logs/k3_e2e_gpu_bench.log new file mode 100644 index 00000000..7d82bd78 --- /dev/null +++ b/results/research/logs/k3_e2e_gpu_bench.log @@ -0,0 +1,31 @@ +[e2e] loading verifier google/gemma-4-26B-A4B-it +Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads. + Loading weights: 0%| | 0/1013 [00:00 16.9x saving; ctx 1254 tok over 68-tok window (18.4x); tok/s restored=2.259 ar=21.514; recall restored=1.0 ar=1.0 + +[e2e] === rung: 160 haystack lines | prompt tokens min=3238 max=3238 === +[e2e] running standalone AR baseline ... +[e2e] AR sample 0: gen=16 recall=Y out[:40]='The secret code is MAPLE-7890.\nthought' +[e2e] AR sample 1: gen=16 recall=Y out[:40]='IOTA-8961\nDESCRIBE THE PROCESS OF CRE' +[e2e] AR sample 2: gen=16 recall=Y out[:40]='THETA-6866\nout of the box\nTH' +[e2e] running Kakeya restored path ... +[e2e] restored sample 0: gen=16 recall=Y resident_kv_tok=68 eff_ctx=3254 out[:40]='The secret code is **MAPLE-7890**.---' +[e2e] restored sample 1: gen=16 recall=Y resident_kv_tok=68 eff_ctx=3254 out[:40]='The secret code is **IOTA-8961**.---' +[e2e] restored sample 2: gen=16 recall=Y resident_kv_tok=68 eff_ctx=3254 out[:40]='The secret code is **THETA-6866**.---' +[e2e] rung 160: KV 733.06MB(AR) vs 16.71MB(restored) -> 43.9x saving; ctx 3254 tok over 68-tok window (47.9x); tok/s restored=1.273 ar=21.92; recall restored=1.0 ar=1.0 + +[e2e] wrote results/research/k3_e2e_gpu_bench.json +EXIT_CODE=0 diff --git a/results/research/logs/k3_e2e_gpu_bench_incremental.log b/results/research/logs/k3_e2e_gpu_bench_incremental.log new file mode 100644 index 00000000..e4931778 --- /dev/null +++ b/results/research/logs/k3_e2e_gpu_bench_incremental.log @@ -0,0 +1,31 @@ +[e2e] loading verifier google/gemma-4-26B-A4B-it + Loading weights: 0%| | 0/1013 [00:00 16.9x saving; ctx 1254 tok over 68-tok window (18.4x); tok/s restored=21.68 ar=21.121; recall restored=1.0 ar=1.0 + +[e2e] === rung: 160 haystack lines | prompt tokens min=3238 max=3238 === +[e2e] running standalone AR baseline ... +[e2e] AR sample 0: gen=16 recall=Y out[:40]='The secret code is MAPLE-7890.\nthought' +[e2e] AR sample 1: gen=16 recall=Y out[:40]='IOTA-8961\nDESCRIBE THE PROCESS OF CRE' +[e2e] AR sample 2: gen=16 recall=Y out[:40]='THETA-6866\nout of the box\nTH' +[e2e] running Kakeya restored path ... +[e2e] restored sample 0: gen=16 recall=Y resident_kv_tok=68 eff_ctx=3254 out[:40]='The secret code is **MAPLE-7890**.though' +[e2e] restored sample 1: gen=16 recall=Y resident_kv_tok=68 eff_ctx=3254 out[:40]='The secret code is **IOTA-8961**.\n' +[e2e] restored sample 2: gen=16 recall=Y resident_kv_tok=68 eff_ctx=3254 out[:40]='The secret code is **THETA-6866**.though' +[e2e] rung 160: KV 733.06MB(AR) vs 16.71MB(restored) -> 43.9x saving; ctx 3254 tok over 68-tok window (47.9x); tok/s restored=20.98 ar=21.943; recall restored=1.0 ar=1.0 + +[e2e] wrote results/research/k3_e2e_gpu_bench_incremental.json +EXIT=0 diff --git a/results/research/logs/k3_s5_kl_niah_ctx280_mac_2d855ba_rerun_oom.log b/results/research/logs/k3_s5_kl_niah_ctx280_mac_2d855ba_rerun_oom.log new file mode 100644 index 00000000..f9fefaa1 --- /dev/null +++ b/results/research/logs/k3_s5_kl_niah_ctx280_mac_2d855ba_rerun_oom.log @@ -0,0 +1,53 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --compress-full-attn --n-samples 10 --haystack-min-lines 238 --haystack-max-lines 322 --sink-size 4 --window-size 64 --max-new-tokens 24 --output results/research/k3_s5_kl_niah_ctx280_mac.json + +Commit under test: +2d855ba Mac M4 K3 S5 KL ctx280 SDPA OOM evidence + +Started: 2026-06-11T04:49:25.658Z +Ended: 2026-06-11T04:49:44.924Z +Elapsed: 19.266s +Exit code: 1 + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on mps +[mac] KakeyaLattice compression ON for full-attn layers (D4 Q38); bits/token/head=3232.0 +[mac] 10 samples, prompt len min=4859 max=6419 +[mac] running restored cross-model verifier (s5) +Traceback (most recent call last): + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 389, in + sys.exit(main()) + ~~~~^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 290, in main + cross_dec, cross_lat, cross_tok = greedy(restored_next_logits) + ~~~~~~^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 275, in greedy + nxt = step_fn(cur) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 209, in restored_next_logits + d_k, d_v = capture_drafter_kv(ids) + ~~~~~~~~~~~~~~~~~~^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 190, in capture_drafter_kv + h = layer(h, qpos, ctx_k=None, ctx_v=None) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/inference_engine/v04/dflash_drafter.py", line 307, in forward + h = h + self.self_attn( + ~~~~~~~~~~~~~~^ + self.input_layernorm(h), query_positions, ctx_k, ctx_v, + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ) + ^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/inference_engine/v04/dflash_drafter.py", line 280, in forward + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=False, scale=self.scale, + ) # [B, nh, T, hd] +RuntimeError: MPS backend out of memory (MPS allocated: 3.31 GiB, other allocations: 24.15 GiB, max allowed: 30.19 GiB). Tried to allocate 4.91 GiB on shared pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure). diff --git a/results/research/logs/k3_s5_kl_niah_ctx280_mac_8452c5a_sdpa_oom.log b/results/research/logs/k3_s5_kl_niah_ctx280_mac_8452c5a_sdpa_oom.log new file mode 100644 index 00000000..8ddf9182 --- /dev/null +++ b/results/research/logs/k3_s5_kl_niah_ctx280_mac_8452c5a_sdpa_oom.log @@ -0,0 +1,53 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --compress-full-attn --n-samples 10 --haystack-min-lines 238 --haystack-max-lines 322 --sink-size 4 --window-size 64 --max-new-tokens 24 --output results/research/k3_s5_kl_niah_ctx280_mac.json + +Commit under test: +8452c5a K3 fix MPS OOM: DFlash attention uses memory-efficient SDPA instead of materializing full fp32 [B,nh,T,C+T] score matrix (~5GB at T~6k, nh=32) — was OOMing the ctx280 S5+KL Mac run in drafter K/V capture. Numerically equivalent (max diff 7e-7), 28 drafter tests pass. + +Started: 2026-06-11T04:25:18.978Z +Ended: 2026-06-11T04:25:37.364Z +Elapsed: 18.384s +Exit code: 1 + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on mps +[mac] KakeyaLattice compression ON for full-attn layers (D4 Q38); bits/token/head=3232.0 +[mac] 10 samples, prompt len min=4859 max=6419 +[mac] running restored cross-model verifier (s5) +Traceback (most recent call last): + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 389, in + sys.exit(main()) + ~~~~^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 290, in main + cross_dec, cross_lat, cross_tok = greedy(restored_next_logits) + ~~~~~~^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 275, in greedy + nxt = step_fn(cur) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 209, in restored_next_logits + d_k, d_v = capture_drafter_kv(ids) + ~~~~~~~~~~~~~~~~~~^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 190, in capture_drafter_kv + h = layer(h, qpos, ctx_k=None, ctx_v=None) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/inference_engine/v04/dflash_drafter.py", line 307, in forward + h = h + self.self_attn( + ~~~~~~~~~~~~~~^ + self.input_layernorm(h), query_positions, ctx_k, ctx_v, + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ) + ^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/inference_engine/v04/dflash_drafter.py", line 280, in forward + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=False, scale=self.scale, + ) # [B, nh, T, hd] +RuntimeError: MPS backend out of memory (MPS allocated: 3.31 GiB, other allocations: 24.15 GiB, max allowed: 30.19 GiB). Tried to allocate 4.91 GiB on shared pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure). diff --git a/results/research/logs/k3_s5_kl_niah_ctx280_mac_oom.log b/results/research/logs/k3_s5_kl_niah_ctx280_mac_oom.log new file mode 100644 index 00000000..f58ca9e6 --- /dev/null +++ b/results/research/logs/k3_s5_kl_niah_ctx280_mac_oom.log @@ -0,0 +1,49 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --compress-full-attn --n-samples 10 --haystack-min-lines 238 --haystack-max-lines 322 --sink-size 4 --window-size 64 --max-new-tokens 24 --output results/research/k3_s5_kl_niah_ctx280_mac.json + +Started: 2026-06-11T03:56:31.618Z +Ended: 2026-06-11T03:56:48.132Z +Elapsed: 16.514s +Exit code: 1 + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on mps +[mac] KakeyaLattice compression ON for full-attn layers (D4 Q38); bits/token/head=3232.0 +[mac] 10 samples, prompt len min=4859 max=6419 +[mac] running restored cross-model verifier (s5) +Traceback (most recent call last): + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 389, in + sys.exit(main()) + ~~~~^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 290, in main + cross_dec, cross_lat, cross_tok = greedy(restored_next_logits) + ~~~~~~^^^^^^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 275, in greedy + nxt = step_fn(cur) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 209, in restored_next_logits + d_k, d_v = capture_drafter_kv(ids) + ~~~~~~~~~~~~~~~~~~^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/scripts/research/k3_integrated_niah_eval_mac.py", line 190, in capture_drafter_kv + h = layer(h, qpos, ctx_k=None, ctx_v=None) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/inference_engine/v04/dflash_drafter.py", line 302, in forward + h = h + self.self_attn( + ~~~~~~~~~~~~~~^ + self.input_layernorm(h), query_positions, ctx_k, ctx_v, + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + ) + ^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1778, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1789, in _call_impl + return forward_call(*args, **kwargs) + File "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine-pr94-resolve/inference_engine/v04/dflash_drafter.py", line 276, in forward + attn = torch.softmax(scores.float(), dim=-1).to(q.dtype) + ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ +RuntimeError: MPS backend out of memory (MPS allocated: 8.12 GiB, other allocations: 14.31 GiB, max allowed: 30.19 GiB). Tried to allocate 4.91 GiB on shared pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure). diff --git a/results/research/logs/k3_s5_kl_niah_ctx70_cpu_no_sample_timeout.log b/results/research/logs/k3_s5_kl_niah_ctx70_cpu_no_sample_timeout.log new file mode 100644 index 00000000..71cd9fe7 --- /dev/null +++ b/results/research/logs/k3_s5_kl_niah_ctx70_cpu_no_sample_timeout.log @@ -0,0 +1,22 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --compress-full-attn --drafter-device cpu --n-samples 10 --haystack-min-lines 60 --haystack-max-lines 81 --max-new-tokens 16 --output results/research/k3_s5_kl_niah_ctx70_mac.json + +Commit under test: +91ecaa1 K3: make DFlash attention query-chunk env-tunable (KAKEYA_DFLASH_ATTN_QCHUNK) for tight-memory Macs + +Started: 2026-06-11T04:53:45.439Z +Stopped: 2026-06-11T05:05:59.705Z +Elapsed: 734.266s +Exit code: unknown (manually stopped after no sample output) + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on cpu +[mac] KakeyaLattice compression ON for full-attn layers (D4 Q38); bits/token/head=3232.0 +[mac] 10 samples, prompt len min=1240 max=1639 +[mac] running restored cross-model verifier (s5) + +No first sample completed after more than 12 minutes. Process remained alive +but low-utilization (~7% CPU, ~2 GB RSS). This avoids the MPS OOM by moving +the drafter/f_theta path to CPU, but the resulting product path is not usable +for even ctx70. diff --git a/results/research/logs/k3_s5_kl_niah_ctx70_freegen_8dcb1d0_slow.log b/results/research/logs/k3_s5_kl_niah_ctx70_freegen_8dcb1d0_slow.log new file mode 100644 index 00000000..ac514978 --- /dev/null +++ b/results/research/logs/k3_s5_kl_niah_ctx70_freegen_8dcb1d0_slow.log @@ -0,0 +1,22 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --compress-full-attn --drafter-device cpu --n-samples 3 --haystack-min-lines 60 --haystack-max-lines 81 --max-new-tokens 16 --output results/research/k3_s5_kl_niah_ctx70_mac_freegen.json + +Commit under test: +8dcb1d0 K3 MLX harness: fix recall metric — default to free-generation (teacher-forced misses the model's preamble -> read 0/10 even for oracle). Oracle now uses mlx NATIVE incremental KV cache (fast + correct reference, expect ~10/10). --teacher-forced kept as labeled diagnostic. Cross = restored free-gen (correct; full-forward/token, slow on M4). + +Started: 2026-06-11T05:43:44.858Z +Stopped: 2026-06-11T05:53:02.023Z +Elapsed: 557.165s +Exit code: unknown (manually stopped as unusably slow) + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on cpu +[mac] KakeyaLattice compression ON for full-attn layers (D4 Q38); bits/token/head=3232.0 +[mac] 3 samples, prompt len min=1301 max=1639 +[mac] running restored cross-model verifier (s5, free_gen) +[mac] sample 0: T=1639 -> '<|channel>thought\nThe user wants to find the "se' + +Only the first sample completed after more than 9 minutes, and it generated a +thought/preamble fragment rather than the needle answer. The current restored +free-generation Mac path remains unusable for product evaluation. diff --git a/results/research/logs/k3_s5_kl_niah_ctx70_mac_95613ed.log b/results/research/logs/k3_s5_kl_niah_ctx70_mac_95613ed.log new file mode 100644 index 00000000..90b06d45 --- /dev/null +++ b/results/research/logs/k3_s5_kl_niah_ctx70_mac_95613ed.log @@ -0,0 +1,44 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --compress-full-attn --drafter-device cpu --n-samples 10 --haystack-min-lines 60 --haystack-max-lines 81 --output results/research/k3_s5_kl_niah_ctx70_mac.json + +Commit under test: +95613ed K3 MLX harness refactor (usability): amortize restoration and default to teacher-forced recall. + +Started: 2026-06-11T05:17:30.200Z +Ended: 2026-06-11T05:34:03.301Z +Elapsed: 993.101s +Exit code: 0 + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on cpu +[mac] KakeyaLattice compression ON for full-attn layers (D4 Q38); bits/token/head=3232.0 +[mac] 10 samples, prompt len min=1240 max=1639 +[mac] running restored cross-model verifier (s5, teacher_forced) +[mac] sample 0: T=1639 pred[:48]='<|channel>ETA-1409' +[mac] sample 1: T=1380 pred[:48]='<|channel>-3286' +[mac] sample 2: T=1301 pred[:48]='<|channel>CHID-9935' +[mac] sample 3: T=1599 pred[:48]='<|channel>-1520' +[mac] sample 4: T=1280 pred[:48]='<|channel>-4811' +[mac] sample 5: T=1619 pred[:48]='<|channel>-4257' +[mac] sample 6: T=1500 pred[:48]='<|channel>-8359' +[mac] sample 7: T=1240 pred[:48]='<|channel>LE-3615' +[mac] sample 8: T=1500 pred[:48]='<|channel>ETA-5552' +[mac] sample 9: T=1360 pred[:48]='<|channel>LE-6514' +[mac] cross-model recall = 0.000 (0/10) +[mac] running oracle (full MLX forward) +[mac] sample 0: T=1639 pred[:48]='<|channel>ETA-1409' +[mac] sample 1: T=1380 pred[:48]='<|channel>-3286' +[mac] sample 2: T=1301 pred[:48]='<|channel>CHID-9935' +[mac] sample 3: T=1599 pred[:48]='<|channel>-1520' +[mac] sample 4: T=1280 pred[:48]='<|channel>-4811' +[mac] sample 5: T=1619 pred[:48]='<|channel>-4257' +[mac] sample 6: T=1500 pred[:48]='<|channel>-8359' +[mac] sample 7: T=1240 pred[:48]='<|channel>LE-3615' +[mac] sample 8: T=1500 pred[:48]='<|channel>ETA-5552' +[mac] sample 9: T=1360 pred[:48]='<|channel>LE-6514' +[mac] oracle recall = 0.000 +[mac] KV resident @T=1639: S5=27.17 MB (growth 7.89 KB/tok); naive-full=369.23 MB +[mac] cross-model throughput (teacher_forced): 0.0931 tok/s (66 tok / 708.976 s, 70.898 s/sample) + +[mac] DONE. cross=0.000 oracle=0.0 -> results/research/k3_s5_kl_niah_ctx70_mac.json diff --git a/results/research/logs/k3_s5_niah_mac_smoke_timeout.log b/results/research/logs/k3_s5_niah_mac_smoke_timeout.log new file mode 100644 index 00000000..0d6fdd40 --- /dev/null +++ b/results/research/logs/k3_s5_niah_mac_smoke_timeout.log @@ -0,0 +1,16 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --n-samples 4 --haystack-min-lines 60 --haystack-max-lines 81 --max-new-tokens 16 --output results/research/k3_s5_niah_mac_smoke.json + +Started: 2026-06-11T03:34:06.986Z +Stopped: 2026-06-11T03:49:33.104Z +Elapsed: 926.118s +Exit code: unknown (manually stopped after no sample output) + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on mps +[mac] 4 samples, prompt len min=1301 max=1639 +[mac] running restored cross-model verifier (s5) + +No sample completed after roughly 15 minutes. Follow-up one-token diagnostic +showed a single restored cross-model token took about 112s on this Mac path. diff --git a/results/research/logs/k3_s5_niah_mac_step1_diag.log b/results/research/logs/k3_s5_niah_mac_step1_diag.log new file mode 100644 index 00000000..a88756ec --- /dev/null +++ b/results/research/logs/k3_s5_niah_mac_step1_diag.log @@ -0,0 +1,17 @@ +Command: +source "/Users/fluffy314/Documents/Kakeya-LLM-Inference-engine/.venv-mac/bin/activate" && PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py --f-theta-dir results/research/f_theta_v5_s5_sliding --s5-exact-full-attn --n-samples 1 --haystack-min-lines 60 --haystack-max-lines 81 --max-new-tokens 1 --skip-oracle --output results/research/k3_s5_niah_mac_step1_diag.json + +Started: 2026-06-11T03:49:41.976Z +Ended: 2026-06-11T03:51:47.709Z +Elapsed: 125.733s +Exit code: 0 + +[mac] loading MLX verifier models/gemma-4-26B-A4B-it-mlx-4bit +[mac] verifier layers=30 full_attn=[5, 11, 17, 23, 29] +[mac] loading drafter models/dflash-kakeya-baseline on mps +[mac] 1 samples, prompt len min=1639 max=1639 +[mac] running restored cross-model verifier (s5) +[mac] sample 0: T=1639 -> '<|channel>' +[mac] cross-model recall = 0.000 (0/1) + +[mac] DONE. cross=0.000 oracle=skipped -> results/research/k3_s5_niah_mac_step1_diag.json diff --git a/scripts/bench_mlx_kakeya_deployment.py b/scripts/bench_mlx_kakeya_deployment.py new file mode 100644 index 00000000..0f410c38 --- /dev/null +++ b/scripts/bench_mlx_kakeya_deployment.py @@ -0,0 +1,294 @@ +"""Mac (MLX) high-performance deployment benchmark for the Kakeya engine. + +Goal: demonstrate, on Apple Silicon (M-series), that the Kakeya **sink+window +bounded-KV** inference path delivers *sustained high throughput at constant +memory* as context grows — vs a vanilla full-KV baseline whose KV cache (and +per-token attention cost) grow with context. + +This is the local-deployment benchmark for the **Gemma 4 26B-A4B** verifier on +a 24 GB M4: 4-bit weights are ~16 GB resident, leaving ~8 GB for KV + activations. +With a vanilla full-KV cache, per-token attention cost and KV memory grow with +context, so decode tok/s collapses and peak memory climbs toward the 24 GB +ceiling at long context. The Kakeya sink+window cache bounds both: persistent KV +is O(sink+window) and per-token attention is over the bounded window, so decode +throughput and peak memory stay ~flat as context grows. (Long-range *recall* +needs the separate K/V-Restoration path; this benchmark measures the throughput ++ memory envelope.) + +Both paths run through mlx_lm's **own** ``generate_step`` engine (chunked +prefill + pipelined async decode) — only the KV cache differs. This is the +apples-to-apples test: from first principles Kakeya is just MLX + a tighter +cache, so it should be *faster + lighter* than vanilla, never slower. If it is +slower, the cache implementation has a bug. For each context length L, on the +SAME model + SAME engine: + + * **Vanilla** — the model's native cache (``make_prompt_cache`` → + ``model.make_cache()``: full ``KVCache`` for the 5 global layers, + ``RotatingKVCache(sliding_window)`` for the 25 sliding layers). The 5 + global layers' KV grows with L; per-token attention there is over all L + keys. + * **Kakeya** — sink+window bounded cache (``make_sink_window_cache``) for + every layer: persistent KV is O(sink+window) and per-token attention is + over the bounded window for *all* layers (incl. the global ones). (Note: + this is the bounded-KV / StreamingLLM-class fast path — long-range *recall* + needs the separate K/V-Restoration; this benchmark measures the throughput + + memory envelope.) + +Reports, per L: time-to-first-token (incl. prefill), decode tok/s, resident KV +bytes, peak memory, and the kakeya/vanilla decode-speedup + KV-shrink ratios. + +Run on the Mac (Apple Silicon): + + source .venv-mac/bin/activate # or your MLX venv + PYTHONPATH=.:sdks/python python3 scripts/bench_mlx_kakeya_deployment.py \ + --model-id models/gemma-4-26B-A4B-it-mlx-4bit \ + --context-lengths 512,2048,8192 \ + --gen-tokens 64 --sink-size 4 --window-size 64 \ + --output results/platform-tests/bench_mlx_kakeya_deployment.json + +The bounded-KV advantage grows with context: push --context-lengths higher +(e.g. 16384,32768) to widen the gap, as long as the vanilla full-KV prefill +still fits in memory. Use --skip-vanilla when vanilla would OOM at long context +so the Kakeya path can still be measured. +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + + +def parse_args() -> argparse.Namespace: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--model-id", default="models/gemma-4-26B-A4B-it-mlx-4bit", + help="MLX 4-bit model id or local path (default: the " + "Gemma 4 26B-A4B 4-bit verifier).") + ap.add_argument("--context-lengths", default="512,2048,8192", + help="Comma-separated prompt token lengths to sweep.") + ap.add_argument("--skip-kakeya", action="store_true", + help="Skip the sink+window path (measure vanilla only).") + ap.add_argument("--gen-tokens", type=int, default=64) + ap.add_argument("--sink-size", type=int, default=4) + ap.add_argument("--window-size", type=int, default=64) + ap.add_argument("--skip-vanilla", action="store_true", + help="Skip the full-KV baseline (e.g. when it would OOM).") + ap.add_argument("--output", default=None) + return ap.parse_args() + + +def _peak_memory_bytes(mx) -> int: + for getter in ("get_peak_memory",): + fn = getattr(mx, getter, None) + if fn is not None: + try: + return int(fn()) + except Exception: + pass + metal = getattr(mx, "metal", None) + if metal is not None and hasattr(metal, "get_peak_memory"): + try: + return int(metal.get_peak_memory()) + except Exception: + pass + return -1 + + +def _reset_peak_memory(mx) -> None: + for name in ("reset_peak_memory",): + fn = getattr(mx, name, None) + if fn is not None: + try: + fn(); return + except Exception: + pass + metal = getattr(mx, "metal", None) + if metal is not None and hasattr(metal, "reset_peak_memory"): + try: + metal.reset_peak_memory() + except Exception: + pass + + +def _resident_kv_bytes(cache: list) -> int: + """Actual resident K+V bytes across a per-layer cache list. + + Uses each tensor's real ``.nbytes`` (the allocated buffer), so it is + correct and uniform across *every* cache type — full ``KVCache`` + (grows with context), ``RotatingKVCache`` (capped ring buffer) and our + ``SinkWindowKVCache`` (sink+window). This is the honest comparison: it + reflects what is physically held, not the unbounded global ``offset``. + """ + total = 0 + for c in cache: + for name in ("keys", "values"): + t = getattr(c, name, None) + nb = getattr(t, "nbytes", None) if t is not None else None + if nb is not None: + total += int(nb) + return total + + +def _run(mx, generate_step, model, prompt_ids: List[int], gen_tokens: int, + cache) -> Dict[str, Any]: + """Prefill + greedy-decode ``gen_tokens`` using mlx_lm's *native* + ``generate_step`` (chunked prefill + pipelined async decode), swapping + only the KV cache. This isolates the cache's effect on the native engine. + Returns timing + memory metrics. + """ + _reset_peak_memory(mx) + prompt = mx.array(prompt_ids) + gen = generate_step(prompt, model, max_tokens=gen_tokens, prompt_cache=cache) + t0 = time.perf_counter() + first = next(gen) # prefill + first decoded token + _ = first[0] # already an int (generate_step yields y.item()) + ttft_s = time.perf_counter() - t0 + n = 0 + t1 = time.perf_counter() + for _tok, _lp in gen: + n += 1 + decode_s = time.perf_counter() - t1 + return { + "ttft_s": round(ttft_s, 4), # time to first token (incl. prefill) + "decode_s": round(decode_s, 4), + "decode_tokens": n, # tokens after the first + "decode_tokens_per_s": round(n / decode_s, 3) if decode_s > 0 and n > 0 else None, + "kv_bytes": int(_resident_kv_bytes(cache)), + "peak_memory_bytes": _peak_memory_bytes(mx), + } + + +def main() -> int: + args = parse_args() + + import mlx.core as mx # type: ignore + import mlx_lm # type: ignore + from mlx_lm.models.cache import make_prompt_cache # type: ignore + from mlx_lm.generate import generate_step # type: ignore + from inference_engine.backends.mlx.cache import make_sink_window_cache + + ctx_lengths = [int(x) for x in args.context_lengths.split(",") if x.strip()] + print(f"[bench] loading MLX model {args.model_id}", file=sys.stderr, flush=True) + model, tokenizer = mlx_lm.load(args.model_id) + + # A deterministic synthetic prompt of a given length (content is + # irrelevant for the throughput/memory envelope; we use a fixed filler + # token so prefill length == L). + bos = getattr(tokenizer, "bos_token_id", None) + filler = tokenizer.encode("the ") + filler_tok = filler[-1] if filler else 1 + + def make_prompt(L: int) -> List[int]: + ids = ([bos] if bos is not None else []) + [filler_tok] * (L - (1 if bos is not None else 0)) + return ids[:L] if len(ids) >= L else ids + [filler_tok] * (L - len(ids)) + + def make_vanilla_cache(): + return make_prompt_cache(model) + + def make_kakeya_cache(): + return make_sink_window_cache( + model, sink_size=args.sink_size, window_size=args.window_size) + + # Warm up MLX kernel compilation for BOTH cache paths before timing. + # MLX compiles graphs lazily on first use; without this, whichever path + # runs first absorbs the (large, one-off) compile cost and looks slower. + # The 1-token decode graph compiled here is shared across all context + # lengths, so decode tok/s is measured fairly for both caches. + warm_prompt = make_prompt(64) + for label, mk in (("vanilla", make_vanilla_cache), ("kakeya", make_kakeya_cache)): + if (label == "vanilla" and args.skip_vanilla) or ( + label == "kakeya" and args.skip_kakeya): + continue + print(f"[bench] warmup ({label}) ...", file=sys.stderr, flush=True) + try: + wc = mk() + for _ in generate_step(mx.array(warm_prompt), model, + max_tokens=8, prompt_cache=wc): + pass + wc = None + mx.clear_cache() + except Exception as e: + print(f"[bench] warmup ({label}) failed: {e}", file=sys.stderr) + + rows: List[Dict[str, Any]] = [] + for L in ctx_lengths: + prompt_ids = make_prompt(L) + row: Dict[str, Any] = {"context_length": L} + if not args.skip_vanilla: + print(f"[bench] L={L}: vanilla (native make_prompt_cache) ...", + file=sys.stderr, flush=True) + try: + vcache = make_vanilla_cache() + row["vanilla"] = _run( + mx, generate_step, model, prompt_ids, args.gen_tokens, vcache) + except Exception as e: # OOM or unsupported → record and continue + row["vanilla"] = {"error": f"{type(e).__name__}: {e}"} + print(f"[bench] L={L}: vanilla path failed: {e}", file=sys.stderr) + finally: + # Free the (possibly large) vanilla KV before measuring kakeya, + # so its peak-memory reading isn't inflated by leftover state. + vcache = None + mx.clear_cache() + + if not args.skip_kakeya: + print(f"[bench] L={L}: Kakeya sink+window ...", file=sys.stderr, flush=True) + try: + kcache = make_kakeya_cache() + row["kakeya"] = _run( + mx, generate_step, model, prompt_ids, args.gen_tokens, kcache) + except Exception as e: + row["kakeya"] = {"error": f"{type(e).__name__}: {e}"} + print(f"[bench] L={L}: kakeya path failed: {e}", file=sys.stderr) + finally: + kcache = None + mx.clear_cache() + + k = row.get("kakeya", {}) + v = row.get("vanilla", {}) + k_ok = isinstance(k, dict) and "decode_tokens_per_s" in k + v_ok = isinstance(v, dict) and "decode_tokens_per_s" in v + if k_ok and v_ok: + sp = (k["decode_tokens_per_s"] or 0) / max(v["decode_tokens_per_s"] or 1e-9, 1e-9) + row["kakeya_vs_vanilla"] = { + "decode_speedup_x": round(sp, 3), + "kv_bytes_ratio_x": round(v.get("kv_bytes", 0) / max(k.get("kv_bytes", 1), 1), 1), + } + if v_ok: + print(f"[bench] L={L}: vanilla {v['decode_tokens_per_s']} tok/s " + f"(ttft {v['ttft_s']}s, KV {v['kv_bytes']/1e6:.2f} MB, " + f"peak {v['peak_memory_bytes']/1e9:.2f} GB)", file=sys.stderr) + if k_ok: + print(f"[bench] L={L}: kakeya {k['decode_tokens_per_s']} tok/s " + f"(ttft {k['ttft_s']}s, KV {k['kv_bytes']/1e6:.2f} MB, " + f"peak {k['peak_memory_bytes']/1e9:.2f} GB)", file=sys.stderr) + if k_ok and v_ok: + r = row["kakeya_vs_vanilla"] + print(f"[bench] L={L}: kakeya vs vanilla -> decode {r['decode_speedup_x']}x, " + f"KV {r['kv_bytes_ratio_x']}x smaller", file=sys.stderr) + rows.append(row) + + report = { + "kind": "mlx_kakeya_deployment_benchmark", + "config": { + "model_id": args.model_id, + "context_lengths": ctx_lengths, + "gen_tokens": args.gen_tokens, + "sink_size": args.sink_size, + "window_size": args.window_size, + }, + "env": {"mlx_version": getattr(mx, "__version__", "?")}, + "results": rows, + } + out_path = Path(args.output) if args.output else Path( + f"results/platform-tests/bench_mlx_kakeya_deployment_{int(time.time())}.json") + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(report, indent=2)) + print(f"\n[bench] wrote {out_path}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/research/k3_dflash_specdecode_eval.py b/scripts/research/k3_dflash_specdecode_eval.py index 1e863632..0ec11e60 100644 --- a/scripts/research/k3_dflash_specdecode_eval.py +++ b/scripts/research/k3_dflash_specdecode_eval.py @@ -54,6 +54,19 @@ "Write a haiku about speculative decoding.", ] +# HumanEval-style code-generation prompts — the regime the z-lab DFlash +# reference (~0.447 / 7.7) is measured on. DFlash drafts code/structured +# output far better than open-ended short Q&A, so this set characterizes +# acceptance on the reference's distribution. +CODE_PROMPTS = [ + "Complete this Python function:\n\ndef has_close_elements(numbers: list[float], threshold: float) -> bool:\n \"\"\"Return True if any two numbers are closer than threshold.\"\"\"\n", + "Complete this Python function:\n\ndef is_palindrome(s: str) -> bool:\n \"\"\"Return True if s reads the same forwards and backwards, ignoring case and non-alphanumeric chars.\"\"\"\n", + "Complete this Python function:\n\ndef merge_sort(arr: list[int]) -> list[int]:\n \"\"\"Return a new list with the elements of arr sorted ascending using merge sort.\"\"\"\n", + "Complete this Python function:\n\ndef gcd(a: int, b: int) -> int:\n \"\"\"Return the greatest common divisor of a and b using the Euclidean algorithm.\"\"\"\n", + "Complete this Python function:\n\ndef flatten(nested: list) -> list:\n \"\"\"Flatten an arbitrarily nested list of integers into a single flat list.\"\"\"\n", + "Complete this Python function:\n\ndef count_words(text: str) -> dict[str, int]:\n \"\"\"Return a dict mapping each lowercased word in text to its frequency.\"\"\"\n", +] + # Disjoint from the alignment trainer's prompt corpus — for honest held-out # acceptance after alignment training (no topic/phrasing near-duplicates). HELD_OUT_PROMPTS = [ @@ -91,14 +104,17 @@ def aux_hidden_context(self, committed_token_ids): return aux, bonus -def _build_embed_lm_head(model, hidden_size, softcap): +def _build_embed_lm_head(model, hidden_size, softcap, embed_scale=None): emb = model.get_input_embeddings() head = model.get_output_embeddings() - scale = math.sqrt(hidden_size) + # Reference DFlashQwen3Model.forward embeds with a PLAIN lookup + # (``self.embed_tokens(input_ids)``) — no Gemma ``×sqrt(hidden)`` + # normalizer (that scale is applied inside the Gemma model body, not in + # the shared embed the Qwen3 drafter consumes). Default to no scale to + # match the reference; ``embed_scale`` overrides for A/B testing. + scale = 1.0 if embed_scale is None else float(embed_scale) def embed_fn(ids: torch.Tensor) -> torch.Tensor: - # Gemma scales token embeddings by sqrt(hidden) (PR #41703: DFlash - # draft path applies the target embedding normalization). return emb(ids).float() * scale def lm_head_fn(h: torch.Tensor) -> torch.Tensor: @@ -156,12 +172,38 @@ def main() -> int: ap.add_argument("--drafter-state", default=None, help="optional .pt state_dict to load over the drafter " "(e.g. an alignment-trained checkpoint).") + ap.add_argument("--embed-scale", type=float, default=None, + help="Scale applied to the shared embedding fed to the " + "drafter. Default None = 1.0 (reference, no Gemma " + "sqrt(hidden) normalizer). Pass e.g. 53.06 to A/B the " + "old (incorrect) sqrt(2816) scaling.") ap.add_argument("--held-out", action="store_true", help="evaluate on HELD_OUT_PROMPTS (disjoint from the " "alignment trainer's prompts) for honest generalization.") + ap.add_argument("--prompt-set", choices=["default", "held-out", "code"], + default=None, + help="Which prompt set to use. 'code' = HumanEval-style " + "(the z-lab reference regime). Overrides --held-out.") + ap.add_argument("--humaneval-jsonl", default=None, + help="Path to the canonical HumanEval .jsonl (each line a " + "problem with a 'prompt' field). Uses the first " + "--n-prompts problems' prompts. This is the exact " + "z-lab reference regime (~0.447 / 7.7).") + ap.add_argument("--raw-completion", action="store_true", + help="Feed the raw prompt tokens (no chat template) — the " + "native HumanEval code-completion setup.") ap.add_argument("--output", default=None) args = ap.parse_args() - prompts = HELD_OUT_PROMPTS if args.held_out else PROMPTS + if args.humaneval_jsonl: + with open(args.humaneval_jsonl) as fh: + rows = [json.loads(line) for line in fh if line.strip()] + prompts = [r["prompt"] for r in rows[: args.n_prompts]] + elif args.prompt_set == "code": + prompts = CODE_PROMPTS + elif args.prompt_set == "held-out" or args.held_out: + prompts = HELD_OUT_PROMPTS + else: + prompts = PROMPTS device = torch.device("cuda") dtype = torch.bfloat16 @@ -183,7 +225,8 @@ def main() -> int: cfg = drafter.cfg hidden = cfg.hidden_size softcap = cfg.final_logit_softcapping - embed_fn, lm_head_fn = _build_embed_lm_head(verifier, hidden, softcap) + embed_fn, lm_head_fn = _build_embed_lm_head( + verifier, hidden, softcap, embed_scale=args.embed_scale) provider = VerifierAuxProvider(verifier, cfg.aux_layer_ids, device) proposer = DFlashProposer(drafter, provider, embed_fn, lm_head_fn) @@ -198,14 +241,19 @@ def main() -> int: for pi in range(min(args.n_prompts, len(prompts))): prompt = prompts[pi] - msgs = [{"role": "user", "content": prompt}] - enc = tok.apply_chat_template( - msgs, add_generation_prompt=True, tokenize=True, return_tensors="pt", - ) - # transformers 5.x may return a Tensor or a BatchEncoding/dict. - if hasattr(enc, "keys"): - enc = enc["input_ids"] - ids = enc[0].tolist() + if args.raw_completion: + # Native HumanEval code-completion: feed the raw prompt tokens + # (no chat template); the verifier continues the function body. + ids = tok(prompt, return_tensors="pt").input_ids[0].tolist() + else: + msgs = [{"role": "user", "content": prompt}] + enc = tok.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=True, return_tensors="pt", + ) + # transformers 5.x may return a Tensor or a BatchEncoding/dict. + if hasattr(enc, "keys"): + enc = enc["input_ids"] + ids = enc[0].tolist() committed = list(ids) generated: List[int] = [] blk_accepts = [] diff --git a/scripts/research/k3_dflash_specdecode_eval_mac.py b/scripts/research/k3_dflash_specdecode_eval_mac.py index 72f279df..fc38f584 100644 --- a/scripts/research/k3_dflash_specdecode_eval_mac.py +++ b/scripts/research/k3_dflash_specdecode_eval_mac.py @@ -146,7 +146,7 @@ def main() -> int: help="Local MLX 4-bit verifier directory (default: standard Mac path).", ) ap.add_argument( - "--drafter-id", default="models/dflash-kakeya-baseline", + "--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash", help="DFlash drafter source — local path or HF id. Default: the " "alignment-trained baseline on main (post PR #93 + #99 merge).", ) diff --git a/scripts/research/k3_e2e_gpu_bench.py b/scripts/research/k3_e2e_gpu_bench.py new file mode 100644 index 00000000..ff473c8b --- /dev/null +++ b/scripts/research/k3_e2e_gpu_bench.py @@ -0,0 +1,333 @@ +"""K3 end-to-end GPU benchmark — Kakeya restored verifier vs standalone AR. + +Runs the served Kakeya inference path (the Gap 1 + Gap 2 +``CrossModelRestoredSinkWindowVerifier``: f_θ + S5 K/V Restoration over a +bounded sink+window cache) and the standalone Gemma 4 26B-A4B AR model on +the *same* NIAH prompts, and reports, per context rung: + + * **Memory** — resident KV bytes (restored = bounded sink+window; + AR = the model's own HF cache, which grows with context) + peak GPU. + * **Throughput** — decode tokens/s (excludes prefill). + * **Verifier attention context length** — restored: resident *window* + (sink+window) vs *effective* context (full prompt+gen, reconstructed + via restoration); AR: full resident context. + * **Recall** — fraction of NIAH needles recalled (correctness check + that the served restored path still answers). + +Run on a CUDA host (e.g. H200) inside the transformers-5.x venv:: + + HF_HOME=/workspace/.hf_home PYTHONPATH=.:sdks/python \ + .venv-k3/bin/python scripts/research/k3_e2e_gpu_bench.py \ + --verifier-id google/gemma-4-26B-A4B-it \ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \ + --f-theta-dir results/research/f_theta_v5_s5_sliding \ + --haystack-lines 60,160 --n-samples 3 --gen-tokens 16 \ + --output results/research/k3_e2e_gpu_bench.json +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from pathlib import Path +from typing import Any, Dict, List + +import torch + + +def _kv_bytes_from_cache(cache: Any) -> int: + """Sum K+V bytes across an HF cache (DynamicCache / hybrid).""" + total = 0 + layers = getattr(cache, "layers", None) + if layers is None: + # Legacy tuple-of-tuples cache. + try: + for k, v in cache: + total += k.numel() * k.element_size() + total += v.numel() * v.element_size() + return total + except TypeError: + return 0 + for layer in layers: + for name in ("keys", "values"): + t = getattr(layer, name, None) + if t is not None and hasattr(t, "numel"): + total += t.numel() * t.element_size() + return total + + +def _peak_gpu(device) -> int: + try: + return int(torch.cuda.max_memory_allocated(device)) + except Exception: + return -1 + + +@torch.no_grad() +def run_ar(model, ids_list, samples, gen_tokens, tokenizer, device) -> Dict[str, Any]: + """Standalone AR: incremental decode with the model's own KV cache.""" + n = len(ids_list) + tot_tok = 0 + tot_t = 0.0 + hits = 0 + kv_bytes = 0 + peak = 0 + prefill_t = 0.0 + for i, ids in enumerate(ids_list): + torch.cuda.reset_peak_memory_stats(device) + t_pf = time.perf_counter() + out = model(input_ids=ids, use_cache=True) + torch.cuda.synchronize(device) + prefill_t += time.perf_counter() - t_pf + cache = out.past_key_values + nxt = int(out.logits[0, -1].argmax().item()) + gen_ids: List[int] = [] + cur = torch.tensor([[nxt]], device=device, dtype=torch.long) + t0 = time.perf_counter() + for _ in range(gen_tokens): + gen_ids.append(nxt) + out = model(input_ids=cur, past_key_values=cache, use_cache=True) + cache = out.past_key_values + nxt = int(out.logits[0, -1].argmax().item()) + cur = torch.tensor([[nxt]], device=device, dtype=torch.long) + torch.cuda.synchronize(device) + tot_t += time.perf_counter() - t0 + tot_tok += len(gen_ids) + txt = tokenizer.decode(gen_ids, skip_special_tokens=True) + if samples[i].answer_text in txt: + hits += 1 + kv_bytes = _kv_bytes_from_cache(cache) # last sample (full context) + peak = max(peak, _peak_gpu(device)) + print(f"[e2e] AR sample {i}: gen={len(gen_ids)} " + f"recall={'Y' if samples[i].answer_text in txt else 'N'} " + f"out[:40]={txt[:40]!r}", file=sys.stderr, flush=True) + return { + "decode_tokens_per_s": round(tot_tok / tot_t, 3) if tot_t > 0 else None, + "prefill_s_mean": round(prefill_t / n, 4), + "kv_bytes_final": kv_bytes, + "peak_mem_bytes": peak, + "recall": round(hits / n, 3), + "decode_tokens": tot_tok, + } + + +@torch.no_grad() +def run_restored(adapter, ids_list, samples, gen_tokens, tokenizer, device) -> Dict[str, Any]: + """Kakeya restored path: bounded sink+window cache + f_θ/S5 restoration.""" + n = len(ids_list) + tot_tok = 0 + tot_t = 0.0 + hits = 0 + peak = 0 + resident_kv = 0 + eff_ctx = 0 + prefill_t = 0.0 + for i, ids in enumerate(ids_list): + torch.cuda.reset_peak_memory_stats(device) + prompt = ids[0].tolist() + t_pf = time.perf_counter() + adapter.prefill(prompt) + torch.cuda.synchronize(device) + prefill_t += time.perf_counter() - t_pf + nxt = int(adapter.next_token_logits.argmax().item()) + gen_ids: List[int] = [] + t0 = time.perf_counter() + for _ in range(gen_tokens): + gen_ids.append(nxt) + adapter.append_token(nxt) + nxt = int(adapter.next_token_logits.argmax().item()) + torch.cuda.synchronize(device) + tot_t += time.perf_counter() - t0 + tot_tok += len(gen_ids) + txt = tokenizer.decode(gen_ids, skip_special_tokens=True) + if samples[i].answer_text in txt: + hits += 1 + resident_kv = adapter.live_kv_bytes() + eff_ctx = max(eff_ctx, len(adapter._committed)) + peak = max(peak, _peak_gpu(device)) + print(f"[e2e] restored sample {i}: gen={len(gen_ids)} " + f"recall={'Y' if samples[i].answer_text in txt else 'N'} " + f"resident_kv_tok={adapter.sink_size + adapter.window_size} " + f"eff_ctx={len(adapter._committed)} " + f"out[:40]={txt[:40]!r}", file=sys.stderr, flush=True) + return { + "decode_tokens_per_s": round(tot_tok / tot_t, 3) if tot_t > 0 else None, + "prefill_s_mean": round(prefill_t / n, 4), + "resident_kv_bytes": resident_kv, + "resident_window_tokens": adapter.sink_size + adapter.window_size, + "effective_context_tokens": eff_ctx, + "peak_mem_bytes": peak, + "recall": round(hits / n, 3), + "decode_tokens": tot_tok, + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-id", default="google/gemma-4-26B-A4B-it") + ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash") + ap.add_argument("--f-theta-dir", default="results/research/f_theta_v5_s5_sliding") + ap.add_argument("--haystack-lines", default="60,160", + help="Comma-separated haystack line counts (context rungs).") + ap.add_argument("--n-samples", type=int, default=3) + ap.add_argument("--gen-tokens", type=int, default=16) + ap.add_argument("--sink", type=int, default=4) + ap.add_argument("--window", type=int, default=64) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--incremental", action="store_true", + help="Use the incremental-decode restored path (capture " + "restored K/V at prefill, then native O(L)/block " + "decode) instead of the O(T) re-forward per step.") + ap.add_argument("--output", default=None) + args = ap.parse_args() + + if not torch.cuda.is_available(): + print("[e2e] CUDA not available — this benchmark requires a GPU.", + file=sys.stderr) + return 2 + device = torch.device("cuda") + dtype = torch.bfloat16 + + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.models.gemma4.modeling_gemma4 import ( # type: ignore + ALL_ATTENTION_FUNCTIONS, apply_rotary_pos_emb, eager_attention_forward, + ) + from inference_engine.v04 import ( + CrossModelRestoredSinkWindowVerifier, DFlashDrafter, FThetaProjection, + make_niah_dataset, + ) + from inference_engine.v04.cross_model_dlm_verifier import ( + CrossModelDLMRestoredVerifier, full_attention_layer_indices, + resolve_text_config, + ) + + print(f"[e2e] loading verifier {args.verifier_id}", file=sys.stderr, flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.verifier_id) + verifier = AutoModelForCausalLM.from_pretrained( + args.verifier_id, dtype=dtype, attn_implementation="eager", + device_map="auto", + ).eval() + for p in verifier.parameters(): + p.requires_grad_(False) + + print(f"[e2e] loading drafter {args.drafter_id}", file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=dtype).to(device).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + + print(f"[e2e] loading f_θ {args.f_theta_dir}", file=sys.stderr, flush=True) + f_theta = FThetaProjection.from_pretrained( + args.f_theta_dir, dtype=torch.float32, device=device, + ) + + exact_layers = full_attention_layer_indices(verifier) + print(f"[e2e] S5 exact full-attention layers: {exact_layers}", file=sys.stderr) + restored = CrossModelDLMRestoredVerifier( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, + sink_size=args.sink, window_size=args.window, + exact_layer_indices=exact_layers, + ) + adapter = CrossModelRestoredSinkWindowVerifier( + restored, + apply_rotary_pos_emb=apply_rotary_pos_emb, + eager_attention_forward=eager_attention_forward, + all_attention_functions=ALL_ATTENTION_FUNCTIONS, + device="cuda", + incremental=args.incremental, + ) + print(f"[e2e] restored adapter incremental={args.incremental}", file=sys.stderr) + + v_cfg = resolve_text_config(verifier.config) + verifier_dims = { + "num_hidden_layers": int(getattr(v_cfg, "num_hidden_layers", 0)), + "num_key_value_heads": int(getattr(v_cfg, "num_key_value_heads", 0) or 0), + "head_dim": int(getattr(v_cfg, "head_dim", 0) or 0), + "sliding_window": int(getattr(v_cfg, "sliding_window", 0) or 0), + } + + def encode_chat(text: str) -> torch.Tensor: + ids = tokenizer.apply_chat_template( + [{"role": "user", "content": text}], + add_generation_prompt=True, tokenize=True, return_tensors="pt", + ) + if hasattr(ids, "keys"): + ids = ids["input_ids"] + elif isinstance(ids, list): + ids = torch.tensor([ids]) + return ids.to(device) + + rungs = [int(x) for x in args.haystack_lines.split(",") if x.strip()] + rows: List[Dict[str, Any]] = [] + for lines in rungs: + samples = make_niah_dataset( + n_samples=args.n_samples, + haystack_min_lines=lines, haystack_max_lines=lines, seed=args.seed, + ) + ids_list = [encode_chat(s.prompt_text) for s in samples] + seqlens = [int(t.size(1)) for t in ids_list] + print(f"\n[e2e] === rung: {lines} haystack lines | prompt tokens " + f"min={min(seqlens)} max={max(seqlens)} ===", file=sys.stderr, flush=True) + + print("[e2e] running standalone AR baseline ...", file=sys.stderr, flush=True) + ar = run_ar(verifier, ids_list, samples, args.gen_tokens, tokenizer, device) + print("[e2e] running Kakeya restored path ...", file=sys.stderr, flush=True) + rs = run_restored(adapter, ids_list, samples, args.gen_tokens, tokenizer, device) + + kv_saving = (ar["kv_bytes_final"] / rs["resident_kv_bytes"] + if rs["resident_kv_bytes"] else None) + ctx_compression = (rs["effective_context_tokens"] / rs["resident_window_tokens"] + if rs["resident_window_tokens"] else None) + row = { + "haystack_lines": lines, + "prompt_tokens": {"min": min(seqlens), "max": max(seqlens)}, + "ar": ar, + "restored": rs, + "comparison": { + "kv_memory_saving_x": round(kv_saving, 1) if kv_saving else None, + "ar_kv_mb": round(ar["kv_bytes_final"] / 1e6, 2), + "restored_resident_kv_mb": round(rs["resident_kv_bytes"] / 1e6, 2), + "context_compression_x": round(ctx_compression, 1) if ctx_compression else None, + "throughput_ratio_restored_over_ar": ( + round(rs["decode_tokens_per_s"] / ar["decode_tokens_per_s"], 3) + if ar["decode_tokens_per_s"] else None + ), + }, + } + rows.append(row) + c = row["comparison"] + print(f"[e2e] rung {lines}: KV {c['ar_kv_mb']}MB(AR) vs " + f"{c['restored_resident_kv_mb']}MB(restored) -> {c['kv_memory_saving_x']}x saving; " + f"ctx {rs['effective_context_tokens']} tok over {rs['resident_window_tokens']}-tok window " + f"({c['context_compression_x']}x); " + f"tok/s restored={rs['decode_tokens_per_s']} ar={ar['decode_tokens_per_s']}; " + f"recall restored={rs['recall']} ar={ar['recall']}", file=sys.stderr, flush=True) + + report = { + "kind": "k3_e2e_gpu_bench", + "config": { + "verifier_id": args.verifier_id, + "drafter_id": args.drafter_id, + "f_theta_dir": args.f_theta_dir, + "sink_size": args.sink, "window_size": args.window, + "gen_tokens": args.gen_tokens, "n_samples": args.n_samples, + "haystack_lines": rungs, + }, + "verifier_dims": verifier_dims, + "env": { + "gpu": torch.cuda.get_device_name(0), + "torch": torch.__version__, + }, + "results": rows, + } + out_path = Path(args.output) if args.output else Path( + f"results/research/k3_e2e_gpu_bench_{int(time.time())}.json") + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(report, indent=2)) + print(f"\n[e2e] wrote {out_path}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/research/k3_f_theta_train.py b/scripts/research/k3_f_theta_train.py new file mode 100644 index 00000000..86f0d3a5 --- /dev/null +++ b/scripts/research/k3_f_theta_train.py @@ -0,0 +1,1604 @@ +"""K3 Block C — Train ``f_θ`` K/V projection: drafter K/V → verifier K/V. + +v3 (2026-06-10) — one-shot principled trainer, attention-output distillation +=========================================================================== + +PR #103 v1 evidence: identity-restore recall = 1.0 (machinery correct); +f_θ-projected recall = 0.0 (training inadequate). + +Per user request 2026-06-10: "一步到位,不要中间态" — skip the v2 +intermediate (cos+mag) and ship the principled fix directly. + +The ONE-SHOT principled fix +--------------------------- + +**Attention-output distillation loss** (the v3 default +``--loss-type attn_distill``). For each verifier layer ℓ: + + K_pred_ℓ, V_pred_ℓ = f_θ(drafter_KV)[ℓ] + + Q_for_attn = q_norm(Q_raw_ℓ).view(B, T, H_q, D) → RoPE → transpose + K_for_attn = k_norm(K_pred_ℓ).view(B, T, H_kv, D) → RoPE → transpose + V_for_attn = v_norm(V_pred_ℓ).view(B, T, H_kv, D) → transpose + + O_pred_ℓ = o_proj(scaled_dot_product_attention(Q, K, V, mask, scale)) + + loss_ℓ = MSE(O_pred_ℓ, O_tgt_ℓ) # O_tgt is the verifier's + actual attn output captured + during data collection + + Total = mean over layers + +This is the **mathematically right loss for K/V projection**. It directly +optimises "f_θ-injected K/V produces equivalent verifier attention output", +accounting for: GQA grouping, RoPE positional encoding, causal/sliding +mask, k_norm/q_norm/v_norm, AND the layer's o_proj. Unlike pure MSE +(v1) or cos+mag (v2), this loss exposes the gradient to the actual +quantity that propagates through the residual stream at inference. + +To make this affordable on H200, data collection caches per layer per +sequence (Q_raw, O_tgt, cos, sin, attention_mask) on CPU bf16; training +streams these to GPU per step. Verifier forward is run ONCE per +sequence (not per training step). For 64 sequences × 30 layers × T=512, +cache is ~25 GB CPU RAM (fits comfortably). + +Three additional changes (carried over from v2 design) +------------------------------------------------------ + + (a) **Larger f_θ rank**: default 256 → 768 for ``attn_distill`` + (more capacity at the encoder bottleneck; ~88M params total + vs v1's 32M). Legacy losses keep rank=256. + + (b) **NIAH-style synthetic training prompts**: 64 prompts (50% of + corpus) match the eval's haystack+needle pattern with + independent seeds, so f_θ sees retrieval structure at training. + + (c) **Cosine LR schedule + 20000 steps**: linear warmup (500 steps) + then cosine decay to peak/100. v1's 4000 constant-lr steps was + grossly undertrained (59 s of training). + +Reproducibility +--------------- + +v1 reproduction: + --loss-type mse --steps 4000 --gen-len 128 --lr-schedule const + --no-niah-prompts --rank 256 +v2 reproduction: + --loss-type combined --steps 20000 --gen-len 512 --lr-schedule cosine + (default in v2 — see git log of this file pre-v3) +v3 (default): --loss-type attn_distill (everything above tuned for it) + +Reproducibility +--------------- + +v1 training is reproducible by passing +``--loss-type mse --steps 4000 --gen-len 128 --lr-schedule const +--no-niah-prompts``. + +v2 defaults are tuned for converging f_θ to a checkpoint that closes +the integrated NIAH gate (recall_delta_vs_oracle ≤ 5pp). + +Pipeline (CUDA, vast.ai H200/H100): + + 1. Load Gemma 4 26B-A4B verifier (transformers, bf16, sdpa) + 2. Load DFlash drafter (PR #93's DFlashDrafter.from_pretrained, + using models/dflash-kakeya-baseline) + 3. Build training corpus: + a. PROMPTS list (general / code / math / facts / creative) + b. (v2) synthetic NIAH-style prompts (haystack with random + marker_id + question — same pattern as the eval but with + independent seeds → no test contamination) + 4. For each training sequence in the corpus: + a. Run verifier forward; record K/V at every layer at every position + (extracted via attention forward hooks on each layer's k_proj + / v_proj — pre-norm pre-RoPE, matching what the cross-model + DLMRestoredVerifier needs to inject) + b. Run drafter forward via capture_proposer_kv; KVCapture has + K/V at every drafter layer at every position (pre-norm pre-RoPE) + c. f_θ targets: f_θ(drafter_kv) ≈ verifier_kv + 5. Train f_θ with the configured loss (default cosine+mag combined, + v1 mse-only via flag), AdamW + cosine LR schedule + +Requires: + * HF_TOKEN (Gemma 4 is gated) + * transformers >= 5.0 (Gemma 4 support) + * drafter checkpoint at models/dflash-kakeya-baseline/ + +Outputs: + * Trained f_θ checkpoint at --save (default: results/research/f_theta/) + Format: f_theta_config.json + f_theta_weights.pt (per + FThetaProjection.save_pretrained contract) + * Training report at .json with config, final_loss, + per-layer-loss breakdown, elapsed time + +Usage: + HF_TOKEN=hf_xxx PYTHONPATH=.:sdks/python python3 \ + scripts/research/k3_f_theta_train.py \ + --steps 4000 --lr 1e-3 --rank 256 --batch-prompts 4 --seq-len 512 \ + --save results/research/f_theta_v1 + +The training set is the same prompts the alignment_train.py corpus +uses (PR #93's PROMPTS list, expanded if --extended-corpus). Each +training step: pick a random sequence from the cache, sample a +random window of ``seq_len`` positions, compute f_θ predictions vs +verifier targets at those positions, MSE loss, AdamW step. + +Memory budget: + * verifier 26B bf16: ~52 GB (needs H200 80 GB / multi-GPU) + * drafter 0.43B bf16: ~0.9 GB + * f_θ rank=256 fp32: ~130 MB (tiny vs everything else) + * verifier K/V cache for 1 sequence at T=512: + 30 layers × 512 × 2048 × 2 (K+V) × 2 bytes = ~125 MB + * drafter K/V cache for 1 sequence at T=512: + 5 layers × 512 × 256 × 2 × 2 = ~2.5 MB + * Training takes a few hundred GB of K/V cache across the corpus; + we keep K/V in fp16 on GPU and stream from CPU when corpus > GPU. + +For the K3 first training run, we start with the same 64-prompt corpus +PR #93 used, ~512 tokens generated per prompt. Total cache: ~64 × 125 +MB = ~8 GB, fits in GPU memory comfortably. + +Validation gate: + * Final MSE loss ≤ 0.5× the initial random-init loss (proves f_θ + learned something meaningful; the 0.5× threshold is conservative + — actual converged f_θ will be much lower). + * Per-layer loss should be roughly uniform; outliers indicate + layer-specific issues that need investigation. + +After training, the cross-model DLMRestoredVerifier loads this +checkpoint and uses it for K/V Restoration in the integrated +Kakeya inference loop. +""" + +from __future__ import annotations + +import argparse +import json +import math +import random +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F + +from inference_engine.v04.f_theta import FThetaConfig, FThetaProjection +from inference_engine.v04.cross_model_dlm_verifier import ( + _capture_drafter_kv, + get_verifier_decoder, + resolve_text_config, +) +from inference_engine.v04.dflash_drafter import DFlashDrafter + + +# Same training prompt corpus as PR #93's k3_dflash_alignment_train.py +# — direct comparability of evidence. +PROMPTS = [ + "Write a Python function that returns the n-th Fibonacci number.", + "Write a Python function to reverse a linked list.", + "Implement binary search in Python with comments.", + "Write a function to compute the factorial of n iteratively.", + "Write a Python class for a simple stack with push/pop/peek.", + "Implement quicksort in Python.", + "Write a regex that matches a valid IPv4 address and explain it.", + "Write a Python decorator that times a function and prints the duration.", + "Implement a function to merge two sorted lists.", + "Write a SQL query to find the second-highest salary in an Employees table.", + "Write a bash one-liner to count lines in all .py files under a directory.", + "Implement a debounce function in JavaScript.", + "Write a Python generator that yields prime numbers.", + "Explain and implement memoization for a recursive Fibonacci.", + "Write a function to detect a cycle in a directed graph.", + "Implement a least-recently-used (LRU) cache in Python.", + "Compute the sum of the first 100 positive integers and show your reasoning.", + "If a train travels 60 km in 45 minutes, what is its speed in km/h?", + "Solve for x: 3x + 7 = 22.", + "What is the derivative of x^3 + 2x with respect to x?", + "Explain the Pythagorean theorem with an example.", + "A bag has 3 red and 2 blue balls; what is the probability of drawing red?", + "List the first eight powers of two.", + "Explain why the square root of 2 is irrational.", + "Convert 0.625 to a fraction and simplify.", + "What is 15% of 240? Show the steps.", + "What is the capital of Japan?", + "Who wrote the play Hamlet?", + "What is photosynthesis in one sentence?", + "Name the four fundamental forces of physics.", + "What gas do plants absorb from the atmosphere?", + "What is the largest planet in the solar system?", + "Who developed the theory of general relativity?", + "What is the chemical symbol for gold?", + "What year did the first human land on the moon?", + "What is the speed of light in a vacuum (approximate)?", + "Explain how a hash map works in one paragraph.", + "Explain the difference between a process and a thread.", + "Explain what a REST API is to a beginner.", + "Describe how TCP establishes a connection (three-way handshake).", + "Explain what overfitting is in machine learning.", + "Explain the concept of recursion with a simple analogy.", + "Describe what a transformer attention mechanism does at a high level.", + "Explain the difference between supervised and unsupervised learning.", + "What is a deadlock and how can it be avoided?", + "Explain garbage collection in managed languages.", + "Write a haiku about autumn leaves.", + "Write a two-sentence horror story.", + "Compose a short motivational quote about perseverance.", + "Write a limerick about a programmer who loves coffee.", + "Draft a one-line git commit message for a bug fix in the parser.", + "Summarize the water cycle in two sentences.", + "Write a polite email asking to reschedule a meeting.", + "Give three tips for writing clear documentation.", + "Write a short poem about the ocean at night.", + "Describe a sunset using vivid imagery in two sentences.", + "Explain why the sky appears blue.", + "Summarize the plot of Cinderella in one sentence.", + "List three benefits of regular exercise.", + "What causes the seasons on Earth?", + "Give two reasons why version control is important.", + "Write a tagline for a fictional eco-friendly water bottle.", +] + + +@dataclass +class AttentionTargetData: + """Per-layer attention-output distillation target data. + + Captured during data collection by running the verifier forward + once with hooks on every layer. Used by the attention-output + distillation loss (v3 / one-shot trainer) to evaluate + ``attention(Q, f_θ(K), f_θ(V))`` against the verifier's actual + attention output without needing to re-run the verifier at every + training step. + + Per-layer (length = num_verifier_layers): + + q_raw [T, num_heads × head_dim] — q_proj output, pre-norm + o_tgt [T, hidden_dim] — attn module output, post-o_proj + cos [1, T, head_dim] — RoPE cosine table + sin [1, T, head_dim] — RoPE sine table + attention_mask — captured causal/sliding mask + + For hybrid loss (loss_type=attn_distill_hybrid), additionally: + + k_raw [T, num_kv_heads × head_dim] — k_proj output, pre-norm + v_raw [T, num_kv_heads × head_dim] — v_proj output, pre-norm + + These are needed to compute K/V direction (cosine post-norm) + + magnitude (pre-norm) losses that prevent the f_θ-collapse + degeneracy exposed by the 2026-06-10 alpha-sweep diagnostic + (f_θ raw rel_mse = 1331× target — k_norm/v_norm normalised the + scale away so attn-output loss alone didn't constrain it). + + All tensors stored bf16 to halve memory (cast to fp32 on use). + Stored on CPU; transferred to GPU per training step. + """ + q_raw: List[torch.Tensor] # per-layer pre-norm Q + o_tgt: List[torch.Tensor] # per-layer attn module output + cos: List[torch.Tensor] # per-layer RoPE cos + sin: List[torch.Tensor] # per-layer RoPE sin + attention_mask: Optional[torch.Tensor] + num_heads_per_layer: List[int] + head_dim_per_layer: List[int] + # Optional pre-norm K/V tgt for hybrid loss (None for legacy + # attn_distill which only needs Q + O_tgt). + k_raw_tgt: Optional[List[torch.Tensor]] = None + v_raw_tgt: Optional[List[torch.Tensor]] = None + + +@dataclass +class CapturedSequence: + """Paired drafter / verifier data over one training sequence. + + All tensors live on the device that produced them by default; the + attention distillation tensors are CPU bf16 to keep total cache + size manageable for 64-prompt corpora. + + Two paths populate this: + + legacy K/V path (loss_type ∈ mse, cos_mag, combined): + drafter_k, drafter_v, verifier_k, verifier_v + attn_target = None + + attention-output distillation (loss_type = attn_distill, default): + drafter_k, drafter_v, attn_target (verifier_k/verifier_v omitted) + + The attn_distill path is the one-shot principled trainer. The + legacy path is kept for v1/v2 reproducibility / ablation but is + not the default after v3. + """ + seq_len: int + drafter_k: torch.Tensor # [num_d_layers, T, drafter_kv_dim] + drafter_v: torch.Tensor # [num_d_layers, T, drafter_kv_dim] + # Legacy K/V (None when attn_distill captured instead) + verifier_k: Optional[List[torch.Tensor]] = None + verifier_v: Optional[List[torch.Tensor]] = None + # Attention-output distillation target data (None for legacy path) + attn_target: Optional[AttentionTargetData] = None + + +def _capture_verifier_kv( + verifier_model: torch.nn.Module, input_ids: torch.Tensor, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Run verifier forward and capture per-layer K, V via forward hooks + on each decoder layer's k_proj / v_proj. + + Returns + ------- + (verifier_k, verifier_v): per-layer lists of length num_v_layers, + element ``i`` shaped ``[T, kv_dim_i]`` on the verifier's device. + Layers can have heterogeneous ``kv_dim_i`` (Gemma 4). + """ + layers = get_verifier_decoder(verifier_model).layers + num_layers = len(layers) + k_capture: List[torch.Tensor] = [None] * num_layers + v_capture: List[torch.Tensor] = [None] * num_layers + handles = [] + # Gemma 4 has KV-sharing layers where v_proj is None; there the + # value_states equal the raw k_proj output (pre k_norm / pre RoPE). + # Capture V from the k_proj output for those layers. + v_shared_from_k: List[int] = [] + + for i, layer in enumerate(layers): + attn = layer.self_attn + + def _make_k_hook(idx): + def hook(_mod, _inp, output): + k_capture[idx] = output.detach() + return hook + + def _make_v_hook(idx): + def hook(_mod, _inp, output): + v_capture[idx] = output.detach() + return hook + + handles.append(attn.k_proj.register_forward_hook(_make_k_hook(i))) + if getattr(attn, "v_proj", None) is not None: + handles.append(attn.v_proj.register_forward_hook(_make_v_hook(i))) + else: + v_shared_from_k.append(i) + + try: + with torch.no_grad(): + _ = verifier_model(input_ids=input_ids, use_cache=False) + finally: + for h in handles: + h.remove() + + if any(k is None for k in k_capture): + raise RuntimeError( + "verifier K capture missing some layers — hooks did not fire" + ) + # Fill V for v_proj-None layers with the captured k_proj output. + for i in v_shared_from_k: + v_capture[i] = k_capture[i] + if any(v is None for v in v_capture): + raise RuntimeError( + "verifier V capture missing some layers — hooks did not fire" + ) + # Per-layer lists; layers may have heterogeneous kv_dim (Gemma 4). + # Each k_capture[i] is [B, T, kv_dim_i]; assume B=1 and drop it. + k_list: List[torch.Tensor] = [] + v_list: List[torch.Tensor] = [] + for kc, vc in zip(k_capture, v_capture): + if kc.size(0) != 1: + raise NotImplementedError( + f"f_θ training currently assumes batch=1 (got {kc.size(0)})" + ) + k_list.append(kc[0]) # [T, kv_dim_i] + v_list.append(vc[0]) + return k_list, v_list + + +def _capture_attention_target_data( + verifier_model: torch.nn.Module, input_ids: torch.Tensor, + *, capture_raw_kv: bool = False, +) -> AttentionTargetData: + """Run verifier forward with hooks to capture per-layer attention + distillation targets (Q_raw, O_tgt, cos, sin, attention_mask). + + If ``capture_raw_kv`` is True, also capture per-layer K_raw and + V_raw (k_proj and v_proj outputs, pre-norm) — required by the + hybrid loss that constrains K/V direction + magnitude in addition + to attention output. + + Returns an :class:`AttentionTargetData` with all tensors moved to + CPU bf16 (the per-step training loop streams these back to GPU). + """ + layers = get_verifier_decoder(verifier_model).layers + num_layers = len(layers) + + q_capture: List[Optional[torch.Tensor]] = [None] * num_layers + o_capture: List[Optional[torch.Tensor]] = [None] * num_layers + cos_capture: List[Optional[torch.Tensor]] = [None] * num_layers + sin_capture: List[Optional[torch.Tensor]] = [None] * num_layers + mask_capture: List[Optional[torch.Tensor]] = [None] * num_layers + k_raw_capture: List[Optional[torch.Tensor]] = [None] * num_layers + v_raw_capture: List[Optional[torch.Tensor]] = [None] * num_layers + v_shared_from_k: List[int] = [] # full-attn layers: V_raw == K_raw + handles = [] + + for i, layer in enumerate(layers): + attn = layer.self_attn + + def _make_q_hook(idx): + def hook(_mod, _inp, output): + q_capture[idx] = output.detach() + return hook + + def _make_o_hook(idx): + def hook(_mod, _inp, output): + # attn module returns (attn_output, attn_weights) + if isinstance(output, tuple): + o_capture[idx] = output[0].detach() + else: + o_capture[idx] = output.detach() + return hook + + def _make_pre_hook(idx): + def hook(_mod, args, kwargs): + # Gemma 4 attention.forward signature: + # (hidden_states, position_embeddings, attention_mask, ...) + pos_emb = None + if "position_embeddings" in kwargs: + pos_emb = kwargs["position_embeddings"] + elif len(args) >= 2: + pos_emb = args[1] + if pos_emb is not None: + cos, sin = pos_emb + cos_capture[idx] = cos.detach() + sin_capture[idx] = sin.detach() + am = None + if "attention_mask" in kwargs: + am = kwargs["attention_mask"] + elif len(args) >= 3: + am = args[2] + if am is not None: + mask_capture[idx] = am.detach() + return hook + + handles.append(attn.q_proj.register_forward_hook(_make_q_hook(i))) + handles.append(attn.register_forward_hook(_make_o_hook(i))) + handles.append( + attn.register_forward_pre_hook(_make_pre_hook(i), with_kwargs=True), + ) + if capture_raw_kv: + def _make_k_hook(idx): + def hook(_mod, _inp, output): + k_raw_capture[idx] = output.detach() + return hook + + def _make_v_hook(idx): + def hook(_mod, _inp, output): + v_raw_capture[idx] = output.detach() + return hook + + handles.append(attn.k_proj.register_forward_hook(_make_k_hook(i))) + if getattr(attn, "v_proj", None) is not None: + handles.append(attn.v_proj.register_forward_hook(_make_v_hook(i))) + else: + v_shared_from_k.append(i) + + try: + with torch.no_grad(): + _ = verifier_model(input_ids=input_ids, use_cache=False) + finally: + for h in handles: + h.remove() + + if any(q is None for q in q_capture): + raise RuntimeError("attention distill: Q capture missing some layers") + if any(o is None for o in o_capture): + raise RuntimeError("attention distill: O capture missing some layers") + if any(c is None for c in cos_capture): + raise RuntimeError("attention distill: cos capture missing some layers") + + num_heads_per_layer: List[int] = [] + head_dim_per_layer: List[int] = [] + for layer in layers: + attn = layer.self_attn + head_dim_per_layer.append(int(attn.head_dim)) + num_heads_per_layer.append(int(attn.q_proj.out_features // attn.head_dim)) + + # Stack and move to CPU bf16. Drop batch dim (B=1). + def _to_cpu_bf16(t: torch.Tensor) -> torch.Tensor: + return t.to(dtype=torch.bfloat16, device="cpu", copy=True) + + q_list = [_to_cpu_bf16(q[0]) for q in q_capture] # [T, n_heads*head_dim] + o_list = [_to_cpu_bf16(o[0]) for o in o_capture] # [T, hidden] + cos_list = [_to_cpu_bf16(c) for c in cos_capture] # [1, T, head_dim] or [B, T, head_dim] + sin_list = [_to_cpu_bf16(s) for s in sin_capture] + mask_cpu = ( + mask_capture[0].to(device="cpu", copy=True) if mask_capture[0] is not None + else None + ) + + k_raw_list: Optional[List[torch.Tensor]] = None + v_raw_list: Optional[List[torch.Tensor]] = None + if capture_raw_kv: + if any(k is None for k in k_raw_capture): + raise RuntimeError("hybrid capture: K_raw missing some layers") + for i in v_shared_from_k: + v_raw_capture[i] = k_raw_capture[i] + if any(v is None for v in v_raw_capture): + raise RuntimeError("hybrid capture: V_raw missing some layers") + k_raw_list = [_to_cpu_bf16(k[0]) for k in k_raw_capture] + v_raw_list = [_to_cpu_bf16(v[0]) for v in v_raw_capture] + + return AttentionTargetData( + q_raw=q_list, + o_tgt=o_list, + cos=cos_list, + sin=sin_list, + attention_mask=mask_cpu, + num_heads_per_layer=num_heads_per_layer, + head_dim_per_layer=head_dim_per_layer, + k_raw_tgt=k_raw_list, + v_raw_tgt=v_raw_list, + ) + + +def _collect_sequence( + verifier_model: torch.nn.Module, + drafter: DFlashDrafter, + input_ids: torch.Tensor, + *, + capture_legacy_kv: bool = False, + capture_attn_target: bool = True, + capture_raw_kv_in_attn_target: bool = False, +) -> CapturedSequence: + """Capture paired drafter + verifier data for one input sequence. + + Parameters + ---------- + capture_legacy_kv : bool + If True, capture verifier K/V via k_proj/v_proj hooks (used by + loss_type ∈ mse | cos_mag | combined). Default False — the v3 + attn_distill path doesn't need it. + capture_attn_target : bool + If True, capture per-layer Q + O_tgt + cos/sin/mask (used by + loss_type=attn_distill, the v3 default). + """ + if not (capture_legacy_kv or capture_attn_target): + raise ValueError( + "must capture at least one of legacy_kv or attn_target" + ) + + v_k = v_v = None + if capture_legacy_kv: + v_k, v_v = _capture_verifier_kv(verifier_model, input_ids) + + attn_target: Optional[AttentionTargetData] = None + if capture_attn_target: + attn_target = _capture_attention_target_data( + verifier_model, input_ids, + capture_raw_kv=capture_raw_kv_in_attn_target, + ) + + # Drafter K/V capture (always; cheap and small, ~5 MB per seq). + capture = _capture_drafter_kv( + verifier_model=verifier_model, + drafter=drafter, + input_ids=input_ids, + ) + k_flat = [k.flatten(-2, -1) for k in capture.keys] + v_flat = [v.flatten(-2, -1) for v in capture.values] + d_k = torch.stack(k_flat, dim=0)[:, 0] + d_v = torch.stack(v_flat, dim=0)[:, 0] + + return CapturedSequence( + seq_len=int(input_ids.size(1)), + drafter_k=d_k.detach(), + drafter_v=d_v.detach(), + verifier_k=[t.detach() for t in v_k] if v_k is not None else None, + verifier_v=[t.detach() for t in v_v] if v_v is not None else None, + attn_target=attn_target, + ) + + +def _attention_distillation_loss( + f_theta: FThetaProjection, + seq: CapturedSequence, + layers: Sequence[torch.nn.Module], + *, + apply_rotary_pos_emb: Any, + device: torch.device, + sample_positions: Optional[int] = None, + seed: Optional[int] = None, + diag_buf: Optional[Dict[str, float]] = None, + hybrid: bool = False, + lambda_k_dir: float = 1.0, + lambda_v_dir: float = 1.0, + lambda_k_mag: float = 0.1, + lambda_v_mag: float = 0.1, + skip_layer_indices: Optional[Sequence[int]] = None, +) -> torch.Tensor: + """Attention-output distillation loss (the v3 / one-shot principled loss). + + For each verifier layer ℓ: + + K_pred_ℓ = f_θ_K(drafter_KV)[ℓ] + V_pred_ℓ = f_θ_V(drafter_KV)[ℓ] + + Q_for_attn = q_norm(Q_raw_ℓ).view(B, T, H_q, D) → RoPE → transpose + K_for_attn = k_norm(K_pred_ℓ).view(B, T, H_kv, D) → RoPE → transpose + V_for_attn = v_norm(V_pred_ℓ).view(B, T, H_kv, D) → transpose + + GQA repeat K_for_attn, V_for_attn to H_q + O_inner = scaled_dot_product_attention(Q, K, V, mask) + O_pred = o_proj(O_inner.reshape(B, T, H_q*D)) + + loss_ℓ = MSE(O_pred, O_tgt_ℓ) # O_tgt captured during data + collection (verifier's actual + attn module post-o_proj output) + + Total loss = mean over layers. + + This is the principled training objective for K/V replacement: it + directly optimises "f_θ-injected K/V produces equivalent verifier + attention output". Unlike pure-MSE-on-K/V (v1) or cos+mag (v2), + this loss accounts for: + + * GQA: same num_heads/num_kv_heads/head_dim per layer + * RoPE: same positional encoding the verifier uses at inference + * Causal mask (and sliding-window for sliding layers): captured + from the verifier's own forward + * o_proj: every layer's downstream projection that f_θ K/V + ultimately propagates through + + Memory: per training step, K_pred/V_pred at full T positions are + needed for attention's K, V dims. We sample only the OUTPUT side + (where loss is evaluated) when ``sample_positions`` < T to save + on attention output + o_proj memory; this reduces gradient noise + only marginally because the loss is averaged across positions. + Default ``None`` ⇒ use all T output positions (recommended for + short sequences T ≤ 1024). + """ + if seq.attn_target is None: + raise RuntimeError( + "attn_distill loss requires CapturedSequence.attn_target; " + "call _collect_sequence with capture_attn_target=True" + ) + target = seq.attn_target + cfg = f_theta.config + T = seq.seq_len + + # f_θ forward (drafter K/V on CPU/GPU, f_θ on GPU). We pull drafter + # K/V to f_θ's device + cast to f_θ encoder dtype. + f_dtype = next(f_theta.parameters()).dtype + drafter_k = seq.drafter_k.to(device=device).unsqueeze(0) # [1, L_d, T, kv_dim] + drafter_v = seq.drafter_v.to(device=device).unsqueeze(0) + d_k_list = [] + d_v_list = [] + for li in range(cfg.drafter_num_layers): + k_per = drafter_k[:, li].view( + 1, T, cfg.drafter_num_kv_heads, cfg.drafter_head_dim, + ) + v_per = drafter_v[:, li].view( + 1, T, cfg.drafter_num_kv_heads, cfg.drafter_head_dim, + ) + d_k_list.append(k_per) + d_v_list.append(v_per) + pred_k_per_layer, pred_v_per_layer = f_theta.forward_kv_pack(d_k_list, d_v_list) + # pred_k_per_layer[ℓ]: [1, T, kv_heads_ℓ, head_dim_ℓ] in fp32 + + # Sample positions for output-side loss + if sample_positions is not None and sample_positions < T: + if seed is not None: + g = torch.Generator(device="cpu").manual_seed(seed) + else: + g = None + idx = torch.randperm(T, generator=g)[:sample_positions].to(device) + idx, _ = idx.sort() + else: + idx = None + + n_layers = cfg.verifier_num_layers + # S5 mode: exclude exact (full-attention) layers from the loss — at + # inference those layers use exact K/V, so f_θ need not fit them, and + # the freed capacity focuses on the sliding layers it must restore. + skip_set = {i for i in (skip_layer_indices or ()) if 0 <= i < n_layers} + n_used = max(n_layers - len(skip_set), 1) + loss = pred_k_per_layer[0].new_zeros(()) + diag = { + "mse_O_total": 0.0, "abs_O_target": 0.0, + # Hybrid-loss diagnostics (zero unless hybrid=True): + "k_dir_total": 0.0, "v_dir_total": 0.0, + "k_mag_total": 0.0, "v_mag_total": 0.0, + } + if hybrid and (target.k_raw_tgt is None or target.v_raw_tgt is None): + raise RuntimeError( + "hybrid=True requires AttentionTargetData with k_raw_tgt + v_raw_tgt; " + "set capture_raw_kv_in_attn_target=True in _collect_sequence." + ) + + for li in range(n_layers): + if li in skip_set: + continue + layer = layers[li] + attn = layer.self_attn + + # Move per-layer cached tensors to GPU (bf16 cache → cast to compute dtype) + compute_dtype = next(layer.parameters()).dtype + q_raw = target.q_raw[li].to(device=device, dtype=compute_dtype).unsqueeze(0) + o_tgt = target.o_tgt[li].to(device=device, dtype=compute_dtype).unsqueeze(0) + cos = target.cos[li].to(device=device, dtype=compute_dtype) + sin = target.sin[li].to(device=device, dtype=compute_dtype) + if cos.ndim == 2: + cos = cos.unsqueeze(0) + if sin.ndim == 2: + sin = sin.unsqueeze(0) + + n_heads = target.num_heads_per_layer[li] + head_dim = target.head_dim_per_layer[li] + kv_heads = cfg.layer_kv_heads[li] + kv_head_dim = cfg.layer_head_dims[li] + if kv_head_dim != head_dim: + # Sanity: f_θ's per-layer head_dim must match verifier's + # actual head_dim. (Both come from the verifier config.) + raise RuntimeError( + f"layer {li}: f_θ head_dim {kv_head_dim} != verifier {head_dim}" + ) + + # Q pipeline: q_norm → RoPE → transpose + Q = q_raw.view(1, T, n_heads, head_dim) + Q = attn.q_norm(Q) + Q = apply_rotary_pos_emb(Q, cos, sin, unsqueeze_dim=2) + Q = Q.transpose(1, 2) # [1, n_heads, T, head_dim] + + # K pipeline (f_θ output → norm → RoPE → transpose) + K_pred_pre = pred_k_per_layer[li].to(dtype=compute_dtype) # [1, T, kv_heads, head_dim] + K_pred_normed = attn.k_norm(K_pred_pre) # post-k_norm, pre-RoPE + K = apply_rotary_pos_emb(K_pred_normed, cos, sin, unsqueeze_dim=2) + K = K.transpose(1, 2) # [1, kv_heads, T, head_dim] + + # V pipeline (f_θ output → v_norm → transpose, no RoPE) + V_pred_pre = pred_v_per_layer[li].to(dtype=compute_dtype) + V_pred_normed = attn.v_norm(V_pred_pre) # post-v_norm + V = V_pred_normed.transpose(1, 2) # [1, kv_heads, T, head_dim] + + # GQA: repeat K, V to match num_heads + if n_heads != kv_heads: + n_rep = n_heads // kv_heads + if n_heads % kv_heads != 0: + raise RuntimeError( + f"layer {li}: n_heads {n_heads} not divisible by " + f"kv_heads {kv_heads}" + ) + K = K.repeat_interleave(n_rep, dim=1) + V = V.repeat_interleave(n_rep, dim=1) + + # Attention with the verifier's actual mask + scaling + scale = float(getattr(attn, "scaling", head_dim ** -0.5)) + # Use scaled_dot_product_attention; if attention_mask is None, + # use causal=True. + attn_mask = target.attention_mask + if attn_mask is None: + O_inner = F.scaled_dot_product_attention( + Q, K, V, scale=scale, is_causal=True, + ) + else: + attn_mask_dev = attn_mask.to(device=device, dtype=compute_dtype) + # attention_mask shapes vary (B, 1, T, T) or (B, T, T); align + # to what scaled_dot_product_attention accepts. + if attn_mask_dev.ndim == 4 and attn_mask_dev.size(0) == 1: + pass + elif attn_mask_dev.ndim == 3: + attn_mask_dev = attn_mask_dev.unsqueeze(1) + elif attn_mask_dev.ndim == 2: + attn_mask_dev = attn_mask_dev.unsqueeze(0).unsqueeze(0) + O_inner = F.scaled_dot_product_attention( + Q, K, V, attn_mask=attn_mask_dev, scale=scale, + ) + + # o_proj (linear, frozen weights → no grad through it) + O_inner = O_inner.transpose(1, 2).reshape(1, T, n_heads * head_dim).contiguous() + O_pred = attn.o_proj(O_inner) + + if idx is not None: + O_pred_eval = O_pred.index_select(1, idx) + O_tgt_eval = o_tgt.index_select(1, idx) + else: + O_pred_eval = O_pred + O_tgt_eval = o_tgt + + l_o = F.mse_loss(O_pred_eval.float(), O_tgt_eval.float()) + loss = loss + l_o + diag["mse_O_total"] += float(l_o.detach().item()) + diag["abs_O_target"] += float(O_tgt_eval.float().abs().mean().item()) + + # Hybrid-loss extension: direct K/V supervision in post-norm + # space (cosine direction) + pre-norm space (magnitude). Prevents + # the f_θ degeneracy exposed by 2026-06-10 alpha-sweep evidence + # (raw K/V rel_mse 1331×; k_norm hides scale errors from + # attn_distill loss; alpha-sweep shows recall=0 for any alpha<1.0 + # because off-scale f_θ K/V dominate raw-space mixing). + if hybrid: + k_raw_tgt = target.k_raw_tgt[li].to( + device=device, dtype=compute_dtype, + ).unsqueeze(0) + v_raw_tgt = target.v_raw_tgt[li].to( + device=device, dtype=compute_dtype, + ).unsqueeze(0) + # Reshape tgt to [1, T, kv_heads, head_dim] (same as pred) + K_tgt_pre = k_raw_tgt.view(1, T, kv_heads, head_dim) + V_tgt_pre = v_raw_tgt.view(1, T, kv_heads, head_dim) + + # Apply norm to tgt (same pipeline as pred) + K_tgt_normed = attn.k_norm(K_tgt_pre) + V_tgt_normed = attn.v_norm(V_tgt_pre) + + # Direction loss: cosine sim on POST-NORM K, V (per + # (position, head) vector, last-dim is head_dim). + cos_K = F.cosine_similarity( + K_pred_normed.float(), K_tgt_normed.float(), dim=-1, + ) + cos_V = F.cosine_similarity( + V_pred_normed.float(), V_tgt_normed.float(), dim=-1, + ) + l_k_dir = (1.0 - cos_K).mean() + l_v_dir = (1.0 - cos_V).mean() + + # Magnitude loss: MSE on PRE-NORM L2 norms, normalised by + # tgt's mean-square so loss is scale-comparable across + # layers with very different K/V magnitudes (sliding vs full). + pred_k_mag = K_pred_pre.float().norm(dim=-1) + tgt_k_mag = K_tgt_pre.float().norm(dim=-1) + pred_v_mag = V_pred_pre.float().norm(dim=-1) + tgt_v_mag = V_tgt_pre.float().norm(dim=-1) + denom_k = tgt_k_mag.pow(2).mean().clamp(min=1e-6) + denom_v = tgt_v_mag.pow(2).mean().clamp(min=1e-6) + l_k_mag = F.mse_loss(pred_k_mag, tgt_k_mag) / denom_k + l_v_mag = F.mse_loss(pred_v_mag, tgt_v_mag) / denom_v + + loss = loss + ( + lambda_k_dir * l_k_dir + lambda_v_dir * l_v_dir + + lambda_k_mag * l_k_mag + lambda_v_mag * l_v_mag + ) + diag["k_dir_total"] += float(l_k_dir.detach().item()) + diag["v_dir_total"] += float(l_v_dir.detach().item()) + diag["k_mag_total"] += float(l_k_mag.detach().item()) + diag["v_mag_total"] += float(l_v_mag.detach().item()) + + del K_tgt_pre, V_tgt_pre, K_tgt_normed, V_tgt_normed + + # Free GPU memory of cached per-layer tensors before next layer + del q_raw, o_tgt, cos, sin, Q, K, V, O_inner, O_pred + del K_pred_pre, K_pred_normed, V_pred_pre, V_pred_normed + + if diag_buf is not None: + diag_buf["mse_O_mean"] = diag["mse_O_total"] / n_used + diag_buf["abs_O_target_mean"] = diag["abs_O_target"] / n_used + if hybrid: + diag_buf["k_dir_mean"] = diag["k_dir_total"] / n_used + diag_buf["v_dir_mean"] = diag["v_dir_total"] / n_used + diag_buf["k_mag_mean"] = diag["k_mag_total"] / n_used + diag_buf["v_mag_mean"] = diag["v_mag_total"] / n_used + return loss / n_used + + +def _per_vector_cosine_mag_loss( + pred: torch.Tensor, tgt: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Cosine-similarity + magnitude-MSE loss between paired K (or V) vectors. + + Each vector is a single-head K (or V) at a single position, shape + ``[..., head_dim]``. Loss components: + + cos: 1 − cosine_similarity(pred, tgt) ∈ [0, 2] + mag: MSE(‖pred‖, ‖tgt‖) / mean(‖tgt‖)² (scale-normalised) + + Why this loss is correct for K/V projection + ------------------------------------------- + + Attention is ``softmax(QK^T / √d) · V``. For the verifier's + attention output to be preserved when K is replaced, we need: + + 1. ``Q · pred_K_p ≈ Q · tgt_K_p`` for every position p — this is + the **direction** of K_p relative to Q. Pure MSE on K_p does + not bound this; cosine sim does (Cauchy-Schwarz). + 2. The scale of ``Q · K_p`` across positions must be preserved + so softmax peaks at the same positions — this is the + **magnitude** of K_p. Mag-MSE handles this. + + For V: attention output is ``Σ a_p · V_p``; here both direction + and magnitude of V_p directly contribute to the output, so cosine + + magnitude on V is also the right structure. + + Returns the (combined_loss_scalar, cos_component, mag_component) so + callers can log per-component diagnostics during training. + """ + pred_f = pred.float() + tgt_f = tgt.float() + # Cosine on the last (head_dim) axis: shape [..., head_dim] → [...] + cos = F.cosine_similarity(pred_f, tgt_f, dim=-1).mean() + cos_loss = 1.0 - cos + # Magnitude: scalar L2 norm per vector, shape [..., 1] squeeze to [...]. + pred_mag = pred_f.norm(dim=-1) + tgt_mag = tgt_f.norm(dim=-1) + tgt_mag_mean_sq = tgt_mag.pow(2).mean().clamp(min=1e-6) + mag_loss = F.mse_loss(pred_mag, tgt_mag) / tgt_mag_mean_sq + return cos_loss + mag_loss, cos_loss.detach(), mag_loss.detach() + + +def _f_theta_loss( + f_theta: FThetaProjection, + seq: CapturedSequence, + *, + sample_positions: int = 256, + seed: Optional[int] = None, + loss_type: str = "attn_distill", + diag_buf: Optional[Dict[str, float]] = None, + layers: Optional[Sequence[torch.nn.Module]] = None, + apply_rotary_pos_emb: Optional[Any] = None, + device: Optional[torch.device] = None, + skip_layer_indices: Optional[Sequence[int]] = None, +) -> torch.Tensor: + """Compute the configured loss for one sequence (subsampled positions). + + Parameters + ---------- + loss_type : str + ``"attn_distill"`` — v3 default (one-shot principled): attention-output + distillation. Requires ``layers`` + + ``apply_rotary_pos_emb`` + ``device``. + ``"mse"`` — v1 MSE on raw K and V (kept for reproducibility). + ``"cos_mag"`` — v2 cosine + magnitude on K and V. + ``"combined"`` — v2 cosine + magnitude + 0.1× normalised MSE. + diag_buf : dict + Optional dict to receive per-component aggregates (cos_K_mean, + cos_V_mean, mag_K_mean, mag_V_mean, mse_mean, mse_O_mean) for logging. + """ + if loss_type in ("attn_distill", "attn_distill_hybrid"): + if layers is None or apply_rotary_pos_emb is None or device is None: + raise ValueError( + f"{loss_type} requires layers + apply_rotary_pos_emb + device" + ) + return _attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=apply_rotary_pos_emb, + device=device, + sample_positions=( + None if sample_positions <= 0 or sample_positions >= seq.seq_len + else sample_positions + ), + seed=seed, diag_buf=diag_buf, + hybrid=(loss_type == "attn_distill_hybrid"), + skip_layer_indices=skip_layer_indices, + ) + + if seq.verifier_k is None or seq.verifier_v is None: + raise RuntimeError( + f"loss_type={loss_type!r} requires legacy K/V capture; " + "ensure data collection ran with capture_legacy_kv=True" + ) + T = seq.seq_len + if seed is not None: + g = torch.Generator(device="cpu").manual_seed(seed) + else: + g = None + if T <= sample_positions: + idx = torch.arange(T, device=seq.drafter_k.device) + else: + idx = torch.randperm(T, generator=g)[:sample_positions].to( + seq.drafter_k.device, + ) + + d_k_sub = seq.drafter_k.index_select(1, idx).unsqueeze(0) + d_v_sub = seq.drafter_v.index_select(1, idx).unsqueeze(0) + cfg = f_theta.config + d_k_list, d_v_list = [], [] + for li in range(cfg.drafter_num_layers): + k_per = d_k_sub[:, li] + v_per = d_v_sub[:, li] + k_per = k_per.view( + 1, k_per.size(1), cfg.drafter_num_kv_heads, cfg.drafter_head_dim, + ) + v_per = v_per.view( + 1, v_per.size(1), cfg.drafter_num_kv_heads, cfg.drafter_head_dim, + ) + d_k_list.append(k_per) + d_v_list.append(v_per) + + pred_k, pred_v = f_theta.forward_kv_pack(d_k_list, d_v_list) + + layer_kv_heads = cfg.layer_kv_heads + layer_head_dims = cfg.layer_head_dims + idx_pos = idx.to(seq.verifier_k[0].device) + loss = pred_k[0].new_zeros(()) + n_layers = cfg.verifier_num_layers + + diag = { + "cos_K_total": 0.0, "cos_V_total": 0.0, + "mag_K_total": 0.0, "mag_V_total": 0.0, + "mse_K_total": 0.0, "mse_V_total": 0.0, + } + + for li in range(n_layers): + v_k_sub = seq.verifier_k[li].index_select(0, idx_pos) + v_v_sub = seq.verifier_v[li].index_select(0, idx_pos) + tgt_k = v_k_sub.view( + 1, v_k_sub.size(0), layer_kv_heads[li], layer_head_dims[li], + ).float() + tgt_v = v_v_sub.view( + 1, v_v_sub.size(0), layer_kv_heads[li], layer_head_dims[li], + ).float() + pred_k_li = pred_k[li].float() + pred_v_li = pred_v[li].float() + + if loss_type == "mse": + l_k = F.mse_loss(pred_k_li, tgt_k) + l_v = F.mse_loss(pred_v_li, tgt_v) + loss = loss + l_k + l_v + diag["mse_K_total"] += float(l_k.detach().item()) + diag["mse_V_total"] += float(l_v.detach().item()) + elif loss_type == "cos_mag": + l_k, c_k, m_k = _per_vector_cosine_mag_loss(pred_k_li, tgt_k) + l_v, c_v, m_v = _per_vector_cosine_mag_loss(pred_v_li, tgt_v) + loss = loss + l_k + l_v + diag["cos_K_total"] += float(c_k.item()) + diag["cos_V_total"] += float(c_v.item()) + diag["mag_K_total"] += float(m_k.item()) + diag["mag_V_total"] += float(m_v.item()) + elif loss_type == "combined": + l_cm_k, c_k, m_k = _per_vector_cosine_mag_loss(pred_k_li, tgt_k) + l_cm_v, c_v, m_v = _per_vector_cosine_mag_loss(pred_v_li, tgt_v) + l_mse_k = F.mse_loss(pred_k_li, tgt_k) + l_mse_v = F.mse_loss(pred_v_li, tgt_v) + # Weight: cos+mag dominate (×1.0), MSE is a stability term (×0.1). + # MSE is normalised by tgt's own variance so it doesn't blow up + # for high-magnitude layers. + tgt_var_k = tgt_k.var().clamp(min=1e-6) + tgt_var_v = tgt_v.var().clamp(min=1e-6) + mse_norm_k = l_mse_k / tgt_var_k + mse_norm_v = l_mse_v / tgt_var_v + loss = loss + l_cm_k + l_cm_v + 0.1 * (mse_norm_k + mse_norm_v) + diag["cos_K_total"] += float(c_k.item()) + diag["cos_V_total"] += float(c_v.item()) + diag["mag_K_total"] += float(m_k.item()) + diag["mag_V_total"] += float(m_v.item()) + diag["mse_K_total"] += float(l_mse_k.detach().item()) + diag["mse_V_total"] += float(l_mse_v.detach().item()) + else: + raise ValueError( + f"unknown loss_type {loss_type!r} " + f"(want mse | cos_mag | combined)" + ) + + if diag_buf is not None: + for k, v in diag.items(): + diag_buf[k] = v / max(n_layers, 1) + return loss / (2.0 * n_layers) + + +# --------------------------------------------------------------------------- +# v2: synthetic NIAH-style training prompts +# --------------------------------------------------------------------------- + +# Same vocabulary the eval uses, reproduced here so training corpus +# generation is independent of the eval module (avoid test contamination +# via shared seeds). PR #94's `make_niah_dataset` uses these patterns; +# we use distinct random seeds + extra word lists so training NIAH never +# reuses an eval-seed needle. +_NIAH_TRAIN_KEY_WORDS = ( + # Greek (overlaps with eval but seeds differ → independent samples) + "ALPHA", "BETA", "GAMMA", "DELTA", "EPSILON", "ZETA", "ETA", "THETA", + "IOTA", "KAPPA", "LAMBDA", "MU", "NU", "XI", "OMICRON", "PI", + "RHO", "SIGMA", "TAU", "UPSILON", "PHI", "CHI", "PSI", "OMEGA", + # Botanical (extra — different from eval's set so no needle reuse) + "ROSE", "TULIP", "DAISY", "ORCHID", "JASMINE", "LILAC", "POPPY", + "VIOLET", "IRIS", "PEONY", "DAHLIA", "ASTER", "SAGE", "BASIL", + "MINT", "THYME", "OAK", "MAPLE", "PINE", "BIRCH", "CEDAR", +) + +_NIAH_TRAIN_FILLER_LINES = ( + "The afternoon sun cast long shadows across the empty courtyard.", + "She turned the pages slowly, savouring each illustration.", + "Most of the equations balanced, though one stubbornly refused to.", + "Light wind stirred the paper notes pinned to the corkboard.", + "The hallway smelled faintly of old wood and lemon polish.", + "Conversations drifted in from the kitchen but no one was listening.", + "Three bicycles leaned against the fence in a perfect row.", + "He paused, considered the diagram, and added another arrow.", + "Outside, snow continued to fall gently and without urgency.", + "The library catalogue was newer than the books it described.", + "A single candle flickered on the mantelpiece beside the clock.", + "Half the bookshelf was devoted entirely to volumes about birds.", + "The path narrowed then widened then narrowed again unpredictably.", + "Faint lines of older handwriting were visible beneath the print.", + "Someone had circled a paragraph in red on the third page.", + "The cat watched the rain from the windowsill without moving.", + "Each measurement was double-checked but a few still seemed wrong.", + "Brass instruments lay arranged along the wall in increasing size.", + "The map was old but the labelling was unexpectedly precise.", + "Footsteps echoed along the corridor before fading into silence.", +) + + +def _make_niah_training_prompts( + n_prompts: int, *, seed: int, + haystack_min_lines: int = 30, haystack_max_lines: int = 90, +) -> List[str]: + """Generate synthetic NIAH-style training prompts. + + Pattern matches PR #94's eval harness: + + Read the following text carefully, then answer the question + at the end. + + + + ... + The secret code is -. + ... + + + Question: What is the secret code? + + The needle position is random within the haystack. Independent + seed from the eval (default seed=0 here vs eval default seed=42), + extra filler lines + word list → no needle reuse. + """ + rng = random.Random(seed) + prompts: List[str] = [] + for _ in range(n_prompts): + n_lines = rng.randint(haystack_min_lines, haystack_max_lines) + keyword = rng.choice(_NIAH_TRAIN_KEY_WORDS) + number = rng.randint(1000, 9999) + needle = f"The secret code is {keyword}-{number}." + needle_pos = rng.randint(1, n_lines - 2) + lines: List[str] = [] + for i in range(n_lines): + if i == needle_pos: + lines.append(needle) + else: + lines.append(rng.choice(_NIAH_TRAIN_FILLER_LINES)) + body = "\n".join(lines) + prompt = ( + "Read the following text carefully, then answer the question " + "at the end.\n\n" + f"{body}\n\n" + "Question: What is the secret code?" + ) + prompts.append(prompt) + return prompts + + +# --------------------------------------------------------------------------- +# v2: cosine LR schedule with linear warmup +# --------------------------------------------------------------------------- + + +def _lr_at_step(step: int, *, peak_lr: float, total_steps: int, + warmup_steps: int, schedule: str) -> float: + """Return the LR at ``step`` (1-indexed) for the configured schedule. + + schedule="const": always peak_lr + schedule="cosine": linear warmup over warmup_steps, then cosine + decay to peak_lr / 100 over the remainder. + """ + if schedule == "const": + return peak_lr + if schedule != "cosine": + raise ValueError(f"unknown schedule {schedule!r}") + if warmup_steps > 0 and step <= warmup_steps: + return peak_lr * (step / max(warmup_steps, 1)) + progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1) + progress = min(max(progress, 0.0), 1.0) + floor_lr = peak_lr * 0.01 + cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) + return floor_lr + (peak_lr - floor_lr) * cosine + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-id", default="google/gemma-4-26B-A4B-it") + ap.add_argument("--drafter-id", default="models/dflash-kakeya-baseline") + # v2 defaults: 5× more steps, 4× longer sequences, cosine LR, NIAH on. + # v1 reproduction: --steps 4000 --gen-len 128 --lr-schedule const + # --no-niah-prompts --loss-type mse + ap.add_argument("--steps", type=int, default=20000) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument( + "--lr-schedule", default="cosine", choices=["const", "cosine"], + help="LR schedule (v2 default cosine; v1 used const)", + ) + ap.add_argument( + "--warmup-steps", type=int, default=500, + help="Linear warmup steps for cosine schedule (ignored if const)", + ) + ap.add_argument("--weight-decay", type=float, default=0.01) + ap.add_argument("--n-prompts", type=int, default=64, + help="General prompts from PROMPTS list (capped at 62)") + ap.add_argument( + "--n-niah-prompts", type=int, default=64, + help="(v2) Synthetic NIAH-style prompts to add to the corpus. " + "Set 0 with --no-niah-prompts to reproduce v1.", + ) + ap.add_argument( + "--no-niah-prompts", action="store_true", + help="Disable NIAH synthetic prompts (v1 reproduction mode)", + ) + ap.add_argument("--niah-min-lines", type=int, default=30) + ap.add_argument("--niah-max-lines", type=int, default=90) + ap.add_argument("--gen-len", type=int, default=512, + help="Tokens generated per prompt during data collection") + ap.add_argument( + "--sample-positions", type=int, default=0, + help="Random output-side positions per training step. 0 (default) " + "= use all T positions. For legacy losses (mse/cos_mag/combined) " + "default falls back to 256 if 0 is passed.", + ) + ap.add_argument( + "--loss-type", default="attn_distill_hybrid", + choices=["attn_distill", "attn_distill_hybrid", "mse", "cos_mag", "combined"], + help="Training loss. RECOMMENDED post-2026-06-10 alpha-sweep evidence: " + "attn_distill_hybrid (= attn_distill + direct K/V direction + " + "magnitude constraints). Prevents f_θ collapse degeneracy where " + "raw K/V are 36× off-scale but k_norm hides it from attn_distill " + "alone. attn_distill is the v3 design (vulnerable to the " + "collapse). Others are legacy v1/v2.", + ) + ap.add_argument( + "--lambda-k-dir", type=float, default=1.0, + help="Hybrid loss weight on K direction (cosine post-norm).", + ) + ap.add_argument( + "--lambda-v-dir", type=float, default=1.0, + help="Hybrid loss weight on V direction (cosine post-norm).", + ) + ap.add_argument( + "--lambda-k-mag", type=float, default=0.1, + help="Hybrid loss weight on K magnitude (pre-norm L2 norm MSE, normalised).", + ) + ap.add_argument( + "--lambda-v-mag", type=float, default=0.1, + help="Hybrid loss weight on V magnitude.", + ) + ap.add_argument( + "--s5-exact-full-attn", action="store_true", + help="S5 mode: exclude the verifier's full-attention (global, max " + "head_dim) layers from the f_θ loss. At inference those layers " + "use exact K/V (--s5-exact-full-attn in the eval), so f_θ only " + "needs to restore the sliding layers; this focuses capacity on " + "them. Pairs with the S5 integrated-NIAH eval.", + ) + ap.add_argument( + "--init-from", default=None, + help="Optional path to existing f_θ checkpoint dir to warm-start from " + "(e.g. --init-from results/research/f_theta_v3_attn_distill). " + "Loads weights then continues training. Useful to fine-tune an " + "attn_distill checkpoint with hybrid loss for fewer steps.", + ) + ap.add_argument( + "--rank", type=int, default=None, + help="f_θ encoder bottleneck. Default 768 for attn_distill, 256 " + "for legacy losses (v1/v2). Override to override default.", + ) + ap.add_argument("--save", default="results/research/f_theta_v1") + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--log-every", type=int, default=50) + ap.add_argument("--eval-every", type=int, default=500) + args = ap.parse_args() + + random.seed(args.seed) + torch.manual_seed(args.seed) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print( + "[f_theta-train] WARNING: no CUDA detected; running on CPU. " + "This will be very slow on the production-scale verifier.", + file=sys.stderr, + ) + dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + + # Resolve rank default per loss type (rank ↑ for attn_distill = more + # f_θ capacity; legacy losses keep v1's 256 for direct comparability). + if args.rank is None: + args.rank = 768 if args.loss_type in ("attn_distill", "attn_distill_hybrid") else 256 + print( + f"[f_theta-train] using rank={args.rank} (loss_type={args.loss_type})", + file=sys.stderr, + ) + + from transformers import AutoModelForCausalLM, AutoTokenizer + # Eager attention is required for attn_distill so we can hook the + # attention module's pre/post forward and capture position_embeddings + # + attention_mask + post-o_proj output. SDPA fuses these and breaks + # the hook contract. For legacy losses, sdpa is fine (and faster). + attn_impl = "eager" if args.loss_type in ("attn_distill", "attn_distill_hybrid") else "sdpa" + apply_rotary_pos_emb = None + if args.loss_type in ("attn_distill", "attn_distill_hybrid"): + from transformers.models.gemma4.modeling_gemma4 import ( # type: ignore + apply_rotary_pos_emb, + ) + + print(f"[f_theta-train] loading verifier {args.verifier_id} (attn={attn_impl})", + file=sys.stderr, flush=True) + tok = AutoTokenizer.from_pretrained(args.verifier_id) + verifier = AutoModelForCausalLM.from_pretrained( + args.verifier_id, dtype=dtype, attn_implementation=attn_impl, + device_map="auto" if device.type == "cuda" else None, + ).eval() + for p in verifier.parameters(): + p.requires_grad_(False) + + print(f"[f_theta-train] loading drafter {args.drafter_id}", + file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=dtype) + drafter = drafter.to(device).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + + # Derive f_θ config from drafter + verifier shapes. Gemma 4's config + # nests decoder dims under .text_config, so resolve it first. + v_cfg = resolve_text_config(verifier.config) + # Read per-layer (head_dim, KV-head count) directly off the decoder + # layers. Gemma 4 uses head_dim=256 / 8 KV heads on sliding layers + # and head_dim=512 (global_head_dim) / 2 KV heads on full-attention + # layers (where v_proj is None: K == V). + v_layers = get_verifier_decoder(verifier).layers + layer_head_dims = tuple(int(layer.self_attn.head_dim) for layer in v_layers) + layer_kv_heads = tuple( + layer.self_attn.k_proj.out_features // hd + for layer, hd in zip(v_layers, layer_head_dims) + ) + uniform_heads = len(set(layer_kv_heads)) == 1 + uniform_dims = len(set(layer_head_dims)) == 1 + f_cfg = FThetaConfig( + drafter_num_layers=drafter.cfg.num_hidden_layers, + drafter_num_kv_heads=drafter.cfg.num_key_value_heads, + drafter_head_dim=drafter.cfg.head_dim, + verifier_num_layers=v_cfg.num_hidden_layers, + verifier_num_kv_heads=layer_kv_heads[0], + verifier_head_dim=layer_head_dims[0], + rank=args.rank, + verifier_layer_kv_heads=None if uniform_heads else layer_kv_heads, + verifier_layer_head_dims=None if uniform_dims else layer_head_dims, + ) + print(f"[f_theta-train] verifier per-layer kv heads: {layer_kv_heads}", + file=sys.stderr) + print(f"[f_theta-train] verifier per-layer head dims: {layer_head_dims}", + file=sys.stderr) + # S5 mode: full-attention (global) layers = those with the max head_dim. + s5_skip_layers = None + if args.s5_exact_full_attn and len(set(layer_head_dims)) > 1: + _max_hd = max(layer_head_dims) + s5_skip_layers = [i for i, hd in enumerate(layer_head_dims) if hd == _max_hd] + print(f"[f_theta-train] S5 mode: excluding full-attention layers " + f"{s5_skip_layers} from f_θ loss (kept exact at inference)", + file=sys.stderr) + print(f"[f_theta-train] f_θ config: {f_cfg}", file=sys.stderr) + + f_theta = FThetaProjection(f_cfg).to(device, dtype=torch.float32) + n_params = sum(p.numel() for p in f_theta.parameters()) + + # Warm-start from existing checkpoint (e.g. fine-tuning attn_distill + # with hybrid loss for fewer steps without re-collecting K/V). + if args.init_from: + from inference_engine.v04.f_theta import FThetaProjection as _FT + print(f"[f_theta-train] warm-start: loading weights from {args.init_from}", + file=sys.stderr) + warm = _FT.from_pretrained(args.init_from, dtype=torch.float32, device=device) + # Validate config compatibility + if warm.config != f_cfg: + print( + f"[f_theta-train] WARNING: init-from config differs from " + f"current — this is OK if shapes match. \n" + f" init-from: {warm.config}\n" + f" current: {f_cfg}", + file=sys.stderr, + ) + f_theta.load_state_dict(warm.state_dict()) + del warm + print("[f_theta-train] warm-start loaded; continuing from those weights", + file=sys.stderr) + print(f"[f_theta-train] f_θ params: {n_params:,}", file=sys.stderr) + + # ---------------- Build training corpus (PROMPTS + optional NIAH) ---------------- + n_general = min(args.n_prompts, len(PROMPTS)) + n_niah = 0 if args.no_niah_prompts else max(args.n_niah_prompts, 0) + corpus_prompts: List[str] = list(PROMPTS[:n_general]) + if n_niah > 0: + # Use args.seed + 1000 so NIAH seed is reproducible but distinct + # from any other seed in the system. + niah_prompts = _make_niah_training_prompts( + n_niah, seed=args.seed + 1000, + haystack_min_lines=args.niah_min_lines, + haystack_max_lines=args.niah_max_lines, + ) + corpus_prompts.extend(niah_prompts) + print( + f"[f_theta-train] corpus: {n_general} general + {n_niah} NIAH " + f"= {len(corpus_prompts)} prompts (NIAH seed={args.seed + 1000})", + file=sys.stderr, + ) + else: + print( + f"[f_theta-train] corpus: {n_general} general prompts " + f"(NIAH disabled — v1 reproduction mode)", + file=sys.stderr, + ) + + # ---------------- Data collection ---------------- + capture_legacy_kv = args.loss_type in ("mse", "cos_mag", "combined") + capture_attn_target = args.loss_type in ("attn_distill", "attn_distill_hybrid") + capture_raw_kv_in_attn_target = args.loss_type == "attn_distill_hybrid" + print( + f"[f_theta-train] data capture: legacy_kv={capture_legacy_kv} " + f"attn_target={capture_attn_target}", + file=sys.stderr, + ) + print(f"[f_theta-train] collecting from {len(corpus_prompts)} prompts ...", + file=sys.stderr, flush=True) + sequences: List[CapturedSequence] = [] + t0 = time.perf_counter() + eos_ids = {tok.eos_token_id} if tok.eos_token_id is not None else set() + for pi, prompt in enumerate(corpus_prompts): + msgs = [{"role": "user", "content": prompt}] + enc = tok.apply_chat_template( + msgs, add_generation_prompt=True, tokenize=True, return_tensors="pt", + ) + if hasattr(enc, "keys"): + enc = enc["input_ids"] + # Greedy AR extension. For NIAH prompts the haystack alone is + # already long; we still extend by gen_len to cover the answer + # region — the answer position is the lexically critical one + # for f_θ to reproduce. + with torch.no_grad(): + cur = enc.to(device) + for _ in range(args.gen_len): + out = verifier(input_ids=cur, use_cache=False) + nxt = int(torch.argmax(out.logits[0, -1]).item()) + cur = torch.cat([cur, torch.tensor([[nxt]], device=device)], dim=1) + if nxt in eos_ids: + break + full_ids = cur + + seq = _collect_sequence( + verifier, drafter, full_ids, + capture_legacy_kv=capture_legacy_kv, + capture_attn_target=capture_attn_target, + capture_raw_kv_in_attn_target=capture_raw_kv_in_attn_target, + ) + sequences.append(seq) + if (pi + 1) % 10 == 0 or pi == len(corpus_prompts) - 1: + print( + f"[f_theta-train] collected {pi + 1}/{len(corpus_prompts)}, " + f"latest seq_len={seq.seq_len}", + file=sys.stderr, + ) + collect_elapsed = time.perf_counter() - t0 + print(f"[f_theta-train] data collection done in {collect_elapsed:.0f}s", + file=sys.stderr) + + # ---------------- Training ---------------- + # Resolve sample_positions: 0 ⇒ full-T for attn_distill (the design + # choice — every position contributes); fall back to 256 for legacy + # losses (memory reduction matters there). + if args.sample_positions <= 0: + args.sample_positions = ( + 0 if args.loss_type in ("attn_distill", "attn_distill_hybrid") else 256 + ) + print( + f"[f_theta-train] training: loss_type={args.loss_type} " + f"schedule={args.lr_schedule} (warmup={args.warmup_steps}) " + f"steps={args.steps} peak_lr={args.lr} " + f"sample_positions={args.sample_positions}", + file=sys.stderr, + ) + optimizer = torch.optim.AdamW( + f_theta.parameters(), lr=args.lr, weight_decay=args.weight_decay, + ) + losses_window: List[float] = [] + initial_loss: Optional[float] = None + final_diag: Dict[str, float] = {} + f_theta.train() + t0 = time.perf_counter() + for step in range(1, args.steps + 1): + # Set per-step LR + cur_lr = _lr_at_step( + step, peak_lr=args.lr, total_steps=args.steps, + warmup_steps=args.warmup_steps, schedule=args.lr_schedule, + ) + for g in optimizer.param_groups: + g["lr"] = cur_lr + + seq = random.choice(sequences) + diag_buf: Dict[str, float] = {} + loss = _f_theta_loss( + f_theta, seq, + sample_positions=args.sample_positions, + loss_type=args.loss_type, + diag_buf=diag_buf, + layers=(v_layers if args.loss_type in ("attn_distill", "attn_distill_hybrid") else None), + apply_rotary_pos_emb=apply_rotary_pos_emb, + device=device, + skip_layer_indices=s5_skip_layers, + ) + if initial_loss is None: + initial_loss = float(loss.item()) + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(f_theta.parameters(), 1.0) + optimizer.step() + losses_window.append(float(loss.item())) + final_diag = diag_buf # last step's per-component breakdown + if step % args.log_every == 0: + recent = losses_window[-args.log_every:] + extra_msg = "" + if args.loss_type in ("cos_mag", "combined"): + extra_msg = ( + f" cosK={diag_buf.get('cos_K_total', 0):.4f}" + f" cosV={diag_buf.get('cos_V_total', 0):.4f}" + ) + elif args.loss_type in ("attn_distill", "attn_distill_hybrid"): + # mse_O_mean is the per-layer attn-output MSE; abs_O_target + # is the magnitude of O_tgt (so MSE/abs is "noise ratio"). + mse_o = diag_buf.get("mse_O_mean", 0) + abs_o = diag_buf.get("abs_O_target_mean", 1e-6) + extra_msg = ( + f" mseO={mse_o:.6f}" + f" |O_tgt|={abs_o:.4f}" + f" ratio={mse_o / max(abs_o ** 2, 1e-12):.4f}" + ) + if args.loss_type == "attn_distill_hybrid": + extra_msg += ( + f" kDir={diag_buf.get('k_dir_mean', 0):.4f}" + f" vDir={diag_buf.get('v_dir_mean', 0):.4f}" + f" kMag={diag_buf.get('k_mag_mean', 0):.4f}" + f" vMag={diag_buf.get('v_mag_mean', 0):.4f}" + ) + print( + f"[f_theta-train] step={step} lr={cur_lr:.2e} " + f"loss={sum(recent)/len(recent):.6f} " + f"(init={initial_loss:.6f}){extra_msg}", + file=sys.stderr, flush=True, + ) + train_elapsed = time.perf_counter() - t0 + + # ---------------- Save ---------------- + f_theta.eval() + f_theta.save_pretrained(args.save) + final_loss = sum(losses_window[-args.log_every:]) / max(len(losses_window[-args.log_every:]), 1) + + report = { + "kind": "k3_f_theta_train", + "schema_version": 2, + "config": vars(args), + "f_theta_config": f_cfg.to_json_dict(), + "n_params": n_params, + "n_sequences": len(sequences), + "n_general_prompts": n_general, + "n_niah_prompts": n_niah, + "collect_seconds": collect_elapsed, + "train_seconds": train_elapsed, + "initial_loss": initial_loss, + "final_loss": final_loss, + "loss_reduction_factor": ( + initial_loss / final_loss if final_loss > 0 else None + ), + # Per-component diagnostic at end of training. For combined / cos_mag + # losses, cosK_total close to 0.0 (≈ cos sim → 1.0) and cosV_total + # close to 0.0 indicates good direction alignment. For combined, + # mse_K_total + mse_V_total is the raw MSE for diff-ability with v1. + "final_diagnostic": final_diag, + "loss_type": args.loss_type, + "lr_schedule": args.lr_schedule, + } + Path(args.save).mkdir(parents=True, exist_ok=True) + Path(f"{args.save}.json").write_text(json.dumps(report, indent=2)) + print( + f"[f_theta-train] DONE in {train_elapsed:.0f}s; " + f"initial_loss={initial_loss:.6f} final_loss={final_loss:.6f} " + f"reduction={report['loss_reduction_factor']:.2f}× " + f"-> {args.save}", + file=sys.stderr, + ) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/research/k3_integrated_niah_eval.py b/scripts/research/k3_integrated_niah_eval.py new file mode 100644 index 00000000..bd474123 --- /dev/null +++ b/scripts/research/k3_integrated_niah_eval.py @@ -0,0 +1,523 @@ +"""K3 Block B + C integrated NIAH eval — the complete Kakeya inference +engine product evidence on CUDA. + +This script is the **final K3 product gate**: it combines +:class:`inference_engine.v04.cross_model_dlm_verifier.CrossModelDLMRestoredVerifier` +(verifier with sink+window cache + drafter K/V Restoration via f_θ) +with the K1.E NIAH evaluation harness (effective_attention_window / +recall / memory metrics). + +Architecture under test: + + verifier (Gemma 4 26B-A4B): + └─ sink+window local KV cache (sink=4 + window=64 default) + └─ K/V at evicted positions injected via f_θ projection of + drafter K/V + + drafter (DFlash 0.4B, alignment-trained baseline at + models/dflash-kakeya-baseline/): + └─ runs full forward over input_ids with verifier embed_tokens + └─ K/V at every layer at every position captured + └─ projected to verifier K/V space via trained f_θ + +What this validates (per ADR 0008 §11.8 release gates): + + 1. **Architectural correctness**: + ``effective_attention_fraction = 1.0`` at every NIAH ladder rung. + Verifier "sees" the full context despite holding only sink+window + in its local cache. Falsifies "K/V Restoration is just + decoration"; proves the architecture's load-bearing claim. + + 2. **Memory bounded**: + Sustained verifier KV-cache memory ≤ O(sink+window) regardless + of input length. Compared against full-attention oracle's KV + cache size, the K3 cross-model path delivers the memory + savings claim. + + 3. **Recall preservation**: + Mid-context recall on NIAH samples vs the full-attention oracle. + ADR §11.8 1a: ``|recall_v04 - recall_oracle| ≤ 5pp`` at every + rung. This is the architecturally-meaningful gate (independent + of base-model long-context capability). + +This is the K3 production-scale evidence. It's the integrated test +that PR #102 (Mac MLX spec decode) doesn't perform. + +Usage (vast.ai H200 / H100): + + HF_TOKEN=hf_xxx PYTHONPATH=.:sdks/python python3 \\ + scripts/research/k3_integrated_niah_eval.py \\ + --verifier-id google/gemma-4-26B-A4B-it \\ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \\ + --f-theta-dir results/research/f_theta_v1 \\ + --n-samples 10 --haystack-min-lines 60 --haystack-max-lines 80 \\ + --sink-size 4 --window-size 64 \\ + --output results/research/k3_integrated_niah_.json + +JSON output mirrors K1.E NIAH harness schema (per_config recall, +attention_window, memory) so it diff-able against PR #94's +ladder evidence + PR #93's CUDA baselines. +""" + +from __future__ import annotations + +import argparse +import dataclasses +import json +import math +import random +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from inference_engine.v04 import ( + CrossModelDLMRestoredVerifier, + DFlashDrafter, + FThetaProjection, + NIAHSample, + aggregate_attention_window_metrics, + aggregate_recall, + compute_effective_attention_window, + make_niah_dataset, + recall_predicate, + record_memory, + reset_memory_peak, +) + + +def parse_args() -> argparse.Namespace: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-id", default="google/gemma-4-26B-A4B-it") + ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash") + ap.add_argument("--f-theta-dir", required=True, + help="Directory containing f_theta_config.json + f_theta_weights.pt") + ap.add_argument("--n-samples", type=int, default=10) + ap.add_argument("--haystack-min-lines", type=int, default=60) + ap.add_argument("--haystack-max-lines", type=int, default=80) + ap.add_argument("--sink-size", type=int, default=4) + ap.add_argument("--window-size", type=int, default=64) + ap.add_argument("--max-new-tokens", type=int, default=24) + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--output", default=None) + ap.add_argument( + "--skip-oracle", action="store_true", + help="Skip the full-attention oracle baseline (saves time but " + "loses the |delta vs oracle| gate signal).", + ) + ap.add_argument( + "--identity-restore", action="store_true", + help="Diagnostic: restore evicted positions with the verifier's " + "OWN true pre-norm K/V instead of the f_θ projection. Under " + "this mode cross-model recall should match the oracle — it " + "isolates 'is the restoration machinery correct?' from 'is " + "f_θ accurate enough?'.", + ) + ap.add_argument( + "--s5-exact-full-attn", action="store_true", + help="S5: keep the verifier's full-attention (global, head_dim 512) " + "layers' K/V EXACT at evicted positions (not f_θ-restored). " + "Only the sliding layers go through f_θ. The full-attention " + "layers are the recall-critical ones f_θ cannot reconstruct; " + "for long context (needle outside the sliding window) recall " + "flows only through them, so exact K/V there should restore " + "recall. Memory cost: store those ~5 layers' KV (~9% of full).", + ) + ap.add_argument( + "--mix-alpha-sweep", default="", + help="S6 fidelity→recall diagnostic. Comma-separated alphas in " + "[0,1]. At each alpha, evicted-position K/V = alpha*true + " + "(1-alpha)*f_θ. alpha=0 is pure f_θ, alpha=1 is identity " + "(oracle-equivalent). Maps recall vs residual K/V error so we " + "can read off the fidelity threshold recall needs. Runs the " + "sweep in a single model load and writes a sweep JSON, then " + "exits (skips the normal cross-model/oracle run).", + ) + return ap.parse_args() + + +def main() -> int: + args = parse_args() + random.seed(args.seed) + torch.manual_seed(args.seed) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device.type == "cpu": + print( + "[k3-integrated] WARNING: CUDA not available; " + "running on CPU will be very slow on production scale.", + file=sys.stderr, + ) + dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + + # ---------- Verifier (CUDA bf16) ---------- + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.models.gemma4.modeling_gemma4 import ( # type: ignore + apply_rotary_pos_emb, eager_attention_forward, ALL_ATTENTION_FUNCTIONS, + ) + + print(f"[k3-integrated] loading verifier {args.verifier_id}", + file=sys.stderr, flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.verifier_id) + verifier = AutoModelForCausalLM.from_pretrained( + args.verifier_id, dtype=dtype, attn_implementation="eager", + device_map="auto" if device.type == "cuda" else None, + ).eval() + for p in verifier.parameters(): + p.requires_grad_(False) + + # ---------- Drafter (CUDA bf16) ---------- + print(f"[k3-integrated] loading drafter {args.drafter_id}", + file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=dtype) + drafter = drafter.to(device).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + + # ---------- f_θ checkpoint ---------- + print(f"[k3-integrated] loading f_θ from {args.f_theta_dir}", + file=sys.stderr, flush=True) + f_theta = FThetaProjection.from_pretrained( + args.f_theta_dir, dtype=torch.float32, device=device, + ) + + # ---------- Cross-model wrapper ---------- + exact_layers = None + if args.s5_exact_full_attn: + from inference_engine.v04.cross_model_dlm_verifier import ( + full_attention_layer_indices, + ) + exact_layers = full_attention_layer_indices(verifier) + print(f"[k3-integrated] S5: keeping full-attention layers exact: " + f"{exact_layers}", file=sys.stderr) + cross_verifier = CrossModelDLMRestoredVerifier( + verifier_model=verifier, + drafter=drafter, + f_theta=f_theta, + sink_size=args.sink_size, + window_size=args.window_size, + exact_layer_indices=exact_layers, + ) + print(f"[k3-integrated] cross-model verifier ready " + f"(sink={args.sink_size}, window={args.window_size})", + file=sys.stderr) + + if args.identity_restore: + # Diagnostic: restore evicted positions with the verifier's own + # true pre-norm K/V (not f_θ). Validates the restoration + # machinery independent of f_θ accuracy. + from inference_engine.v04.cross_model_dlm_verifier import ( + capture_verifier_own_kv, + ) + cross_verifier.project_drafter_kv = ( + lambda ids: capture_verifier_own_kv(verifier, ids) + ) + print("[k3-integrated] IDENTITY-RESTORE diagnostic enabled " + "(evicted K/V come from verifier's own k_proj/v_proj)", + file=sys.stderr) + + # ---------- NIAH dataset ---------- + samples: List[NIAHSample] = make_niah_dataset( + n_samples=args.n_samples, + haystack_min_lines=args.haystack_min_lines, + haystack_max_lines=args.haystack_max_lines, + seed=args.seed, + ) + + # Encode prompts via chat template (ADR 0008 §2.4: the runtime is + # template-free; the harness applies the template), matching the + # K1.E NIAH harness convention. + def encode_chat(prompt_text: str) -> torch.Tensor: + messages = [{"role": "user", "content": prompt_text}] + ids = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + return_tensors="pt", + ) + if hasattr(ids, "keys"): # BatchEncoding / dict + ids = ids["input_ids"] + elif isinstance(ids, list): + ids = torch.tensor([ids]) + return ids.to(device) + + sample_ids = [encode_chat(s.prompt_text) for s in samples] + seq_lens = [int(t.size(1)) for t in sample_ids] + eos_id = tokenizer.eos_token_id + print( + f"[k3-integrated] dataset: {len(samples)} samples, prompt token len " + f"min={min(seq_lens)} max={max(seq_lens)} " + f"mean={sum(seq_lens) // len(seq_lens)}", + file=sys.stderr, + ) + + def _greedy(decode_step) -> Tuple[List[str], List[float], List[int]]: + """Run greedy decode over all samples with a per-step callable + ``decode_step(cur_ids) -> logits[0, -1]``. Returns per-sample + (decoded_text, latency_s, decode_token_count).""" + decoded_all: List[str] = [] + lat_all: List[float] = [] + tok_all: List[int] = [] + for i in range(len(samples)): + cur = sample_ids[i] + gen: List[int] = [] + t0 = time.perf_counter() + for _ in range(args.max_new_tokens): + last_logits = decode_step(cur) + nxt = int(torch.argmax(last_logits).item()) + gen.append(nxt) + if eos_id is not None and nxt == eos_id: + break + cur = torch.cat( + [cur, torch.tensor([[nxt]], device=device, dtype=torch.long)], + dim=1, + ) + lat_all.append(time.perf_counter() - t0) + decoded_all.append(tokenizer.decode(gen, skip_special_tokens=True)) + tok_all.append(len(gen)) + print( + f"[k3-integrated] sample {i}: T={seq_lens[i]} tokens={len(gen)} " + f"decoded[:48]={decoded_all[-1][:48]!r}", + file=sys.stderr, + ) + return decoded_all, lat_all, tok_all + + # ---------- Run integrated cross-model verifier ---------- + print("[k3-integrated] running K3 cross-model verifier (f_θ restoration)", + file=sys.stderr, flush=True) + reset_memory_peak(device) + + def _cross_step(cur): + out = cross_verifier.forward( + cur, + apply_rotary_pos_emb=apply_rotary_pos_emb, + eager_attention_forward=eager_attention_forward, + all_attention_functions=ALL_ATTENTION_FUNCTIONS, + ) + return out.logits[0, -1] + + # ---------- S6: fidelity→recall sweep (alpha-interpolation) ---------- + if args.mix_alpha_sweep: + from inference_engine.v04.cross_model_dlm_verifier import ( + capture_verifier_own_kv, + ) + from inference_engine.v04.kv_merge import compute_evicted_positions + + orig_project = cross_verifier.project_drafter_kv + alphas = [float(x) for x in args.mix_alpha_sweep.split(",") if x.strip()] + cfg = f_theta.config + lhd = cfg.layer_head_dims + full_dim = max(lhd) + all_idx = list(range(cfg.verifier_num_layers)) + full_idx = [i for i in all_idx if lhd[i] == full_dim] + + # Baseline residual error of f_θ vs true (sample 0, evicted positions), + # so each alpha maps to an effective relative K/V error (1-alpha)^2. + def _rel_mse_layers(tk, tv, fk, fv, idx, layers): + num = den = 0.0 + for li in layers: + t = tk[li].index_select(1, idx.to(tk[li].device)).float() + f = fk[li].index_select(1, idx.to(fk[li].device)).to(t.device).float() + num += float(((f - t) ** 2).sum()); den += float((t ** 2).sum()) + tvv = tv[li].index_select(1, idx.to(tv[li].device)).float() + fvv = fv[li].index_select(1, idx.to(fv[li].device)).to(tvv.device).float() + num += float(((fvv - tvv) ** 2).sum()); den += float((tvv ** 2).sum()) + return num / max(den, 1e-9) + + with torch.no_grad(): + ids0 = sample_ids[0] + ev0 = compute_evicted_positions( + int(ids0.size(1)), args.sink_size, args.window_size) + idx0 = torch.tensor(ev0, dtype=torch.long) + tk0, tv0 = capture_verifier_own_kv(verifier, ids0) + fk0, fv0 = orig_project(ids0) + relmse0_all = _rel_mse_layers(tk0, tv0, fk0, fv0, idx0, all_idx) + relmse0_full = _rel_mse_layers(tk0, tv0, fk0, fv0, idx0, full_idx) + print(f"[s6] f_θ baseline rel_mse: overall={relmse0_all:.4f} " + f"full_attn={relmse0_full:.4f}", file=sys.stderr) + + def _make_mixed(alpha): + def _mixed(ids): + tk, tv = capture_verifier_own_kv(verifier, ids) + fk, fv = orig_project(ids) + mk, mv = [], [] + for i in range(len(fk)): + t = tk[i].to(device=fk[i].device, dtype=fk[i].dtype) + tvv = tv[i].to(device=fv[i].device, dtype=fv[i].dtype) + mk.append(alpha * t + (1.0 - alpha) * fk[i]) + mv.append(alpha * tvv + (1.0 - alpha) * fv[i]) + return mk, mv + return _mixed + + sweep_rows = [] + for a in alphas: + cross_verifier.project_drafter_kv = _make_mixed(a) + dec, lat, tok = _greedy(_cross_step) + res = aggregate_recall(f"mix_a{a}", samples, dec, lat, tok) + eff = (1.0 - a) ** 2 + row = { + "alpha": a, + "recall": res.recall, + "samples_correct": res.samples_correct, + "samples_total": res.samples_total, + "eff_rel_mse_overall": eff * relmse0_all, + "eff_rel_mse_full_attn": eff * relmse0_full, + } + sweep_rows.append(row) + print(f"[s6] alpha={a:.3f} recall={res.recall:.3f} " + f"({res.samples_correct}/{res.samples_total}) " + f"eff_rel_mse_full={eff * relmse0_full:.4f}", file=sys.stderr) + + sweep_report = { + "kind": "k3_s6_fidelity_recall_sweep", + "config": { + "f_theta_dir": args.f_theta_dir, + "n_samples": args.n_samples, + "sink_size": args.sink_size, + "window_size": args.window_size, + "max_new_tokens": args.max_new_tokens, + "haystack_min_lines": args.haystack_min_lines, + "haystack_max_lines": args.haystack_max_lines, + "seed": args.seed, + "prompt_token_lens": seq_lens, + }, + "f_theta_baseline_rel_mse": { + "overall": relmse0_all, "full_attn": relmse0_full, + }, + "sweep": sweep_rows, + } + out_path = Path(args.output) if args.output else Path( + f"results/research/k3_s6_fidelity_sweep_{int(time.time())}.json") + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(sweep_report, indent=2)) + print(f"\n[s6] DONE. sweep written to {out_path}", file=sys.stderr) + for r in sweep_rows: + print(f" alpha={r['alpha']:.3f} recall={r['recall']:.3f} " + f"eff_rel_mse_full={r['eff_rel_mse_full_attn']:.4f}", + file=sys.stderr) + return 0 + + cross_decoded, cross_lat, cross_tok = _greedy(_cross_step) + cross_res = aggregate_recall( + "k3_cross_model", samples, cross_decoded, cross_lat, cross_tok, + ) + cross_mem = record_memory(device) + cross_attn_agg = aggregate_attention_window_metrics( + "v04_dlm_restored", + prompt_token_lens=seq_lens, + sink_size=args.sink_size, + window_size=args.window_size, + ) + print( + f"[k3-integrated] cross-model recall={cross_res.recall:.3f} " + f"({cross_res.samples_correct}/{cross_res.samples_total})", + file=sys.stderr, + ) + + # ---------- Optional oracle baseline ---------- + oracle_res = None + oracle_mem = None + if not args.skip_oracle: + print("[k3-integrated] running full-attention oracle baseline", + file=sys.stderr, flush=True) + reset_memory_peak(device) + + def _oracle_step(cur): + with torch.no_grad(): + out = verifier(input_ids=cur, use_cache=False) + return out.logits[0, -1] + + oracle_decoded, oracle_lat, oracle_tok = _greedy(_oracle_step) + oracle_res = aggregate_recall( + "oracle", samples, oracle_decoded, oracle_lat, oracle_tok, + ) + oracle_mem = record_memory(device) + print( + f"[k3-integrated] oracle recall={oracle_res.recall:.3f} " + f"({oracle_res.samples_correct}/{oracle_res.samples_total})", + file=sys.stderr, + ) + + # ---------- Build report ---------- + recall_delta = ( + abs(cross_res.recall - oracle_res.recall) if oracle_res else None + ) + eff_frac_mean = cross_attn_agg.get("effective_attention_fraction_mean") + report = { + "schema_version": 2, + "kind": "k3_integrated_niah_acceptance", + "config": { + "verifier_id": args.verifier_id, + "drafter_id": args.drafter_id, + "f_theta_dir": args.f_theta_dir, + "f_theta_config": f_theta.config.to_json_dict(), + "n_samples": args.n_samples, + "sink_size": args.sink_size, + "window_size": args.window_size, + "haystack_min_lines": args.haystack_min_lines, + "haystack_max_lines": args.haystack_max_lines, + "max_new_tokens": args.max_new_tokens, + "seed": args.seed, + "skip_oracle": bool(args.skip_oracle), + "identity_restore": bool(args.identity_restore), + "s5_exact_full_attn": bool(args.s5_exact_full_attn), + "s5_exact_layers": exact_layers, + "prompt_token_lens": seq_lens, + }, + "results": { + "k3_cross_model": dataclasses.asdict(cross_res), + **({"oracle": dataclasses.asdict(oracle_res)} if oracle_res else {}), + }, + "attention_window": { + "per_config": {"k3_cross_model": cross_attn_agg}, + }, + "memory": { + "k3_cross_model": cross_mem, + **({"oracle": oracle_mem} if oracle_mem else {}), + }, + "gate": { + "architectural_correctness": (eff_frac_mean == 1.0), + "recall_cross_model": cross_res.recall, + "recall_oracle": oracle_res.recall if oracle_res else None, + "recall_delta_vs_oracle_pp": ( + recall_delta * 100 if recall_delta is not None else None + ), + "recall_delta_within_5pp": ( + recall_delta is not None and recall_delta <= 0.05 + ), + }, + } + + out_path = Path(args.output) if args.output else Path( + f"results/research/k3_integrated_niah_{int(time.time())}.json" + ) + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(report, indent=2)) + + print(f"\n[k3-integrated] DONE.", file=sys.stderr) + print( + f" cross-model recall: {cross_res.recall:.3f} " + f"({cross_res.samples_correct}/{cross_res.samples_total})", + file=sys.stderr, + ) + if oracle_res is not None: + print( + f" oracle recall: {oracle_res.recall:.3f} " + f"({oracle_res.samples_correct}/{oracle_res.samples_total})", + file=sys.stderr, + ) + print(f" |delta vs oracle|: {recall_delta * 100:.2f} pp", file=sys.stderr) + print( + f" ADR §11.8 1a gate (≤ 5pp): " + f"{'PASS' if recall_delta <= 0.05 else 'FAIL'}", + file=sys.stderr, + ) + else: + print(" oracle: skipped", file=sys.stderr) + print(f" effective_attention_fraction: {eff_frac_mean}", file=sys.stderr) + print(f" Report: {out_path}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/research/k3_integrated_niah_eval_mac.py b/scripts/research/k3_integrated_niah_eval_mac.py new file mode 100644 index 00000000..ba1b820f --- /dev/null +++ b/scripts/research/k3_integrated_niah_eval_mac.py @@ -0,0 +1,474 @@ +"""K3 integrated NIAH eval — **Mac (MLX) path**. + +Apple-Silicon counterpart of ``scripts/research/k3_integrated_niah_eval.py`` +(the validated CUDA K3 product gate). Wires: + + * verifier = Gemma 4 26B-A4B (MLX 4-bit, ``mlx_lm.load``) + * drafter = DFlash 0.4B (PyTorch ``DFlashDrafter``, MPS/CPU) + * f_θ = trained K/V proj (PyTorch ``FThetaProjection``) + * S5 = full-attention layers kept exact (``--s5-exact-full-attn``) + +Each generated token runs the **restored** verifier forward (sink+window +local cache + evicted-position K/V restoration), exactly mirroring the CUDA +``CrossModelDLMRestoredVerifier`` semantics. Cross-runtime tensors are +bridged via numpy (see ``scripts/research/k3_dflash_mlx_bridge.py``). + +Run on the Mac mini (Apple Silicon, ~24 GB): + + HF_TOKEN unnecessary for the local MLX 4-bit verifier. From repo root: + + PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval_mac.py \\ + --verifier-path models/gemma-4-26B-A4B-it-mlx-4bit \\ + --drafter-id z-lab/gemma-4-26B-A4B-it-DFlash \\ + --f-theta-dir results/research/f_theta_v5_s5_sliding \\ + --s5-exact-full-attn \\ + --n-samples 10 --haystack-min-lines 238 --haystack-max-lines 322 \\ + --sink-size 4 --window-size 64 --max-new-tokens 24 \\ + --output results/research/k3_s5_niah_ctx280_mac.json + +Quick sanity (smaller / faster): + + --n-samples 4 --haystack-min-lines 60 --haystack-max-lines 81 \\ + --max-new-tokens 16 + +Diagnostics: + --identity-restore restore ALL evicted K/V with the verifier's own true + K/V (should match oracle — validates the MLX + restoration machinery independent of f_θ / drafter). + +Output JSON mirrors the CUDA harness gate schema (recall_cross_model, +recall_oracle, recall_delta_vs_oracle_pp, architectural_correctness). +""" + +from __future__ import annotations + +import argparse +import dataclasses +import json +import math +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple + + +def parse_args() -> argparse.Namespace: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-path", default="models/gemma-4-26B-A4B-it-mlx-4bit") + ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash") + ap.add_argument("--f-theta-dir", required=True) + ap.add_argument("--n-samples", type=int, default=10) + ap.add_argument("--haystack-min-lines", type=int, default=238) + ap.add_argument("--haystack-max-lines", type=int, default=322) + ap.add_argument("--sink-size", type=int, default=4) + ap.add_argument("--window-size", type=int, default=64) + ap.add_argument("--max-new-tokens", type=int, default=16) + ap.add_argument("--teacher-forced", action="store_true", + help="DIAGNOSTIC ONLY (under-measures retrieval): single " + "restored forward per sample, check argmax at the " + "needle-code span. Note this misses the model's " + "preamble so it reads ~0 even for the oracle — use " + "the default free-generation for a real recall " + "number. Free-gen oracle uses mlx's fast native " + "incremental cache; the restored cross path does a " + "full forward per token (slow on M4 — see notes).") + ap.add_argument("--seed", type=int, default=42) + ap.add_argument("--drafter-device", default="mps", + help="torch device for the DFlash drafter + f_θ (mps|cpu)") + ap.add_argument("--s5-exact-full-attn", action="store_true", + help="Keep full-attention layers' K/V exact (S5).") + ap.add_argument("--identity-restore", action="store_true", + help="Restore ALL evicted K/V with the verifier's own " + "true K/V (machinery check; should match oracle).") + ap.add_argument("--compress-full-attn", action="store_true", + help="KakeyaLattice-compress the exact full-attention " + "layers' K/V (lossy round-trip) to shrink the O(T) " + "linear term. Reports the compression ratio + recall " + "under compression.") + ap.add_argument("--kl-lattice", default="D4", choices=["D4", "E8"]) + ap.add_argument("--kl-q-range", type=int, default=38) + ap.add_argument("--skip-oracle", action="store_true") + ap.add_argument("--output", default=None) + return ap.parse_args() + + +def main() -> int: + args = parse_args() + + import mlx.core as mx # type: ignore + import mlx_lm # type: ignore + import torch + + from inference_engine.v04 import ( + DFlashDrafter, FThetaProjection, NIAHSample, + aggregate_recall, make_niah_dataset, recall_predicate, + ) + from inference_engine.v04.kv_merge import compute_evicted_positions + from inference_engine.backends.mlx.cross_model_dlm_verifier import ( + resolve_mlx_text_model, mlx_full_attention_layer_indices, + kv_source_layer_map, capture_own_kv, restored_logits, + per_layer_kv_geometry, kv_memory_report, + ) + from inference_engine.v04.kv_compressor import make_default_compressor + from scripts.research.k3_dflash_mlx_bridge import ( + mx_to_torch, torch_to_mx, + ) + + torch.manual_seed(args.seed) + dev = torch.device(args.drafter_device if ( + args.drafter_device == "cpu" or torch.backends.mps.is_available() + ) else "cpu") + + # ---------- Load verifier (MLX) ---------- + print(f"[mac] loading MLX verifier {args.verifier_path}", file=sys.stderr, flush=True) + mlx_model, tokenizer = mlx_lm.load(args.verifier_path) + text_model = resolve_mlx_text_model(mlx_model) + embed_scale = float(getattr(text_model, "embed_scale", 1.0)) + n_layers = len(text_model.layers) + full_attn_idx = mlx_full_attention_layer_indices(text_model) + src_map = kv_source_layer_map(text_model) + print(f"[mac] verifier layers={n_layers} full_attn={full_attn_idx}", file=sys.stderr) + + # ---------- Load drafter + f_θ (PyTorch) ---------- + print(f"[mac] loading drafter {args.drafter_id} on {dev}", file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=torch.float32) + drafter = drafter.to(dev).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + f_theta = FThetaProjection.from_pretrained( + args.f_theta_dir, dtype=torch.float32, device=dev, + ) + fcfg = f_theta.config + + # ---------- Optional KakeyaLattice compression of full-attn layers ---------- + geom = per_layer_kv_geometry(text_model) + compressors: Dict[int, Any] = {} + kl_bits_per_head: Optional[float] = None + if args.compress_full_attn: + for li in full_attn_idx: + n_kv, hd, _ = geom[li] + comp = make_default_compressor( + head_dim=hd, device=torch.device("cpu"), + prefer_kakeya=True, lattice=args.kl_lattice, q_range=args.kl_q_range, + ) + compressors[li] = comp + codec = getattr(comp, "_codec", None) + if codec is not None and kl_bits_per_head is None: + kl_bits_per_head = float(getattr(codec, "bits_per_token_per_head", 0)) or None + print(f"[mac] KakeyaLattice compression ON for full-attn layers " + f"({args.kl_lattice} Q{args.kl_q_range}); " + f"bits/token/head={kl_bits_per_head}", file=sys.stderr) + + def _compress_roundtrip(li: int, k_mx: Any, v_mx: Any): + """Lossy KakeyaLattice round-trip of a full-attn layer's pre-norm K/V. + mx [B,T,n_kv,hd] → torch [B,n_kv,T,hd] (positions=-2) → codec → back.""" + comp = compressors[li] + kt = mx_to_torch(k_mx, dtype=torch.float32, device="cpu").transpose(1, 2).contiguous() + vt = mx_to_torch(v_mx, dtype=torch.float32, device="cpu").transpose(1, 2).contiguous() + T = kt.shape[-2] + pos = torch.arange(T) + comp.compress(kt, vt, pos) + kh, vh = comp.decompress(pos) + comp.evict(pos) # keep state bounded between tokens + kh = kh.transpose(1, 2).contiguous() # [B,T,n_kv,hd] + vh = vh.transpose(1, 2).contiguous() + return torch_to_mx(kh), torch_to_mx(vh) + + # ---------- Drafter K/V capture (Mac): MLX embed → torch → drafter layers ---------- + def capture_drafter_kv(ids: List[int]): + ids_mx = mx.array([ids]) + emb_mx = text_model.embed_tokens(ids_mx) + emb_mx = emb_mx * embed_scale + embedded = mx_to_torch(emb_mx, dtype=torch.float32, device=dev) # [1,T,H] + layers = list(drafter.layers) + k_cap: List[Optional[torch.Tensor]] = [None] * len(layers) + v_cap: List[Optional[torch.Tensor]] = [None] * len(layers) + handles = [] + for i, layer in enumerate(layers): + a = layer.self_attn + handles.append(a.k_proj.register_forward_hook( + lambda m, inp, out, i=i: k_cap.__setitem__(i, out.detach()))) + handles.append(a.v_proj.register_forward_hook( + lambda m, inp, out, i=i: v_cap.__setitem__(i, out.detach()))) + try: + with torch.no_grad(): + T = embedded.size(1) + qpos = torch.arange(T, device=dev) + h = embedded + for layer in layers: + h = layer(h, qpos, ctx_k=None, ctx_v=None) + finally: + for hh in handles: + hh.remove() + dh, ddim = fcfg.drafter_num_kv_heads, fcfg.drafter_head_dim + d_k = [k_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))] + d_v = [v_cap[i].view(1, -1, dh, ddim) for i in range(len(layers))] + return d_k, d_v + + # ---------- Per-sample restoration (amortized: captured ONCE over the + # prompt, reused for all decode steps). The evicted positions are the + # fixed prompt mid-context; with <= window generated tokens nothing else + # is evicted, so the prompt's restored K/V cover every injected slot. + exact_set = set(range(n_layers)) if args.identity_restore else set(full_attn_idx) + + def build_restoration(prompt_ids: List[int]): + d_k, d_v = capture_drafter_kv(prompt_ids) + with torch.no_grad(): + vk, vv = f_theta.forward_kv_pack(d_k, d_v) + own = None + if exact_set: + own = capture_own_kv(mlx_model, prompt_ids) + rk: Dict[int, Any] = {} + rv: Dict[int, Any] = {} + for li in range(n_layers): + if src_map[li] != li: + continue + if li in exact_set and own is not None and li in own: + k_mx, v_mx = own[li] + if li in compressors: + k_mx, v_mx = _compress_roundtrip(li, k_mx, v_mx) + rk[li], rv[li] = k_mx, v_mx + else: + rk[li], rv[li] = torch_to_mx(vk[li]), torch_to_mx(vv[li]) + return rk, rv, len(prompt_ids) + + def _pad(rdict, t_src, t_dst): + if t_dst <= t_src: + return rdict + out = {} + for li, a in rdict.items(): + pad = mx.zeros((a.shape[0], t_dst - t_src, a.shape[2], a.shape[3]), dtype=a.dtype) + out[li] = mx.concatenate([a, pad], axis=1) + return out + + def restored_forward(ids: List[int], rk, rv, t_src, *, return_all: bool): + T = len(ids) + evicted = compute_evicted_positions(T, args.sink_size, args.window_size) + if not evicted: + out = mlx_model(mx.array([ids])); mx.eval(out) + return out[0] if return_all else out[0, -1] + return restored_logits( + mlx_model, ids, + restored_k_per_layer=_pad(rk, t_src, T), + restored_v_per_layer=_pad(rv, t_src, T), + evicted_positions=evicted, return_all=return_all, + ) + + # ---------- Dataset ---------- + samples: List[NIAHSample] = make_niah_dataset( + n_samples=args.n_samples, + haystack_min_lines=args.haystack_min_lines, + haystack_max_lines=args.haystack_max_lines, + seed=args.seed, + ) + + def encode(prompt_text: str) -> List[int]: + msgs = [{"role": "user", "content": prompt_text}] + ids = tokenizer.apply_chat_template(msgs, add_generation_prompt=True) + if hasattr(ids, "tolist"): + ids = ids.tolist() + return list(ids) + + def encode_answer(answer_text: str) -> List[int]: + try: + aid = tokenizer.encode(answer_text, add_special_tokens=False) + except TypeError: + aid = list(tokenizer.encode(answer_text)) + bos = getattr(tokenizer, "bos_token_id", None) + if aid and bos is not None and aid[0] == bos: + aid = aid[1:] + return list(aid) + + sample_ids = [encode(s.prompt_text) for s in samples] + answer_ids = [encode_answer(s.answer_text) for s in samples] + seq_lens = [len(t) for t in sample_ids] + eos_id = getattr(tokenizer, "eos_token_id", None) + print(f"[mac] {len(samples)} samples, prompt len " + f"min={min(seq_lens)} max={max(seq_lens)}", file=sys.stderr) + + def eval_teacher_forced(logits_all_fn) -> Tuple[List[str], List[float], List[int]]: + """One restored forward per sample over [prompt + needle-code]; check + the argmax at the answer span reproduces the code (substring predicate + — same as CUDA). O(T) per sample, no autoregressive loop.""" + decoded, lats, toks = [], [], [] + for i, pid in enumerate(sample_ids): + aid = answer_ids[i] or [eos_id or 0] + full = pid + aid + t0 = time.perf_counter() + logits_all = logits_all_fn(pid, full) # [T_full, V] + Tp = len(pid) + preds = [int(mx.argmax(logits_all[Tp - 1 + j]).item()) + for j in range(len(aid))] + lats.append(time.perf_counter() - t0) + decoded.append(tokenizer.decode(preds)) + toks.append(len(aid)) + print(f"[mac] sample {i}: T={seq_lens[i]} pred[:48]={decoded[-1][:48]!r}", + file=sys.stderr) + return decoded, lats, toks + + def eval_free_gen_cross() -> Tuple[List[str], List[float], List[int]]: + """Restored free generation: 1 restored full forward per token + (amortized restoration). Correct recall metric; slow on M4.""" + decoded, lats, toks = [], [], [] + for i, pid in enumerate(sample_ids): + rk, rv, tsrc = build_restoration(pid) + cur = list(pid); gen: List[int] = [] + t0 = time.perf_counter() + for _ in range(args.max_new_tokens): + last = restored_forward(cur, rk, rv, tsrc, return_all=False) + nxt = int(mx.argmax(last).item()); gen.append(nxt) + if eos_id is not None and nxt == eos_id: + break + cur.append(nxt) + lats.append(time.perf_counter() - t0) + decoded.append(tokenizer.decode(gen)); toks.append(len(gen)) + print(f"[mac] sample {i}: T={seq_lens[i]} -> {decoded[-1][:48]!r}", + file=sys.stderr) + return decoded, lats, toks + + def eval_free_gen_oracle() -> Tuple[List[str], List[float], List[int]]: + """Oracle free generation using mlx's NATIVE incremental KV cache + (fast + correct reference; confirms the metric/dataset).""" + decoded, lats, toks = [], [], [] + make_cache = getattr(mlx_model, "make_cache", None) + for i, pid in enumerate(sample_ids): + cache = make_cache() if make_cache is not None else None + t0 = time.perf_counter() + out = mlx_model(mx.array([pid]), cache=cache); mx.eval(out) + tok = int(mx.argmax(out[0, -1]).item()); gen = [tok] + for _ in range(args.max_new_tokens - 1): + if eos_id is not None and tok == eos_id: + break + out = mlx_model(mx.array([[tok]]), cache=cache); mx.eval(out) + tok = int(mx.argmax(out[0, -1]).item()); gen.append(tok) + lats.append(time.perf_counter() - t0) + decoded.append(tokenizer.decode(gen)); toks.append(len(gen)) + print(f"[mac] oracle {i}: T={seq_lens[i]} -> {decoded[-1][:48]!r}", + file=sys.stderr) + return decoded, lats, toks + + def cross_logits_all(prompt_ids, full_ids): + rk, rv, tsrc = build_restoration(prompt_ids) + return restored_forward(full_ids, rk, rv, tsrc, return_all=True) + + def oracle_logits_all(prompt_ids, full_ids): + out = mlx_model(mx.array([full_ids])); mx.eval(out); return out[0] + + label = "identity" if args.identity_restore else ( + "s5" if args.s5_exact_full_attn else "f_theta_all") + eval_mode = "teacher_forced" if args.teacher_forced else "free_gen" + print(f"[mac] running restored cross-model verifier ({label}, {eval_mode})", + file=sys.stderr, flush=True) + if args.teacher_forced: + cross_dec, cross_lat, cross_tok = eval_teacher_forced(cross_logits_all) + else: + cross_dec, cross_lat, cross_tok = eval_free_gen_cross() + cross_res = aggregate_recall("k3_cross_model_mac", samples, cross_dec, cross_lat, cross_tok) + print(f"[mac] cross-model recall = {cross_res.recall:.3f} " + f"({cross_res.samples_correct}/{cross_res.samples_total})", file=sys.stderr) + + oracle_res = None + if not args.skip_oracle: + print("[mac] running oracle", file=sys.stderr, flush=True) + if args.teacher_forced: + o_dec, o_lat, o_tok = eval_teacher_forced(oracle_logits_all) + else: + o_dec, o_lat, o_tok = eval_free_gen_oracle() # fast native incremental + oracle_res = aggregate_recall("oracle_mac", samples, o_dec, o_lat, o_tok) + print(f"[mac] oracle recall = {oracle_res.recall:.3f}", file=sys.stderr) + + # ---------- KV-memory accounting (bounded S5 engine) ---------- + T_max = max(seq_lens) + exact_for_mem = full_attn_idx # S5: full-attn layers kept exact / compressed + mem_s5 = kv_memory_report( + text_model, sink_size=args.sink_size, window_size=args.window_size, + seq_len=T_max, exact_layer_indices=exact_for_mem, + compress_full_bits_per_token_per_head=( + kl_bits_per_head if args.compress_full_attn else None), + ) + # Baselines for the savings story: + mem_naive = kv_memory_report( # all layers O(T), no bound, no compress + text_model, sink_size=T_max, window_size=0, seq_len=T_max, + exact_layer_indices=list(range(n_layers))) + print(f"[mac] KV resident @T={T_max}: S5={mem_s5['total_resident_mb']} MB " + f"(growth {mem_s5['per_token_growth_kb']} KB/tok); " + f"naive-full={mem_naive['total_resident_mb']} MB", file=sys.stderr) + + # ---------- Throughput ---------- + def _tps(lats, toks): + tot_t = sum(lats) + tot_n = sum(toks) + return { + "tokens": tot_n, "wall_seconds": round(tot_t, 3), + "tokens_per_second": round(tot_n / tot_t, 4) if tot_t > 0 else None, + "mean_latency_per_sample_s": round(tot_t / max(len(lats), 1), 3), + } + cross_tps = _tps(cross_lat, cross_tok) + cross_tps["eval_mode"] = eval_mode + cross_tps["restored_forwards_per_sample"] = ( + 1 if args.teacher_forced else args.max_new_tokens) + print(f"[mac] cross-model throughput ({eval_mode}): " + f"{cross_tps['tokens_per_second']} tok/s " + f"({cross_tps['tokens']} tok / {cross_tps['wall_seconds']} s, " + f"{cross_tps['mean_latency_per_sample_s']} s/sample)", file=sys.stderr) + + delta = (abs(cross_res.recall - oracle_res.recall) if oracle_res else None) + report = { + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": args.verifier_path, + "drafter_id": args.drafter_id, + "f_theta_dir": args.f_theta_dir, + "n_samples": args.n_samples, + "sink_size": args.sink_size, + "window_size": args.window_size, + "haystack_min_lines": args.haystack_min_lines, + "haystack_max_lines": args.haystack_max_lines, + "max_new_tokens": args.max_new_tokens, + "seed": args.seed, + "eval_mode": eval_mode, + "teacher_forced": bool(args.teacher_forced), + "s5_exact_full_attn": bool(args.s5_exact_full_attn), + "identity_restore": bool(args.identity_restore), + "compress_full_attn": bool(args.compress_full_attn), + "kl_lattice": args.kl_lattice if args.compress_full_attn else None, + "kl_q_range": args.kl_q_range if args.compress_full_attn else None, + "kl_bits_per_token_per_head": kl_bits_per_head, + "full_attention_layers": full_attn_idx, + "prompt_token_lens": seq_lens, + }, + "results": { + "k3_cross_model": dataclasses.asdict(cross_res), + **({"oracle": dataclasses.asdict(oracle_res)} if oracle_res else {}), + }, + "gate": { + "recall_cross_model": cross_res.recall, + "recall_oracle": oracle_res.recall if oracle_res else None, + "recall_delta_vs_oracle_pp": (delta * 100 if delta is not None else None), + "recall_delta_within_5pp": (delta is not None and delta <= 0.05), + }, + "memory": { + "s5": mem_s5, + "naive_full_kv": { + "total_resident_mb": mem_naive["total_resident_mb"], + "per_token_growth_kb": mem_naive["per_token_growth_kb"], + }, + "savings_vs_naive_pct": round( + 100 * (1 - mem_s5["total_resident_bytes"] + / max(mem_naive["total_resident_bytes"], 1)), 1), + }, + "throughput": {"k3_cross_model": cross_tps}, + } + out_path = Path(args.output) if args.output else Path( + f"results/research/k3_integrated_niah_mac_{int(time.time())}.json") + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(report, indent=2)) + print(f"\n[mac] DONE. cross={cross_res.recall:.3f} " + f"oracle={oracle_res.recall if oracle_res else 'skipped'} " + f"-> {out_path}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/research/k3_specdecode_gpu_bench.py b/scripts/research/k3_specdecode_gpu_bench.py new file mode 100644 index 00000000..bed07520 --- /dev/null +++ b/scripts/research/k3_specdecode_gpu_bench.py @@ -0,0 +1,480 @@ +"""K3 speculative-decoding GPU bench for the *restored* verifier. + +Re-examines the spec-decode path for the Kakeya inference engine and +measures, on the same NIAH prompts: + + * **AR-incremental** — standalone Gemma 4 26B AR with the model's own KV + cache (the throughput target). + * **restored-pertoken** — the restored verifier decoded one token at a + time (the naive baseline; what k3_e2e_gpu_bench used). + * **restored-specdecode** — DFlash drafts a block, the **restored** + verifier verifies it in one pass, greedily accepting the longest + matching prefix (block-amortized verifier forwards). + +Reports per path: decode tok/s, verifier forward passes, and (for +spec-decode) acceptance length; plus NIAH recall (correctness). This +quantifies how much the DFlash block-acceptance amortizes the (currently +O(T) re-forward) restored verifier, and isolates the two levers to reach +AR-parity: drafter acceptance and an incremental restored forward. + +Run (transformers-5.x venv, CUDA):: + + HF_HOME=/workspace/.hf_home PYTHONPATH=.:sdks/python \ + .venv-k3/bin/python scripts/research/k3_specdecode_gpu_bench.py \ + --haystack-lines 60 --n-samples 3 --max-new-tokens 48 \ + --block-size 16 --output results/research/k3_specdecode_gpu_bench.json +""" + +from __future__ import annotations + +import argparse +import json +import math +import sys +import time +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import torch +import torch.nn.functional as F + + +# --------------------------------------------------------------------------- # +# DFlash wiring helpers (mirrors scripts/research/k3_dflash_specdecode_eval.py) +# --------------------------------------------------------------------------- # +def _build_embed_lm_head(model, hidden_size, softcap): + # Use the RAW weight tensors (plain F.embedding + matmul) rather than the + # module forwards: this is leaner and, critically, side-steps any + # per-call accelerate hook on the drafter's hot path. Reference DFlash + # embeds with a plain (unscaled) lookup — no Gemma ×sqrt(hidden) (Gap-B). + emb_w = model.get_input_embeddings().weight.detach() + head_w = model.get_output_embeddings().weight.detach() + + def embed_fn(ids: torch.Tensor) -> torch.Tensor: + return F.embedding(ids, emb_w).float() + + def lm_head_fn(h: torch.Tensor) -> torch.Tensor: + logits = (h.to(head_w.dtype) @ head_w.t()).float() + if softcap is not None: + logits = softcap * torch.tanh(logits / softcap) + return logits + + return embed_fn, lm_head_fn + + +@torch.no_grad() +def ar_incremental(model, ids, gen_tokens, device) -> Tuple[List[int], float, int]: + """Standalone AR with the model's own KV cache. Returns (tokens, decode_s, fwds).""" + out = model(input_ids=ids, use_cache=True) + cache = out.past_key_values + nxt = int(out.logits[0, -1].argmax().item()) + gen: List[int] = [] + cur = torch.tensor([[nxt]], device=device, dtype=torch.long) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + fwds = 0 + for _ in range(gen_tokens): + gen.append(nxt) + out = model(input_ids=cur, past_key_values=cache, use_cache=True) + fwds += 1 + cache = out.past_key_values + nxt = int(out.logits[0, -1].argmax().item()) + cur = torch.tensor([[nxt]], device=device, dtype=torch.long) + torch.cuda.synchronize(device) + return gen, time.perf_counter() - t0, fwds + + +@torch.no_grad() +def restored_pertoken(adapter, prompt, gen_tokens, device) -> Tuple[List[int], float, int]: + adapter.prefill(prompt) + nxt = int(adapter.next_token_logits.argmax().item()) + gen: List[int] = [] + torch.cuda.synchronize(device) + t0 = time.perf_counter() + for _ in range(gen_tokens): + gen.append(nxt) + adapter.append_token(nxt) + nxt = int(adapter.next_token_logits.argmax().item()) + torch.cuda.synchronize(device) + # forward count: prefill (1) + gen append_token (each 1 restored.forward) + return gen, time.perf_counter() - t0, gen_tokens + + +@torch.no_grad() +def restored_specdecode( + adapter, drafter, provider, embed_fn, lm_head_fn, + prompt, gen_tokens, block_size, device, eos_ids, +) -> Dict[str, Any]: + """DFlash drafts a block; the **incremental** (Gap-A) restored verifier + verifies the block in one O(L) incremental forward, greedily accepting + the matching prefix. The restored verifier is the source of truth + (output == greedy restored decode). Reports a per-component time + breakdown (aux / draft / verify) to expose the bottleneck.""" + assert adapter._incremental, "restored_specdecode needs incremental=True (Gap-A)" + adapter.prefill(prompt) # builds the restored KV cache once + generated: List[int] = [] + accepts: List[int] = [] + t_aux = t_draft = t_verify = 0.0 + torch.cuda.synchronize(device) + t0 = time.perf_counter() + while len(generated) < gen_tokens: + L = min(block_size, gen_tokens - len(generated)) + # DFlash drafts using the *clean* verifier aux hidden (EAGLE) + bonus. + ta = time.perf_counter() + aux_ctx, bonus = provider.aux_hidden_context(adapter._committed) + torch.cuda.synchronize(device); t_aux += time.perf_counter() - ta + td = time.perf_counter() + drafts = drafter.draft_block(aux_ctx, bonus, embed_fn, lm_head_fn, block_size=L) + torch.cuda.synchronize(device); t_draft += time.perf_counter() - td + candidate = [bonus] + drafts[: L - 1] if L > 1 else [bonus] + # Verify with the INCREMENTAL restored verifier (O(L)). + tv = time.perf_counter() + prev = adapter.next_token_logits + block_logits = adapter.forward_block(candidate) # [len(candidate), V] + accepted = 0 + for i in range(len(candidate)): + if int(prev.argmax().item()) == candidate[i]: + accepted += 1 + prev = block_logits[i] + else: + break + correction = int(prev.argmax().item()) + adapter.commit_or_truncate(forwarded=len(candidate), accepted=accepted) + adapter.append_token(correction) # commit correction; updates next_token_logits + torch.cuda.synchronize(device); t_verify += time.perf_counter() - tv + commit = candidate[:accepted] + [correction] + generated += commit + accepts.append(accepted) + if any(t in eos_ids for t in commit): + break + torch.cuda.synchronize(device) + dt = time.perf_counter() - t0 + generated = generated[:gen_tokens] + return { + "tokens": generated, + "decode_s": dt, + "decode_tokens_per_s": round(len(generated) / dt, 3) if dt > 0 else None, + "time_breakdown_s": { + "aux_clean_forward": round(t_aux, 3), + "drafter": round(t_draft, 3), + "incremental_verify": round(t_verify, 3), + }, + "blocks": len(accepts), + "mean_accept_len": round(sum(accepts) / len(accepts), 2) if accepts else 0.0, + "decode_tokens": len(generated), + } + + +@torch.no_grad() +def restored_specdecode_fused( + adapter, drafter, verifier, aux_layer_ids, embed_fn, lm_head_fn, + prompt, gen_tokens, block_size, device, eos_ids, +) -> Dict[str, Any]: + """FUSED spec-decode engine (A+B+C): per-block O(L). + + * C (Gap-A): incremental restored verify (adapter, O(L)). + * B: drafter context K/V cache — built once from the prompt's clean aux, + then EXTENDED incrementally with each newly-committed token's aux + (no O(C) recompute per block). + * A: the newly-committed tokens' aux hidden are captured from the verify + forward itself (restored hidden) — no separate per-block clean-aux O(C) + forward. + """ + n_aux = len(aux_layer_ids) + C = len(prompt) + # --- one-time prefill: clean prompt aux -> drafter context K/V cache (B) --- + t_prefill = time.perf_counter() + ids = torch.tensor([prompt], dtype=torch.long, device=device) + out = verifier(input_ids=ids, use_cache=False, output_hidden_states=True) + aux_prompt = [out.hidden_states[a] for a in aux_layer_ids] # each [1, C, hidden] + ctx_kv = drafter.make_context_kv(aux_prompt, torch.arange(C, device=device)) + adapter.prefill(prompt) # restored KV cache (C) + next_token_logits + adapter._capture_aux = True + torch.cuda.synchronize(device) + t_prefill = time.perf_counter() - t_prefill + + generated: List[int] = [] + accepts: List[int] = [] + t_draft = t_verify = t_extend = 0.0 + torch.cuda.synchronize(device) + t0 = time.perf_counter() + while len(generated) < gen_tokens: + L = min(block_size, gen_tokens - len(generated)) + cstart = adapter._past_len # committed length at block start + bonus = int(adapter.next_token_logits.argmax().item()) + td = time.perf_counter() + drafts = drafter.draft_block_cached( + ctx_kv, bonus, embed_fn, lm_head_fn, + block_size=max(L - 1, 1), context_len=cstart) + torch.cuda.synchronize(device); t_draft += time.perf_counter() - td + candidate = [bonus] + drafts[: L - 1] + tv = time.perf_counter() + prev = adapter.next_token_logits + block_logits = adapter.forward_block(candidate) # O(L) verify + aux capture + cand_aux = adapter._last_aux # [n_aux][len(candidate), hidden] + accepted = 0 + for i in range(len(candidate)): + if int(prev.argmax().item()) == candidate[i]: + accepted += 1 + prev = block_logits[i] + else: + break + correction = int(prev.argmax().item()) + adapter.commit_or_truncate(forwarded=len(candidate), accepted=accepted) + adapter.append_token(correction) # commit correction + aux capture + corr_aux = adapter._last_aux # [n_aux][1, hidden] + torch.cuda.synchronize(device); t_verify += time.perf_counter() - tv + # --- extend drafter context K/V with the newly-committed tokens (B) --- + te = time.perf_counter() + new_positions = torch.arange(cstart, cstart + accepted + 1, device=device) + new_aux = [ + torch.cat([cand_aux[li][:accepted], corr_aux[li][:1]], dim=0).unsqueeze(0) + for li in range(n_aux) + ] # each [1, accepted+1, hidden] + ctx_kv = drafter.extend_context_kv( + ctx_kv, drafter.make_context_kv(new_aux, new_positions)) + torch.cuda.synchronize(device); t_extend += time.perf_counter() - te + commit = candidate[:accepted] + [correction] + generated += commit + accepts.append(accepted) + if any(t in eos_ids for t in commit): + break + torch.cuda.synchronize(device) + dt = time.perf_counter() - t0 + adapter._capture_aux = False + generated = generated[:gen_tokens] + return { + "tokens": generated, + "decode_s": dt, + "prefill_s": round(t_prefill, 3), + "decode_tokens_per_s": round(len(generated) / dt, 3) if dt > 0 else None, + "time_breakdown_s": { + "drafter_cached": round(t_draft, 3), + "incremental_verify": round(t_verify, 3), + "ctx_kv_extend": round(t_extend, 3), + }, + "blocks": len(accepts), + "mean_accept_len": round(sum(accepts) / len(accepts), 2) if accepts else 0.0, + "decode_tokens": len(generated), + } + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__) + ap.add_argument("--verifier-id", default="google/gemma-4-26B-A4B-it") + ap.add_argument("--drafter-id", default="z-lab/gemma-4-26B-A4B-it-DFlash") + ap.add_argument("--f-theta-dir", default="results/research/f_theta_v5_s5_sliding") + ap.add_argument("--haystack-lines", type=int, default=60) + ap.add_argument("--n-samples", type=int, default=3) + ap.add_argument("--max-new-tokens", type=int, default=48) + ap.add_argument("--block-size", type=int, default=16) + ap.add_argument("--sink", type=int, default=4) + ap.add_argument("--window", type=int, default=64) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--skip-unfused", action="store_true", + help="Skip the un-fused restored spec-decode baseline " + "(already characterized; removes GPU contention for a " + "clean fused-vs-AR steady-state measurement).") + ap.add_argument("--output", default=None) + args = ap.parse_args() + + if not torch.cuda.is_available(): + print("[sd] CUDA required.", file=sys.stderr) + return 2 + device = torch.device("cuda") + dtype = torch.bfloat16 + + from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers.models.gemma4.modeling_gemma4 import ( # type: ignore + ALL_ATTENTION_FUNCTIONS, apply_rotary_pos_emb, eager_attention_forward, + ) + from inference_engine.v04 import ( + CrossModelRestoredSinkWindowVerifier, DFlashDrafter, FThetaProjection, + make_niah_dataset, + ) + from inference_engine.v04.cross_model_dlm_verifier import ( + CrossModelDLMRestoredVerifier, full_attention_layer_indices, + ) + from inference_engine.v04.dflash_drafter import DFlashProposer # noqa: F401 (kept for parity) + + class VerifierAuxProvider: + def __init__(self, model, aux_layer_ids, device): + self.model = model + self.aux_layer_ids = aux_layer_ids + self.device = device + + @torch.no_grad() + def aux_hidden_context(self, committed_token_ids): + inp = torch.tensor([committed_token_ids], dtype=torch.long, device=self.device) + out = self.model(input_ids=inp, use_cache=False, output_hidden_states=True) + hs = out.hidden_states + aux = [hs[a].float() for a in self.aux_layer_ids] + bonus = int(torch.argmax(out.logits[0, -1]).item()) + return aux, bonus + + print(f"[sd] loading verifier {args.verifier_id}", file=sys.stderr, flush=True) + tok = AutoTokenizer.from_pretrained(args.verifier_id) + # Load WITHOUT device_map (the model fits on a single H200): device_map + # wraps every module in accelerate AlignDevicesHook, adding variable + # per-forward dispatch latency that inflated/destabilized the drafter's + # per-block embed/lm_head calls. A plain .to(device) is hook-free. + verifier = AutoModelForCausalLM.from_pretrained( + args.verifier_id, dtype=dtype, attn_implementation="eager", + ).to(device).eval() + for p in verifier.parameters(): + p.requires_grad_(False) + print(f"[sd] loading drafter {args.drafter_id}", file=sys.stderr, flush=True) + drafter = DFlashDrafter.from_pretrained(args.drafter_id, dtype=dtype).to(device).eval() + for p in drafter.parameters(): + p.requires_grad_(False) + print(f"[sd] loading f_θ {args.f_theta_dir}", file=sys.stderr, flush=True) + f_theta = FThetaProjection.from_pretrained(args.f_theta_dir, dtype=torch.float32, device=device) + + exact_layers = full_attention_layer_indices(verifier) + restored = CrossModelDLMRestoredVerifier( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, + sink_size=args.sink, window_size=args.window, exact_layer_indices=exact_layers, + ) + adapter = CrossModelRestoredSinkWindowVerifier( + restored, apply_rotary_pos_emb=apply_rotary_pos_emb, + eager_attention_forward=eager_attention_forward, + all_attention_functions=ALL_ATTENTION_FUNCTIONS, device="cuda", + incremental=True, # Gap-A: O(L)/block incremental verify + ) + cfg = drafter.cfg + embed_fn, lm_head_fn = _build_embed_lm_head(verifier, cfg.hidden_size, cfg.final_logit_softcapping) + provider = VerifierAuxProvider(verifier, cfg.aux_layer_ids, device) + eos_ids = set(x for x in [tok.eos_token_id] if x is not None) + + def encode_chat(text): + ids = tok.apply_chat_template( + [{"role": "user", "content": text}], + add_generation_prompt=True, tokenize=True, return_tensors="pt") + if hasattr(ids, "keys"): + ids = ids["input_ids"] + return ids.to(device) + + samples = make_niah_dataset( + n_samples=args.n_samples, haystack_min_lines=args.haystack_lines, + haystack_max_lines=args.haystack_lines, seed=args.seed) + ids_list = [encode_chat(s.prompt_text) for s in samples] + seqlens = [int(t.size(1)) for t in ids_list] + print(f"[sd] prompt tokens min={min(seqlens)} max={max(seqlens)}", file=sys.stderr) + + def recall(tokens, ans): + return ans in tok.decode(tokens, skip_special_tokens=True) + + aux_layer_ids = drafter.cfg.aux_layer_ids + + # Warm up CUDA kernels for every measured path on the first prompt (a few + # tokens, discarded) so the timed samples reflect steady state, not the + # one-off kernel-compile cost (which otherwise inflates the first sample). + print("[sd] warmup ...", file=sys.stderr, flush=True) + _wp = ids_list[0][0].tolist() + try: + # Warm with the FULL gen length so the caching allocator pre-sizes + # every context-growth shape the timed samples will hit (otherwise the + # first sample eats first-time cudaMalloc for the long-context drafter + # attention buffers). Two passes to settle clocks/autotuning. + for _ in range(2): + ar_incremental(verifier, ids_list[0], args.max_new_tokens, device) + restored_pertoken(adapter, _wp, args.max_new_tokens, device) + restored_specdecode_fused(adapter, drafter, verifier, aux_layer_ids, + embed_fn, lm_head_fn, _wp, + args.max_new_tokens, args.block_size, + device, eos_ids) + except Exception as e: + print(f"[sd] warmup note: {e}", file=sys.stderr) + + ar_tps: List[float] = [] + pt_tps: List[float] = [] + sd_rows: List[Dict[str, Any]] = [] + fu_rows: List[Dict[str, Any]] = [] + ar_hits = pt_hits = sd_hits = fu_hits = 0 + for i, ids in enumerate(ids_list): + ans = samples[i].answer_text + prompt = ids[0].tolist() + g_ar, t_ar, _ = ar_incremental(verifier, ids, args.max_new_tokens, device) + ar_tps.append(len(g_ar) / t_ar) + ar_hits += int(recall(g_ar, ans)) + g_pt, t_pt, _ = restored_pertoken(adapter, prompt, args.max_new_tokens, device) + pt_tps.append(len(g_pt) / t_pt) + pt_hits += int(recall(g_pt, ans)) + if args.skip_unfused: + sd = {"decode_tokens_per_s": None, "mean_accept_len": 0.0, + "time_breakdown_s": {"aux_clean_forward": 0.0, "drafter": 0.0, + "incremental_verify": 0.0}, "tokens": []} + else: + sd = restored_specdecode( + adapter, drafter, provider, embed_fn, lm_head_fn, + prompt, args.max_new_tokens, args.block_size, device, eos_ids) + sd_hits += int(recall(sd["tokens"], ans)) + sd_rows.append(sd) + fu = restored_specdecode_fused( + adapter, drafter, verifier, aux_layer_ids, embed_fn, lm_head_fn, + prompt, args.max_new_tokens, args.block_size, device, eos_ids) + fu_rows.append(fu) + fu_hits += int(recall(fu["tokens"], ans)) + print(f"[sd] sample {i}: AR={ar_tps[-1]:.2f} | pertoken(GapA)={pt_tps[-1]:.2f} | " + f"specdecode(unfused)={sd['decode_tokens_per_s']} | " + f"FUSED={fu['decode_tokens_per_s']} tok/s " + f"(accept_len={fu['mean_accept_len']}, blocks={fu['blocks']}, " + f"draft={fu['time_breakdown_s']['drafter_cached']}s " + f"verify={fu['time_breakdown_s']['incremental_verify']}s " + f"ext={fu['time_breakdown_s']['ctx_kv_extend']}s) | recall ar/pt/sd/fused=" + f"{recall(g_ar, ans)}/{recall(g_pt, ans)}/{recall(sd['tokens'], ans)}/" + f"{recall(fu['tokens'], ans)}", file=sys.stderr, flush=True) + + n = len(ids_list) + report = { + "kind": "k3_specdecode_gpu_bench", + "config": vars(args), + "env": {"gpu": torch.cuda.get_device_name(0), "torch": torch.__version__}, + "prompt_tokens": {"min": min(seqlens), "max": max(seqlens)}, + "ar_incremental": { + "decode_tokens_per_s_mean": round(sum(ar_tps) / n, 3), "recall": round(ar_hits / n, 3)}, + "restored_pertoken": { + "decode_tokens_per_s_mean": round(sum(pt_tps) / n, 3), "recall": round(pt_hits / n, 3)}, + "restored_specdecode": { + "skipped": bool(args.skip_unfused), + "decode_tokens_per_s_mean": (None if args.skip_unfused else round( + sum(r["decode_tokens_per_s"] for r in sd_rows) / n, 3)), + "mean_accept_len": round(sum(r["mean_accept_len"] for r in sd_rows) / n, 2), + "recall": round(sd_hits / n, 3), + "per_sample": sd_rows, + }, + "restored_specdecode_fused": { + "decode_tokens_per_s_mean": round( + sum(r["decode_tokens_per_s"] for r in fu_rows) / n, 3), + "mean_accept_len": round(sum(r["mean_accept_len"] for r in fu_rows) / n, 2), + "time_breakdown_s_mean": { + k: round(sum(r["time_breakdown_s"][k] for r in fu_rows) / n, 3) + for k in ("drafter_cached", "incremental_verify", "ctx_kv_extend") + }, + "recall": round(fu_hits / n, 3), + "per_sample": fu_rows, + }, + } + ar_mean = report["ar_incremental"]["decode_tokens_per_s_mean"] + pt_mean = report["restored_pertoken"]["decode_tokens_per_s_mean"] + sd_tps = report["restored_specdecode"]["decode_tokens_per_s_mean"] + fu_tps = report["restored_specdecode_fused"]["decode_tokens_per_s_mean"] + report["restored_specdecode_fused"]["speedup_over_ar_x"] = ( + round(fu_tps / ar_mean, 2) if ar_mean else None) + out_path = Path(args.output) if args.output else Path( + f"results/research/k3_specdecode_gpu_bench_{int(time.time())}.json") + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(report, indent=2)) + print(f"\n[sd] AR={ar_mean} | pertoken(GapA)={pt_mean} | " + f"specdecode(unfused)={sd_tps} | FUSED={fu_tps} tok/s " + f"(fused/AR {report['restored_specdecode_fused']['speedup_over_ar_x']}x, " + f"accept_len={report['restored_specdecode_fused']['mean_accept_len']}) | " + f"recall ar/pt/sd/fused={report['ar_incremental']['recall']}/" + f"{report['restored_pertoken']['recall']}/{report['restored_specdecode']['recall']}/" + f"{report['restored_specdecode_fused']['recall']}", file=sys.stderr) + print(f"[sd] wrote {out_path}", file=sys.stderr) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/review_pr_k3_f_theta_train_on_vast.sh b/scripts/review_pr_k3_f_theta_train_on_vast.sh new file mode 100755 index 00000000..077a1947 --- /dev/null +++ b/scripts/review_pr_k3_f_theta_train_on_vast.sh @@ -0,0 +1,241 @@ +#!/usr/bin/env bash +# vast.ai (CUDA) reviewer aid for K3 Block C — f_θ K/V projection training. +# +# v3 (2026-06-10) — ONE-SHOT principled trainer. +# - --loss-type attn_distill (attention-output distillation — the +# mathematically right loss for K/V +# projection; v1 was raw MSE on K/V, +# v2 intermediate was cos+mag) +# - --rank 768 (3× v1's 256 capacity at f_θ bottleneck) +# - --steps 20000 (5× v1; v1 was 4k → 59s, undertrained) +# - --gen-len 512 (4× v1; v1 was 128) +# - --lr-schedule cosine (linear warmup → cosine decay to peak/100) +# - +64 NIAH-style synthetic prompts (v1 had zero retrieval data) +# v1 reproduction: STEPS=4000 GEN_LEN=128 LR_SCHEDULE=const LOSS_TYPE=mse +# N_NIAH_PROMPTS=0 RANK=256 +# +# Pre-flight: Gemma 4 26B-A4B-it verifier (gated, needs HF_TOKEN) + +# DFlash drafter from models/dflash-kakeya-baseline/ (Git LFS, in main +# post-PR-#93). Joint memory budget: ~52 GB verifier bf16 + 0.9 GB +# drafter bf16 + ~30 GB K/V cache for 128 sequences × 512 tokens + 130 MB +# f_θ. Fits H200 80 GB single GPU; H100 80 GB also works. +# +# Output: trained f_θ checkpoint at $SAVE_DIR (default +# results/research/f_theta_v2/) containing f_theta_config.json + +# f_theta_weights.pt, plus a training report at $SAVE_DIR.json. +# +# Env knobs (v3 defaults): +# +# STEPS 20000 training steps (v3 = 5× v1) +# LR 1e-3 peak AdamW learning rate +# LR_SCHEDULE cosine const | cosine +# WARMUP_STEPS 500 +# LOSS_TYPE attn_distill attn_distill | mse | cos_mag | combined +# RANK (auto) empty = trainer auto-picks 768 for +# attn_distill / 256 for legacy losses +# N_PROMPTS 64 +# N_NIAH_PROMPTS 64 +# GEN_LEN 512 +# SAMPLE_POSITIONS 0 0 = full T (recommended for attn_distill) +# SAVE_DIR results/research/f_theta_v3 +# SEED 0 +# +# Usage (from vast.ai host with repo synced): +# +# HF_TOKEN=hf_xxx bash scripts/review_pr_k3_f_theta_train_on_vast.sh +# +# # Quick sanity (10 prompts, 200 steps, NIAH off, ~5 min): +# N_PROMPTS=10 N_NIAH_PROMPTS=0 STEPS=200 \ +# SAVE_DIR=results/research/f_theta_smoke \ +# HF_TOKEN=hf_xxx bash $0 +# +# # v1 reproduction (for direct comparability with PR #103 evidence): +# STEPS=4000 GEN_LEN=128 LR_SCHEDULE=const LOSS_TYPE=mse \ +# RANK=256 N_NIAH_PROMPTS=0 \ +# SAVE_DIR=results/research/f_theta_v1_repro \ +# HF_TOKEN=hf_xxx bash $0 +# +# Expected timing on H200: +# - Data collection: ~15-25 min (128 prompts × 512 gen_len each; +# NIAH prompts longer due to haystack; eager-attn forward is +# somewhat slower than sdpa) +# - Training 20k steps × ~80 ms/step (attention forward through +# all 30 layers per step) ≈ 25-30 min +# - Total wall: ~40-60 min +# +# Validation gates (printed at end): +# * loss_reduction_factor ≥ 5.0 +# * mseO/|O_tgt|^2 ratio < 0.05 → attention output preserved +# (v3 attn_distill diagnostic) +# * f_theta_weights.pt non-empty (~352 MB at rank=768) +# +# These are sanity gates, not product gates. Product gate is the +# integrated NIAH ladder evidence (separate reviewer aid: +# scripts/review_pr_k3_integrated_niah_on_vast.sh). + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +STEPS="${STEPS:-20000}" +LR="${LR:-1e-3}" +LR_SCHEDULE="${LR_SCHEDULE:-cosine}" +WARMUP_STEPS="${WARMUP_STEPS:-500}" +LOSS_TYPE="${LOSS_TYPE:-attn_distill_hybrid}" +INIT_FROM="${INIT_FROM:-}" # optional warm-start checkpoint dir +S5_EXACT_FULL_ATTN="${S5_EXACT_FULL_ATTN:-0}" # 1 = exclude full-attn layers from loss (S5) +LAMBDA_K_DIR="${LAMBDA_K_DIR:-1.0}" +LAMBDA_V_DIR="${LAMBDA_V_DIR:-1.0}" +LAMBDA_K_MAG="${LAMBDA_K_MAG:-0.1}" +LAMBDA_V_MAG="${LAMBDA_V_MAG:-0.1}" +RANK="${RANK:-}" # empty = trainer auto-picks (768 for attn_distill, else 256) +N_PROMPTS="${N_PROMPTS:-64}" +N_NIAH_PROMPTS="${N_NIAH_PROMPTS:-64}" +NIAH_MIN_LINES="${NIAH_MIN_LINES:-30}" +NIAH_MAX_LINES="${NIAH_MAX_LINES:-90}" +GEN_LEN="${GEN_LEN:-512}" +SAMPLE_POSITIONS="${SAMPLE_POSITIONS:-0}" # 0 = full T (attn_distill default) +SAVE_DIR="${SAVE_DIR:-results/research/f_theta_v4_hybrid}" +SEED="${SEED:-0}" + +stamp="$(date +%s)" +log_dir="results/research/logs" +mkdir -p "$log_dir" +log="${log_dir}/k3_f_theta_train_vast_${stamp}.log" + +attn_impl_msg="eager" +if [[ "$LOSS_TYPE" != "attn_distill" ]]; then attn_impl_msg="sdpa"; fi +rank_msg="$RANK" +if [[ -z "$RANK" ]]; then + if [[ "$LOSS_TYPE" == "attn_distill" ]]; then rank_msg="auto (768)"; else rank_msg="auto (256)"; fi +fi +echo "==> K3 Block C — f_θ K/V projection training (vast.ai CUDA, v3)" +echo " Verifier: google/gemma-4-26B-A4B-it (bf16, $attn_impl_msg)" +echo " Drafter: models/dflash-kakeya-baseline (in main, Git LFS)" +echo " Loss type: $LOSS_TYPE" +echo " Steps: $STEPS" +echo " Peak LR: $LR (schedule: $LR_SCHEDULE, warmup: $WARMUP_STEPS)" +echo " Rank: $rank_msg" +echo " N general prompts: $N_PROMPTS" +echo " N NIAH prompts: $N_NIAH_PROMPTS" +echo " NIAH lines: ${NIAH_MIN_LINES}-${NIAH_MAX_LINES}" +echo " Gen len: $GEN_LEN" +echo " Sample positions: $SAMPLE_POSITIONS (0 = full T)" +echo " Save dir: $SAVE_DIR" +echo " Log: $log" +echo + +# Pre-flight 1: HF token +if [[ -z "${HF_TOKEN:-}" ]] && ! huggingface-cli whoami > /dev/null 2>&1; then + echo "ERROR: no HF auth detected. Run:" + echo " huggingface-cli login # Gemma 4 is gated" + echo "or:" + echo " export HF_TOKEN=hf_xxx" + exit 1 +fi + +# Pre-flight 2: drafter checkpoint +if [[ ! -d "models/dflash-kakeya-baseline" ]]; then + echo "ERROR: models/dflash-kakeya-baseline/ missing." + echo " This is Git LFS-tracked; pull via:" + echo " git lfs install" + echo " git lfs pull" + exit 2 +fi +if [[ ! -f "models/dflash-kakeya-baseline/model.safetensors" ]]; then + echo "ERROR: models/dflash-kakeya-baseline/model.safetensors missing." + exit 2 +fi +size_bytes=$(stat -c%s "models/dflash-kakeya-baseline/model.safetensors" 2>/dev/null \ + || stat -f%z "models/dflash-kakeya-baseline/model.safetensors") +if [[ "$size_bytes" -lt 100000000 ]]; then + echo "ERROR: model.safetensors is only $size_bytes bytes (likely LFS pointer)." + echo " Run 'git lfs pull' to fetch the real 859 MB file." + exit 2 +fi + +# Pre-flight 3: torch + CUDA + transformers 5.x +if ! python3 -c " +import torch, sys +if not torch.cuda.is_available(): + print('ERROR: CUDA not available', file=sys.stderr); sys.exit(2) +print(f'torch {torch.__version__} cuda={torch.version.cuda}', file=sys.stderr) +"; then + exit 3 +fi +if ! python3 -c " +import transformers, sys +v = transformers.__version__.split('.') +if int(v[0]) < 5: + print(f'WARN: transformers {transformers.__version__} (need 5.x for Gemma 4)', + file=sys.stderr) +print(f'transformers {transformers.__version__}', file=sys.stderr) +"; then + exit 4 +fi + +# Run +echo "==> Running f_θ training" +extra_flags=() +if [[ "$N_NIAH_PROMPTS" -eq 0 ]]; then + extra_flags+=(--no-niah-prompts) +else + extra_flags+=(--niah-min-lines "$NIAH_MIN_LINES") + extra_flags+=(--niah-max-lines "$NIAH_MAX_LINES") +fi +if [[ -n "$RANK" ]]; then + extra_flags+=(--rank "$RANK") +fi +if [[ -n "$INIT_FROM" ]]; then + extra_flags+=(--init-from "$INIT_FROM") +fi +if [[ "$S5_EXACT_FULL_ATTN" == "1" ]]; then + extra_flags+=(--s5-exact-full-attn) +fi +if [[ "$LOSS_TYPE" == "attn_distill_hybrid" ]]; then + extra_flags+=(--lambda-k-dir "$LAMBDA_K_DIR") + extra_flags+=(--lambda-v-dir "$LAMBDA_V_DIR") + extra_flags+=(--lambda-k-mag "$LAMBDA_K_MAG") + extra_flags+=(--lambda-v-mag "$LAMBDA_V_MAG") +fi +PYTHONPATH=.:sdks/python python3 scripts/research/k3_f_theta_train.py \ + --steps "$STEPS" \ + --lr "$LR" \ + --lr-schedule "$LR_SCHEDULE" \ + --warmup-steps "$WARMUP_STEPS" \ + --loss-type "$LOSS_TYPE" \ + --n-prompts "$N_PROMPTS" \ + --n-niah-prompts "$N_NIAH_PROMPTS" \ + --gen-len "$GEN_LEN" \ + --sample-positions "$SAMPLE_POSITIONS" \ + --save "$SAVE_DIR" \ + --seed "$SEED" "${extra_flags[@]}" 2>&1 | tee "$log" +exit_code=${PIPESTATUS[0]} + +echo +if [[ "$exit_code" -eq 0 ]]; then + echo "==> f_θ training PASS" + echo " Checkpoint: $SAVE_DIR/{f_theta_config.json, f_theta_weights.pt}" + echo " Report: ${SAVE_DIR}.json" + echo " Log: $log" + echo + echo "Inspect training report:" + echo " python3 -c 'import json; r = json.load(open(\"${SAVE_DIR}.json\"));" + echo " print(\"initial_loss:\", r[\"initial_loss\"]);" + echo " print(\"final_loss:\", r[\"final_loss\"]);" + echo " print(\"reduction_factor:\", r[\"loss_reduction_factor\"]);" + echo " print(\"train_seconds:\", r[\"train_seconds\"])'" + echo + echo "Commit checkpoint + report:" + echo " git add $SAVE_DIR/ ${SAVE_DIR}.json" + echo " git lfs track \"$SAVE_DIR/f_theta_weights.pt\"" + echo " git add .gitattributes" + echo " git commit -m 'K3 f_θ trained checkpoint v1'" + echo " git push" +else + echo "==> f_θ training FAILED (exit=$exit_code)" + echo " Log: $log" +fi + +exit "$exit_code" diff --git a/scripts/review_pr_k3_integrated_niah_on_vast.sh b/scripts/review_pr_k3_integrated_niah_on_vast.sh new file mode 100755 index 00000000..dd47111a --- /dev/null +++ b/scripts/review_pr_k3_integrated_niah_on_vast.sh @@ -0,0 +1,184 @@ +#!/usr/bin/env bash +# vast.ai (CUDA) reviewer aid for K3 integrated NIAH eval — +# the complete Kakeya inference engine product evidence on CUDA. +# +# Combines CrossModelDLMRestoredVerifier (verifier with sink+window +# cache + drafter K/V Restoration via f_θ) with the K1.E NIAH +# evaluation harness. This is the **K3 product gate**. +# +# Pre-flight requires: +# 1. HF_TOKEN (Gemma 4 is gated) +# 2. models/dflash-kakeya-baseline/ Git LFS pulled +# 3. f_θ checkpoint at $F_THETA_DIR (default +# results/research/f_theta_v1/) — produced by +# scripts/review_pr_k3_f_theta_train_on_vast.sh +# 4. CUDA + transformers 5.x (Gemma 4 support) +# +# Validates (per ADR 0008 §11.8 release gates): +# +# 1. Architectural correctness: +# effective_attention_fraction = 1.0 at every NIAH ladder rung. +# Verifier "sees" full context despite sink+window-only cache. +# +# 2. Memory bounded: +# Sustained cross-model verifier KV-cache memory ≤ O(sink+window). +# +# 3. Recall preservation: +# |recall_cross_model - recall_oracle| ≤ 5 pp at every rung +# (ADR §11.8 criterion 1a). This is the architecturally-meaningful +# gate (independent of base-model long-context capability). +# +# Env knobs (defaults): +# +# F_THETA_DIR results/research/f_theta_v1 +# N_SAMPLES 10 per ladder rung +# SINK_SIZE 4 +# WINDOW_SIZE 64 +# MAX_NEW_TOKENS 24 +# SEED 42 +# CONTEXT_LADDER '70 280' padding-line counts; '70'≈1.4k, '280'≈5.6k tokens +# SKIP_ORACLE=1 skip the full-attention oracle baseline +# (saves ~50% time but loses recall_delta gate) +# +# Usage: +# +# HF_TOKEN=hf_xxx bash scripts/review_pr_k3_integrated_niah_on_vast.sh +# +# Quick sanity (1.4k context, 4 samples): +# +# N_SAMPLES=4 CONTEXT_LADDER='70' \ +# HF_TOKEN=hf_xxx bash $0 +# +# Output JSONs at: +# results/research/k3_integrated_niah_ctx_.json (per rung) +# results/research/logs/k3_integrated_niah_.log (combined log) +# +# This is the production-evidence reviewer aid. After it passes: +# * ADR §11.8 K3 product gate is empirically closed +# * K3 production-scale Kakeya inference is validated on CUDA +# * Mac MLX path follows (separate PR — instrument mlx_lm directly) + +set -euo pipefail + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +F_THETA_DIR="${F_THETA_DIR:-results/research/f_theta_v1}" +N_SAMPLES="${N_SAMPLES:-10}" +SINK_SIZE="${SINK_SIZE:-4}" +WINDOW_SIZE="${WINDOW_SIZE:-64}" +MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-24}" +SEED="${SEED:-42}" +CONTEXT_LADDER="${CONTEXT_LADDER:-70 280}" +SKIP_ORACLE="${SKIP_ORACLE:-0}" + +stamp="$(date +%s)" +out_dir="results/research" +log_dir="${out_dir}/logs" +mkdir -p "$out_dir" "$log_dir" +log="${log_dir}/k3_integrated_niah_vast_${stamp}.log" + +echo "==> K3 integrated NIAH eval (vast.ai CUDA)" +echo " Verifier: google/gemma-4-26B-A4B-it" +echo " Drafter: models/dflash-kakeya-baseline" +echo " f_θ checkpoint: $F_THETA_DIR" +echo " N samples / rung: $N_SAMPLES" +echo " Sink × window: ${SINK_SIZE} × ${WINDOW_SIZE}" +echo " Context ladder: $CONTEXT_LADDER" +echo " Skip oracle: $SKIP_ORACLE" +echo " Log: $log" +echo + +# Pre-flight 1: HF token +if [[ -z "${HF_TOKEN:-}" ]] && ! huggingface-cli whoami > /dev/null 2>&1; then + echo "ERROR: no HF auth detected. Run 'huggingface-cli login' or 'export HF_TOKEN=...'." + exit 1 +fi + +# Pre-flight 2: f_θ checkpoint +if [[ ! -d "$F_THETA_DIR" ]]; then + echo "ERROR: f_θ directory '$F_THETA_DIR' missing." + echo " Train it first via:" + echo " HF_TOKEN=hf_xxx bash scripts/review_pr_k3_f_theta_train_on_vast.sh" + exit 2 +fi +if [[ ! -f "$F_THETA_DIR/f_theta_config.json" ]] || [[ ! -f "$F_THETA_DIR/f_theta_weights.pt" ]]; then + echo "ERROR: '$F_THETA_DIR' missing f_theta_config.json or f_theta_weights.pt." + ls -la "$F_THETA_DIR" 2>&1 | head -10 + exit 2 +fi + +# Pre-flight 3: drafter checkpoint +if [[ ! -f "models/dflash-kakeya-baseline/model.safetensors" ]]; then + echo "ERROR: models/dflash-kakeya-baseline/ missing or LFS not pulled." + echo " Run: git lfs install && git lfs pull" + exit 3 +fi + +# Pre-flight 4: CUDA +if ! python3 -c "import torch; assert torch.cuda.is_available(), 'no CUDA'" 2>&1; then + echo "ERROR: CUDA not available." + exit 4 +fi + +flags=( + --f-theta-dir "$F_THETA_DIR" + --n-samples "$N_SAMPLES" + --sink-size "$SINK_SIZE" + --window-size "$WINDOW_SIZE" + --max-new-tokens "$MAX_NEW_TOKENS" + --seed "$SEED" +) +[[ "$SKIP_ORACLE" == "1" ]] && flags+=(--skip-oracle) + +# Run per-rung +exit_code=0 +for n_lines in $CONTEXT_LADDER; do + lo=$(( (n_lines * 85 + 50) / 100 )) + hi=$(( (n_lines * 115 + 50) / 100 )) + [[ $lo -lt 10 ]] && lo=10 + [[ $hi -lt $((lo + 1)) ]] && hi=$((lo + 1)) + rung_report="${out_dir}/k3_integrated_niah_ctx${n_lines}_${stamp}.json" + + echo "==> ctx${n_lines}: lines [$lo, $hi] → $rung_report" + PYTHONPATH=.:sdks/python python3 scripts/research/k3_integrated_niah_eval.py \ + --haystack-min-lines "$lo" \ + --haystack-max-lines "$hi" \ + --output "$rung_report" \ + "${flags[@]}" 2>&1 | tee -a "$log" + rc=${PIPESTATUS[0]} + if [[ "$rc" -ne 0 ]]; then + echo "==> ctx${n_lines} FAILED (exit=$rc); continuing to next rung" + exit_code="$rc" + fi +done + +echo +if [[ "$exit_code" -eq 0 ]]; then + echo "==> K3 integrated NIAH eval PASS (all rungs)" + echo " Reports:" + for n_lines in $CONTEXT_LADDER; do + echo " ${out_dir}/k3_integrated_niah_ctx${n_lines}_${stamp}.json" + done + echo + echo "Inspect aggregates per rung:" + echo " for f in ${out_dir}/k3_integrated_niah_ctx*_${stamp}.json; do" + echo " python3 -c 'import json,sys; r=json.load(open(sys.argv[1]));" + echo " print(\"file:\", sys.argv[1])" + echo " print(\" cross-model recall:\", r[\"results\"][\"k3_cross_model\"][\"recall\"])" + echo " print(\" oracle recall: \", r[\"results\"].get(\"oracle\",{}).get(\"recall\"))" + echo " print(\" effective_attn: \", r[\"attention_window\"][\"per_config\"][\"k3_cross_model\"][\"effective_attention_fraction_mean\"])" + echo " print(\" recall_delta_pp: \", r[\"gate\"][\"recall_delta_vs_oracle_pp\"])" + echo " print(\" gate_5pp: \", r[\"gate\"][\"recall_delta_within_5pp\"])' \"\$f\"" + echo " done" + echo + echo "Commit evidence:" + echo " git add ${out_dir}/k3_integrated_niah_ctx*_${stamp}.json $log" + echo " git commit -m 'K3 integrated NIAH evidence (cross-model + f_θ + sink+window)'" + echo " git push" +else + echo "==> Some rungs FAILED (last exit=$exit_code)" + echo " See $log for details" +fi + +exit "$exit_code" diff --git a/scripts/start_grpc_runtime_server.py b/scripts/start_grpc_runtime_server.py index e1b36348..7c633868 100755 --- a/scripts/start_grpc_runtime_server.py +++ b/scripts/start_grpc_runtime_server.py @@ -59,6 +59,8 @@ def _resolve_kv_dims(verifier) -> Tuple[int, int, int]: over gRPC match what the verifier is actually holding. """ cfg = verifier.model.config + # Gemma 4 is multimodal: decoder dims live under config.text_config. + cfg = getattr(cfg, "text_config", None) or cfg num_layers = int(getattr(cfg, "num_hidden_layers")) # Qwen3 / Gemma / DeepSeek all support GQA — kv-heads is the # dimension that matters for KV cache size, not attention-heads. @@ -79,6 +81,10 @@ def _build_verifier( verifier_id: str, sink: int, window: int, + drafter_id: str = "", + f_theta_dir: str = "", + s5_exact_full_attn: bool = True, + device: str = "cpu", ): cfg = VerifierConfig( model_id=verifier_id, @@ -99,6 +105,23 @@ def _build_verifier( sys.exit(2) from inference_engine.backends.mlx.verifier import MLXSinkWindowVerifier return MLXSinkWindowVerifier(cfg) + if backend == "restored": + # f_θ + S5 K/V-Restoration verifier (the Kakeya inference path). + # Requires the DFlash drafter + trained f_θ checkpoint. + if not drafter_id or not f_theta_dir: + raise SystemExit( + "backend=restored requires --drafter-id and --f-theta-dir" + ) + from inference_engine.v04.build_restored import load_restored_verifier + return load_restored_verifier( + verifier_id=verifier_id, + drafter_id=drafter_id, + f_theta_dir=f_theta_dir, + sink_size=sink, + window_size=window, + s5_exact_full_attn=s5_exact_full_attn, + device=device, + ) raise SystemExit(f"unknown backend: {backend}") @@ -128,6 +151,9 @@ async def _serve(args: argparse.Namespace) -> int: verifier = _build_verifier( backend=args.backend, verifier_id=args.verifier_id, sink=args.sink, window=args.window, + drafter_id=args.drafter_id, f_theta_dir=args.f_theta_dir, + s5_exact_full_attn=not args.no_s5_exact_full_attn, + device=args.device, ) num_layers, num_kv_heads, head_dim = _resolve_kv_dims(verifier) @@ -189,8 +215,19 @@ def _on_signal(sig: int) -> None: def main() -> int: ap = argparse.ArgumentParser(description=__doc__) - ap.add_argument("--backend", choices=["cpu", "mlx"], default="cpu") + ap.add_argument("--backend", choices=["cpu", "mlx", "restored"], default="cpu") ap.add_argument("--verifier-id", default="Qwen/Qwen3-0.6B") + ap.add_argument("--device", default="cpu", + help="Torch device for the restored backend " + "(e.g. 'cuda' on a GPU host). Ignored by cpu/mlx.") + ap.add_argument("--drafter-id", default="", + help="DFlash drafter id/path (backend=restored).") + ap.add_argument("--f-theta-dir", default="", + help="Trained f_θ checkpoint dir (backend=restored).") + ap.add_argument("--no-s5-exact-full-attn", action="store_true", + help="Disable S5 (keep f_θ for full-attention layers too). " + "By default backend=restored uses S5 exact full-attn " + "layers for recall.") ap.add_argument("--bind", default=DEFAULT_BIND_ADDRESS, help=f"host:port to bind. Default: {DEFAULT_BIND_ADDRESS}") ap.add_argument("--capacity", type=int, default=4, diff --git a/tests/backends/mlx/test_cross_model_dlm_verifier.py b/tests/backends/mlx/test_cross_model_dlm_verifier.py new file mode 100644 index 00000000..add1e833 --- /dev/null +++ b/tests/backends/mlx/test_cross_model_dlm_verifier.py @@ -0,0 +1,162 @@ +"""Linux-CI tests for the MLX cross-model DLM-restored verifier helpers. + +Only the non-MLX (model-structure) helpers are exercised here — ``mlx`` is +imported lazily inside the MLX-touching functions, so this module imports and +these helpers run on Linux without Apple Silicon. The MLX forward/injection +path is validated on a Mac by +``scripts/research/k3_integrated_niah_eval_mac.py``. +""" + +from __future__ import annotations + +import pytest + +from inference_engine.backends.mlx import cross_model_dlm_verifier as cmv + + +class _Attn: + def __init__(self, head_dim, layer_type, has_kv, layer_idx): + self.head_dim = head_dim + self.layer_type = layer_type + self.has_kv = has_kv + self.layer_idx = layer_idx + + +class _Layer: + def __init__(self, attn): + self.self_attn = attn + + +class _TextModel: + def __init__(self, layers, previous_kvs=None): + self.layers = layers + self.embed_tokens = object() + if previous_kvs is not None: + self.previous_kvs = previous_kvs + + +def _gemma4_like(num_kv_shared=0): + """30 layers: full-attention (head_dim 512) at 5,11,17,23,29, else sliding + (256). KV sharing for the last `num_kv_shared` layers (same-type source).""" + n = 30 + full = {5, 11, 17, 23, 29} + layers = [] + for i in range(n): + hd = 512 if i in full else 256 + lt = "full_attention" if i in full else "sliding_attention" + has_kv = i < n - num_kv_shared + layers.append(_Layer(_Attn(hd, lt, has_kv, i))) + prev = list(range(n)) + if num_kv_shared > 0: + m = n - num_kv_shared + by_type = {} + for i in range(m): + by_type[layers[i].self_attn.layer_type] = i + for j in range(m, n): + prev[j] = by_type[layers[j].self_attn.layer_type] + return _TextModel(layers, prev) + + +class _Wrapper: + """Mimics mlx_lm wrapper: .model is the text model.""" + def __init__(self, tm): + self.model = tm + + +def test_resolve_text_model_via_model_attr(): + tm = _gemma4_like() + assert cmv.resolve_mlx_text_model(_Wrapper(tm)) is tm + + +def test_resolve_text_model_direct(): + tm = _gemma4_like() + # text-only wrapper: object whose .model is the text model + assert cmv.resolve_mlx_text_model(_Wrapper(tm)) is tm + + +def test_full_attention_layer_indices_gemma4(): + tm = _gemma4_like() + assert cmv.mlx_full_attention_layer_indices(tm) == [5, 11, 17, 23, 29] + + +def test_full_attention_layer_indices_uniform_returns_empty(): + layers = [_Layer(_Attn(256, "sliding_attention", True, i)) for i in range(4)] + tm = _TextModel(layers, list(range(4))) + assert cmv.mlx_full_attention_layer_indices(tm) == [] + + +def test_kv_source_map_no_sharing_is_identity(): + tm = _gemma4_like(num_kv_shared=0) + assert cmv.kv_source_layer_map(tm) == list(range(30)) + + +def test_kv_source_map_with_sharing_points_to_source(): + tm = _gemma4_like(num_kv_shared=10) + src = cmv.kv_source_layer_map(tm) + # The last 10 layers (20..29) are sharers; each maps to an earlier + # same-type source layer (< 20), never to itself. + for j in range(20, 30): + assert src[j] < 20 + assert src[j] != j + # has_kv layers map to themselves. + for i in range(20): + assert src[i] == i + + +def test_resolve_raises_without_text_model(): + class Bad: + pass + with pytest.raises(AttributeError): + cmv.resolve_mlx_text_model(Bad()) + + +def _gemma4_geom_model(): + """Mock with n_kv_heads/head_dim per layer (8/256 sliding, 2/512 full).""" + full = {5, 11, 17, 23, 29} + layers = [] + for i in range(30): + if i in full: + layers.append(_Layer(_AttnGeom(2, 512, "full_attention", i))) + else: + layers.append(_Layer(_AttnGeom(8, 256, "sliding_attention", i))) + return _TextModel(layers, list(range(30))) + + +class _AttnGeom(_Attn): + def __init__(self, n_kv, head_dim, layer_type, layer_idx): + super().__init__(head_dim, layer_type, True, layer_idx) + self.n_kv_heads = n_kv + + +def test_kv_memory_report_s5_vs_naive(): + tm = _gemma4_geom_model() + full = [5, 11, 17, 23, 29] + s5 = cmv.kv_memory_report( + tm, sink_size=4, window_size=64, seq_len=5500, exact_layer_indices=full) + naive = cmv.kv_memory_report( + tm, sink_size=5500, window_size=0, seq_len=5500, + exact_layer_indices=list(range(30))) + # S5 dramatically smaller than naive full-KV; growth = 5 full layers only. + assert s5["total_resident_bytes"] < naive["total_resident_bytes"] / 5 + # per-token growth = 5 full layers * (2 * 2 kv * 512 * 2 bytes) = 20480 B + assert s5["per_token_growth_bytes"] == 5 * (2 * 2 * 512 * 2) + + +def test_kv_memory_report_compression_shrinks_slope(): + tm = _gemma4_geom_model() + full = [5, 11, 17, 23, 29] + exact = cmv.kv_memory_report( + tm, sink_size=4, window_size=64, seq_len=5500, exact_layer_indices=full) + comp = cmv.kv_memory_report( + tm, sink_size=4, window_size=64, seq_len=5500, exact_layer_indices=full, + compress_full_bits_per_token_per_head=3232.0) + # KakeyaLattice (~2.5x) shrinks the full-layer term + the linear slope. + assert comp["total_resident_bytes"] < exact["total_resident_bytes"] + assert comp["per_token_growth_bytes"] < exact["per_token_growth_bytes"] + + +def test_per_layer_kv_geometry(): + tm = _gemma4_geom_model() + geom = cmv.per_layer_kv_geometry(tm) + assert geom[0] == (8, 256, "sliding_attention") + assert geom[5] == (2, 512, "full_attention") diff --git a/tests/inference_engine/v04/test_cross_model_dlm_verifier.py b/tests/inference_engine/v04/test_cross_model_dlm_verifier.py new file mode 100644 index 00000000..a8cdcd80 --- /dev/null +++ b/tests/inference_engine/v04/test_cross_model_dlm_verifier.py @@ -0,0 +1,381 @@ +"""Linux CI tests for inference_engine.v04.cross_model_dlm_verifier. + +Covers the testable surface: + +* CrossModelDLMRestoredVerifier construction + dimension validation +* project_drafter_kv shape contract +* forward end-to-end on a synthetic verifier + drafter +* No-evict short-prompt path (T <= sink+window) +* Patched attention forward correctness on a tiny synthetic verifier + +Real Gemma 4 26B-A4B + DFlash 0.4B integration is validated by the +training run + integration evidence (separate vast.ai runs). +""" + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from inference_engine.v04 import ( + CrossModelDLMRestoredVerifier, + DFlashConfig, + DFlashDrafter, + FThetaConfig, + FThetaProjection, +) + + +def _tiny_drafter_config() -> DFlashConfig: + return DFlashConfig( + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=4, + intermediate_size=32, + vocab_size=64, + rms_norm_eps=1e-6, + rope_theta=10000.0, + max_position_embeddings=64, + block_size=4, + mask_token_id=3, + target_layer_ids=(1, 3), + final_logit_softcapping=30.0, + ) + + +def _tiny_f_theta_config() -> FThetaConfig: + """Aligned with _tiny_drafter_config + a 3-layer verifier.""" + return FThetaConfig( + drafter_num_layers=2, + drafter_num_kv_heads=2, + drafter_head_dim=4, + verifier_num_layers=3, + verifier_num_kv_heads=4, + verifier_head_dim=8, + rank=16, + ) + + +class _SyntheticVerifierConfig: + num_hidden_layers = 3 + num_key_value_heads = 4 + head_dim = 8 + hidden_size = 32 + num_attention_heads = 4 + _attn_implementation = "eager" + + +class _SyntheticVerifierAttention(nn.Module): + def __init__(self) -> None: + super().__init__() + self.q_proj = nn.Linear(32, 32, bias=False) + self.k_proj = nn.Linear(32, 4 * 8, bias=False) + self.v_proj = nn.Linear(32, 4 * 8, bias=False) + self.o_proj = nn.Linear(32, 32, bias=False) + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + self.v_norm = nn.Identity() # Gemma 4 runs V through v_norm + self.head_dim = 8 + self.num_key_value_groups = 1 + self.scaling = 8 ** -0.5 + self.attention_dropout = 0.0 + self.sliding_window = None + self.config = _SyntheticVerifierConfig() + + def forward(self, hidden_states, position_embeddings, attention_mask=None, **kw): + # Exact (unpatched) path: verifier's own K/V, Gemma 4-style. + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + cos, sin = position_embeddings + q = _gemma4_style_rope( + self.q_norm(self.q_proj(hidden_states).view(hidden_shape)), + cos, sin, 2).transpose(1, 2) + k = _gemma4_style_rope( + self.k_norm(self.k_proj(hidden_states).view(hidden_shape)), + cos, sin, 2).transpose(1, 2) + v = self.v_norm(self.v_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + o, _ = _synthetic_eager(self, q, k, v, attention_mask, scaling=self.scaling) + return self.o_proj(o.reshape(*input_shape, -1)), None + + +class _SyntheticVerifierLayer(nn.Module): + def __init__(self) -> None: + super().__init__() + self.self_attn = _SyntheticVerifierAttention() + + +class _SyntheticVerifierInner(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = nn.ModuleList([ + _SyntheticVerifierLayer() for _ in range(3) + ]) + + +class _SyntheticVerifier(nn.Module): + """A minimal HF-shaped verifier for the cross-model path. + + Structure mirrors `model.model.layers[i].self_attn.{q,k,v,o}_proj` + that CrossModelDLMRestoredVerifier patches. + """ + + def __init__(self) -> None: + super().__init__() + self.config = _SyntheticVerifierConfig() + self.model = _SyntheticVerifierInner() + + def forward(self, input_ids=None, **kwargs): + # Trivial forward: just iterate layers + return logits. + # The CrossModelDLMRestoredVerifier path patches each layer's + # attn.forward; this top-level forward iterates and calls the + # patched forward on each layer with synthetic hidden state. + B, T = input_ids.shape + h = torch.randn(B, T, 32) + cos = torch.ones(B, T, 8) * 0.5 + sin = torch.ones(B, T, 8) * 0.5 + mask = torch.zeros(B, 1, T, T) + for layer in self.model.layers: + attn_out, _ = layer.self_attn.forward( + hidden_states=h, + position_embeddings=(cos, sin), + attention_mask=mask, + ) + h = attn_out # simplified + # Return a namespace with logits attribute for compatibility + class _Out: + logits = torch.zeros(B, T, 64) + return _Out() + + +class TestConstruction: + + def test_dimension_validation_rejects_mismatch(self): + f_cfg = _tiny_f_theta_config() + f_theta = FThetaProjection(f_cfg) + drafter = DFlashDrafter(_tiny_drafter_config()) + + # Verifier with 5 layers but f_θ trained for 3 → should reject + class WrongConfig: + num_hidden_layers = 5 + num_key_value_heads = 4 + head_dim = 8 + hidden_size = 32 + num_attention_heads = 4 + _attn_implementation = "eager" + + class WrongVerifier(nn.Module): + def __init__(self): + super().__init__() + self.config = WrongConfig() + self.model = nn.Module() + self.model.layers = nn.ModuleList() + + with pytest.raises(ValueError, match="verifier_num_layers"): + CrossModelDLMRestoredVerifier( + verifier_model=WrongVerifier(), + drafter=drafter, + f_theta=f_theta, + ) + + def test_construction_with_aligned_dimensions(self): + f_cfg = _tiny_f_theta_config() + f_theta = FThetaProjection(f_cfg) + drafter = DFlashDrafter(_tiny_drafter_config()) + verifier = _SyntheticVerifier() + v = CrossModelDLMRestoredVerifier( + verifier_model=verifier, + drafter=drafter, + f_theta=f_theta, + sink_size=2, + window_size=4, + ) + assert v.sink_size == 2 + assert v.window_size == 4 + + def test_negative_sink_or_window_raises(self): + f_cfg = _tiny_f_theta_config() + f_theta = FThetaProjection(f_cfg) + drafter = DFlashDrafter(_tiny_drafter_config()) + verifier = _SyntheticVerifier() + with pytest.raises(ValueError, match="non-negative"): + CrossModelDLMRestoredVerifier( + verifier_model=verifier, + drafter=drafter, + f_theta=f_theta, + sink_size=-1, + ) + + +class TestProjectDrafterKV: + """project_drafter_kv runs the drafter forward + f_θ projection + and returns verifier-K, verifier-V tensors of the right shape. + + Synthetic verifier needs a real ``get_input_embeddings()`` since + _capture_drafter_kv now uses verifier embed_tokens (corrected + 2026-06-09 to use real embedded hiddens, not synthetic zero). + """ + + def test_returns_correct_shape(self): + f_cfg = _tiny_f_theta_config() + f_theta = FThetaProjection(f_cfg) + drafter = DFlashDrafter(_tiny_drafter_config()) + verifier = _SyntheticVerifier() + # Synthetic verifier needs a real embed_tokens for the + # _capture_drafter_kv path (verifier_model.get_input_embeddings() + # is called). + verifier.embed_tokens = torch.nn.Embedding(64, 16) # vocab 64, hidden 16 + verifier.get_input_embeddings = lambda: verifier.embed_tokens + v = CrossModelDLMRestoredVerifier( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, + ) + B, T = 1, 6 + ids = torch.randint(0, 64, (B, T), dtype=torch.long) + v_k, v_v = v.project_drafter_kv(ids) + # Per-layer list contract (layers may have heterogeneous KV heads). + assert len(v_k) == f_cfg.verifier_num_layers + assert len(v_v) == f_cfg.verifier_num_layers + per_layer = ( + B, T, f_cfg.verifier_num_kv_heads, f_cfg.verifier_head_dim, + ) + for ko, vo in zip(v_k, v_v): + assert tuple(ko.shape) == per_layer + assert tuple(vo.shape) == per_layer + + +class TestNoEvictPath: + """When T <= sink+window, no positions are evicted and the + cross-model verifier path short-circuits to the underlying + verifier's plain forward.""" + + def test_short_prompt_skips_drafter_forward(self, monkeypatch): + f_cfg = _tiny_f_theta_config() + f_theta = FThetaProjection(f_cfg) + drafter = DFlashDrafter(_tiny_drafter_config()) + verifier = _SyntheticVerifier() + v = CrossModelDLMRestoredVerifier( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, + sink_size=2, window_size=4, # sink+window = 6 + ) + + # Counter to verify drafter not invoked for short prompts + calls = {"drafter": 0} + original_project = v.project_drafter_kv + def _counted(ids): + calls["drafter"] += 1 + return original_project(ids) + v.project_drafter_kv = _counted + + ids = torch.randint(0, 64, (1, 5), dtype=torch.long) # T=5, all resident + + # Verifier's forward in the synthetic stub doesn't have + # apply_rotary_pos_emb wired so we just check the no-evict + # decision: when evicted_positions is empty, project_drafter_kv + # should not run. + try: + v.forward( + ids, + apply_rotary_pos_emb=lambda q, k, c, s: (q, k), + eager_attention_forward=lambda *a, **kw: ( + torch.zeros(1, 4, 5, 8), None, + ), + ) + except Exception: + # We don't care about forward correctness here, only that + # project_drafter_kv was NOT called + pass + assert calls["drafter"] == 0 + + +def _gemma4_style_rope(x, cos, sin, unsqueeze_dim=2): + """Mirror Gemma 4's apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim).""" + c = cos.unsqueeze(unsqueeze_dim) + s = sin.unsqueeze(unsqueeze_dim) + half = x.shape[-1] // 2 + rot = torch.cat([-x[..., half:], x[..., :half]], dim=-1) + return x * c + rot * s + + +def _synthetic_eager(module, q, k, v, mask, dropout=0.0, scaling=1.0, + sliding_window=None, **kw): + """Minimal eager attention returning [B, T, H, D] like HF's.""" + attn = torch.matmul(q, k.transpose(-1, -2)) * scaling + if mask is not None: + attn = attn + mask[..., : k.shape[-2]] + attn = attn.softmax(dim=-1) + out = torch.matmul(attn, v) # [B, H, T, D] + return out.transpose(1, 2).contiguous(), None # [B, T, H, D] + + +class TestPatchedForwardRestore: + """Exercise the Gemma 4-style patched attention forward end-to-end on + the synthetic verifier with real evictions, so signature/shape/RoPE + regressions are caught without the 26B model.""" + + def test_forward_with_eviction_runs_and_injects(self): + f_cfg = _tiny_f_theta_config() + f_theta = FThetaProjection(f_cfg) + drafter = DFlashDrafter(_tiny_drafter_config()) + verifier = _SyntheticVerifier() + verifier.embed_tokens = torch.nn.Embedding(64, 16) + verifier.get_input_embeddings = lambda: verifier.embed_tokens + v = CrossModelDLMRestoredVerifier( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, + sink_size=1, window_size=2, # sink+window = 3 + ) + B, T = 1, 6 # evicted positions [1, 2, 3] + ids = torch.randint(0, 64, (B, T), dtype=torch.long) + out = v.forward( + ids, + apply_rotary_pos_emb=_gemma4_style_rope, + eager_attention_forward=_synthetic_eager, + ) + assert out.logits.shape == (B, T, 64) + + +class TestS5ExactLayers: + """S5: exact_layer_indices layers are left unpatched (use the verifier's + own K/V) and originals are restored after the forward.""" + + def test_full_attention_layer_indices_uniform_returns_empty(self): + from inference_engine.v04.cross_model_dlm_verifier import ( + full_attention_layer_indices, + ) + verifier = _SyntheticVerifier() # all layers head_dim 8 (uniform) + assert full_attention_layer_indices(verifier) == [] + + def test_exact_layer_skipped_and_restored(self): + f_theta = FThetaProjection(_tiny_f_theta_config()) + drafter = DFlashDrafter(_tiny_drafter_config()) + verifier = _SyntheticVerifier() + verifier.embed_tokens = torch.nn.Embedding(64, 16) + verifier.get_input_embeddings = lambda: verifier.embed_tokens + v = CrossModelDLMRestoredVerifier( + verifier_model=verifier, drafter=drafter, f_theta=f_theta, + sink_size=1, window_size=2, exact_layer_indices=[1], + ) + layers = verifier.model.layers + orig_fwds = [l.self_attn.forward for l in layers] + ids = torch.randint(0, 64, (1, 6), dtype=torch.long) + out = v.forward( + ids, + apply_rotary_pos_emb=_gemma4_style_rope, + eager_attention_forward=_synthetic_eager, + ) + assert out.logits.shape == (1, 6, 64) + # All attn forwards restored to their originals after the call + for l, of in zip(layers, orig_fwds): + assert l.self_attn.forward == of + + +class TestExports: + + def test_module_exposes_classes(self): + from inference_engine.v04 import cross_model_dlm_verifier as m + assert hasattr(m, "CrossModelDLMRestoredVerifier") + assert hasattr(m, "CrossModelLayerMapping") + # And the inference_engine.v04 namespace re-exports them + from inference_engine import v04 + assert v04.CrossModelDLMRestoredVerifier is m.CrossModelDLMRestoredVerifier diff --git a/tests/inference_engine/v04/test_dflash_drafter.py b/tests/inference_engine/v04/test_dflash_drafter.py index 76f4dafe..d16415f2 100644 --- a/tests/inference_engine/v04/test_dflash_drafter.py +++ b/tests/inference_engine/v04/test_dflash_drafter.py @@ -97,6 +97,63 @@ def aux_hidden_context(self, committed_token_ids): # --------------------------------------------------------------------------- +class TestDraftBlockCached: + """Fused-engine fast path: draft_block_cached (precomputed context K/V) + must equal draft_block (recomputes context K/V each call).""" + + def test_cached_matches_draft_block(self): + cfg = _tiny_cfg() + torch.manual_seed(0) + drafter = DFlashDrafter(cfg).to(torch.float32).eval() + embed_fn, lm_head_fn = _synthetic_verifier_heads(cfg) + provider = _SyntheticAuxProvider(cfg) + committed = [1, 2, 3, 4, 5] + aux, bonus = provider.aux_hidden_context(committed) + C = len(committed) + L = 4 + std = drafter.draft_block(aux, bonus, embed_fn, lm_head_fn, block_size=L) + ctx_kv = drafter.make_context_kv(aux, torch.arange(C)) + cached = drafter.draft_block_cached( + ctx_kv, bonus, embed_fn, lm_head_fn, block_size=L, context_len=C) + assert std == cached + + def test_extend_context_kv_concatenates(self): + cfg = _tiny_cfg() + torch.manual_seed(0) + drafter = DFlashDrafter(cfg).to(torch.float32).eval() + provider = _SyntheticAuxProvider(cfg) + aux, _ = provider.aux_hidden_context([1, 2, 3]) + ck = drafter.make_context_kv(aux, torch.arange(3)) + new_aux = [a[:, :2] for a in aux] # 2 "new" positions + nk = drafter.make_context_kv(new_aux, torch.arange(3, 5)) + ext = drafter.extend_context_kv(ck, nk) + assert len(ext) == cfg.num_hidden_layers + assert ext[0][0].shape[2] == 5 # 3 + 2 along seq axis + assert ext[0][1].shape[2] == 5 + + def test_incremental_extend_matches_full_context(self): + """Building ctx_kv incrementally (prompt + extend) equals building it + in one shot — so draft_block_cached drafts identically.""" + cfg = _tiny_cfg() + torch.manual_seed(0) + drafter = DFlashDrafter(cfg).to(torch.float32).eval() + embed_fn, lm_head_fn = _synthetic_verifier_heads(cfg) + provider = _SyntheticAuxProvider(cfg) + full = [1, 2, 3, 4, 5, 6] + aux_full, bonus = provider.aux_hidden_context(full) + C = len(full) + full_kv = drafter.make_context_kv(aux_full, torch.arange(C)) + # incremental: first 4, then extend by 2 (same aux slices) + ck = drafter.make_context_kv([a[:, :4] for a in aux_full], torch.arange(4)) + ck = drafter.extend_context_kv( + ck, drafter.make_context_kv([a[:, 4:6] for a in aux_full], torch.arange(4, 6))) + d_full = drafter.draft_block_cached( + full_kv, bonus, embed_fn, lm_head_fn, block_size=4, context_len=C) + d_inc = drafter.draft_block_cached( + ck, bonus, embed_fn, lm_head_fn, block_size=4, context_len=C) + assert d_full == d_inc + + class TestDFlashConfig: def test_parses_core_fields(self): cfg = _tiny_cfg() diff --git a/tests/inference_engine/v04/test_f_theta.py b/tests/inference_engine/v04/test_f_theta.py new file mode 100644 index 00000000..3d8a4dab --- /dev/null +++ b/tests/inference_engine/v04/test_f_theta.py @@ -0,0 +1,313 @@ +"""Linux CI tests for inference_engine.v04.f_theta. + +Covers the shape contract, parameter counts, save/load round-trip, +and dtype/device dispatch. No actual training — that's exercised by +scripts/research/k3_f_theta_train.py + the integration evidence run. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest +import torch + +from inference_engine.v04.f_theta import FThetaConfig, FThetaProjection + + +def _gemma4_dflash_config(rank: int = 256) -> FThetaConfig: + """Production K3 config: Gemma 4 26B-A4B verifier + DFlash 0.4B drafter.""" + return FThetaConfig( + drafter_num_layers=5, + drafter_num_kv_heads=2, + drafter_head_dim=128, + verifier_num_layers=30, + verifier_num_kv_heads=8, + verifier_head_dim=256, + rank=rank, + ) + + +def _tiny_config() -> FThetaConfig: + """Tiny config for fast tests.""" + return FThetaConfig( + drafter_num_layers=2, + drafter_num_kv_heads=2, + drafter_head_dim=4, + verifier_num_layers=3, + verifier_num_kv_heads=4, + verifier_head_dim=8, + rank=16, + ) + + +class TestFThetaConfig: + + def test_drafter_kv_dim(self): + c = _tiny_config() + assert c.drafter_kv_dim == 2 * 4 + + def test_verifier_kv_dim(self): + c = _tiny_config() + assert c.verifier_kv_dim == 4 * 8 + + def test_encoder_in_features(self): + c = _tiny_config() + assert c.encoder_in_features == 2 * (2 * 4) + + def test_production_dimensions(self): + c = _gemma4_dflash_config() + assert c.drafter_kv_dim == 256 + assert c.verifier_kv_dim == 2048 + assert c.encoder_in_features == 5 * 256 + + def test_to_from_json_round_trip(self): + c1 = _gemma4_dflash_config(rank=128) + d = c1.to_json_dict() + c2 = FThetaConfig.from_json_dict(d) + assert c1 == c2 + + +class TestForwardShapes: + + def test_forward_k_shape(self): + c = _tiny_config() + m = FThetaProjection(c) + B, T = 2, 7 + x = torch.randn(B, T, c.encoder_in_features) + y = m.forward_k(x) + assert isinstance(y, list) + assert len(y) == c.verifier_num_layers + for layer_out in y: + assert tuple(layer_out.shape) == ( + B, T, c.verifier_num_kv_heads, c.verifier_head_dim, + ) + + def test_forward_v_shape(self): + c = _tiny_config() + m = FThetaProjection(c) + B, T = 1, 3 + x = torch.randn(B, T, c.encoder_in_features) + y = m.forward_v(x) + assert isinstance(y, list) + assert len(y) == c.verifier_num_layers + for layer_out in y: + assert tuple(layer_out.shape) == ( + B, T, c.verifier_num_kv_heads, c.verifier_head_dim, + ) + + def test_heterogeneous_layer_kv_heads(self): + """Per-layer KV-head counts (Gemma 4: 8 sliding, 4 full).""" + c = FThetaConfig( + drafter_num_layers=2, drafter_num_kv_heads=2, drafter_head_dim=4, + verifier_num_layers=4, verifier_num_kv_heads=8, verifier_head_dim=8, + rank=16, verifier_layer_kv_heads=(8, 4, 8, 4), + ) + assert c.layer_kv_dims == (64, 32, 64, 32) + m = FThetaProjection(c) + B, T = 1, 3 + x = torch.randn(B, T, c.encoder_in_features) + y = m.forward_k(x) + assert [t.shape[2] for t in y] == [8, 4, 8, 4] + # JSON round-trip preserves the per-layer field + c2 = FThetaConfig.from_json_dict(c.to_json_dict()) + assert c2 == c + + def test_forward_k_rejects_wrong_rank(self): + c = _tiny_config() + m = FThetaProjection(c) + with pytest.raises(ValueError, match="expected"): + m.forward_k(torch.randn(c.encoder_in_features)) # 1-D input + + def test_forward_k_rejects_wrong_feature_dim(self): + c = _tiny_config() + m = FThetaProjection(c) + with pytest.raises(ValueError, match="encoder_in_features"): + m.forward_k(torch.randn(2, 7, c.encoder_in_features + 1)) + + +class TestForwardKVPack: + """forward_kv_pack accepts the natural KVCapture layout + [B, T, num_kv_heads, head_dim] per layer (list of tensors). + """ + + def test_returns_paired_k_v(self): + c = _tiny_config() + m = FThetaProjection(c) + B, T = 2, 5 + k_per_layer = [ + torch.randn(B, T, c.drafter_num_kv_heads, c.drafter_head_dim) + for _ in range(c.drafter_num_layers) + ] + v_per_layer = [ + torch.randn(B, T, c.drafter_num_kv_heads, c.drafter_head_dim) + for _ in range(c.drafter_num_layers) + ] + k_out, v_out = m.forward_kv_pack(k_per_layer, v_per_layer) + assert len(k_out) == c.verifier_num_layers + assert len(v_out) == c.verifier_num_layers + per_layer = (B, T, c.verifier_num_kv_heads, c.verifier_head_dim) + for ko, vo in zip(k_out, v_out): + assert tuple(ko.shape) == per_layer + assert tuple(vo.shape) == per_layer + + def test_rejects_wrong_layer_count(self): + c = _tiny_config() + m = FThetaProjection(c) + B, T = 1, 3 + k_per_layer = [ + torch.randn(B, T, c.drafter_num_kv_heads, c.drafter_head_dim) + for _ in range(c.drafter_num_layers - 1) # one short + ] + v_per_layer = [ + torch.randn(B, T, c.drafter_num_kv_heads, c.drafter_head_dim) + for _ in range(c.drafter_num_layers) + ] + with pytest.raises(ValueError, match="drafter layers"): + m.forward_kv_pack(k_per_layer, v_per_layer) + + def test_consistency_with_explicit_concat(self): + """forward_kv_pack must equal forward_k(flatten + concat) explicitly.""" + c = _tiny_config() + torch.manual_seed(0) + m = FThetaProjection(c) + m.eval() + B, T = 2, 4 + k_per_layer = [ + torch.randn(B, T, c.drafter_num_kv_heads, c.drafter_head_dim) + for _ in range(c.drafter_num_layers) + ] + v_per_layer = [ + torch.randn(B, T, c.drafter_num_kv_heads, c.drafter_head_dim) + for _ in range(c.drafter_num_layers) + ] + k_out_pack, v_out_pack = m.forward_kv_pack(k_per_layer, v_per_layer) + k_concat = torch.cat([k.flatten(-2, -1) for k in k_per_layer], dim=-1) + v_concat = torch.cat([v.flatten(-2, -1) for v in v_per_layer], dim=-1) + with torch.no_grad(): + k_out_direct = m.forward_k(k_concat) + v_out_direct = m.forward_v(v_concat) + for kp, kd in zip(k_out_pack, k_out_direct): + assert torch.allclose(kp, kd, atol=1e-6) + for vp, vd in zip(v_out_pack, v_out_direct): + assert torch.allclose(vp, vd, atol=1e-6) + + +class TestParameterCount: + """Lock the parameter-count contract so future architecture changes + are explicit (not silent regressions in training cost).""" + + def test_tiny_param_count(self): + c = _tiny_config() + m = FThetaProjection(c) + # encoder_k, encoder_v: 2 × (encoder_in × rank) = 2 × 16×16 = 512 + # decoders_k: 3 × (rank × verifier_kv_dim) = 3 × 16×32 = 1536 + # decoders_v: same = 1536 + # Total: 512 + 1536 + 1536 = 3584 + n = sum(p.numel() for p in m.parameters()) + assert n == 512 + 1536 + 1536 + + def test_production_param_count_in_expected_range(self): + """Production f_θ should be ~31.8M params (rank=256).""" + c = _gemma4_dflash_config(rank=256) + m = FThetaProjection(c) + n = sum(p.numel() for p in m.parameters()) + # encoder_k + encoder_v: 2 * 5 * 256 * 256 = 655,360 + # decoders_k: 30 * 256 * 2048 = 15,728,640 + # decoders_v: same = 15,728,640 + # total ≈ 32,112,640 + assert 30_000_000 < n < 35_000_000 + + +class TestSaveLoadRoundTrip: + + def test_save_and_load_preserves_outputs(self, tmp_path): + c = _tiny_config() + torch.manual_seed(42) + m1 = FThetaProjection(c).eval() + # Run a forward, snapshot output + B, T = 1, 3 + x_k = torch.randn(B, T, c.encoder_in_features) + x_v = torch.randn(B, T, c.encoder_in_features) + with torch.no_grad(): + y_k_1 = m1.forward_k(x_k) + y_v_1 = m1.forward_v(x_v) + + # Save + m1.save_pretrained(tmp_path) + assert (tmp_path / "f_theta_config.json").is_file() + assert (tmp_path / "f_theta_weights.pt").is_file() + + # Load + m2 = FThetaProjection.from_pretrained(tmp_path) + assert m2.config == m1.config + with torch.no_grad(): + y_k_2 = m2.forward_k(x_k) + y_v_2 = m2.forward_v(x_v) + + for a, b in zip(y_k_1, y_k_2): + assert torch.allclose(a, b) + for a, b in zip(y_v_1, y_v_2): + assert torch.allclose(a, b) + + def test_load_rejects_missing_config(self, tmp_path): + # Write only weights, no config + m = FThetaProjection(_tiny_config()) + torch.save(m.state_dict(), tmp_path / "f_theta_weights.pt") + with pytest.raises(FileNotFoundError, match="f_theta_config.json"): + FThetaProjection.from_pretrained(tmp_path) + + def test_load_rejects_missing_weights(self, tmp_path): + c = _tiny_config() + import json + (tmp_path / "f_theta_config.json").write_text( + json.dumps(c.to_json_dict()), + ) + with pytest.raises(FileNotFoundError, match="f_theta_weights.pt"): + FThetaProjection.from_pretrained(tmp_path) + + def test_load_rejects_non_directory(self): + with pytest.raises(FileNotFoundError, match="must be a directory"): + FThetaProjection.from_pretrained("/tmp/not_a_real_directory") + + +class TestDeviceDtypeDispatch: + + def test_to_dtype(self): + m = FThetaProjection(_tiny_config()) + m_bf16 = m.to(torch.bfloat16) + for p in m_bf16.parameters(): + assert p.dtype == torch.bfloat16 + + def test_load_with_dtype_override(self, tmp_path): + m1 = FThetaProjection(_tiny_config()) + m1.save_pretrained(tmp_path) + m2 = FThetaProjection.from_pretrained(tmp_path, dtype=torch.bfloat16) + for p in m2.parameters(): + assert p.dtype == torch.bfloat16 + + +class TestGradientFlow: + """f_θ must be trainable end-to-end. Verify gradients flow through + encoder + decoders during a backward pass.""" + + def test_gradients_flow_for_k_path(self): + c = _tiny_config() + m = FThetaProjection(c) + B, T = 1, 3 + x = torch.randn(B, T, c.encoder_in_features, requires_grad=False) + out = m.forward_k(x) # list of [B, T, H, D] + loss = sum(((o) ** 2).mean() for o in out) + loss.backward() + # encoder_k should have a grad + assert m.encoder_k.weight.grad is not None + assert m.encoder_k.weight.grad.abs().sum() > 0 + # All K decoders should have grads + for dec in m.decoders_k: + assert dec.weight.grad is not None + assert dec.weight.grad.abs().sum() > 0 + # encoder_v / decoders_v should NOT (separate path) + assert m.encoder_v.weight.grad is None + for dec in m.decoders_v: + assert dec.weight.grad is None diff --git a/tests/inference_engine/v04/test_restored_sink_window_verifier.py b/tests/inference_engine/v04/test_restored_sink_window_verifier.py new file mode 100644 index 00000000..b2dae3de --- /dev/null +++ b/tests/inference_engine/v04/test_restored_sink_window_verifier.py @@ -0,0 +1,553 @@ +"""Unit tests for the Gap 1 + Gap 2 served-path integration. + +Covers, on CPU with tiny synthetic stand-ins (no real models): + +* :class:`CrossModelRestoredSinkWindowVerifier` — the full + ``SinkWindowVerifier`` public surface, with assertions that + ``forward_block`` is bit-equivalent to the underlying restored forward. +* End-to-end :class:`SpeculativeDecoder` integration over the restored + adapter (accept-all path and reject-all path), proving the served + output equals greedy restored-AR. +* :func:`build_restored_speculative_decoder` factory. + +The heavy ``load_restored_verifier`` model loader is coverage-exempt +(``# pragma: no cover``) and validated by GPU integration runs. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from inference_engine.v04 import ( + CrossModelRestoredSinkWindowVerifier, + build_restored_speculative_decoder, +) + +V = 16 # synthetic vocab size + + +# --------------------------------------------------------------------------- # +# Synthetic stand-ins +# --------------------------------------------------------------------------- # +class _Cfg: + """Verifier text-config shape consumed by the adapter's KV accounting.""" + + def __init__( + self, + num_hidden_layers=3, + num_key_value_heads=4, + head_dim=8, + hidden_size=32, + num_attention_heads=4, + ): + self.num_hidden_layers = num_hidden_layers + if num_key_value_heads is not None: + self.num_key_value_heads = num_key_value_heads + if head_dim is not None: + self.head_dim = head_dim + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + + +class _Param(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(2, 2, bias=False) + + +class _FakeVerifierModel(nn.Module): + def __init__(self, cfg=None): + super().__init__() + self.config = cfg or _Cfg() + self.lin = nn.Linear(2, 2, bias=False) + + +class _FakeRestored: + """Deterministic stand-in for CrossModelDLMRestoredVerifier. + + Implements an "increment" language model: the predicted next token + after seeing token ``x`` is ``(x + 1) % V``. ``forward`` returns + ``[1, T, V]`` logits whose argmax at position ``t`` is + ``(seq[t] + 1) % V``. This makes greedy restored-AR fully predictable. + """ + + def __init__(self, sink_size=2, window_size=4, bare_tensor=False, cfg=None): + self.sink_size = sink_size + self.window_size = window_size + self.verifier_model = _FakeVerifierModel(cfg) + self.drafter = _Param() + self.f_theta = _Param() + self._bare_tensor = bare_tensor + self.seen_helpers = [] + + def forward( + self, + input_ids, + *, + apply_rotary_pos_emb=None, + eager_attention_forward=None, + all_attention_functions=None, + ): + self.seen_helpers.append( + (apply_rotary_pos_emb, eager_attention_forward, all_attention_functions) + ) + seq = input_ids[0].tolist() + T = len(seq) + logits = torch.full((1, T, V), -10.0) + for t, tok in enumerate(seq): + logits[0, t, (int(tok) + 1) % V] = 10.0 + if self._bare_tensor: + return logits + return SimpleNamespace(logits=logits) + + +class _FakeProposer: + """Minimal DLMProposer stand-in: ``propose_block`` returns whatever + ``predict_fn(committed, L)`` yields. Carries the stats attributes the + SpeculativeDecoder resets/reads.""" + + def __init__(self, predict_fn): + self._predict = predict_fn + self.stats = SimpleNamespace( + total_blocks=0, + total_diffusion_steps=0, + total_forward_passes=0, + peak_activation_bytes=0, + weight_bytes=0, + ) + + def propose_block(self, committed_token_ids, block_size, num_steps): + toks = self._predict(list(committed_token_ids), block_size) + return SimpleNamespace(tokens=list(toks)) + + +def _make_adapter(**kw): + restored = _FakeRestored(**kw) + sentinel_aprp = object() + sentinel_eager = object() + sentinel_all = object() + adapter = CrossModelRestoredSinkWindowVerifier( + restored, + apply_rotary_pos_emb=sentinel_aprp, + eager_attention_forward=sentinel_eager, + all_attention_functions=sentinel_all, + device="cpu", + ) + return adapter, restored, (sentinel_aprp, sentinel_eager, sentinel_all) + + +# --------------------------------------------------------------------------- # +# Construction / accounting +# --------------------------------------------------------------------------- # +def test_construction_basic(): + adapter, restored, _ = _make_adapter(sink_size=2, window_size=4) + assert adapter.sink_size == 2 + assert adapter.window_size == 4 + assert adapter.cache is None + assert adapter.cache_logical_size == 0 + assert adapter.next_global_position == 0 + assert adapter.next_token_logits is None + assert adapter.cached_token_sequence == [] + assert adapter.model is restored.verifier_model + # weight_bytes sums verifier + drafter + f_theta params (>0). + assert adapter.stats.weight_bytes > 0 + assert adapter._bytes_per_kv_token > 0 + + +def test_weight_bytes_skips_module_without_parameters(): + restored = _FakeRestored() + restored.drafter = object() # no .parameters → exercised `continue` + adapter = CrossModelRestoredSinkWindowVerifier( + restored, + apply_rotary_pos_emb=None, + eager_attention_forward=None, + all_attention_functions=None, + ) + assert adapter.stats.weight_bytes > 0 # verifier + f_theta still counted + + +def test_bytes_per_kv_token_head_dim_present(): + adapter, _, _ = _make_adapter() + cfg = _Cfg(num_hidden_layers=3, num_key_value_heads=4, head_dim=8) + expected = 3 * 4 * 8 * 4 * 2 # layers*kv_heads*head_dim*itemsize(fp32)*2 + assert adapter._bytes_per_kv_token == expected + + +def test_bytes_per_kv_token_head_dim_derived_from_hidden(): + cfg = _Cfg(num_key_value_heads=2, head_dim=None, + hidden_size=32, num_attention_heads=4) + adapter, _, _ = _make_adapter(cfg=cfg) + # head_dim = hidden_size // num_attention_heads = 32 // 4 = 8 + expected = 3 * 2 * 8 * 4 * 2 + assert adapter._bytes_per_kv_token == expected + + +def test_bytes_per_kv_token_default_itemsize_when_no_params(): + # Verifier model with no parameters → itemsize loop does zero + # iterations → default itemsize (4) is used. + class _NoParamVerifier(nn.Module): + def __init__(self): + super().__init__() + self.config = _Cfg(num_hidden_layers=3, num_key_value_heads=4, + head_dim=8) + + restored = _FakeRestored() + restored.verifier_model = _NoParamVerifier() + adapter = CrossModelRestoredSinkWindowVerifier( + restored, + apply_rotary_pos_emb=None, + eager_attention_forward=None, + all_attention_functions=None, + ) + assert adapter._bytes_per_kv_token == 3 * 4 * 8 * 4 * 2 + + +def test_bytes_per_kv_token_kv_heads_fallback_and_zero_qheads(): + # No num_key_value_heads → falls back to num_attention_heads; head_dim + # None and num_attention_heads=0 → head_dim resolves to 0. + cfg = _Cfg(num_key_value_heads=None, head_dim=None, + hidden_size=0, num_attention_heads=0) + adapter, _, _ = _make_adapter(cfg=cfg) + assert adapter._bytes_per_kv_token == 0 + + +# --------------------------------------------------------------------------- # +# prefill +# --------------------------------------------------------------------------- # +def test_prefill_empty_raises(): + adapter, _, _ = _make_adapter() + with pytest.raises(ValueError, match="prompt_ids must be non-empty"): + adapter.prefill([]) + + +def test_prefill_sets_next_token_logits_and_passes_helpers(): + adapter, restored, sentinels = _make_adapter(sink_size=2, window_size=4) + prompt = [5, 6, 7] + adapter.prefill(prompt) + # next_token_logits predicts (last_token + 1) % V + assert int(torch.argmax(adapter.next_token_logits)) == (7 + 1) % V + assert adapter.next_global_position == 3 + assert adapter.cached_token_sequence == [5, 6, 7] # <= budget=6 + assert adapter.cache_logical_size == 3 + assert adapter.stats.forward_calls == 1 + assert adapter.stats.tokens_consumed == 3 + assert adapter.stats.peak_activation_bytes > 0 + assert adapter.stats.peak_kv_bytes > 0 + # the configured HF helpers were threaded through to restored.forward + assert restored.seen_helpers[-1] == sentinels + + +def test_prefill_bounds_resident_cache_when_over_budget(): + adapter, _, _ = _make_adapter(sink_size=2, window_size=4) # budget 6 + prompt = list(range(10)) # length 10 > 6 + adapter.prefill(prompt) + # sink (first 2) + window (last 4) + assert adapter.cached_token_sequence == [0, 1, 6, 7, 8, 9] + assert adapter.cache_logical_size == 6 + assert adapter.next_global_position == 10 # logical length unbounded + + +# --------------------------------------------------------------------------- # +# forward_block +# --------------------------------------------------------------------------- # +def test_forward_block_requires_prefill(): + adapter, _, _ = _make_adapter() + with pytest.raises(RuntimeError, match="not prefilled"): + adapter.forward_block([1, 2]) + + +def test_forward_block_empty_raises(): + adapter, _, _ = _make_adapter() + adapter.prefill([1, 2, 3]) + with pytest.raises(ValueError, match="tokens must be non-empty"): + adapter.forward_block([]) + + +def test_forward_block_equivalent_to_restored_forward(): + adapter, restored, _ = _make_adapter(sink_size=2, window_size=4) + prompt = [3, 4, 5] + adapter.prefill(prompt) + block = [9, 1] + out = adapter.forward_block(block) # [2, V] + assert tuple(out.shape) == (2, V) + # Equivalence: forward_block rows == restored.forward(prompt+block) slice + ref = restored.forward( + torch.tensor([prompt + block]), + ).logits[0] + assert torch.equal(out, ref[len(prompt):len(prompt) + len(block)]) + # argmax rows predict (token+1)%V + assert int(torch.argmax(out[0])) == (9 + 1) % V + assert int(torch.argmax(out[1])) == (1 + 1) % V + # provisional resident size = committed + L (un-trimmed pre-commit) + assert adapter.cache_logical_size == 3 + 2 + assert adapter.stats.forward_calls == 2 # prefill + this block + + +# --------------------------------------------------------------------------- # +# commit_or_truncate +# --------------------------------------------------------------------------- # +def test_commit_invalid_accepted_raises(): + adapter, _, _ = _make_adapter() + adapter.prefill([1, 2, 3]) + adapter.forward_block([4, 5]) + with pytest.raises(ValueError, match="0 <= accepted <= forwarded"): + adapter.commit_or_truncate(forwarded=2, accepted=3) + + +def test_commit_accept_partial_extends_committed(): + adapter, _, _ = _make_adapter(sink_size=2, window_size=4) + adapter.prefill([1, 2, 3]) + adapter.forward_block([4, 5]) + adapter.commit_or_truncate(forwarded=2, accepted=1) # keep only 4 + assert adapter.next_global_position == 4 + assert adapter.cached_token_sequence == [1, 2, 3, 4] + assert adapter._committed == [1, 2, 3, 4] + + +def test_commit_accept_zero_keeps_committed(): + adapter, _, _ = _make_adapter() + adapter.prefill([1, 2, 3]) + adapter.forward_block([4, 5]) + adapter.commit_or_truncate(forwarded=2, accepted=0) + assert adapter._committed == [1, 2, 3] + assert adapter.next_global_position == 3 + + +# --------------------------------------------------------------------------- # +# append_token +# --------------------------------------------------------------------------- # +def test_append_token_advances_and_predicts(): + adapter, _, _ = _make_adapter(sink_size=2, window_size=4) + adapter.prefill([1, 2, 3]) + nt = adapter.append_token(8) + assert adapter._committed == [1, 2, 3, 8] + assert adapter.next_global_position == 4 + # predicts (8 + 1) % V + assert int(torch.argmax(nt)) == (8 + 1) % V + + +# --------------------------------------------------------------------------- # +# CacheInspector accessors +# --------------------------------------------------------------------------- # +def test_cache_inspector_accessors(): + adapter, _, _ = _make_adapter(sink_size=2, window_size=4) + adapter.prefill(list(range(10))) + assert adapter.k_seq_length(object()) == 6 + assert adapter.kv_live_bytes(object()) == 6 * adapter._bytes_per_kv_token + assert adapter.live_kv_bytes() == 6 * adapter._bytes_per_kv_token + + +# --------------------------------------------------------------------------- # +# _sync_bounded_state window edge + _restored_logits bare-tensor + peak +# --------------------------------------------------------------------------- # +def test_sync_zero_window_keeps_only_sink(): + # window_size = 0 → budget == sink, keep_window <= 0 branch. + adapter, _, _ = _make_adapter(sink_size=2, window_size=0) + adapter.prefill([1, 2, 3, 4, 5]) + assert adapter.cached_token_sequence == [1, 2] + assert adapter.cache_logical_size == 2 + + +def test_restored_forward_returns_bare_tensor(): + adapter, _, _ = _make_adapter(bare_tensor=True, sink_size=2, window_size=4) + adapter.prefill([2, 3, 4]) + assert int(torch.argmax(adapter.next_token_logits)) == (4 + 1) % V + + +def test_record_peak_activation_keeps_max(): + adapter, _, _ = _make_adapter() + big = torch.zeros(1, 100, V) + small = torch.zeros(1, 1, V) + adapter._record_peak_activation(big) + peak = adapter.stats.peak_activation_bytes + adapter._record_peak_activation(small) # not greater → unchanged + assert adapter.stats.peak_activation_bytes == peak + + +# --------------------------------------------------------------------------- # +# Incremental-decode path (Gap-A throughput) — exercised with a fake model +# that uses a real transformers DynamicCache so the cache bookkeeping +# (build / append / truncate / position tracking) is covered on CPU. +# --------------------------------------------------------------------------- # +class _FakeIncVerifierModel(nn.Module): + def __init__(self, n_layers, V): + super().__init__() + self.config = _Cfg() + self.lin = nn.Linear(2, 2, bias=False) + self.model = SimpleNamespace(layers=[object() for _ in range(n_layers)]) + self._n = n_layers + self._V = V + + def forward(self, input_ids=None, position_ids=None, cache_position=None, + past_key_values=None, use_cache=False, **kw): + seq = input_ids[0].tolist() + L = len(seq) + logits = torch.full((1, L, self._V), -10.0) + for t, tk in enumerate(seq): + logits[0, t, (int(tk) + 1) % self._V] = 10.0 + if past_key_values is not None: + for i in range(self._n): + past_key_values.update( + torch.zeros(1, 2, L, 4), torch.zeros(1, 2, L, 4), i) + return SimpleNamespace(logits=logits, past_key_values=past_key_values) + + +class _FakeRestoredInc: + def __init__(self, n_layers=3, V=16, sink=2, window=4, incomplete=False): + self.sink_size = sink + self.window_size = window + self.verifier_model = _FakeIncVerifierModel(n_layers, V) + self.drafter = _Param() + self.f_theta = _Param() + self._n = n_layers + self._V = V + self._incomplete = incomplete + + def forward(self, input_ids, *, apply_rotary_pos_emb=None, + eager_attention_forward=None, all_attention_functions=None, + capture_kv=None): + seq = input_ids[0].tolist() + T = len(seq) + logits = torch.full((1, T, self._V), -10.0) + for t, tk in enumerate(seq): + logits[0, t, (int(tk) + 1) % self._V] = 10.0 + if capture_kv is not None: + for i in range(self._n): + if self._incomplete and i == self._n - 1: + continue # leave a None to trigger the guard + capture_kv[i] = (torch.zeros(1, 2, T, 4), torch.zeros(1, 2, T, 4)) + return SimpleNamespace(logits=logits) + + +def _make_inc_adapter(**kw): + restored = _FakeRestoredInc(**kw) + return CrossModelRestoredSinkWindowVerifier( + restored, apply_rotary_pos_emb=None, eager_attention_forward=None, + all_attention_functions=None, incremental=True), restored + + +def test_incremental_prefill_builds_cache(): + a, _ = _make_inc_adapter(sink=2, window=4) + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) # T=8 > budget 6 → eviction → capture + assert a._past is not None + assert a._past_len == 8 + assert len(a._past.layers) == 3 + assert int(a.next_token_logits.argmax()) == (8 + 1) % V + + +def test_incremental_capture_incomplete_raises(): + a, _ = _make_inc_adapter(incomplete=True) + with pytest.raises(RuntimeError, match="not captured"): + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + + +def test_incremental_forward_block_native_and_commit_accept_all(): + a, _ = _make_inc_adapter(sink=2, window=4) + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + blk = a.forward_block([9, 1]) + assert int(blk[0].argmax()) == (9 + 1) % V + assert int(blk[1].argmax()) == (1 + 1) % V + assert a._past.layers[0].keys.shape[2] == 8 + 2 # appended + a.commit_or_truncate(forwarded=2, accepted=2) + assert a._past_len == 10 + assert a._committed[-2:] == [9, 1] + + +def test_incremental_commit_truncates_rejected_tail(): + a, _ = _make_inc_adapter(sink=2, window=4) + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + a.forward_block([9, 1]) # cache → 10 + a.commit_or_truncate(forwarded=2, accepted=1) # drop 1 + assert a._past_len == 9 + assert a._past.layers[0].keys.shape[2] == 9 + assert a._committed[-1] == 9 + + +def test_incremental_append_token_advances(): + a, _ = _make_inc_adapter(sink=2, window=4) + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + nt = a.append_token(5) + assert a._past_len == 9 + assert int(nt.argmax()) == (5 + 1) % V + + +def test_incremental_reset_clears_past(): + a, _ = _make_inc_adapter() + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + a.reset() + assert a._past is None and a._past_len == 0 + + +def test_incremental_prefill_twice_reuses_num_layers(): + a, _ = _make_inc_adapter(sink=2, window=4) + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + n1 = a._num_layers_cache + a.prefill([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # _num_layers_cache already set + assert a._num_layers_cache == n1 == 3 + + +def test_incremental_commit_skips_empty_layer(): + a, _ = _make_inc_adapter(sink=2, window=4) + a.prefill([1, 2, 3, 4, 5, 6, 7, 8]) + a.forward_block([9, 1]) + # Defensive: a layer with keys=None must be skipped during truncation. + a._past.layers.append(SimpleNamespace(keys=None, values=None)) + a.commit_or_truncate(forwarded=2, accepted=1) + assert a._past_len == 9 + assert a._past.layers[0].keys.shape[2] == 9 + + +# --------------------------------------------------------------------------- # +# End-to-end SpeculativeDecoder integration (Gap 1 + factory) +# --------------------------------------------------------------------------- # +def _greedy_reference(prompt, n): + """Greedy restored-AR: predict (x+1)%V repeatedly from prompt[-1].""" + out = [] + x = prompt[-1] + for _ in range(n): + x = (x + 1) % V + out.append(x) + return out + + +def test_spec_decode_accept_all_matches_greedy(): + adapter, _, _ = _make_adapter(sink_size=4, window_size=64) + # Proposer that proposes the *correct* continuation → all accepted. + def predict(committed, L): + x = committed[-1] + toks = [] + for _ in range(L): + x = (x + 1) % V + toks.append(x) + return toks + + decoder = build_restored_speculative_decoder( + _FakeProposer(predict), adapter, block_size=4, num_diffusion_steps=2, + ) + prompt = [1, 2, 3] + res = decoder.generate(prompt_ids=prompt, max_new_tokens=10) + assert res.output_token_ids == _greedy_reference(prompt, 10) + assert res.acceptance_rate > 0.0 # tokens were accepted + + +def test_spec_decode_reject_all_still_matches_greedy(): + adapter, _, _ = _make_adapter(sink_size=4, window_size=64) + # Proposer that always proposes a token the verifier won't predict + # (offset by 2) → accepted=0 each block; verifier emits the correction. + def predict(committed, L): + x = committed[-1] + return [(x + 2) % V] * L + + decoder = build_restored_speculative_decoder( + _FakeProposer(predict), adapter, block_size=4, num_diffusion_steps=2, + ) + prompt = [1, 2, 3] + res = decoder.generate(prompt_ids=prompt, max_new_tokens=6) + # Even with 0 acceptance, the verifier's correction token each block + # is the greedy next token → output still equals greedy restored-AR. + assert res.output_token_ids == _greedy_reference(prompt, 6) + assert res.total_accepted == 0 diff --git a/tests/research/test_k3_f_theta_train_v2.py b/tests/research/test_k3_f_theta_train_v2.py new file mode 100644 index 00000000..d151608a --- /dev/null +++ b/tests/research/test_k3_f_theta_train_v2.py @@ -0,0 +1,617 @@ +"""Linux CI tests for the v2 trainer pieces in +``scripts/research/k3_f_theta_train``: cosine+magnitude loss, NIAH +synthetic prompts, cosine LR schedule. + +These are the trainer-side fixes for the recall=0 evidence in PR #103 +(f_θ v1). See the script docstring for v2 motivation. + +The training loop itself requires CUDA + a 26B verifier and is +validated empirically via vast.ai (see scripts/review_pr_k3_f_theta_ +train_on_vast.sh). Linux CI verifies the building blocks. +""" + +from __future__ import annotations + +import math + +import pytest +import torch + +# Import the v2 helpers — the script is importable as a module via the +# scripts/research package convention used elsewhere in the codebase. +import importlib.util +import pathlib +import sys + +_SCRIPT = ( + pathlib.Path(__file__).resolve().parents[2] + / "scripts" / "research" / "k3_f_theta_train.py" +) +_spec = importlib.util.spec_from_file_location("k3_f_theta_train", _SCRIPT) +_mod = importlib.util.module_from_spec(_spec) +assert _spec.loader is not None +# Register in sys.modules BEFORE exec so @dataclass (which probes +# sys.modules[cls.__module__] for KW_ONLY type-id check) doesn't trip. +sys.modules["k3_f_theta_train"] = _mod +_spec.loader.exec_module(_mod) + + +# --------------------------------------------------------------------------- +# Per-vector cosine + magnitude loss +# --------------------------------------------------------------------------- + + +class TestPerVectorCosineMagLoss: + + def test_identical_vectors_give_zero_loss(self): + x = torch.randn(2, 5, 4, 8) + loss, cos, mag = _mod._per_vector_cosine_mag_loss(x, x) + assert float(loss) == pytest.approx(0.0, abs=1e-5) + assert float(cos) == pytest.approx(0.0, abs=1e-5) + assert float(mag) == pytest.approx(0.0, abs=1e-5) + + def test_negated_vectors_give_cos_loss_2(self): + x = torch.randn(2, 5, 4, 8) + loss, cos, mag = _mod._per_vector_cosine_mag_loss(x, -x) + # cos sim between x and -x is -1, so 1-cos = 2 + assert float(cos) == pytest.approx(2.0, abs=1e-3) + # magnitude is the same so mag_loss ≈ 0 + assert float(mag) == pytest.approx(0.0, abs=1e-3) + + def test_orthogonal_vectors_give_cos_loss_1(self): + # Two orthogonal unit vectors — cos sim = 0 → cos_loss = 1 + pred = torch.zeros(1, 1, 1, 4) + pred[..., 0] = 1.0 + tgt = torch.zeros(1, 1, 1, 4) + tgt[..., 1] = 1.0 + _, cos, mag = _mod._per_vector_cosine_mag_loss(pred, tgt) + assert float(cos) == pytest.approx(1.0, abs=1e-5) + assert float(mag) == pytest.approx(0.0, abs=1e-5) + + def test_scaled_vector_gives_zero_cos_nonzero_mag(self): + # pred = 2 * tgt → same direction (cos=1, cos_loss=0) + # but different magnitude (‖pred‖ = 2‖tgt‖) + tgt = torch.randn(2, 5, 4, 8) + pred = 2.0 * tgt + _, cos, mag = _mod._per_vector_cosine_mag_loss(pred, tgt) + assert float(cos) == pytest.approx(0.0, abs=1e-3) + assert float(mag) > 0.0 + + def test_loss_is_differentiable(self): + pred = torch.randn(2, 5, 4, 8, requires_grad=True) + tgt = torch.randn(2, 5, 4, 8) + loss, _, _ = _mod._per_vector_cosine_mag_loss(pred, tgt) + loss.backward() + assert pred.grad is not None + assert pred.grad.norm().item() > 0.0 + + +# --------------------------------------------------------------------------- +# Cosine LR schedule +# --------------------------------------------------------------------------- + + +class TestLRSchedule: + + def test_const_schedule_returns_peak(self): + for s in [1, 100, 10000]: + assert _mod._lr_at_step( + s, peak_lr=1e-3, total_steps=1000, + warmup_steps=100, schedule="const", + ) == 1e-3 + + def test_cosine_warmup_starts_below_peak(self): + lr_step1 = _mod._lr_at_step( + 1, peak_lr=1e-3, total_steps=1000, warmup_steps=100, + schedule="cosine", + ) + # step 1 of 100 warmup → lr = 1e-3 * 1/100 = 1e-5 + assert lr_step1 == pytest.approx(1e-5, rel=1e-6) + + def test_cosine_warmup_reaches_peak(self): + lr = _mod._lr_at_step( + 100, peak_lr=1e-3, total_steps=1000, warmup_steps=100, + schedule="cosine", + ) + assert lr == pytest.approx(1e-3, rel=1e-6) + + def test_cosine_decay_reaches_floor(self): + # At final step, cosine should be ≈ peak / 100 + lr_final = _mod._lr_at_step( + 1000, peak_lr=1e-3, total_steps=1000, warmup_steps=100, + schedule="cosine", + ) + assert lr_final == pytest.approx(1e-5, rel=1e-3) + + def test_cosine_midway_above_floor(self): + # halfway through decay (step ≈ 550), cosine factor = cos(π/2) = 0 + # → lr = floor + (peak - floor) * 0.5 ≈ 5e-4 + lr_mid = _mod._lr_at_step( + 550, peak_lr=1e-3, total_steps=1000, warmup_steps=100, + schedule="cosine", + ) + assert lr_mid == pytest.approx(5e-4, rel=0.05) + + def test_unknown_schedule_raises(self): + with pytest.raises(ValueError, match="unknown schedule"): + _mod._lr_at_step( + 1, peak_lr=1e-3, total_steps=1000, warmup_steps=100, + schedule="exponential", + ) + + +# --------------------------------------------------------------------------- +# NIAH-style synthetic training prompts +# --------------------------------------------------------------------------- + + +class TestNIAHTrainingPrompts: + + def test_returns_requested_count(self): + prompts = _mod._make_niah_training_prompts(8, seed=1234) + assert len(prompts) == 8 + assert all(isinstance(p, str) for p in prompts) + + def test_prompts_contain_needle(self): + prompts = _mod._make_niah_training_prompts(4, seed=42) + for p in prompts: + assert "secret code is" in p.lower(), \ + "needle pattern missing" + assert "Question: What is the secret code?" in p, \ + "question line missing" + + def test_seed_determinism(self): + a = _mod._make_niah_training_prompts(4, seed=99) + b = _mod._make_niah_training_prompts(4, seed=99) + assert a == b + + def test_different_seeds_give_different_prompts(self): + a = _mod._make_niah_training_prompts(4, seed=1) + b = _mod._make_niah_training_prompts(4, seed=2) + assert a != b + + def test_haystack_size_respected(self): + prompts = _mod._make_niah_training_prompts( + 4, seed=1, haystack_min_lines=10, haystack_max_lines=12, + ) + for p in prompts: + # Count haystack lines: split on newlines, drop the + # introductory + trailing question blocks. + body = p.split("\n\n", 1)[1].rsplit("\n\n", 1)[0] + n_lines = len(body.split("\n")) + assert 10 <= n_lines <= 12, f"got {n_lines}" + + def test_no_eval_seed_collision(self): + """The trainer uses seed = args.seed + 1000 to avoid colliding + with the eval's needle generator. Verify the trainer's NIAH + prompts at seed 1000+default are not byte-identical to a + trivially-seeded set the eval might use (seed 0 or 42).""" + train_seed_default = 0 + 1000 # default training NIAH seed + train_prompts = _mod._make_niah_training_prompts( + 10, seed=train_seed_default, + ) + eval_seed_42 = _mod._make_niah_training_prompts(10, seed=42) + eval_seed_0 = _mod._make_niah_training_prompts(10, seed=0) + assert train_prompts != eval_seed_42 + assert train_prompts != eval_seed_0 + + +# --------------------------------------------------------------------------- +# v3: attention-output distillation loss +# --------------------------------------------------------------------------- + + +class _StubAttn(torch.nn.Module): + """Minimal stand-in for a Gemma 4 self_attn module: q_norm, k_norm, + v_norm, q_proj (only out_features used), o_proj. Used by the + distillation loss to apply per-layer norms + o_proj. Nothing else + is needed since cos/sin/mask/attention are operator-level.""" + + def __init__(self, n_heads, n_kv_heads, head_dim, hidden_dim): + super().__init__() + from torch.nn import RMSNorm + self.head_dim = head_dim + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.scaling = head_dim ** -0.5 + self.q_norm = RMSNorm(head_dim) + self.k_norm = RMSNorm(head_dim) + self.v_norm = RMSNorm(head_dim) + # q_proj.out_features is read by the loss to compute n_heads + self.q_proj = torch.nn.Linear(hidden_dim, n_heads * head_dim, bias=False) + self.o_proj = torch.nn.Linear(n_heads * head_dim, hidden_dim, bias=False) + + +class _StubLayer(torch.nn.Module): + def __init__(self, n_heads, n_kv_heads, head_dim, hidden_dim): + super().__init__() + self.self_attn = _StubAttn(n_heads, n_kv_heads, head_dim, hidden_dim) + + +def _identity_rotary_pos_emb(x, cos, sin, unsqueeze_dim=2): + # Identity RoPE for tests — just return x unchanged. We're testing + # the wiring, not RoPE correctness (RoPE is applied via the same + # function signature the actual transformers helper uses). + return x + + +class TestAttentionDistillationLoss: + + def _build_synthetic(self, T=16): + from inference_engine.v04.f_theta import FThetaConfig, FThetaProjection + torch.manual_seed(0) + n_d_layers = 2 + d_kv_heads, d_head = 2, 4 + n_v_layers = 3 + n_heads, head_dim, hidden = 4, 4, 16 + n_kv_heads = 2 + + cfg = FThetaConfig( + drafter_num_layers=n_d_layers, + drafter_num_kv_heads=d_kv_heads, drafter_head_dim=d_head, + verifier_num_layers=n_v_layers, + verifier_num_kv_heads=n_kv_heads, verifier_head_dim=head_dim, + rank=8, + ) + f_theta = FThetaProjection(cfg).float() + + layers = [ + _StubLayer(n_heads, n_kv_heads, head_dim, hidden) + for _ in range(n_v_layers) + ] + + # Synthetic captured target data + target = _mod.AttentionTargetData( + q_raw=[torch.randn(T, n_heads * head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + o_tgt=[torch.randn(T, hidden, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + cos=[torch.randn(1, T, head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + sin=[torch.randn(1, T, head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + attention_mask=None, + num_heads_per_layer=[n_heads] * n_v_layers, + head_dim_per_layer=[head_dim] * n_v_layers, + ) + seq = _mod.CapturedSequence( + seq_len=T, + drafter_k=torch.randn(n_d_layers, T, d_kv_heads * d_head), + drafter_v=torch.randn(n_d_layers, T, d_kv_heads * d_head), + attn_target=target, + ) + return f_theta, layers, seq + + def test_attention_distill_loss_runs(self): + f_theta, layers, seq = self._build_synthetic() + diag = {} + loss = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + diag_buf=diag, + ) + assert torch.is_tensor(loss) + assert loss.dim() == 0 + assert float(loss) > 0.0 + assert "mse_O_mean" in diag + assert "abs_O_target_mean" in diag + + def test_s5_skip_layer_indices_excludes_layers(self): + """S5 mode: skip_layer_indices excludes those layers from the loss + (loss differs and is averaged over the remaining layers).""" + f_theta, layers, seq = self._build_synthetic() + full = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + ) + skipped = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + skip_layer_indices=[0], + ) + # Excluding a layer changes the (per-used-layer averaged) loss. + assert abs(float(full) - float(skipped)) > 1e-9 + + def test_loss_is_differentiable_through_f_theta(self): + f_theta, layers, seq = self._build_synthetic() + loss = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + ) + loss.backward() + any_grad = any( + p.grad is not None and p.grad.norm().item() > 0 + for p in f_theta.parameters() + ) + assert any_grad, "f_θ params should receive non-zero gradient" + + def test_o_proj_weights_remain_frozen_in_loss(self): + """o_proj is the verifier's frozen weight; gradient should NOT + accumulate on it through the loss (it's not registered to f_θ + optimizer, but we still check o_proj's grad is unset before + backprop and unset after, since we pass o_proj from + non-trainable verifier modules).""" + f_theta, layers, seq = self._build_synthetic() + # Freeze o_proj like the trainer does + for layer in layers: + for p in layer.parameters(): + p.requires_grad_(False) + loss = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + ) + loss.backward() + for layer in layers: + for p in layer.parameters(): + assert p.grad is None, "verifier params must not receive grad" + + def test_dispatch_through_f_theta_loss_function(self): + f_theta, layers, seq = self._build_synthetic() + diag = {} + loss = _mod._f_theta_loss( + f_theta, seq, sample_positions=0, + loss_type="attn_distill", + diag_buf=diag, + layers=layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + ) + assert torch.is_tensor(loss) and loss.dim() == 0 + assert "mse_O_mean" in diag + + def test_attn_distill_requires_layers_arg(self): + f_theta, layers, seq = self._build_synthetic() + with pytest.raises(ValueError, match="attn_distill requires"): + _mod._f_theta_loss( + f_theta, seq, sample_positions=0, + loss_type="attn_distill", + ) + + def test_legacy_loss_rejects_attn_only_capture(self): + """If loss_type=mse but seq has only attn_target (no verifier_k/v), + we should fail loud, not silently.""" + _, layers, seq = self._build_synthetic() + # seq has attn_target but no verifier_k/v + f_theta_legacy = None # not used past the dispatch check + from inference_engine.v04.f_theta import FThetaConfig, FThetaProjection + cfg = FThetaConfig( + drafter_num_layers=2, drafter_num_kv_heads=2, drafter_head_dim=4, + verifier_num_layers=3, verifier_num_kv_heads=2, verifier_head_dim=4, + rank=8, + ) + f_theta = FThetaProjection(cfg).float() + with pytest.raises(RuntimeError, match="legacy K/V capture"): + _mod._f_theta_loss( + f_theta, seq, sample_positions=64, loss_type="mse", + ) + + def test_sample_positions_subselects_output(self): + f_theta, layers, seq = self._build_synthetic(T=16) + # With sample=4, loss should still be a scalar but use only 4 + # output positions — verify the loss runs and is differentiable. + loss_full = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + sample_positions=None, seed=42, + ) + loss_sub = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + sample_positions=4, seed=42, + ) + assert torch.is_tensor(loss_sub) and loss_sub.dim() == 0 + # Different sample sizes should generally give different scalars + # (it's an average over different sets of positions) + # Don't strictly assert they differ — small T might collide. + assert float(loss_full) > 0.0 and float(loss_sub) > 0.0 + + +# --------------------------------------------------------------------------- +# v3 dataclass surface +# --------------------------------------------------------------------------- + + +class TestAttentionDistillationHybridLoss: + """v3 hybrid loss — fixes the f_θ collapse degeneracy exposed by + the 2026-06-10 alpha-sweep diagnostic (raw K/V rel_mse 1331×; + k_norm hides scale errors from attn_distill alone).""" + + def _build_synthetic_with_raw_kv(self, T=16): + from inference_engine.v04.f_theta import FThetaConfig, FThetaProjection + torch.manual_seed(0) + n_d_layers, d_kv_heads, d_head = 2, 2, 4 + n_v_layers = 3 + n_heads, head_dim, hidden = 4, 4, 16 + n_kv_heads = 2 + + cfg = FThetaConfig( + drafter_num_layers=n_d_layers, + drafter_num_kv_heads=d_kv_heads, drafter_head_dim=d_head, + verifier_num_layers=n_v_layers, + verifier_num_kv_heads=n_kv_heads, verifier_head_dim=head_dim, + rank=8, + ) + f_theta = FThetaProjection(cfg).float() + + layers = [ + _StubLayer(n_heads, n_kv_heads, head_dim, hidden) + for _ in range(n_v_layers) + ] + + target = _mod.AttentionTargetData( + q_raw=[torch.randn(T, n_heads * head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + o_tgt=[torch.randn(T, hidden, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + cos=[torch.randn(1, T, head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + sin=[torch.randn(1, T, head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + attention_mask=None, + num_heads_per_layer=[n_heads] * n_v_layers, + head_dim_per_layer=[head_dim] * n_v_layers, + k_raw_tgt=[torch.randn(T, n_kv_heads * head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + v_raw_tgt=[torch.randn(T, n_kv_heads * head_dim, dtype=torch.bfloat16) + for _ in range(n_v_layers)], + ) + seq = _mod.CapturedSequence( + seq_len=T, + drafter_k=torch.randn(n_d_layers, T, d_kv_heads * d_head), + drafter_v=torch.randn(n_d_layers, T, d_kv_heads * d_head), + attn_target=target, + ) + return f_theta, layers, seq + + def test_hybrid_runs_and_emits_full_diag(self): + f_theta, layers, seq = self._build_synthetic_with_raw_kv() + diag = {} + loss = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + hybrid=True, diag_buf=diag, + ) + assert torch.is_tensor(loss) and loss.dim() == 0 + for k in ("mse_O_mean", "k_dir_mean", "v_dir_mean", + "k_mag_mean", "v_mag_mean"): + assert k in diag, f"missing diag key: {k}" + + def test_hybrid_requires_raw_kv_tgt(self): + f_theta, layers, _ = self._build_synthetic_with_raw_kv() + # Build seq WITHOUT k_raw_tgt/v_raw_tgt — should fail loud + T = 16; n_v = 3; n_kv = 2; hd = 4 + target_no_raw = _mod.AttentionTargetData( + q_raw=[torch.randn(T, 4*hd, dtype=torch.bfloat16) for _ in range(n_v)], + o_tgt=[torch.randn(T, 16, dtype=torch.bfloat16) for _ in range(n_v)], + cos=[torch.randn(1, T, hd, dtype=torch.bfloat16) for _ in range(n_v)], + sin=[torch.randn(1, T, hd, dtype=torch.bfloat16) for _ in range(n_v)], + attention_mask=None, + num_heads_per_layer=[4]*n_v, head_dim_per_layer=[hd]*n_v, + ) + seq_no_raw = _mod.CapturedSequence( + seq_len=T, + drafter_k=torch.randn(2, T, 8), drafter_v=torch.randn(2, T, 8), + attn_target=target_no_raw, + ) + with pytest.raises(RuntimeError, match="k_raw_tgt"): + _mod._attention_distillation_loss( + f_theta, seq_no_raw, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + hybrid=True, + ) + + def test_hybrid_dispatch_via_loss_type(self): + f_theta, layers, seq = self._build_synthetic_with_raw_kv() + diag = {} + loss = _mod._f_theta_loss( + f_theta, seq, sample_positions=0, + loss_type="attn_distill_hybrid", + diag_buf=diag, + layers=layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + ) + assert torch.is_tensor(loss) and loss.dim() == 0 + assert "k_dir_mean" in diag + + def test_hybrid_loss_strictly_higher_than_attn_distill_alone(self): + """Hybrid adds direction + magnitude terms; with random initial + f_θ, all four components are non-trivial → hybrid > attn_distill + (which only has the mse_O term). Verifies the additional terms + actually affect the loss, not silently zero.""" + f_theta, layers, seq = self._build_synthetic_with_raw_kv() + loss_attn_only = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + hybrid=False, + ) + loss_hybrid = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + hybrid=True, + ) + assert float(loss_hybrid.detach()) > float(loss_attn_only.detach()) + + def test_hybrid_grad_flows_to_f_theta(self): + f_theta, layers, seq = self._build_synthetic_with_raw_kv() + loss = _mod._attention_distillation_loss( + f_theta, seq, layers, + apply_rotary_pos_emb=_identity_rotary_pos_emb, + device=torch.device("cpu"), + hybrid=True, + ) + loss.backward() + any_grad = any( + p.grad is not None and p.grad.norm().item() > 0 + for p in f_theta.parameters() + ) + assert any_grad + + +class TestAttentionTargetDataDataclass: + + def test_fields_present(self): + td = _mod.AttentionTargetData( + q_raw=[], o_tgt=[], cos=[], sin=[], + attention_mask=None, + num_heads_per_layer=[], head_dim_per_layer=[], + ) + assert td.q_raw == [] + assert td.attention_mask is None + + def test_captured_sequence_optional_kv_and_attn(self): + seq = _mod.CapturedSequence( + seq_len=10, + drafter_k=torch.zeros(2, 10, 8), + drafter_v=torch.zeros(2, 10, 8), + ) + assert seq.verifier_k is None + assert seq.verifier_v is None + assert seq.attn_target is None + + def test_captured_sequence_attn_target_path(self): + td = _mod.AttentionTargetData( + q_raw=[], o_tgt=[], cos=[], sin=[], + attention_mask=None, + num_heads_per_layer=[], head_dim_per_layer=[], + ) + seq = _mod.CapturedSequence( + seq_len=10, + drafter_k=torch.zeros(2, 10, 8), + drafter_v=torch.zeros(2, 10, 8), + attn_target=td, + ) + assert seq.attn_target is td + + def test_attention_target_data_optional_raw_kv_for_hybrid(self): + """k_raw_tgt and v_raw_tgt fields default to None; populated + when capture_raw_kv=True is passed during data collection.""" + td_legacy = _mod.AttentionTargetData( + q_raw=[], o_tgt=[], cos=[], sin=[], + attention_mask=None, + num_heads_per_layer=[], head_dim_per_layer=[], + ) + assert td_legacy.k_raw_tgt is None + assert td_legacy.v_raw_tgt is None + + td_hybrid = _mod.AttentionTargetData( + q_raw=[torch.zeros(8, 16)], o_tgt=[torch.zeros(8, 32)], + cos=[torch.zeros(1, 8, 4)], sin=[torch.zeros(1, 8, 4)], + attention_mask=None, + num_heads_per_layer=[4], head_dim_per_layer=[4], + k_raw_tgt=[torch.randn(8, 8)], + v_raw_tgt=[torch.randn(8, 8)], + ) + assert td_hybrid.k_raw_tgt is not None + assert td_hybrid.v_raw_tgt is not None