From 389a3948d305ede99b7548f44cb6912c0322bbc7 Mon Sep 17 00:00:00 2001 From: fluffy314 Date: Thu, 11 Jun 2026 18:42:29 +0800 Subject: [PATCH] Experiment with proposer KV full-attn restoration on Mac Adds a disabled S5 experiment that builds full-attention restored K/V from proposer K/V via f_theta instead of an extra verifier capture forward. The Mac ctx70 evidence shows the fixed build cost drops to about 2s, but recall regresses to 0/1 and decode slows, so this is evaluation evidence for retraining/next-step design rather than a merge-ready optimization. Co-authored-by: Cursor --- ..._fullattn_niah_ctx70_mac_no_kl_oracle.json | 439 ++++++++++++++++++ ...ftheta_fullattn_niah_ctx70_mac_oracle.json | 439 ++++++++++++++++++ .../research/k3_integrated_niah_eval_mac.py | 44 +- 3 files changed, 921 insertions(+), 1 deletion(-) create mode 100644 results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_no_kl_oracle.json create mode 100644 results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_oracle.json diff --git a/results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_no_kl_oracle.json b/results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_no_kl_oracle.json new file mode 100644 index 00000000..13cf3271 --- /dev/null +++ b/results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_no_kl_oracle.json @@ -0,0 +1,439 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 16, + "seed": 42, + "chat_template": false, + "eval_mode": "free_gen", + "teacher_forced": false, + "s5_exact_full_attn": true, + "s5_f_theta_restored_full_attn": true, + "identity_restore": false, + "compress_full_attn": false, + "kl_lattice": null, + "kl_q_range": null, + "kl_bits_per_token_per_head": null, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1625 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 49.83435483300127, + "median_latency_s": 49.83435483300127, + "per_sample_decoded": [ + " BETA-01-10-01-10-01" + ], + "per_sample_correct": [ + false + ], + "per_sample_decode_tokens": [ + 16 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.32106365284786414 + ], + "mean_throughput_tokens_per_sec": 0.32106365284786414, + "median_throughput_tokens_per_sec": 0.32106365284786414, + "min_throughput_tokens_per_sec": 0.32106365284786414, + "max_throughput_tokens_per_sec": 0.32106365284786414 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 14.481073249829933, + "median_latency_s": 14.481073249829933, + "per_sample_decoded": [ + " BETA-1409. Question: what is the secret code" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 16 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.1048904818009884 + ], + "mean_throughput_tokens_per_sec": 1.1048904818009884, + "median_throughput_tokens_per_sec": 1.1048904818009884, + "min_throughput_tokens_per_sec": 1.1048904818009884, + "max_throughput_tokens_per_sec": 1.1048904818009884 + } + }, + "gate": { + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + }, + "memory": { + "s5": { + "seq_len": 1625, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": null, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 33280000, + "total_resident_bytes": 47206400, + "total_resident_mb": 47.21, + "per_token_growth_bytes": 20480, + "per_token_growth_kb": 20.0, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 4096, + "resident_bytes": 6656000 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 4096, + "resident_bytes": 6656000 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 4096, + "resident_bytes": 6656000 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 4096, + "resident_bytes": 6656000 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 4096, + "resident_bytes": 6656000 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 366.08, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 87.1 + }, + "throughput": { + "k3_cross_model": { + "tokens": 16, + "wall_seconds": 49.834, + "tokens_per_second": 0.3211, + "mean_latency_per_sample_s": 49.834, + "eval_mode": "free_gen", + "restored_forwards_per_sample": 1, + "incremental_decode": true, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 2.154, + "prefill_attach_s": 21.971, + "decode_s": 27.862 + } + ] + } + } +} \ No newline at end of file diff --git a/results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_oracle.json b/results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_oracle.json new file mode 100644 index 00000000..d742df91 --- /dev/null +++ b/results/research/k3_s5_ftheta_fullattn_niah_ctx70_mac_oracle.json @@ -0,0 +1,439 @@ +{ + "schema_version": 1, + "kind": "k3_integrated_niah_acceptance_mac", + "config": { + "verifier_path": "models/gemma-4-26B-A4B-it-mlx-4bit", + "drafter_id": "models/dflash-kakeya-baseline", + "f_theta_dir": "results/research/f_theta_v5_s5_sliding", + "n_samples": 1, + "sink_size": 4, + "window_size": 64, + "haystack_min_lines": 60, + "haystack_max_lines": 81, + "max_new_tokens": 16, + "seed": 42, + "chat_template": false, + "eval_mode": "free_gen", + "teacher_forced": false, + "s5_exact_full_attn": true, + "s5_f_theta_restored_full_attn": true, + "identity_restore": false, + "compress_full_attn": true, + "kl_lattice": "D4", + "kl_q_range": 38, + "kl_bits_per_token_per_head": 3232.0, + "full_attention_layers": [ + 5, + 11, + 17, + 23, + 29 + ], + "prompt_token_lens": [ + 1625 + ] + }, + "results": { + "k3_cross_model": { + "name": "k3_cross_model_mac", + "samples_total": 1, + "samples_correct": 0, + "recall": 0.0, + "mean_latency_s": 66.55492287501693, + "median_latency_s": 66.55492287501693, + "per_sample_decoded": [ + " BETA-01-10-10-10-10" + ], + "per_sample_correct": [ + false + ], + "per_sample_decode_tokens": [ + 16 + ], + "per_sample_throughput_tokens_per_sec": [ + 0.24040295306248494 + ], + "mean_throughput_tokens_per_sec": 0.24040295306248494, + "median_throughput_tokens_per_sec": 0.24040295306248494, + "min_throughput_tokens_per_sec": 0.24040295306248494, + "max_throughput_tokens_per_sec": 0.24040295306248494 + }, + "oracle": { + "name": "oracle_mac", + "samples_total": 1, + "samples_correct": 1, + "recall": 1.0, + "mean_latency_s": 12.182880334090441, + "median_latency_s": 12.182880334090441, + "per_sample_decoded": [ + " BETA-1409. Question: what is the secret code" + ], + "per_sample_correct": [ + true + ], + "per_sample_decode_tokens": [ + 16 + ], + "per_sample_throughput_tokens_per_sec": [ + 1.3133183254889567 + ], + "mean_throughput_tokens_per_sec": 1.3133183254889567, + "median_throughput_tokens_per_sec": 1.3133183254889567, + "min_throughput_tokens_per_sec": 1.3133183254889567, + "max_throughput_tokens_per_sec": 1.3133183254889567 + } + }, + "gate": { + "recall_cross_model": 0.0, + "recall_oracle": 1.0, + "recall_delta_vs_oracle_pp": 100.0, + "recall_delta_within_5pp": false + }, + "memory": { + "s5": { + "seq_len": 1625, + "kv_dtype_bytes": 2, + "sink_window": 68, + "exact_layer_indices": [ + 5, + 11, + 17, + 23, + 29 + ], + "compress_full_bits_per_token_per_head": 3232.0, + "sliding_resident_bytes": 13926400, + "full_resident_bytes": 13130000, + "total_resident_bytes": 27056400, + "total_resident_mb": 27.06, + "per_token_growth_bytes": 8080, + "per_token_growth_kb": 7.89, + "per_layer": [ + { + "layer": 0, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 1, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 2, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 3, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 4, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 5, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 1616, + "resident_bytes": 2626000 + }, + { + "layer": 6, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 7, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 8, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 9, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 10, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 11, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 1616, + "resident_bytes": 2626000 + }, + { + "layer": 12, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 13, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 14, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 15, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 16, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 17, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 1616, + "resident_bytes": 2626000 + }, + { + "layer": 18, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 19, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 20, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 21, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 22, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 23, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 1616, + "resident_bytes": 2626000 + }, + { + "layer": 24, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 25, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 26, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 27, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 28, + "layer_type": "sliding_attention", + "n_kv_heads": 8, + "head_dim": 256, + "exact": false, + "resident_positions": 68, + "bytes_per_token": 8192, + "resident_bytes": 557056 + }, + { + "layer": 29, + "layer_type": "full_attention", + "n_kv_heads": 2, + "head_dim": 512, + "exact": true, + "resident_positions": 1625, + "bytes_per_token": 1616, + "resident_bytes": 2626000 + } + ] + }, + "naive_full_kv": { + "total_resident_mb": 366.08, + "per_token_growth_kb": 220.0 + }, + "savings_vs_naive_pct": 92.6 + }, + "throughput": { + "k3_cross_model": { + "tokens": 16, + "wall_seconds": 66.555, + "tokens_per_second": 0.2404, + "mean_latency_per_sample_s": 66.555, + "eval_mode": "free_gen", + "restored_forwards_per_sample": 1, + "incremental_decode": true, + "stage_timings": [ + { + "sample": 0, + "build_restoration_s": 1.962, + "prefill_attach_s": 24.096, + "decode_s": 42.455 + } + ] + } + } +} \ No newline at end of file diff --git a/scripts/research/k3_integrated_niah_eval_mac.py b/scripts/research/k3_integrated_niah_eval_mac.py index 4f24d857..07f71d22 100644 --- a/scripts/research/k3_integrated_niah_eval_mac.py +++ b/scripts/research/k3_integrated_niah_eval_mac.py @@ -86,6 +86,11 @@ def parse_args() -> argparse.Namespace: "layers' K/V (lossy round-trip) to shrink the O(T) " "linear term. Reports the compression ratio + recall " "under compression.") + ap.add_argument("--s5-f-theta-restored-full-attn", action="store_true", + help="Optimization experiment: in S5 mode, build full-attn " + "restored K/V from proposer K/V via f_theta instead " + "of running an extra verifier prompt forward to " + "capture exact full-attn K/V.") ap.add_argument("--kl-lattice", default="D4", choices=["D4", "E8"]) ap.add_argument("--kl-q-range", type=int, default=38) ap.add_argument("--skip-oracle", action="store_true") @@ -230,6 +235,28 @@ def build_restoration(prompt_ids: List[int]): evicted = compute_evicted_positions( len(prompt_ids), args.sink_size, args.window_size, ) + if ( + args.s5_exact_full_attn + and args.s5_f_theta_restored_full_attn + and not args.identity_restore + and evicted + ): + d_k, d_v = capture_drafter_kv(prompt_ids) + with torch.no_grad(): + vk, vv = f_theta.forward_kv_pack(d_k, d_v) + rk: Dict[int, Any] = {} + rv: Dict[int, Any] = {} + start, end = int(evicted[0]), int(evicted[-1]) + 1 + for li in full_attn_idx: + k_mx = torch_to_mx(vk[li][:, start:end]).astype(mx.bfloat16) + v_mx = torch_to_mx(vv[li][:, start:end]).astype(mx.bfloat16) + if li in compressors: + k_mx, v_mx = _compress_roundtrip(li, k_mx, v_mx) + rk[li], rv[li] = _pre_norm_slice_to_cache_bank( + li, k_mx, v_mx, offset=start, + ) + return rk, rv, len(prompt_ids) + if args.s5_exact_full_attn and not args.identity_restore and evicted: own = capture_own_kv_cache_slice_detached( mlx_model, @@ -317,6 +344,17 @@ def _post_rope_restored_bank( v = attn.v_norm(v).transpose(0, 2, 1, 3) return k, v + def _pre_norm_slice_to_cache_bank( + layer_idx: int, k_mx: Any, v_mx: Any, *, offset: int, + ) -> Tuple[Any, Any]: + """Convert a pre-norm restored slice to MLX cache-layout K/V.""" + layer = text_model.layers[layer_idx] + attn = layer.self_attn + k = attn.k_norm(k_mx).transpose(0, 2, 1, 3) + k = attn.rope(k, offset=int(offset)) + v = attn.v_norm(v_mx).transpose(0, 2, 1, 3) + return mx.stop_gradient(k), mx.stop_gradient(v) + def attach_restored_banks(cache, rk, rv, prompt_len: int) -> None: evicted = compute_evicted_positions( prompt_len, args.sink_size, args.window_size, @@ -512,7 +550,9 @@ def oracle_logits_all(prompt_ids, full_ids): out = mlx_model(mx.array([full_ids])); mx.eval(out); return out[0] label = "identity" if args.identity_restore else ( - "s5" if args.s5_exact_full_attn else "f_theta_all") + "s5_f_theta_full_attn" if ( + args.s5_exact_full_attn and args.s5_f_theta_restored_full_attn + ) else "s5" if args.s5_exact_full_attn else "f_theta_all") eval_mode = "teacher_forced" if args.teacher_forced else "free_gen" print(f"[mac] running restored cross-model verifier ({label}, {eval_mode})", file=sys.stderr, flush=True) @@ -593,6 +633,8 @@ def _tps(lats, toks): "eval_mode": eval_mode, "teacher_forced": bool(args.teacher_forced), "s5_exact_full_attn": bool(args.s5_exact_full_attn), + "s5_f_theta_restored_full_attn": bool( + args.s5_f_theta_restored_full_attn), "identity_restore": bool(args.identity_restore), "compress_full_attn": bool(args.compress_full_attn), "kl_lattice": args.kl_lattice if args.compress_full_attn else None,