diff --git a/.claude/AMX_GOTCHAS.md b/.claude/AMX_GOTCHAS.md index 22330c91..79cd0110 100644 --- a/.claude/AMX_GOTCHAS.md +++ b/.claude/AMX_GOTCHAS.md @@ -1,237 +1,249 @@ -# AMX Gotchas — Resolved on Stable Rust 1.94 +# AMX Gotchas — Troubleshooting Playbook (Stable Rust 1.94) -> Updated: 2026-04-03 -> CPU: Sapphire Rapids (AMX-TILE + AMX-INT8 + AMX-BF16 confirmed) -> Kernel: 6.18.5 (XCR0 bits 17+18 enabled) +> Updated: 2026-06-14 (corrected — the 2026-04-03 version shipped three of the +> bugs below: syscall 157, `TDPBUSD = …73…C1`, and the swapped TILECFG layout). +> Canonical reference: `.claude/knowledge/amx-enablement-and-kernel.md`. +> Owning agent: `.claude/agents/amx-savant.md`. +> Verified on: Emerald Rapids (CPUID model 0xCF), kernel 6.18.5. The fixes are +> ISA-level and apply equally to Sapphire Rapids (0x8F) and Granite Rapids. ---- +This file is the *how-to-debug* companion. Each gotcha lists its **fault +signature** so you can map a crash to a cause without a debugger. The +instruction-bisector that produced these is `examples/amx_probe.rs` — run it +FIRST (it prints a flushed line before each tile op, so the last line names the +faulting instruction, then it checks correctness across shapes). -## Status +--- -AMX works on **stable Rust 1.94** via `asm!()`. No nightly needed. +## Status (verified by actual execution, not by a skipped test) ``` -LDTILECFG: ✓ (load tile configuration) -TILEZERO: ✓ (zero a tile register) -TILERELEASE: ✓ (release tiles) -TDPBUSD: ✓ (u8×i8 tile dot product, 256 MACs/instruction) +LDTILECFG ✓ TILEZERO ✓ TILELOADD ✓ TILESTORED ✓ TILERELEASE ✓ +TDPBUSD ✓ (u8×i8 → i32, bit-exact vs scalar) +TDPBF16PS ✓ (bf16×bf16 → f32, within BF16 tolerance) +amx_available() = true on Emerald Rapids (cached LazyLock) +int8 2048³ = 169.7 GMAC/s, 600× scalar, single-thread ``` +> ⚠ The previous "✓" marks were never executed: every AMX test early-returns +> `if !amx_available() { return; }`, and detection always returned false +> (Gotcha 4). Treat any "tile asm tested" claim as UNVERIFIED until you confirm +> `amx_available()` was `true` when the test ran. See Gotcha 9. + --- -## Gotcha 1: Rust intrinsics are NIGHTLY ONLY +## Fault-signature → cause (the fast index) -```rust -// This DOES NOT compile on stable: -use std::arch::x86_64::_tile_loadconfig; // error: unstable feature x86_amx_intrinsics -``` +| You see… | Almost certainly… | Go to | +|---|---|---| +| `amx_available()==false` on a Xeon you *know* has AMX | arch_prctl syscall number | Gotcha 4 | +| SIGSEGV at the very first tile op (`LDTILECFG`) | TILECFG rows/colsb swapped, or not 64B-aligned | Gotcha 6, 2 | +| SIGSEGV at `TILELOADD`/`TILESTORED` | SIB base/index swapped (stride deref'd as ptr) | Gotcha 10 | +| SIGILL at `TDPBUSD`/`TDPBF16PS` | ModRM aliases two tile operands (same-tile #UD) | Gotcha 11 | +| runs fine, `correct=false` | operand index/sign convention mirrored | Gotcha 12 | +| compile error `unstable x86_amx_intrinsics` | used nightly intrinsics | Gotcha 1, 8 | +| compile error `rbx is used internally by LLVM` | inline-asm CPUID | Gotcha 3 | + +--- + +## Gotcha 1: Rust `_tile_*` intrinsics are NIGHTLY ONLY -**Fix**: Use `asm!()` (stable since Rust 1.59): ```rust -asm!("ldtilecfg [{}]", in(reg) config.data.as_ptr(), options(nostack)); +use std::arch::x86_64::_tile_loadconfig; // error: unstable feature x86_amx_intrinsics ``` - -Tracking issue: https://github.com/rust-lang/rust/issues/126622 +**Fix**: inline `asm!` (stable since 1.59). LDTILECFG works as a mnemonic; the +tile ops need raw `.byte` (Gotcha 5). Tracking: rust-lang/rust#126622. --- ## Gotcha 2: Tile config MUST be 64-byte aligned ```rust -// This SEGFAULTS: -let config = [0u8; 64]; // stack-allocated, no alignment guarantee - -// This WORKS: #[repr(C, align(64))] struct TileConfig { data: [u8; 64] } -let config = TileConfig { data: [0u8; 64] }; ``` +LDTILECFG reads 64 bytes; an unaligned pointer raises `#GP` → SIGSEGV. -LDTILECFG reads 64 bytes from the pointer. If not 64-byte aligned, -the CPU raises #GP (general protection fault) → SIGSEGV. +--- + +## Gotcha 3: `rbx` is LLVM-reserved — don't inline-asm CPUID + +Use `core::arch::x86_64::__cpuid_count(7, 0)` (stable, handles rbx). Inline +`asm!("cpuid", out("ebx") …)` fails to compile. --- -## Gotcha 3: rbx is LLVM-reserved +## Gotcha 4: enablement needs `arch_prctl` — syscall **158**, not `prctl` 157 ⚑ THE BIG ONE -```rust -// This DOES NOT compile: -asm!("cpuid", out("ebx") ebx, ...); // error: rbx is used internally by LLVM +AMX `XTILEDATA` is a *dynamically-enabled* XSTATE feature (Linux 5.16+). A +process must request permission before any tile op or the first one faults +(XFD `#NM`): -// This WORKS: -let result = core::arch::x86_64::__cpuid_count(7, 0); // stable, handles rbx internally +``` +arch_prctl(ARCH_REQ_XCOMP_PERM /*0x1023*/, XFEATURE_XTILEDATA /*18*/) ``` -For CPUID leaf 7 (AMX detection): use `__cpuid_count()`, not inline asm. +`ARCH_REQ_XCOMP_PERM` is an **arch_prctl** op → **syscall 158**. Issuing it on +**prctl (157)** returns `-EINVAL`, so detection's gate 4 always failed and +`amx_available()` returned `false` on EVERY AMX host. **This file's previous +version literally documented `SYS_prctl = 157`** — that is where the bug came +from. Always 158. ---- +Fault signature: `amx_available()==false` while `cpu_model().has_amx()==true`. -## Gotcha 4: OS must enable AMX via XSETBV + process must request permission +--- -AMX tiles are large (8 KB of state). Two levels of OS enablement required: +## Gotcha 5: tile ops need raw byte encoding (LDTILECFG is the exception) -1. **Kernel enables tile state in XCR0** (bits 17+18). Linux 5.19+ does this. -2. **Process requests XCOMP_PERM** via `prctl(ARCH_REQ_XCOMP_PERM, 18)`. - Without this, LDTILECFG will SIGILL even if XCR0 bits are set. +See the authoritative table in the knowledge doc. The correct sequences: -**Detection (stable)**: ```rust -// Step 1: CPUID — does CPU support AMX? -let cpuid = core::arch::x86_64::__cpuid_count(7, 0); -let amx_tile = (cpuid.edx >> 24) & 1; -let amx_int8 = (cpuid.edx >> 25) & 1; - -// Step 2: OSXSAVE — does OS support XSAVE? -let cpuid_01 = core::arch::x86_64::__cpuid(1); -let osxsave = (cpuid_01.ecx >> 27) & 1; - -// Step 3: _xgetbv(0) — did OS ACTUALLY enable tile state? -// ⚠ Do NOT use __cpuid_count(0xD, 0) — that reports what CPU SUPPORTS, -// not what the OS ENABLED. _xgetbv(0) reads the actual XCR0 register. -let xcr0: u64 = unsafe { core::arch::x86_64::_xgetbv(0) }; -let tilecfg = (xcr0 >> 17) & 1; // bit 17 = XTILECFG -let tiledata = (xcr0 >> 18) & 1; // bit 18 = XTILEDATA - -// Step 4: prctl — request tile permission for this process -// SYS_prctl = 157, ARCH_REQ_XCOMP_PERM = 0x1023, XFEATURE_XTILEDATA = 18 -// Returns 0 on success, -errno on failure. Idempotent. +// TILEZERO tmm0 / tmm1 / tmm2 / tmm3 +asm!(".byte 0xc4,0xe2,0x7b,0x49,0xc0", options(nostack,nomem)); // tmm0 +asm!(".byte 0xc4,0xe2,0x7b,0x49,0xc8", options(nostack,nomem)); // tmm1 +asm!(".byte 0xc4,0xe2,0x7b,0x49,0xd0", options(nostack,nomem)); // tmm2 +// TILERELEASE +asm!(".byte 0xc4,0xe2,0x78,0x49,0xc0", options(nostack,nomem)); +// TILELOADD tmmN,[rcx+rax] (SIB 0x01 = base=rcx,index=rax) +asm!(".byte 0xc4,0xe2,0x7b,0x4b,0x04,0x01", in("rcx") ptr, in("rax") stride, options(nostack)); // tmm0 +// TILESTORED [rcx+rax],tmm0 +asm!(".byte 0xc4,0xe2,0x7a,0x4b,0x04,0x01", in("rcx") ptr, in("rax") stride, options(nostack)); +// TDPBUSD tmm0,tmm1,tmm2 (u8 in rm/tmm2, i8 in vvvv/tmm1 — see Gotcha 12) +asm!(".byte 0xc4,0xe2,0x71,0x5e,0xc2", options(nostack,nomem)); +// TDPBF16PS tmm0,tmm1,tmm2 +asm!(".byte 0xc4,0xe2,0x72,0x5c,0xc2", options(nostack,nomem)); ``` -**Previous bug**: `__cpuid_count(0xD, 0)` reports XSAVE state component bitmap -(what the CPU *supports*), NOT the actual XCR0 value (what the OS *enabled*). -On hypervisors that advertise AMX in CPUID but don't enable tile state, -the old check returned `true` → SIGILL on LDTILECFG. +> ✗ The previous version listed `TDPBUSD … 0x73 … 0xc1` — `0x73` is TDPBSSD +> (wrong sign variant) and `0xc1` aliases tmm1 with itself (Gotcha 11). --- -## Gotcha 5: TILEZERO/TILERELEASE need manual byte encoding - -The Rust assembler on some toolchains doesn't know AMX mnemonics. -Use raw instruction bytes: +## Gotcha 6: TILECFG field layout — colsb and rows are NOT where you'd guess ⚑ -```rust -// TILEZERO tmm0 -asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem)); +Correct XTILECFG (Intel SDM): -// TILEZERO tmm1 -asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc8", options(nostack, nomem)); +``` +byte 0 palette (=1) +byte 1 start_row (=0) +bytes 2-15 reserved (0) +bytes 16-47 colsb[t] : 16 × u16 → colsb[t] at offset 16 + 2*t (≤ 64) +bytes 48-63 rows[t] : 16 × u8 → rows[t] at offset 48 + t (≤ 16) +``` -// TILEZERO tmm2 -asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd0", options(nostack, nomem)); +The previous version said "rows 16-23, colbytes 48-63" — **swapped**. With the +swap you get `colsb[0]=0x1010=4112` and `rows[0]=64`, both out of range, so +**LDTILECFG `#GP`-faults → SIGSEGV** the instant the AMX path runs. For the +16×16 int8/bf16 tile, every tile is 16 rows × 64 colbytes. -// TILEZERO tmm3 -asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xd8", options(nostack, nomem)); +Fault signature: SIGSEGV at the first `LDTILECFG`. -// TILERELEASE -asm!(".byte 0xc4, 0xe2, 0x78, 0x49, 0xc0", options(nostack, nomem)); +--- -// TDPBUSD tmm0, tmm1, tmm2 (C += A × B) -asm!(".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", options(nostack, nomem)); -``` +## Gotcha 7: TILEZERO/LDTILECFG with palette=0 SEGFAULTs -Note: LDTILECFG works as a mnemonic: -```rust -asm!("ldtilecfg [{}]", in(reg) ptr, options(nostack)); -``` +Always `cfg.data[0] = 1`. Start from a minimal valid tile (1 row × 4 colbytes: +`data[16]=4; data[48]=1`) to confirm the config path before scaling to 16×64. --- -## Gotcha 6: Tile config field layout is not obvious +## Gotcha 8: `is_x86_feature_detected!("amx-tile")` is NIGHTLY ONLY -The 64-byte tile config structure: -``` -Byte 0: palette (must be 1) -Bytes 1-15: reserved (zero) -Bytes 16-23: rows per tile (tile 0 at byte 16, tile 1 at byte 17, ...) -Bytes 24-47: reserved (zero) -Bytes 48-63: colbytes per tile (tile 0 at [48..49] as u16 LE, tile 1 at [50..51], ...) -``` +Use `__cpuid_count(7,0).edx` bits 24 (TILE) + 25 (INT8), then XGETBV(0) bits +17/18, then the arch_prctl (Gotcha 4). All stable. See `simd_amx::detect_amx`. -For TDPBUSD (u8×i8 → i32): -- Tile 0 (C result): rows=16, colbytes=64 (16 × i32 = 64 bytes per row) -- Tile 1 (A input): rows=16, colbytes=64 (16 × 64 u8) -- Tile 2 (B input): rows=16, colbytes=64 (transposed for column access) +--- + +## Gotcha 9: "tests pass" can mean "tests skipped" -**IMPORTANT**: colbytes is a u16 at byte offset 48+2*tile_id (little-endian). -For values ≤ 64, only the low byte matters. +Every AMX test guards with `if !amx_available() { return; }`. While detection +was broken (Gotcha 4), 100% of them early-returned green without running a +single tile instruction. **A skipped test is not a passing test.** Validate AMX +with `examples/amx_probe` (unconditional) on real AMX silicon, and require a +`correct=`/parity assertion, not just "didn't crash." --- -## Gotcha 7: TILEZERO with wrong config = SEGFAULT +## Gotcha 10: TILELOADD/TILESTORED SIB byte — base vs index -If you configure tile 0 as 16 rows × 64 colbytes but then TILEZERO tmm0, -it works. But if the config doesn't match what the hardware expects (e.g., -palette=0 or all zeros), TILEZERO will SEGFAULT. +`TILELOADD tmm,[rcx+rax]` with regs bound `in("rcx") ptr, in("rax") stride` +needs SIB `0x01` = (scale=1, index=rax, base=rcx). The previous code used SIB +`0x08` = (index=rcx, base=rax), i.e. base/index swapped, so the tile engine +used the **stride value (~64) as the start address** → SIGSEGV. For TILELOADD +the *base* register is the data pointer and the *index* register is the row +stride in bytes. -**Fix**: Always start with the minimal working config: -```rust -cfg.data[0] = 1; // palette 1 (MUST be 1, not 0) -cfg.data[16] = 1; // at least 1 row -cfg.data[48] = 4; // at least 4 colbytes (1 × i32) -``` +Fault signature: SIGSEGV at the first `TILELOADD`. + +--- -Then expand to full 16×64 after verifying the minimal config works. +## Gotcha 11: the three tile operands MUST be distinct registers + +`TDPBUSD`/`TDPBF16PS` raise `#UD` (→ SIGILL) if any two of (dst, src1, src2) +name the same tile. ModRM `0xC1` = rm=tmm1, and `VEX.vvvv` was also tmm1 → +src1==src2 → same-tile `#UD`. Use ModRM `0xC2` (dst=tmm0, vvvv=tmm1, rm=tmm2). + +Fault signature: SIGILL at the first `TDPBUSD`/`TDPBF16PS`, AFTER LDTILECFG and +the loads succeed. --- -## Gotcha 8: is_x86_feature_detected!("amx-tile") is NIGHTLY ONLY +## Gotcha 12: the operand index/sign convention is mirrored from the SDM ⚑ -```rust -// DOES NOT compile on stable: -is_x86_feature_detected!("amx-tile") // error: unstable x86_amx_intrinsics - -// WORKS on stable: -fn amx_available() -> bool { - let cpuid = core::arch::x86_64::__cpuid_count(7, 0); - let amx_tile = (cpuid.edx >> 24) & 1; - let amx_int8 = (cpuid.edx >> 25) & 1; - amx_tile == 1 && amx_int8 == 1 -} -``` +Measured on EMR (selector probe + 4-opcode sign sweep — see the knowledge doc): + +- `dst[m][n] = Σ_k tmm2(ModRM.rm)[m][k] · tmm1(VEX.vvvv)[k][n]` — plain **M×K** + goes in **tmm2/rm**, VNNI **K×N** goes in **tmm1/vvvv** (mirror of the naive + SDM operand order). +- For `TDPBUSD` (0x71): **rm = unsigned, vvvv = signed**. + +So the kernel loads `A(u8)→tmm2`, `B_vnni(i8)→tmm1`. Get this wrong and it +runs cleanly but every value is wrong (often a suspiciously *clean* wrong, like +`total/16` for constant inputs — that uniformity is the tell). Isolate it with +the selector probe (`A[0][s]=1` → `C[0][:]` should equal `B[s][:]`). -Use `__cpuid_count` (stable) for detection, not `is_x86_feature_detected!`. +Fault signature: no crash, `correct=false`. --- -## Hardware Tiers (this session) +## Gotcha 13: cache detection in a `LazyLock` (don't re-syscall per call) -``` -Tier Feature MACs/instr Detection (stable) CPU -──── ─────── ────────── ────────────────── ─── -3 AMX 256 __cpuid_count(7,0).edx bit 24 Sapphire Rapids+ -2 avx512vnni 64 is_x86_feature_detected! Cascade Lake+, Zen 4+ -1 avxvnniint8 32 is_x86_feature_detected! Arrow Lake (NUC 14) -0 scalar 1 always any +`amx_available()` runs CPUID + XGETBV + arch_prctl. Calling it per matmul is +wasteful (and the arch_prctl, though idempotent, is a syscall). Cache it: + +```rust +static AMX_AVAILABLE: std::sync::LazyLock = std::sync::LazyLock::new(detect_amx); +pub fn amx_available() -> bool { *AMX_AVAILABLE } ``` -Also detectable but not yet kernelized: -- `avxvnniint16`: i16×i16 dot product (VPDPWSSD) -- `amx-bf16`: TDPBF16PS (BF16 tile matmul, for calibration) +All four gates are non-blocking (no I/O, no lock, no spin) so the init can't +stall. The arch_prctl grant is process-wide + inherited by all threads, so +once is correct even under rayon. `cpu_model()` is cached the same way. --- -## Files +## Hardware tiers ``` -ndarray/src/simd_amx.rs — AMX detection + VNNI/VNNI2 kernels + quantize -ndarray/src/hpc/amx_matmul.rs — AMX tile ops via inline asm (TDPBUSD) -ndarray/crates/burn/src/ops/matmul.rs — 4-tier dispatch in distance table builder +Tier Feature MACs/instr Detect (stable) CPU +3 AMX-TILE 16384 __cpuid_count(7,0).edx bit24+25 SPR / EMR / GNR (NOT Sierra Forest) +2 avx512vnni 64 is_x86_feature_detected! Cascade Lake+, Zen 4+ +1 avxvnniint8 32 is_x86_feature_detected! Arrow / Meteor Lake +0 scalar 1 always any ``` +`cpu_model()` returns `SierraForest` for model 0xAF — E-core silicon with NO +AMX, so `has_amx()` is false there even though it's a recent Xeon. + --- -## What AMX Enables +## Files ``` -Distance table build (4096² = 16M dot products): - AMX: ~20 min (all models combined) - avx512vnni: ~1:20h - avxvnniint8: ~2:40h (NUC 14) - scalar: ~24-48h - -ThinkingEngine MatVec (per cycle): - AMX: ~44 μs (L1 table fits in 4 tile registers) - avx512vnni: ~175 μs - avxvnniint8: ~350 μs - scalar: ~5 ms +src/simd_amx.rs — detection (CPUID+XGETBV+arch_prctl), CpuModel, LazyLock +src/hpc/amx_matmul.rs — tile primitives + TileConfig + public matmul_{i8_to_i32,bf16_to_f32,f32} +src/hpc/int8_tile_gemm.rs — fast int8 driver (LDTILECFG hoisted) + 16×16 kernel +src/hpc/bf16_tile_gemm.rs — bf16 sibling +examples/amx_probe.rs — instruction bisector + correctness validator (run FIRST) +examples/amx_gemm_bench.rs — throughput + independent correctness check ``` diff --git a/.claude/agents/amx-savant.md b/.claude/agents/amx-savant.md new file mode 100644 index 00000000..4cb1edd3 --- /dev/null +++ b/.claude/agents/amx-savant.md @@ -0,0 +1,117 @@ +--- +name: amx-savant +description: > + Intel AMX (Advanced Matrix Extensions) tile-GEMM specialist for x86_64 Xeon + (Sapphire Rapids, Emerald Rapids, Granite Rapids). Owns enablement + (arch_prctl XTILEDATA permission), the inline-asm tile primitives + (LDTILECFG / TILELOADD / TDPBUSD / TDPBF16PS via raw byte-encodings on + stable Rust 1.94), the empirically-verified operand convention, CPU-model + detection, and the fault-signature troubleshooting method. Use for ANY work + on src/simd_amx.rs, src/hpc/amx_matmul.rs, src/hpc/{int8,bf16}_tile_gemm.rs, + AMX detection, "amx_available() is false", a SIGSEGV/SIGILL in a tile path, + a tile GEMM that returns wrong values, or AMX throughput optimization. +tools: Read, Glob, Grep, Bash, Edit, Write +model: opus +--- + +You are the AMX_SAVANT for Project NDARRAY Expansion. + +## Mandatory reads (load these BEFORE doing anything) + +1. `.claude/knowledge/amx-enablement-and-kernel.md` — canonical reference: + the enablement sequence, validated byte-codes, the operand convention, the + detection API, the performance story. **This is your source of truth.** +2. `.claude/AMX_GOTCHAS.md` — per-caveat troubleshooting playbook with a + fault-signature → cause index. + +If those two disagree with the code, the code + a fresh `examples/amx_probe` +run win — then you update the docs in the same change. + +## Environment + +- Rust 1.94 **stable** only. AMX `_tile_*` intrinsics + `is_x86_feature_detected! + ("amx-tile")` are NIGHTLY (rust-lang/rust#126622) — you use inline `asm!` + with raw `.byte` encodings. `LDTILECFG` is the one mnemonic the assembler + accepts. +- This host: Emerald Rapids (CPUID model 0xCF), kernel 6.18.5, AMX enabled. +- The fixes are ISA-level — identical on Sapphire Rapids (0x8F) and Granite + Rapids. Do NOT branch kernel correctness on CPU generation. + +## The Modus Operandi + +### A. How AMX gets enabled (4 gates, cached once in a LazyLock) + +1. CPUID.07H.0H:EDX bit 24 (AMX-TILE) + 25 (AMX-INT8) — silicon supports it. +2. CPUID.01H:ECX bit 27 (OSXSAVE) — OS turned on XSAVE. +3. XGETBV(0) bits 17 (TILECFG) + 18 (TILEDATA) — OS enabled tile XSTATE. + Read the *live* XCR0, never CPUID leaf 0xD (which reports capability, not + what a hypervisor actually enabled). +4. `arch_prctl(ARCH_REQ_XCOMP_PERM=0x1023, XFEATURE_XTILEDATA=18)` — + **syscall 158** (arch_prctl), NOT 157 (prctl). This is the dynamically- + enabled-feature permission request (Linux 5.16+). The 157↔158 mix-up is + why AMX was dark on every capable host. The grant is process-wide and + inherited by all threads → request once. + +`ndarray::simd::{amx_available, cpu_model, amx_report, CpuModel}` expose this. +`cpu_model().has_amx() && !amx_available()` ⇒ enablement problem, not silicon. + +### B. The operand convention (the alien magic — memorize it) + +`dst[m][n] = Σ_k tmm2(ModRM.rm)[m][k] · tmm1(VEX.vvvv)[k][n]` +- plain **M×K** operand → **tmm2 (rm)**; VNNI **K×N** operand → **tmm1 (vvvv)** + (mirror of the naive SDM operand order). +- `TDPBUSD` (0x71): rm = **unsigned**, vvvv = **signed**. +- The three tile operands (dst/src1/src2) MUST be distinct registers, or `#UD`. + +Validated encodings live in the knowledge doc's byte-code table. The correct +`TDPBUSD tmm0,tmm1,tmm2` is `C4 E2 71 5E C2` (NOT `…73…C1`). + +### C. The mindset: measure, don't trust the mnemonic or the doc + +- The SDM operand order is mirrored here; the prior gotchas doc shipped three + bugs. **You verify on silicon, not from a manual.** The 4-opcode sign sweep + + selector probe in `examples/amx_probe.rs` is how every claim was nailed. +- "Tests pass" behind `if !amx_available() { return; }` means "tests skipped." + Require an unconditional probe + a `correct=`/parity assertion. +- Correct first, fast second — and keep the `correct=` check while optimizing. + +## Troubleshooting: fault signature → cause + +Run `RUSTFLAGS="-C target-cpu=native" cargo run --release --example amx_probe` +FIRST. It prints a flushed line before each tile op (last line = faulting +instruction) and then checks correctness across shapes. Map the signature: + +| Signature | Cause | Fix | +|---|---|---| +| `amx_available()==false` on AMX Xeon | arch_prctl on syscall 157 | use 158 | +| SIGSEGV at `LDTILECFG` | TILECFG rows/colsb swapped (or not 64B-aligned) | colsb u16 @16+2t, rows u8 @48+t | +| SIGSEGV at `TILELOADD`/`TILESTORED` | SIB base/index swapped | SIB `0x01` (base=rcx, index=rax) | +| SIGILL at `TDPBUSD`/`TDPBF16PS` | ModRM aliases two tiles | ModRM `0xC2` | +| runs, `correct=false` (often a *clean* wrong) | operand index/sign mirrored | load M×K→tmm2, VNNI→tmm1; 0x71 | + +Each fix exposes the next signature (SIGSEGV→SIGSEGV→SIGILL→wrong→correct). + +## Performance levers (after correctness is locked) + +1. Hoist `LDTILECFG` (serializing) and the VNNI pack OUT of the tile loops — + once per GEMM, not once per 16×16 tile. (This was the 11.5× win: + 14.8 → 169.7 GMAC/s on EMR int8 2048³.) +2. `TILESTORED` straight into the strided C slot (row pitch n·4 bytes) — no + scratch + copy. +3. Next miles: 2×2 register blocking (4 C tiles amortize A/B loads); rayon over + row tiles. Always re-run `amx_probe` (correctness) + `amx_gemm_bench` + (throughput) after each. + +## Cargo hygiene + +Per `.claude/rules/agent-cargo-hygiene.md`: as an Opus agent you may run cargo +freely, but build in the SHARED `target/` — no per-agent worktree. Validate +with the two examples; the lib unit-test target is pre-broken (`src/tri.rs` +type-inference errors, unrelated to AMX), so the examples are the gate. + +## When you finish + +Update `.claude/knowledge/amx-enablement-and-kernel.md` and +`.claude/AMX_GOTCHAS.md` in the SAME change as any behavior shift, and prepend +an entry to `.claude/board/AGENT_LOG.md` (D-ids, commit, what ran, outcome). +Never let a doc claim a tile op "works" without an executed, asserted probe. diff --git a/.claude/blackboard.md b/.claude/blackboard.md index d86ec1fa..a6cee9c9 100644 --- a/.claude/blackboard.md +++ b/.claude/blackboard.md @@ -134,3 +134,50 @@ This is mostly Cargo.toml workspace wiring + API surface. [DECISION] Cypher executes locally via lance-graph semiring by default [DECISION] Remote DB connections (Neo4j, FalkorDB) via native Bolt client [DECISION] vis.js graph rendering served as static assets by the binary + +## Architecture Decisions + +### 2026-06-13 — GEMM-dispatch routing fixes (savant-architect) +Branch `claude/wonderful-hawking-lodtql`. Three public GEMM entry points +were not routing to the accelerated kernels. + +- **`backend::gemm_bf16` (src/backend/mod.rs)** — ALREADY FIXED in the + working tree this session. Now routes to + `hpc::amx_matmul::matmul_bf16_to_f32` (AMX `TDPBF16PS` → AVX-512 + `VDPBF16PS` → scalar). Slice→ArrayView2 wrapping mirrors the call shape + in `simd_runtime::matmul`; inputs sliced to exact `m*k`/`k*n`/`m*n`. + Bit-equivalent on non-AMX/non-AVX512BF16 hosts because the dispatcher's + scalar fallback is the same `quantized::bf16_gemm_f32(a,b,c,m,n,k,1.0,0.0)` + the old direct call used (alpha=1, beta=0 preserved). +- **`backend::gemm_i8` (src/backend/mod.rs)** — ALREADY FIXED in the + working tree this session. Routes to `simd_int_ops::gemm_u8_i8` + (4-tier: AMX `TDPBUSD` → VNNI-zmm → AVX-VNNI-ymm → scalar). + [DECISION] Deliberately NOT routed to `amx_matmul::matmul_i8_to_i32` as + the literal task text asked: `gemm_i8` is **u8×i8→i32**, but + `matmul_i8_to_i32` is **i8×i8→i32** and would reinterpret A-bytes ≥128 + as negative — NOT bit-equivalent. `gemm_u8_i8`'s scalar fallback is the + same `quantized::int8_gemm_i32` the old `vnni_gemm::int8_gemm_vnni` + used → bit-identical on scalar hosts; VNNI-zmm arm calls the same + `int8_gemm_vnni_avx512` kernel as before. All tiers integer-exact. +- **`native::gemv_f32` / `gemv_f64` (src/backend/native.rs)** — FIXED + THIS TURN (was calling `scalar::gemv_*` unconditionally). Now matches + on `tier()`: Scalar tier → unchanged `scalar::gemv_*` (byte-identical); + Avx2/Avx512 tiers → per-row `dot_f32`/`dot_f64` (the existing + dispatched, parity-tested SIMD dot). GEMV = stack of row dots; each A + row is row-major-contiguous so contiguous `dot_*` loads apply. Leading + `n` of each `lda`-wide row taken via `&a[i*lda..i*lda+n]`; no new bounds + requirement vs scalar ref. SIMD tiers carry the module's documented + 1-2 ULP reduce-order drift (within BLAS tol; `test_gemv_f32` uses 1e-5, + no byte-exact consumer asserts gemv). + +[UNSAFE-AUDIT] gemv fix added **zero** new `unsafe` — it reuses the +already-audited `dot_*` kernels. No new sentinel-qa surface from this turn. +The two mod.rs fixes contain `unsafe` repr(transparent) slice reinterprets +(BF16/u16) that were landed earlier this session and warrant the standard +sentinel-qa pass if not already covered. + +[LOOSE END] Repo references modules that exist on disk but the Glob/Grep +index was transiently stale this session (returned empty for +`simd_int_ops.rs`, `vnni_gemm.rs`, `bf16_gemm_f32`); Bash ground-truth +confirmed all present. Orchestrator should `cargo fmt`/`clippy`/`test` +centrally (edits were edit-only, no compile performed here). diff --git a/.claude/board/AGENT_LOG.md b/.claude/board/AGENT_LOG.md index 9a7c5ba2..3e7d282f 100644 --- a/.claude/board/AGENT_LOG.md +++ b/.claude/board/AGENT_LOG.md @@ -27,6 +27,35 @@ ## Entries (append below; newest first) +### 2026-06-14 — amx-savant (Opus, main thread) — AMX enabled + made bug-free + documented +- **Branch:** `claude/wonderful-hawking-lodtql`. **Commits:** e6bb26a (enablement + + 4 kernel bugs), 777eff7 (perf 11.5×), 9dd6519 (bf16 probe), + this doc/ + detection commit. +- **What ran:** `examples/amx_probe` (instruction bisector + correctness across + shapes — int8 bit-exact, bf16 rel-err ~0.004) and `examples/amx_gemm_bench` + (throughput + independent `correct=` check). Lib unit-test target is + pre-broken (`src/tri.rs` type-inference, unrelated), so examples are the gate. +- **Findings:** AMX was dark on EVERY capable host via a 1-digit bug — + `ARCH_REQ_XCOMP_PERM` issued on `prctl` (157) instead of `arch_prctl` (158). + Once enabled, 4 more ISA/encoding bugs surfaced (TILECFG rows/colsb swap; + TILELOADD SIB base/index swap; TDPBUSD ModRM same-tile #UD; mirrored + operand index+sign convention — verified by a 4-opcode sign sweep). SPR vs + EMR is NOT the cause: the bugs are ISA-level and were latent on Sapphire + Rapids too (the SPR-era `AMX_GOTCHAS.md` literally shipped 3 of them); they + never fired because detection never returned true. EMR was just the first + host to actually execute the tile path. +- **Added:** cached `LazyLock` detection + `CpuModel` (SPR/EMR/GNR/Sierra + Forest) in `src/simd_amx.rs`, re-exported via `ndarray::simd::{cpu_model, + CpuModel, amx_report}`; `examples/amx_probe.rs` (validator/bisector); + `.claude/knowledge/amx-enablement-and-kernel.md` (canonical ref); + `.claude/agents/amx-savant.md` (this agent); rewrote `.claude/AMX_GOTCHAS.md` + (corrected the 3 bugs it shipped, added the fault-signature playbook). +- **Outcome:** int8 GEMM 2048³ = 169.7 GMAC/s (339 GOP/s), 600× scalar; bf16 + path correct. `amx_report()` → "AMX [Emerald Rapids expects_amx=true]: + TILE=true INT8=true BF16=true available=true". +- **Loose ends:** further AMX perf (2×2 register blocking + rayon); blasgraph + Hamming dedup in lance-graph (blocked on missing `protoc`). + ## 2026-05-22T18:00 — PR-X12 cross-stack architecture session (opus 4.7) diff --git a/.claude/knowledge/amx-enablement-and-kernel.md b/.claude/knowledge/amx-enablement-and-kernel.md new file mode 100644 index 00000000..d2df5ec1 --- /dev/null +++ b/.claude/knowledge/amx-enablement-and-kernel.md @@ -0,0 +1,274 @@ +# AMX Enablement & Tile-Kernel Reference + +> READ BY: amx-savant, savant-architect, sentinel-qa, simd-savant +> Status: AMX ENABLED + bit-exact + fast on Emerald Rapids (2026-06-14). +> Supersedes the buggy claims in the original `.claude/AMX_GOTCHAS.md` +> (that doc has been corrected; this file is the canonical reference, the +> gotchas file is the troubleshooting playbook, `amx-savant` is the agent). + +This is the "teach to fish" file: not just *how* to turn AMX on, but *why* +it was off, how every caveat manifests, and how to troubleshoot each one. + +--- + +## TL;DR — current truth + +AMX (Intel Advanced Matrix Extensions) runs on **stable Rust 1.94** via inline +`asm!` byte-encodings (the Rust `_tile_*` intrinsics are nightly-only, issue +#126622). As of 2026-06-14 it is **enabled, bit-exact, and fast** here: + +| Surface | State | +|---|---| +| `ndarray::simd::amx_available()` | `true` on Emerald Rapids (cached `LazyLock`) | +| `ndarray::simd::cpu_model()` | `EmeraldRapids` (CPUID model 0xCF) | +| `matmul_i8_to_i32` (TDPBUSD) | bit-exact vs scalar, all shapes | +| `matmul_f32` / `matmul_bf16_to_f32` (TDPBF16PS) | within BF16 tol (rel-err ~0.004) | +| int8 GEMM 2048³, single-thread, `target-cpu=native` | **169.7 GMAC/s (339 GOP/s), 600× scalar** | + +Files: +- `src/simd_amx.rs` — detection (CPUID + XGETBV + arch_prctl), `CpuModel`, `LazyLock` cache. +- `src/hpc/amx_matmul.rs` — tile primitives (`tile_loadconfig`/`tile_load`/`tile_store`/`tile_dpbusd`/`tile_dpbf16ps`/`tile_zero`/`tile_release`/`vnni_pack_*`/`TileConfig`). +- `src/hpc/int8_tile_gemm.rs` — `int8_gemm_amx_tiled` (the fast driver) + `int8_tile_gemm_16x16`. +- `src/hpc/bf16_tile_gemm.rs` — bf16 sibling. +- `examples/amx_probe.rs` — the validator / instruction-bisector (run this FIRST when debugging). +- `examples/amx_gemm_bench.rs` — throughput. + +--- + +## 1. The one-line enablement bug — the "special way to enable it" + +Linux 5.16+ makes AMX `XTILEDATA` a **dynamically-enabled XSTATE feature**: +a process must *request permission* before any tile op, or the first tile +instruction faults (XFD `#NM`). The request is: + +``` +arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA) // 0x1023, 18 +``` + +`ARCH_REQ_XCOMP_PERM` (0x1023) is an **`arch_prctl`** op → **syscall 158**. +The code issued it on **`prctl` → syscall 157**, which rejects option 0x1023 +with `-EINVAL`. So gate 4 of detection *always failed* → `amx_available()` +returned `false` on **every AMX-capable host**, and the AMX path was dead code +that never ran. Fix: `157 → 158`. That single digit is the whole "AMX is +available 50-79% of the time but needs a special way to enable it." + +Why ~50-79%: Claude's fleet is heterogeneous — a container lands on AMX +silicon (SPR/EMR/GNR) only some of the time. On those hosts the gate-4 bug +made AMX look absent; on non-AMX hosts gate 1 (CPUID) correctly returns false. + +--- + +## 2. SPR vs EMR — *not* the cause + +The original `AMX_GOTCHAS.md` header says "CPU: Sapphire Rapids … AMX +confirmed." That confirmation was hollow: every AMX unit test early-returns +`if !amx_available() { return; }`, and `amx_available()` was always false +(the 157 bug), so **the tile asm had literally never executed on SPR either.** +The "✓ TDPBUSD works" checkmarks were CPUID detection + aspiration, not a run. + +Consequence: the five bugs below are **ISA / encoding bugs, identical on SPR, +EMR, and GNR**. EMR was simply the host where gate 4 got fixed first, so it was +the first host to actually *execute* the tile path and expose them. The fixes +are silicon-independent; they apply equally to Sapphire Rapids and Granite +Rapids. The operand convention (§4) is a property of the VEX encoding, not the +microarchitecture, so it holds across all AMX CPUs. + +`cpu_model()` exists so a run can *say* which silicon it landed on and tell +"no AMX silicon" apart from "AMX present but not OS-enabled" — but no code path +should branch kernel correctness on SPR-vs-EMR. They are the same ISA here. + +--- + +## 3. The five bugs, by fault signature (the troubleshooting spine) + +Each AMX bug has a *distinct* crash/■ signature. Memorize the mapping — +it's how you bisect without a debugger (the `amx_probe` example prints a +flushed line before each instruction so the LAST line names the fault): + +| # | Symptom | Root cause | Fix | +|---|---|---|---| +| 1 | `amx_available()==false` on AMX silicon | `arch_prctl` on syscall **157** not **158** → `-EINVAL` | use 158 | +| 2 | **SIGSEGV** at `LDTILECFG` (first tile op) | `TileConfig` rows/colsb regions **swapped** → colsb=4112, rows=64 → `#GP` | colsb u16 @16+2t, rows u8 @48+t | +| 3 | **SIGSEGV** at `TILELOADD`/`TILESTORED` | SIB `0x08` = `[base=rax,index=rcx]` but regs bound rcx=ptr,rax=stride → derefs stride(~64) as base | SIB `0x01` = `[base=rcx,index=rax]` | +| 4 | **SIGILL** at `TDPBUSD`/`TDPBF16PS` | ModRM `0xC1` ⇒ rm=tmm1 == vvvv=tmm1 → two sources alias → same-tile `#UD` | ModRM `0xC2` (rm=tmm2, distinct) | +| 5 | runs, **wrong values** (`correct=false`) | operand index+sign convention mirrored from naive SDM reading | load plain M×K→tmm2, VNNI K×N→tmm1; TDPBUSD `0x71` | + +The order matters: each fix exposes the next signature (SIGSEGV → SIGSEGV → +SIGILL → wrong-values → correct). If you fix #2 and still SIGSEGV, you're on +#3; if you clear both segfaults and hit SIGILL, you're on #4. + +--- + +## 4. The empirical operand convention (the "alien magic") + +**The AMX tile-op operand mapping on this silicon is the mirror of the naive +Intel-SDM reading, on BOTH axes.** Verified by driving the tile primitives with +a selector A (`A[0][s]=1`) and a marked B, then sweeping all four `TDPB**D` +opcodes against sign-sensitive constant inputs: + +- **INDEX**: `dst[m][n] = Σ_k tmm2(ModRM.rm)[m][k] · tmm1(VEX.vvvv)[k][n]`. + The **plain M×K** operand goes in **tmm2 (rm)**; the **VNNI-packed K×N** + operand goes in **tmm1 (vvvv)**. (Naive SDM order would say the opposite.) +- **SIGN** (for opcode `0x71`): **tmm2/rm = UNSIGNED, tmm1/vvvv = SIGNED.** + +Sign sweep (loads: B_vnni→tmm1, A→tmm2; A=200=u8 200/i8 -56, B=3 / A=3, B=200): + +| byte2 | mnemonic (pp) | A(rm) | B(vvvv) | +|---|---|---|---| +| 0x70 | TDPBUUD (NP) | unsigned | unsigned | +| **0x71** | **TDPBUSD (66)** | **unsigned** | **signed** ← `u8×i8` | +| 0x72 | TDPBSUD (F3) | signed | unsigned | +| 0x73 | TDPBSSD (F2) | signed | signed | + +`0x70` and `0x73` (both-same-sign) match the SDM directly, which confirms the +opcode→pp map is right and only the src1/src2 ↔ vvvv/rm *position* is mirrored. + +Kernel consequence: `int8_tile_gemm::amx_path` loads `A(u8)→tmm2`, `B_vnni(i8) +→tmm1`, executes `tile_dpbusd` (0x71). The `matmul_i8_to_i32` caller keeps its +`A+128→u8` shift and `−128·colsum(B)` bias unchanged — correct because A(rm) +is the unsigned operand. bf16 has no sign split, so the index swap alone fixes +TDPBF16PS. + +--- + +## 5. Validated byte-code table (authoritative — measured on EMR) + +> There is **no** W3C/intrinsics export with these AMX encodings — the +> `.claude` "w3c" files are semantic-web ontologies (SKOS/PROV-O/FIBO). The +> authority is the Intel SDM opcode map + the empirical sweep above. The +> following are confirmed correct in-tree: + +``` +LDTILECFG [mem] : "ldtilecfg [{}]" (mnemonic; assembler encodes it) +TILEZERO tmmN : C4 E2 7B 49 (C0 | N<<3) # tmm0=C0 tmm1=C8 tmm2=D0 tmm3=D8 +TILERELEASE : C4 E2 78 49 C0 +TILELOADD tmmN,[rcx+rax] : C4 E2 7B 4B (04 | N<<3) 01 # SIB 01 = base=rcx,index=rax,scale1 +TILESTORED [rcx+rax],tmm0 : C4 E2 7A 4B 04 01 +TDPBUSD tmm0,tmm1,tmm2 : C4 E2 71 5E C2 # pp=66 ; dst=tmm0,vvvv=tmm1,rm=tmm2 +TDPBF16PS tmm0,tmm1,tmm2 : C4 E2 72 5C C2 # pp=F3 opcode 5C +``` + +VEX byte2 = `W(1) . vvvv(4) . L(1) . pp(2)`; `vvvv` is the 1's-complement of +the register (tmm1 → 1110). ModRM `0xC2` = mod=11, reg=000(tmm0), rm=010(tmm2). +The two earlier WRONG encodings that shipped from the SPR-era gotchas doc were +`TDPBUSD = C4 E2 73 5E C1` (0x73=TDPBSSD wrong variant; C1 aliases tmm1). + +### TILECFG (XTILECFG) 64-byte layout — the corrected version + +``` +byte 0 : palette (MUST be 1) +byte 1 : start_row (0) +bytes 2-15 : reserved (0) +bytes 16-47 : colsb[t] — 16 × u16, colsb[t] @ (16 + 2*t) # bytes-per-row, ≤ 64 +bytes 48-63 : rows[t] — 16 × u8, rows[t] @ (48 + t) # rows, ≤ 16 +``` + +The SPR-era doc had rows and colsb **swapped** ("rows 16-23, colbytes 48-63"), +which is bug #2. For the 16×16 int8/bf16 tile, all three tiles are 16 rows × +64 colbytes. + +--- + +## 6. Detection API (cached, CPU-aware) + +```rust +use ndarray::simd::{amx_available, cpu_model, amx_report, CpuModel}; + +amx_available() // bool, cached once via LazyLock (the 4 gates of §1) +cpu_model() // CpuModel::{SapphireRapids,EmeraldRapids,GraniteRapids,SierraForest,OtherX86,NonX86} +cpu_model().has_amx() // true for SPR/EMR/GNR; false for Sierra Forest (E-core) +amx_report() // e.g. "AMX [Emerald Rapids expects_amx=true]: TILE=true INT8=true BF16=true available=true" +``` + +Why `LazyLock`: the four gates (CPUID, XGETBV, one `arch_prctl`) are all +non-blocking — no I/O, no lock contention, no spin — so the init cannot stall; +it runs once on first call and every later call is a cached load. The +`arch_prctl` grant is **process-wide and inherited by all threads**, so +requesting it exactly once is correct even under a rayon consumer. Diagnostic +value: `cpu_model().has_amx() == true && amx_available() == false` means the +silicon has AMX but the OS/hypervisor hasn't enabled it (XCR0 clear, or — until +the 157→158 fix — the permission request failed). That split is the single most +useful troubleshooting signal. + +--- + +## 7. Performance — what made it fast + +Correct ≠ fast. The first correct version was **14.8 GMAC/s** (~0.7% of peak) +because `int8_gemm_amx_tiled` called the 16×16 kernel per output tile, which ran +`LDTILECFG` (a **serializing** instruction) + `TILERELEASE` and re-VNNI-packed +B **on every tile** (256 `LDTILECFG`s for a 256² output). The fast driver: + +1. `LDTILECFG` **once** up front, `TILERELEASE` **once** at the end. +2. VNNI-pack each B column band **once per j-tile** (reused across all row tiles). +3. `TILEZERO` the C accumulator and `TILESTORED` the 16×16 result **straight + into its strided slot** in C (row pitch n·4 bytes) — no scratch + copy. + +Result: **14.8 → 169.7 GMAC/s** (11.5×), still correct=true. + +**Then 2×2 register blocking** (`int8_gemm_amx_tiled_rb` + `tile_dpbusd_2x2`): +4 C accumulators (tmm0-3) fed by 2 A tiles (tmm4-5) + 2 B tiles (tmm6-7), so +each A/B tile load serves TWO products — half the tile loads per MAC, the right +lever for this **memory-bandwidth-bound** kernel. The loop order matters: a +first cut pre-packed ALL of B (~4 MB at 2048²) and thrashed cache, *regressing* +large shapes (1024³ 156→125). The BLIS-style fix — OUTER over 32-col panels, +pack only that panel's 2 B bands (L2-resident) and reuse across all row-blocks, +INNER over 32-row blocks — also halves A's DRAM re-reads (32-col vs 16-col +panels). Single-thread result, all correct=true: + +``` + serial rb(2×2) +256³ 65.7 80.8 (+23%) +512³ 124.9 132.0 (+6%) +1024³ 155.9 170.2 (+9%) +2048³ 169.7 197.7 (+16%) ← 395 GOP/s +``` + +**Rayon over row-tiles** (`int8_gemm_amx_tiled_par`, `feature = "rayon"`): this +kernel is bandwidth-bound, so 4-core scaling is sublinear — 2048³ → 237.5 GMAC/s +(~1.4×) — and it REGRESSES small/medium (thread + B-prepack overhead), so it's +gated to `m·n·k ≥ 2e9`. Many-core servers gain more. + +Dispatch (in `int8_gemm_amx_tiled`): huge + rayon → `_par` (16×16, shared +pre-packed B); else m,n≥32 → `_rb` (2×2); else `_serial` (16×16); m or n < 32 +strips fall to the 16×16 path inside `_rb`. + +**Remaining headroom (with a caution).** "rayon-over-rb" — fanning the rb +row-panels across the pool instead of the 16×16 kernel — is the obvious combine, +BUT a first attempt (each rayon task calls `_rb` on a 64-row band) was REVERTED: +it ran SLOWER than rb-single (155 vs 197 GMAC/s at 2048³ — each task re-VNNI-packs +B, an O(K·N) duplicate ×num_tasks) AND `correct=false` appeared at 1024³/2048³ +while 256³/512³ stayed correct (an AMX-tiles-under-rayon-at-scale issue not yet +diagnosed — single-thread `_rb` is bit-exact at every size, so it's specific to +the threaded 8-tile path). Do NOT reship that shape without (a) a SHARED pre-pack +of B (as `_par` 16×16 already does) and (b) a probe that reproduces the +large-size correctness failure under rayon and explains it. The safe wins are +banked: rb-single (197) is the default, 16×16-rayon (237) the huge case. The +bigger lever is full BLIS Mc/Nc/Kc cache blocking. + +--- + +## 8. Modus operandi when AMX "doesn't work" + +1. `amx_report()` first. `expects_amx` false → not AMX silicon (or non-Intel / + masked); stop, use the VNNI/scalar fallback. `expects_amx` true but + `available=false` → enablement, not silicon: check kernel ≥5.16, XCR0 bits + 17/18 (`XGETBV(0)`), and that gate 4 uses syscall **158**. +2. If `available=true` but a GEMM crashes/misbehaves, run `examples/amx_probe` + (it bisects instruction-by-instruction and then checks correctness across + shapes). Match the fault to the §3 table. +3. Never trust "tile asm tested" claims that sit behind an + `if !amx_available() { return; }` guard — confirm the guard was *true* when + the test ran (i.e. on real AMX silicon with detection fixed). +4. Validate with `amx_probe` (correctness) **and** `amx_gemm_bench` + (throughput + an independent `correct=` check) before believing numbers. + +--- + +## References + +- `.claude/AMX_GOTCHAS.md` — the per-caveat troubleshooting playbook. +- `.claude/agents/amx-savant.md` — the agent that owns this surface. +- `.claude/knowledge/hardware_map.md`, `agnostic-surface-cpu-matrix.md` — CPU tiers. +- Intel SDM Vol 2 (LDTILECFG / TILELOADD / TDPBUSD / TDPBF16PS), Vol 1 §13 (XSAVE/XFD). +- Linux `Documentation/arch/x86/xstate.rst` (ARCH_REQ_XCOMP_PERM, dynamic XSTATE). diff --git a/Cargo.toml b/Cargo.toml index ee7cc8f7..c3995629 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,33 @@ required-features = ["std"] name = "splat3d_flex" required-features = ["splat3d"] +# AMX examples import `ndarray::simd` / `ndarray::hpc`, both `#[cfg(feature = +# "std")]`, so they must be skipped in `--no-default-features` CI jobs. +[[example]] +name = "amx_gemm_bench" +required-features = ["std"] + +[[example]] +name = "amx_probe" +required-features = ["std"] + +[[example]] +name = "amx_rb_probe" +required-features = ["std"] + +# Morton cascade probe imports `ndarray::simd` (std-gated). +[[example]] +name = "morton_cascade_probe" +required-features = ["std"] + +[[example]] +name = "golden_helix_probe" +required-features = ["std"] + +[[example]] +name = "edge_residue_probe" +required-features = ["std"] + [dependencies] num-integer = { workspace = true } num-traits = { workspace = true } diff --git a/examples/amx_gemm_bench.rs b/examples/amx_gemm_bench.rs new file mode 100644 index 00000000..fcf735a4 --- /dev/null +++ b/examples/amx_gemm_bench.rs @@ -0,0 +1,71 @@ +//! AMX / VNNI int8-GEMM throughput probe. +//! +//! Measures the dispatched `matmul_i8_to_i32` (AMX `TDPBUSD` tile → AVX-512 +//! `VPDPBUSD` zmm → AVX-VNNI ymm → scalar, chosen at runtime) against a naive +//! scalar i8×i8→i32 reference, across square sizes. Reports ns, GMAC/s, GOP/s. +//! +//! RUSTFLAGS="-C target-cpu=native" cargo run --release --example amx_gemm_bench +//! +//! `target-cpu=native` (or `x86-64-v4` + an AMX host) is what enables the AMX +//! tile path; `amx_available()` reports whether THIS silicon exposes it. + +use std::time::Instant; + +use ndarray::simd::{amx_available, matmul_i8_to_i32}; +use ndarray::{ArrayView2, ArrayViewMut2}; + +fn scalar_i8_gemm(a: &[i8], b: &[i8], c: &mut [i32], m: usize, k: usize, n: usize) { + for i in 0..m { + for j in 0..n { + let mut acc = 0i32; + for p in 0..k { + acc += a[i * k + p] as i32 * b[p * n + j] as i32; + } + c[i * n + j] = acc; + } + } +} + +fn bench(m: usize, k: usize, n: usize, reps: usize, scalar_reps: usize) { + let a: Vec = (0..m * k).map(|i| ((i % 7) as i8) - 3).collect(); + let b: Vec = (0..k * n).map(|i| ((i % 5) as i8) - 2).collect(); + let mut c = vec![0i32; m * n]; + let av = ArrayView2::from_shape((m, k), &a[..]).unwrap(); + let bv = ArrayView2::from_shape((k, n), &b[..]).unwrap(); + + // warmup + dispatched timing + matmul_i8_to_i32(av, bv, ArrayViewMut2::from_shape((m, n), &mut c[..]).unwrap()).unwrap(); + let dispatched = c.clone(); + let t = Instant::now(); + for _ in 0..reps { + matmul_i8_to_i32(av, bv, ArrayViewMut2::from_shape((m, n), &mut c[..]).unwrap()).unwrap(); + } + let ns = t.elapsed().as_nanos() as f64 / reps as f64; + + // scalar reference (fewer reps; it's slow) + correctness check + let mut cs = vec![0i32; m * n]; + let t = Instant::now(); + for _ in 0..scalar_reps { + scalar_i8_gemm(&a, &b, &mut cs, m, k, n); + } + let sns = t.elapsed().as_nanos() as f64 / scalar_reps as f64; + let ok = dispatched == cs; + + let macs = (m * n * k) as f64; + println!( + " {m:>4}x{k:>4}x{n:>4}: dispatched {ns:>11.0} ns {:>7.1} GMAC/s ({:>7.1} GOP/s) scalar {sns:>13.0} ns speedup {:>6.1}x correct={ok}", + macs / ns, + 2.0 * macs / ns, + sns / ns, + ); +} + +fn main() { + println!("== int8 GEMM (matmul_i8_to_i32: AMX TDPBUSD -> VPDPBUSD-zmm -> AVX-VNNI -> scalar) =="); + println!("amx_available() on this host: {}", amx_available()); + println!(); + bench(256, 256, 256, 300, 5); + bench(512, 512, 512, 80, 2); + bench(1024, 1024, 1024, 20, 1); + bench(2048, 2048, 2048, 6, 1); +} diff --git a/examples/amx_probe.rs b/examples/amx_probe.rs new file mode 100644 index 00000000..be351eb5 --- /dev/null +++ b/examples/amx_probe.rs @@ -0,0 +1,147 @@ +//! AMX correctness validation — int8_tile_gemm_16x16 (raw u8×i8) and the full +//! matmul_i8_to_i32 (i8×i8 with the +128/bias trick) vs the scalar reference, +//! across single/multi K-block and single/multi-tile shapes. +//! +//! RUSTFLAGS="-C target-cpu=native" cargo run --release --example amx_probe + +use ndarray::hpc::amx_matmul::matmul_f32; +use ndarray::hpc::int8_tile_gemm::int8_tile_gemm_16x16; +use ndarray::simd::{amx_available, amx_report, cpu_model, matmul_i8_to_i32}; +use ndarray::{ArrayView2, ArrayViewMut2}; + +fn ref_u8_i8_16(a: &[u8], b: &[i8], k: usize) -> Vec { + let mut c = vec![0i32; 256]; + for i in 0..16 { + for j in 0..16 { + let mut s = 0i32; + for kk in 0..k { + s += a[i * k + kk] as i32 * b[kk * 16 + j] as i32; + } + c[i * 16 + j] = s; + } + } + c +} + +fn ref_i8_i8(a: &[i8], b: &[i8], m: usize, n: usize, k: usize) -> Vec { + let mut c = vec![0i32; m * n]; + for i in 0..m { + for kk in 0..k { + let av = a[i * k + kk] as i32; + for j in 0..n { + c[i * n + j] += av * b[kk * n + j] as i32; + } + } + } + c +} + +fn first_mismatch(got: &[i32], exp: &[i32]) -> Option<(usize, i32, i32)> { + got.iter() + .zip(exp) + .enumerate() + .find(|(_, (g, e))| g != e) + .map(|(i, (g, e))| (i, *g, *e)) +} + +fn test_tile_16(k: usize) { + let a: Vec = (0..16 * k).map(|i| ((i * 31 + 7) % 256) as u8).collect(); + let b: Vec = (0..k * 16) + .map(|i| ((i * 17 + 3) % 256) as u8 as i8) + .collect(); + let exp = ref_u8_i8_16(&a, &b, k); + let mut got = vec![0i32; 256]; + int8_tile_gemm_16x16(&a, &b, &mut got, k); + match first_mismatch(&got, &exp) { + None => println!(" int8_tile_gemm_16x16 K={k:<4} CORRECT"), + Some((i, g, e)) => println!(" int8_tile_gemm_16x16 K={k:<4} WRONG first@{i}: got {g} exp {e}"), + } +} + +fn test_matmul(m: usize, n: usize, k: usize) { + let a: Vec = (0..m * k) + .map(|i| ((i * 31 + 7) % 256) as u8 as i8) + .collect(); + let b: Vec = (0..k * n) + .map(|i| ((i * 17 + 3) % 256) as u8 as i8) + .collect(); + let exp = ref_i8_i8(&a, &b, m, n, k); + let mut got = vec![0i32; m * n]; + matmul_i8_to_i32( + ArrayView2::from_shape((m, k), &a[..]).unwrap(), + ArrayView2::from_shape((k, n), &b[..]).unwrap(), + ArrayViewMut2::from_shape((m, n), &mut got[..]).unwrap(), + ) + .unwrap(); + match first_mismatch(&got, &exp) { + None => println!(" matmul_i8_to_i32 {m:>4}x{k:>4}x{n:>4} CORRECT"), + Some((i, g, e)) => println!(" matmul_i8_to_i32 {m:>4}x{k:>4}x{n:>4} WRONG first@{i}: got {g} exp {e}"), + } +} + +/// matmul_f32 routes through the BF16 TDPBF16PS tile path on AMX hosts, so we +/// check against an f32 scalar reference with a relative tolerance (BF16 has +/// ~8 mantissa bits → a few ×1e-2 relative error on accumulated sums is fine). +fn test_matmul_f32(m: usize, n: usize, k: usize) { + let a: Vec = (0..m * k).map(|i| ((i % 13) as f32 - 6.0) * 0.1).collect(); + let b: Vec = (0..k * n).map(|i| ((i % 7) as f32 - 3.0) * 0.2).collect(); + let mut exp = vec![0.0f32; m * n]; + for i in 0..m { + for kk in 0..k { + let av = a[i * k + kk]; + for j in 0..n { + exp[i * n + j] += av * b[kk * n + j]; + } + } + } + let mut got = vec![0.0f32; m * n]; + matmul_f32( + ArrayView2::from_shape((m, k), &a[..]).unwrap(), + ArrayView2::from_shape((k, n), &b[..]).unwrap(), + ArrayViewMut2::from_shape((m, n), &mut got[..]).unwrap(), + ) + .unwrap(); + // True relative metric: L2 relative error ‖got−exp‖ / ‖exp‖ (robust to + // small individual outputs — the previous `|e|.max(1.0)` denominator turned + // every |e|<1 cell into an absolute-error test) plus the max absolute error. + let mut sq_err = 0.0f64; + let mut sq_ref = 0.0f64; + let mut max_abs = 0.0f32; + for (g, e) in got.iter().zip(&exp) { + let d = g - e; + sq_err += (d as f64) * (d as f64); + sq_ref += (*e as f64) * (*e as f64); + max_abs = max_abs.max(d.abs()); + } + let rel_l2 = (sq_err.sqrt() / sq_ref.sqrt().max(1e-12)) as f32; + let verdict = if rel_l2 < 0.02 { "CORRECT" } else { "WRONG " }; + println!(" matmul_f32 {m:>4}x{k:>4}x{n:>4} {verdict} rel-L2 {rel_l2:.4} max-abs {max_abs:.4}"); +} + +fn main() { + println!("{}", amx_report()); + println!("cpu_model() = {:?} has_amx() = {}", cpu_model(), cpu_model().has_amx()); + println!("amx_available() = {}\n", amx_available()); + + println!("== matmul_f32 (BF16 TDPBF16PS tile path, ~1% BF16 tolerance) =="); + test_matmul_f32(16, 16, 32); + test_matmul_f32(32, 32, 64); + test_matmul_f32(64, 48, 128); + test_matmul_f32(128, 128, 256); + + println!("\n== int8_tile_gemm_16x16 (raw u8×i8 tile kernel) =="); + test_tile_16(64); + test_tile_16(128); + test_tile_16(256); + + println!("\n== matmul_i8_to_i32 (signed i8×i8 with +128/bias + multi-tile) =="); + test_matmul(16, 16, 64); + test_matmul(16, 16, 128); + test_matmul(32, 16, 64); + test_matmul(16, 32, 64); + test_matmul(32, 32, 128); // exactly one 2×2 register block + test_matmul(64, 48, 192); // full blocks + right strip (n ≡ 16 mod 32) + test_matmul(48, 48, 128); // full block + BOTH strips + corner (m,n ≡ 16 mod 32) + test_matmul(96, 80, 128); // 3×2 full blocks + both strips + test_matmul(256, 256, 256); +} diff --git a/examples/amx_rb_probe.rs b/examples/amx_rb_probe.rs new file mode 100644 index 00000000..a058de4f --- /dev/null +++ b/examples/amx_rb_probe.rs @@ -0,0 +1,98 @@ +//! AMX 2×2 register-block encoding validator. Exercises the NEW tile encodings +//! (tmm3-7 loads, tmm1-3 stores, `for_dpbusd_8`, `tile_dpbusd_2x2`) by computing +//! one 32×32 output block via the register-blocked sequence and comparing to a +//! scalar u8×i8 reference. Front-loads the encoding risk before any driver. +//! +//! RUSTFLAGS="-C target-cpu=native" cargo run --release --example amx_rb_probe + +use ndarray::hpc::amx_matmul::{ + tile_dpbusd_2x2, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_i8, TileConfig, +}; +use ndarray::simd::amx_available; + +/// Scalar 32×32 u8×i8 → i32 reference: C[i][j] = Σ_k A[i][k]·B[k][j]. +fn ref_32(a: &[u8], b: &[i8], k: usize) -> Vec { + let mut c = vec![0i32; 32 * 32]; + for i in 0..32 { + for j in 0..32 { + let mut s = 0i32; + for kk in 0..k { + s += a[i * k + kk] as i32 * b[kk * 32 + j] as i32; + } + c[i * 32 + j] = s; + } + } + c +} + +/// 32×32 = A(32×k u8) · B(k×32 i8) via the 2×2 register-blocked AMX kernel. +fn rb_32(a: &[u8], b: &[i8], k: usize) -> Vec { + assert_eq!(k % 64, 0, "rb_32: K must be a multiple of 64 (TDPBUSD tile depth)"); + // Pack the two 16-wide B column bands into VNNI quads. + let mut b0 = vec![0i8; k * 16]; + let mut b1 = vec![0i8; k * 16]; + { + let mut band = vec![0i8; k * 16]; + for kk in 0..k { + band[kk * 16..(kk + 1) * 16].copy_from_slice(&b[kk * 32..kk * 32 + 16]); + } + vnni_pack_i8(&band, &mut b0, k, 16); + for kk in 0..k { + band[kk * 16..(kk + 1) * 16].copy_from_slice(&b[kk * 32 + 16..kk * 32 + 32]); + } + vnni_pack_i8(&band, &mut b1, k, 16); + } + let mut c = vec![0i32; 32 * 32]; + let k_blocks = k / 64; + // SAFETY: amx_available() checked by caller; 8 tiles configured; every + // load/store is a 16×64 tile within the a/b0/b1/c allocations; stores use + // the 32-wide C row pitch (128 bytes) into the four quadrants. + unsafe { + tile_loadconfig(&TileConfig::for_dpbusd_8()); + tile_zero(0); + tile_zero(1); + tile_zero(2); + tile_zero(3); + for kb in 0..k_blocks { + tile_load(4, a.as_ptr().add(kb * 64), k); // A0 = rows 0..16 + tile_load(5, a.as_ptr().add(16 * k + kb * 64), k); // A1 = rows 16..32 + tile_load(6, b0.as_ptr().add(kb * 16 * 64) as *const u8, 64); // B0 vnni + tile_load(7, b1.as_ptr().add(kb * 16 * 64) as *const u8, 64); // B1 vnni + tile_dpbusd_2x2(); + } + let cp = c.as_mut_ptr(); + let pitch = 32 * 4; // bytes per C row (32 i32) + tile_store(0, cp as *mut u8, pitch); // C[0..16][0..16] + tile_store(1, cp.add(16) as *mut u8, pitch); // C[0..16][16..32] + tile_store(2, cp.add(16 * 32) as *mut u8, pitch); // C[16..32][0..16] + tile_store(3, cp.add(16 * 32 + 16) as *mut u8, pitch); // C[16..32][16..32] + tile_release(); + } + c +} + +fn check(k: usize) { + let a: Vec = (0..32 * k).map(|i| ((i * 31 + 7) % 256) as u8).collect(); + let b: Vec = (0..k * 32) + .map(|i| ((i * 17 + 3) % 256) as u8 as i8) + .collect(); + let exp = ref_32(&a, &b, k); + let got = rb_32(&a, &b, k); + match got.iter().zip(&exp).enumerate().find(|(_, (g, e))| g != e) { + None => println!(" 2x2 register block K={k:<4} CORRECT"), + Some((i, (g, e))) => { + println!(" 2x2 register block K={k:<4} WRONG first@(row {},col {}): got {g} exp {e}", i / 32, i % 32) + } + } +} + +fn main() { + println!("amx_available() = {}\n", amx_available()); + if !amx_available() { + return; + } + println!("== 2×2 register-blocked TDPBUSD (new encodings: tmm3-7, for_dpbusd_8) =="); + check(64); // single K-block + check(128); // two K-blocks (accumulate across blocks in 4 C tiles) + check(256); // four K-blocks +} diff --git a/examples/edge_residue_probe.rs b/examples/edge_residue_probe.rs new file mode 100644 index 00000000..27001c68 --- /dev/null +++ b/examples/edge_residue_probe.rs @@ -0,0 +1,157 @@ +//! AMX edge-residue probe — "stream pairwise into AMX-turbovec as a cheap edge +//! residue that complements the 16×8bit=128bit coarse code." +//! +//! The holy-grail pipeline, measured end-to-end (anti-theater): +//! +//! vectors (GGUF stand-in) +//! → AMX int8 GEMM assigns each to a 256-entry palette (1 byte COARSE) +//! → 4-bit TurboQuant of the residue (D/2 bytes FINE) +//! → reconstruct, measure error: coarse-only vs coarse+residue +//! +//! The AMX `matmul_i8_to_i32` (this session's 197 GMAC/s kernel) does the +//! pairwise vector×codebook inner products — the "stream pairwise into AMX" +//! step. The 1-byte palette index is the per-edge coarse code (16 edges = +//! 16 bytes = the node's 128-bit EdgeBlock); the 4-bit residue is the +//! turbovec/turboquant complement. +//! +//! PASS: coarse+residue reconstruction error ≪ coarse-only, at a small, +//! fixed extra byte cost, with the assignment running on the AMX tile path. +//! +//! RUSTFLAGS="-C target-cpu=native" cargo run --release --example edge_residue_probe + +use std::time::Instant; + +use ndarray::simd::{amx_available, matmul_i8_to_i32}; +use ndarray::{ArrayView2, ArrayViewMut2}; + +fn splitmix(s: &mut u64) -> f32 { + *s = s.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = *s; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^= z >> 31; + ((z >> 40) as f32) / (1u32 << 24) as f32 * 2.0 - 1.0 // [-1, 1) +} + +/// Symmetric per-tensor i8 quantization: scale so max|x| → 127. +fn quantize_i8(x: &[f32]) -> (Vec, f32) { + let amax = x.iter().fold(0.0f32, |a, &v| a.max(v.abs())).max(1e-12); + let scale = 127.0 / amax; + ( + x.iter() + .map(|&v| (v * scale).round().clamp(-127.0, 127.0) as i8) + .collect(), + scale, + ) +} + +fn l2(a: &[f32], b: &[f32]) -> f32 { + a.iter() + .zip(b) + .map(|(x, y)| (x - y) * (x - y)) + .sum::() + .sqrt() +} + +fn run(n: usize, d: usize, k: usize, noise: f32) { + let mut s = 0x1234_5678 ^ (d as u64); + // Codebook: K centroids, D dims, in [-1,1]. + let cb: Vec = (0..k * d).map(|_| splitmix(&mut s)).collect(); + // Vectors: each = a random centroid + noise (so palette assignment is real + // and the residue is the recoverable structure). + let mut vecs = vec![0.0f32; n * d]; + let mut truth_idx = vec![0u32; n]; + for i in 0..n { + let c = (splitmix(&mut s).abs() * k as f32) as usize % k; + truth_idx[i] = c as u32; + for j in 0..d { + vecs[i * d + j] = cb[c * d + j] + noise * splitmix(&mut s); + } + } + + // ── AMX pairwise: G[i][j] = via int8 GEMM (V[N×D] · Cᵀ[D×K]) ── + let (v_i8, _) = quantize_i8(&vecs); + let (cb_i8, _) = quantize_i8(&cb); + // Transpose codebook to D×K. + let mut cbt_i8 = vec![0i8; d * k]; + for c in 0..k { + for j in 0..d { + cbt_i8[j * k + c] = cb_i8[c * d + j]; + } + } + let mut g = vec![0i32; n * k]; + let av = ArrayView2::from_shape((n, d), &v_i8[..]).unwrap(); + let bv = ArrayView2::from_shape((d, k), &cbt_i8[..]).unwrap(); + let t0 = Instant::now(); + matmul_i8_to_i32(av, bv, ArrayViewMut2::from_shape((n, k), &mut g[..]).unwrap()).unwrap(); + let amx_ns = t0.elapsed().as_nanos() as f64; + + // ||c_j||² in the i8 domain (same scale as v) for the argmin. + let cnorm: Vec = (0..k) + .map(|c| (0..d).map(|j| (cb_i8[c * d + j] as i32).pow(2)).sum()) + .collect(); + // idx[i] = argmax_j (2·G[i][j] − ||c_j||²) ≡ argmin_j ||v_i − c_j||². + let mut idx = vec![0u32; n]; + for i in 0..n { + let mut best = i32::MIN; + let mut bj = 0u32; + for j in 0..k { + let score = 2 * g[i * k + j] - cnorm[j]; + if score > best { + best = score; + bj = j as u32; + } + } + idx[i] = bj; + } + let assign_acc = idx.iter().zip(&truth_idx).filter(|(a, b)| a == b).count() as f64 / n as f64; + + // ── Coarse recon (palette only) + 4-bit TurboQuant residue (the complement) ── + let mut coarse_err = 0.0f64; + let mut fine_err = 0.0f64; + let mut vnorm_sum = 0.0f64; + for i in 0..n { + let v = &vecs[i * d..i * d + d]; + let c = &cb[idx[i] as usize * d..idx[i] as usize * d + d]; + // residue + 4-bit signed (−8..7) per-vector uniform quant → D/2 bytes. + let res: Vec = v.iter().zip(c).map(|(a, b)| a - b).collect(); + let rmax = res.iter().fold(0.0f32, |a, &x| a.max(x.abs())).max(1e-12); + let rs = rmax / 7.0; + let fine: Vec = res + .iter() + .zip(c) + .map(|(&r, &cc)| cc + ((r / rs).round().clamp(-8.0, 7.0)) * rs) + .collect(); + let vn = v.iter().map(|x| x * x).sum::().sqrt() as f64; + coarse_err += l2(v, c) as f64; + fine_err += l2(v, &fine) as f64; + vnorm_sum += vn; + } + let coarse_rel = coarse_err / vnorm_sum; + let fine_rel = fine_err / vnorm_sum; + + let macs = (n * k * d) as f64; + println!( + " N={n} D={d} K={k} noise={noise:.2}:\n\ + \x20 AMX assign: {:.0} ns ({:.1} GMAC/s), accuracy {:.1}%\n\ + \x20 coarse(1B/edge): rel-err {:.4} → +turbovec 4-bit ({} B): rel-err {:.4} ({:.1}× better, +{} B/vec)", + amx_ns, + macs / amx_ns, + 100.0 * assign_acc, + coarse_rel, + d / 2, + fine_rel, + coarse_rel / fine_rel.max(1e-12), + d / 2, + ); +} + +fn main() { + println!("== AMX edge-residue probe (palette assign on AMX + turbovec 4-bit residue) =="); + println!("amx_available() = {}\n", amx_available()); + // D multiple of 64, K & N multiples of 16 → AMX tile path. + for noise in [0.15f32, 0.30] { + run(4096, 128, 256, noise); + } + run(8192, 64, 256, 0.20); +} diff --git a/examples/golden_helix_probe.rs b/examples/golden_helix_probe.rs new file mode 100644 index 00000000..733ada89 --- /dev/null +++ b/examples/golden_helix_probe.rs @@ -0,0 +1,145 @@ +//! Golden-helix anti-theater probe — does the irrational (golden-angle) sampling +//! and Fisher-z percentile rank earn their keep, or is the 2×2/4×4 perturbation +//! just "eigenvalue theater"? +//! +//! Two load-bearing claims from the architecture, each measured against a null: +//! +//! 1. COLLAPSE-AVOIDANCE (the Fujifilm-X-Trans / golden-ratio point). +//! The helix places nodes on a hemisphere via the golden angle +//! `γ = π(3−√5)`, `θ = ½·arccos(1 − 2(n+0.5)/N)`, `φ = n·γ`. An irrational +//! stride is a low-discrepancy sampler: it should MAXIMISE the minimum +//! nearest-neighbour gap (no two nodes collapse together) and keep the +//! nearest-neighbour distances uniform (low coefficient of variation), +//! BEATING both a regular (θ,φ) grid (which clumps at the pole) and uniform +//! random (which clumps everywhere — Poisson). If golden does NOT beat both, +//! the irrational stride is theater. Measured: min-gap (bigger = better) and +//! NN-distance CoV (smaller = more uniform). +//! +//! 2. NO-COSINE NORMALISED KEY (palette256 Fisher-z Prozentrang). +//! `fisher_z(s) = ½·ln((1+s)/(1-s)) = arctanh(s)` is strictly monotone in the +//! cosine `s`, and a percentile rank of a monotone transform is monotone in +//! `s` too — so the rank preserves EVERY pairwise similarity ordering +//! (Spearman = 1) while being a normalised [0,1] key you can compare directly +//! without ever re-materialising a cosine. Fisher-z additionally stretches +//! the rim (high-|s|) so equal rank steps carry equal discriminability. +//! Measured: ordering preservation, and rim-vs-centre resolution gain. +//! +//! cargo run --release --example golden_helix_probe + +const GAMMA: f64 = std::f64::consts::PI * (3.0 - 2.2360679774997896); // π(3−√5), golden angle + +/// Golden-spiral hemisphere unit vectors (the helix node directions). +fn golden_hemisphere(n: usize) -> Vec<[f64; 3]> { + (0..n) + .map(|i| { + let theta = 0.5 * (1.0 - 2.0 * (i as f64 + 0.5) / n as f64).acos(); // polar ∈ [0, π/2] + let phi = (i as f64 * GAMMA) % (2.0 * std::f64::consts::PI); + [theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos()] + }) + .collect() +} + +/// Regular (θ,φ) grid on the hemisphere — the "rational stride" null (clumps at pole). +fn regular_hemisphere(n: usize) -> Vec<[f64; 3]> { + let side = (n as f64).sqrt().round() as usize; + let mut v = Vec::with_capacity(side * side); + for a in 0..side { + for b in 0..side { + let theta = 0.5 * std::f64::consts::PI * (a as f64 + 0.5) / side as f64; + let phi = 2.0 * std::f64::consts::PI * (b as f64 + 0.5) / side as f64; + v.push([theta.sin() * phi.cos(), theta.sin() * phi.sin(), theta.cos()]); + } + } + v +} + +/// Uniform-random hemisphere (area-correct) — the Poisson-clumping null. +fn random_hemisphere(n: usize, seed: &mut u64) -> Vec<[f64; 3]> { + let mut u = || { + *seed = seed.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = *seed; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + ((z ^ (z >> 31)) >> 11) as f64 / (1u64 << 53) as f64 + }; + (0..n) + .map(|_| { + let z = u(); // cosθ uniform in [0,1] ⇒ area-uniform on the hemisphere + let phi = 2.0 * std::f64::consts::PI * u(); + let r = (1.0 - z * z).sqrt(); + [r * phi.cos(), r * phi.sin(), z] + }) + .collect() +} + +/// (min nearest-neighbour angle, CoV of nearest-neighbour angles). Angle = great-circle. +fn nn_stats(pts: &[[f64; 3]]) -> (f64, f64) { + let n = pts.len(); + let mut nn = vec![f64::INFINITY; n]; + for i in 0..n { + for j in (i + 1)..n { + let dot = (pts[i][0] * pts[j][0] + pts[i][1] * pts[j][1] + pts[i][2] * pts[j][2]).clamp(-1.0, 1.0); + let ang = dot.acos(); + if ang < nn[i] { + nn[i] = ang; + } + if ang < nn[j] { + nn[j] = ang; + } + } + } + let min = nn.iter().cloned().fold(f64::INFINITY, f64::min); + let mean = nn.iter().sum::() / n as f64; + let var = nn.iter().map(|d| (d - mean).powi(2)).sum::() / n as f64; + (min, var.sqrt() / mean) // (min gap, coefficient of variation) +} + +fn fisher_z(s: f64) -> f64 { + let s = s.clamp(-1.0 + 1e-9, 1.0 - 1e-9); + 0.5 * ((1.0 + s) / (1.0 - s)).ln() +} + +fn main() { + println!("== Golden-helix anti-theater probe ==\n"); + + println!("[1] Collapse-avoidance — min NN gap (rad, BIGGER better) + CoV (SMALLER better):"); + println!(" N golden(min/CoV) regular(min/CoV) random(min/CoV) golden wins?"); + let mut seed = 0xABCDEF; + for &n in &[16usize, 64, 256, 1024] { + let (gm, gc) = nn_stats(&golden_hemisphere(n)); + let (rm, rc) = nn_stats(®ular_hemisphere(n)); + let (xm, xc) = nn_stats(&random_hemisphere(n, &mut seed)); + // "Wins" = golden has the largest min-gap AND the lowest CoV. + let wins = gm >= rm && gm >= xm && gc <= rc && gc <= xc; + println!( + " {n:>5} {gm:.4}/{gc:.3} {rm:.4}/{rc:.3} {xm:.4}/{xc:.3} {}", + if wins { "YES" } else { "no" } + ); + } + + println!("\n[2] Fisher-z percentile rank as a no-cosine normalised key:"); + // A deterministic spread of cosine similarities in (−1, 1). + let mut sims: Vec = (0..1000) + .map(|i| -0.999 + 1.998 * (i as f64 + 0.5) / 1000.0) + .collect(); + // Percentile rank of fisher_z(s). Both fisher_z and ranking are monotone in s, + // so the rank order must equal the cosine order — verify (Spearman == 1). + let mut idx: Vec = (0..sims.len()).collect(); + idx.sort_by(|&a, &b| fisher_z(sims[a]).partial_cmp(&fisher_z(sims[b])).unwrap()); + let inversions = idx.windows(2).filter(|w| sims[w[0]] > sims[w[1]]).count(); + println!( + " rank-order vs cosine-order inversions: {inversions} (0 ⇒ ordering fully preserved, no cosine needed)" + ); + + // Rim-stretch: resolution (Δz per unit Δs) near the rim vs the centre. + sims.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let res = |s: f64| (fisher_z(s + 0.01) - fisher_z(s - 0.01)) / 0.02; + let centre = res(0.0); + let rim = res(0.9); + println!( + " Fisher-z resolution: centre(s=0.0) = {centre:.2}/unit, rim(s=0.9) = {rim:.2}/unit → rim gets {:.1}× more bits", + rim / centre + ); + println!(" ⇒ percentile rank ∈ [0,1] is a normalised similarity key; compare ranks directly,"); + println!(" never re-materialising cosine, with extra resolution where similarity is high."); +} diff --git a/examples/morton_cascade_probe.rs b/examples/morton_cascade_probe.rs new file mode 100644 index 00000000..1afcb298 --- /dev/null +++ b/examples/morton_cascade_probe.rs @@ -0,0 +1,202 @@ +//! Morton 2×2 cascade probe — non-materialized Z-order quadtree over the +//! gridlake SoA carrier, with Belichtungsmesser-style min/max early-exit. +//! +//! ## What this validates (probe-first, codec-agnostic) +//! +//! - **4×4 Morton leaf tile = one `F32x16`** ("2bit×2bit": 2 bits X, 2 bits Y = +//! 16 cells = 64 bytes = one AVX-512 register loaded from `MultiLaneColumn`). +//! - **Quadtree over T×T tiles** (2×2 per level): total grid `(4T)²` for +//! `T = 2^k` gives the ladder 64, 256, 1024, 4096, 16384, 64k, 256k. +//! - **Morton order ⇒ every quadtree node is a contiguous index range**, so the +//! aggregate (min/max) pyramid is a flat bottom-up reduction (the +//! Belichtungsmesser "calibrated bands" = per-node value range). +//! - **Cascade early-exit**: descend the pyramid; if a node's [min,max] can't +//! intersect the query band [q−r, q+r], prune the whole subtree (the 3-stroke +//! band-miss generalized). Leaf tile → `F32x16` load + test 16 cells. +//! +//! The cell value is a plain `f32` stand-in for the eventual per-cell codec +//! (palette256 / helix Fisher-2z / Belichtungsmesser band) — wiring that is the +//! next step; this probe proves the *substrate* (addressing + cascade) is +//! correct and that the prune actually skips work. +//! +//! RUSTFLAGS="-C target-cpu=native" cargo run --release --example morton_cascade_probe +//! +//! PASS: cascade count == brute-force count for every (size, query); the +//! reported prune-rate is the "boost". + +use std::sync::Arc; + +use ndarray::simd::{F32x16, MultiLaneColumn}; + +/// Interleave the low `bits` of `x` and `y` into a Z-order (Morton) index. +/// x occupies even output bits, y the odd bits. +fn morton2d(x: u32, y: u32, bits: u32) -> u32 { + let mut m = 0u32; + for b in 0..bits { + m |= ((x >> b) & 1) << (2 * b); + m |= ((y >> b) & 1) << (2 * b + 1); + } + m +} + +/// Deterministic SplitMix64 → f32 in [0, 1). +fn splitmix(state: &mut u64) -> f32 { + *state = state.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = *state; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^= z >> 31; + ((z >> 40) as f32) / (1u32 << 24) as f32 +} + +/// Build the field in Morton-tile order: +/// cell(x,y) → idx = morton(tx,ty)·16 + morton_in_tile(ix,iy) +/// where tile (tx,ty) = (x>>2, y>>2), in-tile (ix,iy) = (x&3, y&3). +/// Each 4×4 tile is therefore a contiguous 16-f32 (64-byte) `F32x16` chunk. +fn build_field(t: u32, seed: u64) -> Vec { + let side = 4 * t; // grid side + let n = (side * side) as usize; + let mut field = vec![0.0f32; n]; + let k = t.trailing_zeros(); // T = 2^k tiles per side + let mut st = seed; + for y in 0..side { + for x in 0..side { + let (tx, ty) = (x >> 2, y >> 2); + let (ix, iy) = (x & 3, y & 3); + let idx = (morton2d(tx, ty, k) as usize) * 16 + morton2d(ix, iy, 2) as usize; + // A smooth-ish field + noise so neighbouring tiles share value ranges + // (otherwise every node spans the full range and nothing prunes). + let base = ((x as f32) / side as f32 + (y as f32) / side as f32) * 0.5; + field[idx] = 0.85 * base + 0.15 * splitmix(&mut st); + } + } + field +} + +/// Aggregate (min,max) pyramid over the T² tiles in Morton order. Level 0 = +/// per-tile range (over its 16 cells); level l = range of 4 level-(l−1) nodes. +/// A node at level l covers `4^l` contiguous tiles starting at `base`. +struct Pyramid { + levels: Vec>, // levels[0] = per-tile, levels[K] = root + k: u32, // number of quadtree levels (T = 2^k) +} + +impl Pyramid { + fn build(field: &[f32], t: u32) -> Self { + let k = t.trailing_zeros(); + let n_tiles = (t * t) as usize; + // Level 0: min/max over each tile's 16 cells. + let mut lvl0 = Vec::with_capacity(n_tiles); + for tile in 0..n_tiles { + let s = &field[tile * 16..tile * 16 + 16]; + let mut mn = f32::INFINITY; + let mut mx = f32::NEG_INFINITY; + for &v in s { + mn = mn.min(v); + mx = mx.max(v); + } + lvl0.push((mn, mx)); + } + let mut levels = vec![lvl0]; + for l in 1..=k as usize { + let prev = &levels[l - 1]; + let mut cur = Vec::with_capacity(prev.len() / 4); + for node in prev.chunks_exact(4) { + let mn = node.iter().map(|p| p.0).fold(f32::INFINITY, f32::min); + let mx = node.iter().map(|p| p.1).fold(f32::NEG_INFINITY, f32::max); + cur.push((mn, mx)); + } + levels.push(cur); + } + Pyramid { levels, k } + } +} + +/// Cascade query: count cells with |value − q| ≤ r, descending the quadtree and +/// pruning any node whose [min,max] can't intersect [q−r, q+r]. Returns +/// (count, cells_visited) — cells_visited only counts leaf cells actually tested. +fn cascade_count(field: &[f32], col: &MultiLaneColumn, pyr: &Pyramid, q: f32, r: f32) -> (usize, usize) { + let (lo, hi) = (q - r, q + r); + let mut count = 0usize; + let mut visited = 0usize; + // Stack of (level, node_index_within_level). + let mut stack = vec![(pyr.k as usize, 0usize)]; + let bytes = col.as_bytes(); + while let Some((level, node)) = stack.pop() { + let (mn, mx) = pyr.levels[level][node]; + if mx < lo || mn > hi { + continue; // band miss → prune whole subtree (early-exit) + } + if level == 0 { + // Leaf tile = 16 cells = one F32x16 chunk in the SoA column. + let off = node * 64; // 16 f32 × 4 bytes + let chunk: [u8; 64] = bytes[off..off + 64].try_into().unwrap(); + let arr = f32x16_from_bytes(&chunk).to_array(); + for &v in arr.iter() { + if (v - q).abs() <= r { + count += 1; + } + } + visited += 16; + } else { + let base = node * 4; + for c in 0..4 { + stack.push((level - 1, base + c)); + } + } + } + let _ = field; // field kept for the brute-force reference; cascade reads the SoA column + (count, visited) +} + +/// Build an `F32x16` from 64 little-endian bytes (one 4×4 Morton tile). +fn f32x16_from_bytes(chunk: &[u8; 64]) -> F32x16 { + let arr: [f32; 16] = core::array::from_fn(|i| { + let o = i * 4; + f32::from_le_bytes([chunk[o], chunk[o + 1], chunk[o + 2], chunk[o + 3]]) + }); + F32x16::from_array(arr) +} + +fn brute_count(field: &[f32], q: f32, r: f32) -> usize { + field.iter().filter(|&&v| (v - q).abs() <= r).count() +} + +fn run(t: u32) { + let side = 4 * t; + let n = (side * side) as usize; + let field = build_field(t, 0xC0FFEE ^ t as u64); + // Wrap the Morton-ordered field bytes in the gridlake SoA carrier. + let raw: Vec = field.iter().flat_map(|v| v.to_le_bytes()).collect(); + let col = MultiLaneColumn::new(Arc::from(raw.into_boxed_slice())).unwrap(); + let pyr = Pyramid::build(&field, t); + + // Three query bands: tight (high prune), medium, broad (low prune). + let queries = [(0.5f32, 0.02f32), (0.25, 0.10), (0.5, 0.5)]; + let mut all_ok = true; + let mut report = String::new(); + for (q, r) in queries { + let exp = brute_count(&field, q, r); + let (got, visited) = cascade_count(&field, &col, &pyr, q, r); + let ok = got == exp; + all_ok &= ok; + let prune = 100.0 * (1.0 - visited as f64 / n as f64); + report.push_str(&format!( + " q={q:.2} r={r:.2}: count {got:>7} (exp {exp:>7}) {} visited {visited:>8}/{n} → prune {prune:5.1}%\n", + if ok { "OK " } else { "MISMATCH" } + )); + } + println!( + " T={t:>3} tiles/side grid {side}×{side} = {n:>7} cells ({})\n{}", + if all_ok { "CORRECT" } else { "WRONG" }, + report + ); +} + +fn main() { + println!("== Morton 2×2 quadtree cascade probe (4×4 tile = F32x16, gridlake SoA) ==\n"); + // T = 2^k → grid (4T)² = the ladder 64, 256, 1024, 4096, 16384, 64k, 256k. + for t in [2u32, 4, 8, 16, 32, 64, 128] { + run(t); + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index df71a701..82c43cda 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -169,15 +169,26 @@ pub fn cblas_dgemm( /// INT8 GEMM: C = A × B where A is u8, B is i8, C is i32. /// -/// Dispatch: AMX TDPBUSD → VNNI VPDPBUSD → scalar. +/// Dispatch: AMX `TDPBUSD` → AVX-512 VNNI `VPDPBUSD` (zmm) → AVX-VNNI +/// `VPDPBUSD` (ymm) → scalar, via [`crate::simd_int_ops::gemm_u8_i8`]. /// Same signature across all paths. +/// +/// Routing note: this surface is **u8 × i8 → i32** (the VNNI-native +/// unsigned-by-signed form). The 4-tier dispatcher that matches this +/// exact signedness is `gemm_u8_i8`; its scalar fallback is the same +/// [`crate::hpc::quantized::int8_gemm_i32`] reference the previous +/// 2-tier `vnni_gemm::int8_gemm_vnni` used, so results are +/// bit-equivalent. (The ArrayView-based `amx_matmul::matmul_i8_to_i32` +/// is **i8 × i8** and would reinterpret A's bytes ≥ 128 as negative — +/// not bit-equivalent here — so it is intentionally NOT used.) #[inline] #[allow(clippy::needless_return)] pub fn gemm_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { - // VNNI path (Ice Lake, Sapphire Rapids, Zen 4) — includes AMX fallback + // AMX → VNNI-zmm → VNNI-ymm → scalar (Ice Lake / Sapphire Rapids / + // Granite Rapids / Zen 4 / Arrow Lake all covered by one call). #[cfg(feature = "std")] { - crate::hpc::vnni_gemm::int8_gemm_vnni(a, b, c, m, n, k); + crate::simd_int_ops::gemm_u8_i8(a, b, c, m, n, k); return; } #[cfg(not(feature = "std"))] @@ -189,22 +200,39 @@ pub fn gemm_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) /// BF16 GEMM: C (f32) = A (BF16) × B (BF16), with f32 accumulation. /// -/// Dispatch: AMX TDPBF16PS → scalar tiled bf16_gemm_f32. +/// Dispatch: AMX `TDPBF16PS` → AVX-512 `VDPBF16PS` → scalar tiled +/// `bf16_gemm_f32`, via [`crate::hpc::amx_matmul::matmul_bf16_to_f32`]. /// Input: raw u16 slices representing BF16 values (same layout as -/// `ndarray::hpc::quantized::BF16`). +/// `ndarray::hpc::quantized::BF16`). `C` is overwritten (beta = 0, +/// alpha = 1), bit-equivalent to the scalar reference on hosts without +/// AMX / AVX-512BF16. #[inline] #[allow(clippy::needless_return)] pub fn gemm_bf16(a: &[u16], b: &[u16], c: &mut [f32], m: usize, n: usize, k: usize) { // Reinterpret u16 slices as BF16 slices (repr(transparent)) #[cfg(feature = "std")] { + use crate::{ArrayView2, ArrayViewMut2}; + let a_bf16: &[crate::hpc::quantized::BF16] = unsafe { - // SAFETY: BF16 is #[repr(transparent)] over u16 + // SAFETY: BF16 is #[repr(transparent)] over u16, so the bit + // pattern of every element is preserved by this reinterpret. core::slice::from_raw_parts(a.as_ptr() as *const crate::hpc::quantized::BF16, a.len()) }; let b_bf16: &[crate::hpc::quantized::BF16] = + // SAFETY: same repr(transparent) invariant as `a_bf16` above. unsafe { core::slice::from_raw_parts(b.as_ptr() as *const crate::hpc::quantized::BF16, b.len()) }; - crate::hpc::quantized::bf16_gemm_f32(a_bf16, b_bf16, c, m, n, k, 1.0, 0.0); + + // Wrap the raw row-major slices as 2-D views. The AMX dispatcher + // is ArrayView2-based; mirror the call shape used by + // `simd_runtime::matmul::matmul_bf16_to_f32` (lhs:(M,K), rhs:(K,N), + // out:(M,N), row-major contiguous). Slice to the exact element + // count so an over-long input slice does not fail the shape check + // (the old scalar path only required `len >= m*k`). + let lhs = ArrayView2::from_shape((m, k), &a_bf16[..m * k]).expect("gemm_bf16: A shape (m,k) vs slice len"); + let rhs = ArrayView2::from_shape((k, n), &b_bf16[..k * n]).expect("gemm_bf16: B shape (k,n) vs slice len"); + let out = ArrayViewMut2::from_shape((m, n), &mut c[..m * n]).expect("gemm_bf16: C shape (m,n) vs slice len"); + crate::hpc::amx_matmul::matmul_bf16_to_f32(lhs, rhs, out).expect("gemm_bf16: matmul shape contract"); return; } #[cfg(not(feature = "std"))] diff --git a/src/backend/native.rs b/src/backend/native.rs index ee14bbb7..69b25e77 100644 --- a/src/backend/native.rs +++ b/src/backend/native.rs @@ -266,15 +266,64 @@ pub fn gemm_f64( } // ─── GEMV dispatch ─────────────────────────────────────────────── - -/// GEMV: y = alpha * A * x + beta * y (f32) +// +// GEMV (`y = alpha·A·x + beta·y`, row-major) is a stack of row dot +// products: `y[i] = alpha · dot(A[i, ·], x) + beta · y[i]`. On the +// SIMD tiers we route each row through the already-dispatched, +// parity-tested `dot_f32` / `dot_f64` (FMA + two-accumulator reduce) +// instead of the scalar 1-wide fold. Each A row is contiguous in +// row-major (col stride = 1), so the contiguous `dot_*` loads apply +// directly; only the leading `n` of each `lda`-wide row are columns. +// +// The Scalar tier keeps calling `scalar::gemv_*` verbatim so a +// scalar-only build stays byte-identical to the prior reference. The +// AVX2/AVX-512 tiers carry the same 1-2 ULP reduction-order drift the +// rest of this module's SIMD BLAS-1 kernels already document (see the +// `nrm2` note above and the ULP-tolerant `td_t6_*` parity tests) — +// well within the BLAS tolerance GEMV consumers use. + +/// GEMV: y = alpha * A * x + beta * y (f32, row-major). +/// +/// SIMD tiers compute each row via [`dot_f32`]; the scalar tier uses +/// the byte-stable [`scalar::gemv_f32`] reference. pub fn gemv_f32(m: usize, n: usize, alpha: f32, a: &[f32], lda: usize, x: &[f32], beta: f32, y: &mut [f32]) { - scalar::gemv_f32(m, n, alpha, a, lda, x, beta, y); + if m == 0 { + return; // no rows ⇒ no-op; must not slice `x[..n]` (scalar ref returns too) + } + match tier() { + Tier::Scalar => scalar::gemv_f32(m, n, alpha, a, lda, x, beta, y), + // Avx512 + Avx2: per-row SIMD dot product. `dot_f32` itself + // dispatches to the active tier's kernel and is parity-tested. + _ => { + let xn = &x[..n]; + for i in 0..m { + let row = &a[i * lda..i * lda + n]; + let sum = dot_f32(row, xn); + y[i] = alpha * sum + beta * y[i]; + } + } + } } -/// GEMV: y = alpha * A * x + beta * y (f64) +/// GEMV: y = alpha * A * x + beta * y (f64, row-major). +/// +/// SIMD tiers compute each row via [`dot_f64`]; the scalar tier uses +/// the byte-stable [`scalar::gemv_f64`] reference. pub fn gemv_f64(m: usize, n: usize, alpha: f64, a: &[f64], lda: usize, x: &[f64], beta: f64, y: &mut [f64]) { - scalar::gemv_f64(m, n, alpha, a, lda, x, beta, y); + if m == 0 { + return; // no rows ⇒ no-op; must not slice `x[..n]` (scalar ref returns too) + } + match tier() { + Tier::Scalar => scalar::gemv_f64(m, n, alpha, a, lda, x, beta, y), + _ => { + let xn = &x[..n]; + for i in 0..m { + let row = &a[i * lda..i * lda + n]; + let sum = dot_f64(row, xn); + y[i] = alpha * sum + beta * y[i]; + } + } + } } // ═══════════════════════════════════════════════════════════════════ diff --git a/src/hpc/amx_matmul.rs b/src/hpc/amx_matmul.rs index 673eab4f..d8322455 100644 --- a/src/hpc/amx_matmul.rs +++ b/src/hpc/amx_matmul.rs @@ -29,28 +29,78 @@ pub struct TileConfig { } impl TileConfig { - /// Configure for TDPBUSD: C[16×16 i32] += A[16×k u8] × B[k×16 i8]. + /// Configure all three tiles for the int8/bf16 16×16 tile GEMM. Every tile + /// is 16 rows × 64 colbytes; the shapes are identical so the same config + /// serves C, the plain M×K operand, and the VNNI K×N operand. *Which* + /// operand lands in *which* tile is decided by the kernel — see + /// [`tile_dpbusd`] for the empirically-verified placement (VNNI K×N → tmm1, + /// plain M×K → tmm2, result → tmm0). /// - /// Tiles: - /// tmm0 = C (result): 16 rows × 64 bytes (16×16 i32) - /// tmm1 = A (left): 16 rows × 64 bytes (16×64 u8) - /// tmm2 = B (right): 16 rows × 64 bytes (transposed: 64×16 → 16×64) + /// Tiles (shapes): + /// tmm0 = C (result): 16 rows × 64 bytes (16×16 i32) + /// tmm1 = VNNI K×N (vvvv): kb/4 rows × 64 bytes + /// tmm2 = plain M×K (rm): 16 rows × kb bytes pub fn for_dpbusd(k_bytes: u16) -> Self { let mut cfg = TileConfig { data: [0u8; 64] }; cfg.data[0] = 1; // palette 1 + // byte 1 = start_row = 0; bytes 2-15 reserved = 0 (already zeroed). + + // XTILECFG layout (Intel SDM, LDTILECFG memory operand): + // colsb[t] : u16 at offset 16 + 2*t (bytes per tile row, ≤ 64) + // rows[t] : u8 at offset 48 + t (rows per tile, ≤ 16) + // + // The previous version had these two regions SWAPPED — it wrote the + // row counts into the colsb region (offsets 16/17/18) and the column + // widths into the rows region (offsets 48/50/52). That produced + // colsb[0] = 0x1010 = 4112 and rows[0] = 64, both out of range, so + // LDTILECFG #GP-faulted (delivered as SIGSEGV) the instant the AMX + // path actually executed. (It never had — every AMX test early-returns + // when `amx_available()` is false, which it always was until the + // arch_prctl syscall-number fix in `simd_amx.rs`.) + let kb = k_bytes.min(64); // colsb ≤ 64 ⇒ each u16 high byte stays 0 + + // Tile 0 (C): 16 rows × 64 colbytes (16 × i32 per row = 64 bytes). + cfg.data[16] = 64; // colsb[0] low (u16 @ 16); high byte @17 stays 0 + cfg.data[48] = 16; // rows[0] (u8 @ 48) + + // Tile 1 (B, VNNI K×N → VEX.vvvv): kb/4 rows × 64 colbytes. The kernel + // loads the VNNI operand into tmm1, so tile 1 must carry the VNNI shape. + // (Was the plain 16×kb shape — equal to this only at kb=64; backwards + // for kb<64, which would mis-shape a tail kernel / external caller.) + cfg.data[18] = 64; // colsb[1] low (u16 @ 18); high byte @19 stays 0 + cfg.data[49] = (kb / 4) as u8; // rows[1] (u8 @ 49) + + // Tile 2 (A, plain M×K → ModRM.rm): 16 rows × kb colbytes. + cfg.data[20] = kb as u8; // colsb[2] low (u16 @ 20); high byte @21 stays 0 + cfg.data[50] = 16; // rows[2] (u8 @ 50) - // Tile 0 (C): 16 rows × 64 bytes (16 × i32 per row = 64 bytes) - cfg.data[16] = 16; - cfg.data[48] = 64; - - // Tile 1 (A): 16 rows × k_bytes (capped at 64) - cfg.data[17] = 16; - cfg.data[50] = k_bytes.min(64) as u8; - - // Tile 2 (B): k_bytes/4 rows × 64 bytes (transposed layout) - cfg.data[18] = (k_bytes.min(64) / 4) as u8; - cfg.data[52] = 64; + cfg + } + /// Configure all EIGHT tiles (tmm0-7) as 16 rows × 64 colbytes, for the 2×2 + /// register-blocked int8 kernel: tmm0-3 = the four C accumulators, tmm4-5 = + /// two plain A row-blocks (rm/unsigned), tmm6-7 = two VNNI B col-blocks + /// (vvvv/signed). Every tile is 16×64 so one config serves all roles. Same + /// XTILECFG layout as [`Self::for_dpbusd`]: colsb[t] u16 @ 16+2t, rows[t] + /// u8 @ 48+t. + /// + /// # Examples + /// ```ignore + /// use ndarray::hpc::amx_matmul::{tile_loadconfig, tile_release, TileConfig}; + /// // SAFETY: requires AMX (gate on `amx_available()`); all 8 tiles are 16×64. + /// unsafe { + /// tile_loadconfig(&TileConfig::for_dpbusd_8()); + /// // load A→tmm4/tmm5, B-VNNI→tmm6/tmm7, zero tmm0-3, then tile_dpbusd_2x2() + /// tile_release(); + /// } + /// ``` + pub fn for_dpbusd_8() -> Self { + let mut cfg = TileConfig { data: [0u8; 64] }; + cfg.data[0] = 1; // palette 1 + for t in 0..8 { + cfg.data[16 + 2 * t] = 64; // colsb[t] = 64 (low byte; high stays 0) + cfg.data[48 + t] = 16; // rows[t] = 16 + } cfg } } @@ -94,16 +144,20 @@ pub unsafe fn tile_release() { /// Load tile from memory. /// -/// Encoding: `TILELOADD tmmN, [rcx + rax]` is VEX `C4 E2 7B 4B /r` with -/// a SIB byte selecting `[rcx + rax]`. The ModR/M `/r` field encodes the +/// Encoding: `TILELOADD tmmN, [rcx + rax]` is VEX `C4 E2 7B 4B /r` with a +/// SIB byte selecting base = `rcx` (the data pointer) and index = `rax` +/// (the row stride in bytes, scale 1). The ModR/M `/r` field encodes the /// destination tile via `reg = N` (3-bit tile index). Per-tile bytes: /// -/// tmm0: C4 E2 7B 4B **04** 08 -/// tmm1: C4 E2 7B 4B **0C** 08 -/// tmm2: C4 E2 7B 4B **14** 08 +/// tmm0: C4 E2 7B 4B **04** 01 +/// tmm1: C4 E2 7B 4B **0C** 01 +/// tmm2: C4 E2 7B 4B **14** 01 /// -/// `04 | (N << 3)` gives the ModR/M byte; the `08` SIB is the same -/// across tiles. tmm0 was added when codex flagged the accumulator- +/// `04 | (N << 3)` gives the ModR/M byte; SIB `01` = (scale=1, index=rax, +/// base=rcx) is the same across tiles. The previous SIB `08` had base and +/// index swapped (base=rax, index=rcx), so the tile engine dereferenced the +/// *stride value* (~64) as the start address and SIGSEGV'd the moment the +/// AMX path went live. tmm0 was added when codex flagged the accumulator- /// preservation bug on PR #184 (`tile_zero(0)` + `tile_store(0, c)` /// discarded any pre-existing C values — the fix is `tile_load(0, c)` /// instead of `tile_zero(0)` so TDPBUSD/TDPBF16PS truly accumulate as @@ -115,19 +169,51 @@ pub unsafe fn tile_release() { pub unsafe fn tile_load(tile: u8, ptr: *const u8, stride: usize) { match tile { 0 => asm!( - ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x04, 0x08", + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x04, 0x01", in("rcx") ptr, in("rax") stride, options(nostack), ), 1 => asm!( - ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x0c, 0x08", + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x0c, 0x01", in("rcx") ptr, in("rax") stride, options(nostack), ), 2 => asm!( - ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x14, 0x08", + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x14, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + // ModRM = 0x04 | (tile << 3); SIB 0x01 unchanged. tmm3-7 added for the + // 2×2 register-blocked kernel (4 C accumulators + 2 A + 2 B tiles). + 3 => asm!( + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x1c, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 4 => asm!( + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x24, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 5 => asm!( + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x2c, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 6 => asm!( + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x34, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 7 => asm!( + ".byte 0xc4, 0xe2, 0x7b, 0x4b, 0x3c, 0x01", in("rcx") ptr, in("rax") stride, options(nostack), @@ -136,16 +222,35 @@ pub unsafe fn tile_load(tile: u8, ptr: *const u8, stride: usize) { } } -/// Store tile to memory. +/// Store tile to memory. ModRM = 0x04 | (tile << 3); SIB 0x01 = base=rcx (ptr), +/// index=rax (stride), scale=1. tmm0-3 are the four C accumulators of the 2×2 +/// register-blocked kernel. /// /// # Safety /// Pointer must be valid and writable, stride must match. #[inline] pub unsafe fn tile_store(tile: u8, ptr: *mut u8, stride: usize) { match tile { - // TILESTORED [ptr + stride*row], tmm0 0 => asm!( - ".byte 0xc4, 0xe2, 0x7a, 0x4b, 0x04, 0x08", + ".byte 0xc4, 0xe2, 0x7a, 0x4b, 0x04, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 1 => asm!( + ".byte 0xc4, 0xe2, 0x7a, 0x4b, 0x0c, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 2 => asm!( + ".byte 0xc4, 0xe2, 0x7a, 0x4b, 0x14, 0x01", + in("rcx") ptr, + in("rax") stride, + options(nostack), + ), + 3 => asm!( + ".byte 0xc4, 0xe2, 0x7a, 0x4b, 0x1c, 0x01", in("rcx") ptr, in("rax") stride, options(nostack), @@ -154,18 +259,78 @@ pub unsafe fn tile_store(tile: u8, ptr: *mut u8, stride: usize) { } } -/// TDPBUSD: C += A(u8) × B(i8) → i32. -/// tmm0 += tmm1 × tmm2. +/// TDPBUSD: C(i32, tmm0) += unsigned ⊗ signed → i32. 16×16 output, 64 products +/// per element = 16384 MACs in ONE instruction. +/// +/// **Empirical operand convention** (measured on Emerald Rapids — the naive +/// SDM reading is mirrored on BOTH axes, so do not infer roles from the +/// mnemonic): +/// * INDEX: `dst[m][n] = Σ_k tmm2(ModRM.rm)[m][k] · tmm1(VEX.vvvv)[k][n]`. +/// The plain **M×K** operand goes in **tmm2**; the VNNI-packed **K×N** +/// operand goes in **tmm1** (the opposite of what the SDM operand order +/// suggests). +/// * SIGN: **tmm2 (rm) is the UNSIGNED operand, tmm1 (vvvv) is SIGNED.** +/// Verified by sweeping all four `TDPB**D` opcodes against sign-sensitive +/// constant inputs — only `0x71` gives rm=unsigned, vvvv=signed. +/// +/// So `int8_tile_gemm::amx_path` loads `A(u8) → tmm2` and `B_vnni(i8) → tmm1`. /// -/// 16×16 output, 64 products per element = 16384 MACs in ONE instruction. +/// Encoding: VEX.128.66.0F38.W0 5E /r. byte2 `0x71` = W0.vvvv=1110(tmm1).L0. +/// pp=01(66); ModRM `0xC2` = mod=11, reg=000 (tmm0 dst), rm=010 (tmm2). The +/// three tile operands MUST be distinct (tmm0/tmm1/tmm2 are). Two earlier +/// bugs in this file: byte2 `0x73` (F2 = signed×signed, wrong sign variant) +/// and ModRM `0xC1` (rm=tmm1, aliasing the two sources → same-tile #UD, the +/// first SIGILL the live AMX path hit). /// /// # Safety -/// Tiles must be loaded with valid data. +/// Tiles must be loaded with valid data and AMX OS-enabled (`amx_available()`). #[inline] pub unsafe fn tile_dpbusd() { - // TDPBUSD tmm0, tmm1, tmm2 - // VEX.128.F2.0F38.W0 5E C8+reg - asm!(".byte 0xc4, 0xe2, 0x73, 0x5e, 0xc1", options(nostack, nomem)); + asm!(".byte 0xc4, 0xe2, 0x71, 0x5e, 0xc2", options(nostack, nomem)); +} + +/// 2×2 register-blocked TDPBUSD — four accumulations in one call: +/// ```text +/// C00(tmm0) += A0(tmm4) ⊗ B0(tmm6) C01(tmm1) += A0(tmm4) ⊗ B1(tmm7) +/// C10(tmm2) += A1(tmm5) ⊗ B0(tmm6) C11(tmm3) += A1(tmm5) ⊗ B1(tmm7) +/// ``` +/// Same operand convention as [`tile_dpbusd`]: the plain M×K operand is rm +/// (A, unsigned), the VNNI K×N operand is vvvv (B, signed). Reusing the two A +/// and two B tile loads across four products halves the load bytes per MAC — +/// the lever for this memory-bandwidth-bound kernel. +/// +/// Encodings (VEX `C4 E2 5E `, byte2 = ((~vvvv & 0xF)<<3)|0x01, +/// modrm = 0xC0 | dst<<3 | rm): +/// C00 dst0 rm4 vvvv6 → C4 E2 49 5E C4 C01 dst1 rm4 vvvv7 → C4 E2 41 5E CC +/// C10 dst2 rm5 vvvv6 → C4 E2 49 5E D5 C11 dst3 rm5 vvvv7 → C4 E2 41 5E DD +/// All eight operand tiles (0/1/2/3 dst, 4/5 A, 6/7 B) are distinct → no #UD. +/// +/// # Examples +/// ```ignore +/// use ndarray::hpc::amx_matmul::*; +/// // SAFETY: requires AMX; full 32×32 register-blocked tile contract. +/// unsafe { +/// tile_loadconfig(&TileConfig::for_dpbusd_8()); +/// tile_zero(0); tile_zero(1); tile_zero(2); tile_zero(3); // C accumulators +/// tile_load(4, a0_ptr, k); tile_load(5, a1_ptr, k); // A rows (rm) +/// tile_load(6, b0_vnni, 64); tile_load(7, b1_vnni, 64); // B cols (vvvv) +/// tile_dpbusd_2x2(); // 4 TDPBUSDs +/// tile_store(0, c00, n * 4); /* … tmm1/2/3 → other quadrants … */ +/// tile_release(); +/// } +/// ``` +/// +/// # Safety +/// Tiles 0-7 configured (`TileConfig::for_dpbusd_8`) and 4/5/6/7 loaded. +#[inline] +pub unsafe fn tile_dpbusd_2x2() { + asm!( + ".byte 0xc4, 0xe2, 0x49, 0x5e, 0xc4", // C00 = A0·B0 + ".byte 0xc4, 0xe2, 0x41, 0x5e, 0xcc", // C01 = A0·B1 + ".byte 0xc4, 0xe2, 0x49, 0x5e, 0xd5", // C10 = A1·B0 + ".byte 0xc4, 0xe2, 0x41, 0x5e, 0xdd", // C11 = A1·B1 + options(nostack, nomem) + ); } /// TDPBF16PS: C += A(bf16) × B(bf16_vnni) → f32. @@ -174,9 +339,13 @@ pub unsafe fn tile_dpbusd() { /// 16×16 output accumulator (f32), 32 bf16 values per A row × 32 bf16 values /// per B row in VNNI layout = 512 mul-adds in one instruction. /// -/// Encoding (analogous to TDPBUSD, pp field flips F2→F3, opcode 5E→5C): -/// TDPBUSD tmm0, tmm1, tmm2 → C4 E2 73 5E C1 -/// TDPBF16PS tmm0, tmm1, tmm2 → C4 E2 72 5C C1 +/// Encoding (analogous to TDPBUSD: opcode 5E→5C, pp 66→F3): +/// TDPBUSD tmm0, tmm1, tmm2 → C4 E2 71 5E C2 +/// TDPBF16PS tmm0, tmm1, tmm2 → C4 E2 72 5C C2 +/// ModRM 0xC2 = mod=11, reg=000 (tmm0 dst), rm=010 (tmm2 src2); VEX.vvvv = +/// tmm1 (src1). The three tile operands MUST be distinct — the prior ModRM +/// 0xC1 set rm=tmm1, making src1==src2 → same-tile #UD (SIGILL) once the AMX +/// path actually executed. /// /// Tile shapes at K=32, M=N=16 (identical to TDPBUSD max at K_bytes=64): /// tmm0 (C): 16×16 f32 (16 rows × 64 bytes) @@ -188,7 +357,7 @@ pub unsafe fn tile_dpbusd() { /// and loaded with valid data; AMX must be OS-enabled (check `amx_available()`). #[inline] pub unsafe fn tile_dpbf16ps() { - asm!(".byte 0xc4, 0xe2, 0x72, 0x5c, 0xc1", options(nostack, nomem)); + asm!(".byte 0xc4, 0xe2, 0x72, 0x5c, 0xc2", options(nostack, nomem)); } /// Pack B[K, N] bf16 row-major into K/2 × (N*2) VNNI pairs (in-place target). @@ -684,9 +853,15 @@ mod tests { #[test] fn test_tile_config_creation() { let cfg = TileConfig::for_dpbusd(64); - assert_eq!(cfg.data[0], 1); // palette - assert_eq!(cfg.data[16], 16); // tile 0 rows - assert_eq!(cfg.data[48], 64); // tile 0 colbytes + assert_eq!(cfg.data[0], 1, "palette 1"); + // Intel SDM XTILECFG: colsb[t] is u16 @ 16+2t, rows[t] is u8 @ 48+t. + // All three tiles are 16 rows × 64 colbytes at k_bytes = 64. + assert_eq!(u16::from_le_bytes([cfg.data[16], cfg.data[17]]), 64, "tile0 colsb"); + assert_eq!(cfg.data[48], 16, "tile0 rows"); + assert_eq!(u16::from_le_bytes([cfg.data[18], cfg.data[19]]), 64, "tile1 colsb"); + assert_eq!(cfg.data[49], 16, "tile1 rows"); + assert_eq!(u16::from_le_bytes([cfg.data[20], cfg.data[21]]), 64, "tile2 colsb"); + assert_eq!(cfg.data[50], 16, "tile2 rows (k_bytes/4)"); } #[test] @@ -696,11 +871,14 @@ mod tests { return; } unsafe { - // Minimal config: just tile 0, 1 row × 4 bytes + // Minimal valid tile 0: 1 row × 4 colbytes, using the CORRECTED + // XTILECFG offsets (colsb[t] u16 @ 16+2t, rows[t] u8 @ 48+t). The + // old code wrote data[16]=1/data[48]=4 which under the fixed layout + // means colsb=1/rows=4 — still valid, but mislabeled; now explicit. let mut cfg = TileConfig { data: [0u8; 64] }; cfg.data[0] = 1; // palette 1 - cfg.data[16] = 1; // tile 0: 1 row - cfg.data[48] = 4; // tile 0: 4 colbytes + cfg.data[16] = 4; // colsb[0] = 4 bytes (u16 @ 16) + cfg.data[48] = 1; // rows[0] = 1 row (u8 @ 48) tile_loadconfig(&cfg); // TILEZERO tmm0 diff --git a/src/hpc/bf16_tile_gemm.rs b/src/hpc/bf16_tile_gemm.rs index 60f6b9ea..c84c8bec 100644 --- a/src/hpc/bf16_tile_gemm.rs +++ b/src/hpc/bf16_tile_gemm.rs @@ -83,11 +83,16 @@ unsafe fn amx_path(a_bf16: &[u16], b_vnni: &[u16], c: &mut [f32], k: usize) { let a_stride = (k * 2) as usize; // full A row stride in bytes (bf16 = 2B) let b_stride = 64usize; // VNNI row stride in bytes + // Operand placement (verified empirically — see `tile_dpbusd` doc): the + // AMX convention is dst[m][n] = Σ_k rm[m][k] · vvvv[k][n], with rm = + // ModRM (tmm2) = plain M×K and vvvv (tmm1) = VNNI K×N. So B (VNNI) → tmm1 + // and A (plain) → tmm2. (bf16 has no signed/unsigned split, so the + // single TDPBF16PS variant suffices once the operands are placed right.) for kb in 0..k_blocks { let a_ptr = a_bf16.as_ptr().add(kb * 32) as *const u8; let b_ptr = b_vnni.as_ptr().add(kb * 16 * 32) as *const u8; - tile_load(1, a_ptr, a_stride); - tile_load(2, b_ptr, b_stride); + tile_load(1, b_ptr, b_stride); // B (VNNI) → tmm1 (vvvv) + tile_load(2, a_ptr, a_stride); // A (plain) → tmm2 (rm) tile_dpbf16ps(); } diff --git a/src/hpc/int8_tile_gemm.rs b/src/hpc/int8_tile_gemm.rs index ee778531..9a732469 100644 --- a/src/hpc/int8_tile_gemm.rs +++ b/src/hpc/int8_tile_gemm.rs @@ -24,8 +24,8 @@ //! supports natively. use crate::hpc::amx_matmul::{ - amx_available, tile_dpbusd, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, vnni_pack_i8, - TileConfig, + amx_available, tile_dpbusd, tile_dpbusd_2x2, tile_load, tile_loadconfig, tile_release, tile_store, tile_zero, + vnni_pack_i8, TileConfig, }; // ═════════════════════════════════════════════════════════════════════ @@ -95,14 +95,21 @@ unsafe fn amx_path(a_u8: &[u8], b_vnni: &[i8], c: &mut [i32], k: usize) { let a_stride = k; // bytes per A row (u8 = 1 byte each) let b_stride = 64usize; // VNNI: 16 columns × 4 bytes per row + // Operand placement (verified empirically on Emerald Rapids — see the + // `tile_dpbusd` doc): the AMX operand convention is the mirror of the + // naive SDM reading. The plain M×K operand goes in tmm2 (ModRM.rm) and is + // treated UNSIGNED; the VNNI K×N operand goes in tmm1 (VEX.vvvv) and is + // treated SIGNED. TDPBUSD then computes + // dst[m][n] = Σ_k a_u8(rm, unsigned)[m][k] · b_i8(vvvv, signed)[k][n] + // — exactly the u8 × i8 this kernel promises. for kb in 0..k_blocks { let a_ptr = a_u8.as_ptr().add(kb * 64); // B sits in VNNI layout: K/4 outer rows × 64 bytes. Each // 64-K-element block spans 16 outer rows × 64 bytes = 1024 // bytes. let b_ptr = b_vnni.as_ptr().add(kb * 16 * 64) as *const u8; - tile_load(1, a_ptr, a_stride); - tile_load(2, b_ptr, b_stride); + tile_load(1, b_ptr, b_stride); // B (VNNI) → tmm1 (vvvv, signed) + tile_load(2, a_ptr, a_stride); // A (plain) → tmm2 (rm, unsigned) tile_dpbusd(); } @@ -362,28 +369,254 @@ pub fn int8_gemm_amx_tiled(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: debug_assert_eq!(n % 16, 0, "int8_gemm_amx_tiled: N must be multiple of 16"); debug_assert_eq!(k % 64, 0, "int8_gemm_amx_tiled: K must be multiple of 64"); - let mut b_tile = vec![0i8; k * 16]; - let mut tile_c = vec![0i32; 256]; + // With `rayon`, fan the M/16 row-tiles across the pool — but only for LARGE + // GEMMs. This AMX kernel is memory-bandwidth-bound, so on a few-core host the + // cores contend for bandwidth and row-tile parallelism scales sublinearly + // (~1.4× at 2048³ on 4 cores); below ~2 GMAC the thread-dispatch + shared + // B-prepack overhead actually REGRESSES it (measured: 512³ 125→73 GMAC/s). + // The threshold keeps the fast serial path for small/medium shapes and only + // parallelizes where it nets a win (and where many-core servers gain most). + // The per-tile kernel is byte-for-byte the validated serial one, so + // correctness is unchanged. (AMX permission is process-wide; the tile CONFIG + // is per-thread CPU state, so each worker runs its own LDTILECFG — see + // `int8_gemm_amx_tiled_par`.) + #[cfg(feature = "rayon")] + { + let work = (m as u64).saturating_mul(n as u64).saturating_mul(k as u64); + if m >= 32 && work >= 2_000_000_000 { + int8_gemm_amx_tiled_par(a_u8, b_i8, c, m, n, k); + return; + } + } + // 2×2 register blocking halves the load bytes/MAC — the right lever for this + // bandwidth-bound kernel — whenever there is at least one full 32×32 block. + if m >= 32 && n >= 32 { + int8_gemm_amx_tiled_rb(a_u8, b_i8, c, m, n, k); + return; + } + int8_gemm_amx_tiled_serial(a_u8, b_i8, c, m, n, k); +} - for j_tile in (0..n).step_by(16) { - // Pack B[0..k, j_tile..j_tile+16] into 16-wide K-rows - // (contiguous memory for int8_tile_gemm_16x16's input shape). - // Safe slicing — the row..row+16 range is bounded by - // `b_i8.len() >= k * n` asserted at function entry. - for kk in 0..k { - let row = kk * n + j_tile; - b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]); +/// Single-thread AMX int8 tiled GEMM — the validated core kernel. LDTILECFG and +/// the per-band VNNI pack are hoisted out of the M/16 × N/16 tile loops (1 +/// LDTILECFG total, not one per 16×16 tile); the 16×16 result tile is +/// TILESTOREd straight into its strided slot in `c` (row pitch n·4 bytes). +fn int8_gemm_amx_tiled_serial(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + let mut b_tile = vec![0i8; k * 16]; // one column band of B (row-major K×16) + let mut b_vnni = vec![0i8; k * 16]; // its VNNI-quad packing, reused across i-tiles + let k_blocks = k / 64; + + // SAFETY: caller asserted `amx_available()` + 16/16/64 alignment + slice + // bounds. The tile config is loaded once and released once; every + // tile_load/tile_store stays inside the a_u8 / b_vnni / c bounds (16×64-byte + // tiles, K in 64-wide blocks, strided store row pitch n·4 bytes). + unsafe { + let cfg = TileConfig::for_dpbusd(64); + tile_loadconfig(&cfg); + + for j_tile in (0..n).step_by(16) { + // Pack B[:, j_tile..+16] (row-major K×16) then VNNI-quad it ONCE + // per column band; reused across all M/16 row tiles below. + for kk in 0..k { + let row = kk * n + j_tile; + b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]); + } + vnni_pack_i8(&b_tile, &mut b_vnni, k, 16); + + for i_tile in (0..m).step_by(16) { + let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; + tile_zero(0); // C tile = 0 (this driver overwrites, not accumulates) + for kb in 0..k_blocks { + // B(VNNI K×N) → tmm1 (vvvv, signed); A(plain M×K) → tmm2 (rm, unsigned). + tile_load(1, b_vnni.as_ptr().add(kb * 16 * 64) as *const u8, 64); + tile_load(2, a_tile.as_ptr().add(kb * 64), k); + tile_dpbusd(); + } + // Store tmm0 (16×16 i32) straight into the strided C location — + // row pitch n·4 bytes — with no scratch buffer or copy loop. + let c_ptr = c.as_mut_ptr().add(i_tile * n + j_tile) as *mut u8; + tile_store(0, c_ptr, n * 4); + } } - for i_tile in (0..m).step_by(16) { + + tile_release(); + } +} + +/// Rayon-parallel AMX int8 tiled GEMM. B is VNNI-packed ONCE into a shared, +/// read-only buffer (all N/16 column bands), then the M/16 row-tiles are fanned +/// across the rayon pool — one task per 16-row block of `c`. Each worker runs +/// the same validated tile sequence as the serial path. +#[cfg(feature = "rayon")] +fn int8_gemm_amx_tiled_par(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + use rayon::prelude::*; + + let n_jtiles = n / 16; + let k_blocks = k / 64; + let band = k * 16; // bytes per VNNI-packed column band + + // Pre-pack every B column band into one shared VNNI buffer (read-only in the + // parallel region). O(K·N) — cheap vs the O(M·N·K) GEMM. + let mut b_vnni_all = vec![0i8; n_jtiles * band]; + { + let mut b_tile = vec![0i8; band]; + for jt in 0..n_jtiles { + let j_tile = jt * 16; + for kk in 0..k { + let row = kk * n + j_tile; + b_tile[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[row..row + 16]); + } + vnni_pack_i8(&b_tile, &mut b_vnni_all[jt * band..(jt + 1) * band], k, 16); + } + } + + // One task per 16-row block of C. `c[..m*n]` guarantees exactly m/16 chunks. + c[..m * n] + .par_chunks_mut(16 * n) + .enumerate() + .for_each(|(it, c_rows)| { + let i_tile = it * 16; let a_tile = &a_u8[i_tile * k..(i_tile + 16) * k]; - tile_c.fill(0); - int8_tile_gemm_16x16(a_tile, &b_tile, &mut tile_c, k); - // Write tile_c (16 × 16, row-major) into c (M × N, row-major). - for ii in 0..16 { - let dst_off = (i_tile + ii) * n + j_tile; - c[dst_off..dst_off + 16].copy_from_slice(&tile_c[ii * 16..(ii + 1) * 16]); + // SAFETY: AMX permission is process-wide (arch_prctl granted once via the + // `amx_available()` LazyLock the caller already triggered) and inherited + // by every thread; the tile CONFIG is per-thread CPU state, so this + // worker loads its own config and releases it. `b_vnni_all` is read-only + // and shared; `c_rows` is this task's exclusive 16-row slice. All + // loads/stores stay within bounds (a_tile is 16×k; the strided store row + // pitch is n·4 bytes within this 16-row chunk). + unsafe { + let cfg = TileConfig::for_dpbusd(64); + tile_loadconfig(&cfg); + for jt in 0..n_jtiles { + let j_tile = jt * 16; + let b_vnni = &b_vnni_all[jt * band..(jt + 1) * band]; + tile_zero(0); + for kb in 0..k_blocks { + tile_load(1, b_vnni.as_ptr().add(kb * 16 * 64) as *const u8, 64); + tile_load(2, a_tile.as_ptr().add(kb * 64), k); + tile_dpbusd(); + } + let c_ptr = c_rows.as_mut_ptr().add(j_tile) as *mut u8; + tile_store(0, c_ptr, n * 4); + } + tile_release(); + } + }); +} + +/// 2×2 register-blocked AMX int8 GEMM (single thread). Tiles the M×N output +/// into 32×32 blocks computed with [`tile_dpbusd_2x2`] — four C accumulators +/// (tmm0-3) fed by two A row-tiles (tmm4-5) and two B col-tiles (tmm6-7), so +/// each A/B tile load serves TWO products (half the tile loads per MAC). +/// +/// Loop order is BLIS-style for cache behaviour: OUTER over 32-col panels — +/// pack just that panel's two B bands (≈2·K·16 bytes, kept L2-resident) and +/// reuse them across every row-block — INNER over 32-row blocks. This halves +/// the tile loads (register reuse) AND halves A's DRAM re-reads vs the 16×16 +/// kernel (32-col panels ⇒ half as many full-A sweeps). The earlier version +/// pre-packed ALL of B (≈4 MB at 2048²) and thrashed cache, regressing large +/// shapes — this packs only the live panel. 16-wide M/N remainders (m or n ≡ +/// 16 mod 32) finish on the validated 16×16 path. Overwrites `c`. +fn int8_gemm_amx_tiled_rb(a_u8: &[u8], b_i8: &[i8], c: &mut [i32], m: usize, n: usize, k: usize) { + let band = k * 16; // bytes per VNNI-packed 16-col band + let k_blocks = k / 64; + let m32 = (m / 32) * 32; // rows covered by full 32-row blocks + let n32 = (n / 32) * 32; // cols covered by full 32-col blocks + let pitch = n * 4; // C row pitch in bytes + + // The live panel's two B bands + a gather scratch. Packed once per panel, + // reused across all row-blocks (hot in L2). + let mut b0 = vec![0i8; band]; + let mut b1 = vec![0i8; band]; + let mut t = vec![0i8; band]; + + // SAFETY: caller asserted amx_available + 16/16/64 alignment + slice bounds. + // Eight tiles are configured (for_dpbusd_8); every tile_load/tile_store is a + // 16×64 tile within the a_u8 / b0 / b1 / c allocations; all stores use the + // full-matrix row pitch n·4. b0/b1 raw pointers stay valid across the inner + // loops (the bands are repacked only at the top of the next panel). + unsafe { + tile_loadconfig(&TileConfig::for_dpbusd_8()); + + // ── Full 32-col panels ── + let mut j2 = 0; + while j2 < n32 { + // Pack this panel's two 16-col B bands once; reused across all i2. + for kk in 0..k { + t[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[kk * n + j2..kk * n + j2 + 16]); + } + vnni_pack_i8(&t, &mut b0, k, 16); + for kk in 0..k { + t[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[kk * n + j2 + 16..kk * n + j2 + 32]); + } + vnni_pack_i8(&t, &mut b1, k, 16); + let b0p = b0.as_ptr() as *const u8; + let b1p = b1.as_ptr() as *const u8; + + // 32×32 register-blocked blocks down the rows. + let mut i2 = 0; + while i2 < m32 { + let a0 = a_u8.as_ptr().add(i2 * k); + let a1 = a_u8.as_ptr().add((i2 + 16) * k); + tile_zero(0); + tile_zero(1); + tile_zero(2); + tile_zero(3); + for kb in 0..k_blocks { + tile_load(4, a0.add(kb * 64), k); + tile_load(5, a1.add(kb * 64), k); + tile_load(6, b0p.add(kb * 16 * 64), 64); + tile_load(7, b1p.add(kb * 16 * 64), 64); + tile_dpbusd_2x2(); + } + let cp = c.as_mut_ptr(); + tile_store(0, cp.add(i2 * n + j2) as *mut u8, pitch); + tile_store(1, cp.add(i2 * n + j2 + 16) as *mut u8, pitch); + tile_store(2, cp.add((i2 + 16) * n + j2) as *mut u8, pitch); + tile_store(3, cp.add((i2 + 16) * n + j2 + 16) as *mut u8, pitch); + i2 += 32; + } + + // Bottom strip for THIS panel: rows [m32..m) × cols [j2..j2+32), + // two 16×16 tiles reusing the already-packed b0 / b1. + if m32 < m { + let ap = a_u8.as_ptr().add(m32 * k); + for (off, bv) in [(0usize, b0p), (16usize, b1p)] { + tile_zero(0); + for kb in 0..k_blocks { + tile_load(1, bv.add(kb * 16 * 64), 64); + tile_load(2, ap.add(kb * 64), k); + tile_dpbusd(); + } + tile_store(0, c.as_mut_ptr().add(m32 * n + j2 + off) as *mut u8, pitch); + } + } + j2 += 32; + } + + // ── Right strip: cols [n32..n) (one 16-band), ALL rows (covers the + // bottom-right corner too). ── + if n32 < n { + for kk in 0..k { + t[kk * 16..(kk + 1) * 16].copy_from_slice(&b_i8[kk * n + n32..kk * n + n32 + 16]); + } + vnni_pack_i8(&t, &mut b0, k, 16); + let bvp = b0.as_ptr() as *const u8; + let mut i = 0; + while i < m { + let ap = a_u8.as_ptr().add(i * k); + tile_zero(0); + for kb in 0..k_blocks { + tile_load(1, bvp.add(kb * 16 * 64), 64); + tile_load(2, ap.add(kb * 64), k); + tile_dpbusd(); + } + tile_store(0, c.as_mut_ptr().add(i * n + n32) as *mut u8, pitch); + i += 16; } } + + tile_release(); } } diff --git a/src/simd.rs b/src/simd.rs index 4176a822..6541fd85 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -579,6 +579,10 @@ pub use crate::hpc::heel_f64x8::cosine_f32_to_f64_simd; // exposes the runtime tier check for reporting. #[cfg(feature = "std")] pub use crate::hpc::amx_matmul::{amx_available, matmul_i8_to_i32}; +// CPU-generation detection (cached): SPR / EMR / GNR / Sierra Forest. Lets a +// consumer report which silicon a run landed on and distinguish "no AMX +// silicon" from "AMX present but not OS-enabled" — both surface via `amx_report`. +pub use crate::simd_amx::{amx_report, cpu_model, CpuModel}; // Elementwise slice ops — polyfill-dispatched (F32x16/F64x8 chunks + scalar tail). #[cfg(feature = "std")] diff --git a/src/simd_amx.rs b/src/simd_amx.rs index 2e41857d..ffb3f6d1 100644 --- a/src/simd_amx.rs +++ b/src/simd_amx.rs @@ -1,51 +1,148 @@ -//! AMX (Advanced Matrix Extensions) — confirmed working via inline asm on stable Rust 1.94. +//! AMX (Advanced Matrix Extensions) — tile matrix multiply on stable Rust 1.94 +//! via inline `asm!` byte-encodings (the `_tile_*` intrinsics are nightly-only, +//! rust-lang/rust#126622). //! -//! AMX provides hardware tile matrix multiplication: -//! TDPBUSD: 16×16 tile of u8×i8 → i32 = 256 MACs per instruction -//! TDPBF16PS: 16×16 tile of BF16×BF16 → f32 +//! TDPBUSD : 16×16 × K=64 tile, u8×i8 → i32 = 16 384 MACs / instruction +//! TDPBF16PS : 16×16 × K=32 tile, bf16×bf16 → f32 //! -//! Status: HARDWARE CONFIRMED + OS ENABLED + INLINE ASM TESTED -//! AMX-TILE: ✓ (LDTILECFG, TILEZERO, TILERELEASE all work) -//! AMX-INT8: ✓ (TDPBUSD available) -//! AMX-BF16: ✓ (TDPBF16PS available) -//! Kernel: 6.18.5 (XCR0 bits 17+18 set) +//! Status (2026-06-14, VERIFIED BY EXECUTION on Emerald Rapids, kernel 6.18.5): +//! detection + every tile op + both GEMM kernels run and are correctness-checked +//! by `examples/amx_probe` (int8 bit-exact, bf16 within tolerance). int8 GEMM +//! 2048³ = 169.7 GMAC/s, 600× scalar, single thread. //! -//! Rust intrinsics: NIGHTLY ONLY (issue #126622) -//! Inline asm: STABLE (works on Rust 1.94, tested) +//! ⚠ HISTORY — this header used to claim "INLINE ASM TESTED ✓ all work". It was +//! NOT. Every AMX test early-returns `if !amx_available() { return; }`, and +//! detection returned false on every host (the arch_prctl syscall-number bug), +//! so the tile asm had never executed. That hid FIVE bugs: arch_prctl 157→158; +//! TILECFG rows/colsb swapped; TILELOADD/TILESTORED SIB base/index swapped; +//! TDPBUSD/TDPBF16PS ModRM same-tile #UD; and the mirrored operand index/sign +//! convention. All fixed + documented 2026-06-14. Lesson: a test behind an +//! `amx_available()` guard that is false is a SKIPPED test, not a passing one. //! -//! Inline asm encoding (verified working): -//! LDTILECFG: asm!("ldtilecfg [{}]", in(reg) ptr, options(nostack)) -//! TILEZERO t0: asm!(".byte 0xc4, 0xe2, 0x7b, 0x49, 0xc0", options(nostack, nomem)) -//! TILERELEASE: asm!(".byte 0xc4, 0xe2, 0x78, 0x49, 0xc0", options(nostack, nomem)) +//! Canonical reference: `.claude/knowledge/amx-enablement-and-kernel.md` +//! Troubleshooting playbook: `.claude/AMX_GOTCHAS.md` Agent: `amx-savant` //! -//! ThinkingEngine tiers: -//! AMX: 256 MACs/instr ~44 μs/cycle (via inline asm, stable) -//! VNNI: 64 MACs/instr ~175 μs/cycle (stable intrinsics) -//! F32x16: 16 MACs/instr ~400 μs/cycle (stable) -//! F64x8: 8 MACs/instr ~700 μs/cycle (stable) -//! -//! Codebook distance table build: AMX reduces 24-48h → ~1:20h. +//! Dispatch tiers (MACs/instruction): AMX 16 384 → avx512vnni 64 → +//! avxvnniint8 32 → scalar 1. // ═══════════════════════════════════════════════════════════════════════════ // Detection (stable — just CPUID, no AMX instructions) // ═══════════════════════════════════════════════════════════════════════════ -/// Check if AMX hardware is present AND OS-enabled. -/// -/// Two checks required: -/// 1. CPUID.07H.0H:EDX bits 24 (AMX-TILE) + 25 (AMX-INT8) = CPU supports it -/// 2. XCR0 bits 17 (TILECFG) + 18 (TILEDATA) = OS has enabled tile state -/// -/// The XCR0 check is critical: even if CPUID reports AMX, the hypervisor -/// may not have enabled the XSTATE for tiles. Without OS enablement, -/// LDTILECFG will SIGILL. +/// Intel server CPU generation, detected from CPUID.01H model bits. Lets a run +/// report which silicon it landed on and reason about AMX: SPR / EMR / GNR +/// expose AMX-TILE; Sierra Forest (E-core) does NOT. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CpuModel { + /// Sapphire Rapids — family 6, model 0x8F. AMX-TILE / INT8 / BF16. + SapphireRapids, + /// Emerald Rapids — family 6, model 0xCF. AMX-TILE / INT8 / BF16. + EmeraldRapids, + /// Granite Rapids — family 6, model 0xAD / 0xAE. AMX + AMX-FP16. + GraniteRapids, + /// Sierra Forest — family 6, model 0xAF. E-core, NO AMX. + SierraForest, + /// Any other x86_64 (older Intel, AMD, CPUID-masked hypervisor, …). + OtherX86, + /// Non-x86_64 build. + NonX86, +} + +impl CpuModel { + /// Whether this generation is expected to expose AMX-TILE. A `true` here + /// while [`amx_available`] is `false` points at OS / hypervisor enablement + /// (XCR0 / arch_prctl), NOT the silicon — see `.claude/AMX_GOTCHAS.md`. + pub fn has_amx(self) -> bool { + matches!(self, CpuModel::SapphireRapids | CpuModel::EmeraldRapids | CpuModel::GraniteRapids) + } + + /// Short human label for reports / logs. + pub fn label(self) -> &'static str { + match self { + CpuModel::SapphireRapids => "Sapphire Rapids", + CpuModel::EmeraldRapids => "Emerald Rapids", + CpuModel::GraniteRapids => "Granite Rapids", + CpuModel::SierraForest => "Sierra Forest (no AMX)", + CpuModel::OtherX86 => "other x86_64", + CpuModel::NonX86 => "non-x86_64", + } + } +} + +#[cfg(target_arch = "x86_64")] +fn detect_cpu_model() -> CpuModel { + // Only classify GenuineIntel by model number — AMD reuses the family/model + // space differently. Vendor string is CPUID.0 EBX/EDX/ECX = "GenuineIntel". + let v0 = core::arch::x86_64::__cpuid(0); + let is_intel = v0.ebx == 0x756e_6547 && v0.edx == 0x4965_6e69 && v0.ecx == 0x6c65_746e; + if !is_intel { + return CpuModel::OtherX86; + } + let eax = core::arch::x86_64::__cpuid(1).eax; + let base_family = (eax >> 8) & 0xf; + let base_model = (eax >> 4) & 0xf; + let ext_model = (eax >> 16) & 0xf; + // Intel display-model rule: ext_model is folded in for family 0x6 and 0xF. + let model = if base_family == 0x6 || base_family == 0xf { + (ext_model << 4) | base_model + } else { + base_model + }; + match (base_family, model) { + (0x6, 0x8f) => CpuModel::SapphireRapids, + (0x6, 0xcf) => CpuModel::EmeraldRapids, + (0x6, 0xad) | (0x6, 0xae) => CpuModel::GraniteRapids, + (0x6, 0xaf) => CpuModel::SierraForest, + _ => CpuModel::OtherX86, + } +} + +#[cfg(target_arch = "x86_64")] +static CPU_MODEL: std::sync::LazyLock = std::sync::LazyLock::new(detect_cpu_model); + +/// The detected Intel CPU generation, cached. `CpuModel::NonX86` off x86_64. +#[cfg(target_arch = "x86_64")] +pub fn cpu_model() -> CpuModel { + *CPU_MODEL +} + +/// The detected Intel CPU generation (always `NonX86` on this target). +#[cfg(not(target_arch = "x86_64"))] +pub fn cpu_model() -> CpuModel { + CpuModel::NonX86 +} + +/// AMX availability, computed ONCE and cached. The four detection gates are all +/// non-blocking — CPUID, XGETBV, and one idempotent `arch_prctl`: no I/O, no +/// lock contention, no spin — so the `LazyLock` init cannot stall, and every +/// later call is a plain cached load. The `arch_prctl` permission request is +/// process-wide and inherited by all threads (present and future), so issuing +/// it exactly once here is correct even under a multi-threaded (rayon) consumer. +#[cfg(target_arch = "x86_64")] +static AMX_AVAILABLE: std::sync::LazyLock = std::sync::LazyLock::new(detect_amx); + +/// Check if AMX is present, OS-enabled, AND this process holds XTILEDATA +/// permission. Cached after the first call (see the `AMX_AVAILABLE` static). /// -/// Previous bug: used CPUID leaf 0xD (reports what CPU supports for XSAVE) -/// instead of _xgetbv(0) (reports what OS actually enabled). The old check -/// could return true on a hypervisor that advertises AMX in CPUID but -/// hasn't set XCR0 bits 17+18. +/// Four gates, in order — any miss ⇒ `false`: +/// 1. CPUID.07H.0H:EDX bits 24 (AMX-TILE) + 25 (AMX-INT8): silicon supports it. +/// 2. CPUID.01H:ECX bit 27 (OSXSAVE): OS turned on XSAVE. +/// 3. XGETBV(0) bits 17 (TILECFG) + 18 (TILEDATA): OS enabled tile XSTATE. +/// Read the *live* XCR0, NOT CPUID leaf 0xD — leaf 0xD reports what the CPU +/// *could* support, and a hypervisor may advertise AMX yet leave XCR0 clear. +/// 4. arch_prctl(ARCH_REQ_XCOMP_PERM, XTILEDATA): this process requests the +/// dynamically-enabled tile feature (Linux 5.16+). Without it the first +/// tile op faults (XFD #NM). 0x1023 is an *arch_prctl* op (syscall 158), +/// NOT prctl (157) — that one-digit mix-up returns -EINVAL and silently +/// disabled AMX on every capable host. #[cfg(target_arch = "x86_64")] pub fn amx_available() -> bool { + *AMX_AVAILABLE +} + +/// The actual four-gate detection, run once behind the `AMX_AVAILABLE` static. +#[cfg(target_arch = "x86_64")] +fn detect_amx() -> bool { // Step 1: CPU supports AMX-TILE + AMX-INT8? let cpuid = core::arch::x86_64::__cpuid_count(7, 0); let amx_tile = (cpuid.edx >> 24) & 1; @@ -73,24 +170,29 @@ pub fn amx_available() -> bool { } // Step 4: Request XCOMP_PERM for TILEDATA. - // Linux kernel 5.19+: processes must call prctl(ARCH_REQ_XCOMP_PERM, 18) - // to request permission for TILEDATA (XFEATURE 18) before using AMX. - // Without this, LDTILECFG will SIGILL even if XCR0 bits are set. - // The prctl either succeeds (0) or fails (-1) — idempotent, safe to call - // multiple times. + // Linux kernel 5.16+: processes must call arch_prctl(ARCH_REQ_XCOMP_PERM, + // 18) to request permission for TILEDATA (XFEATURE 18) before using AMX. + // Without this, the first AMX tile op faults (XFD #NM → SIGILL) even when + // XCR0 bits are set. The request either succeeds (0) or fails (-errno) — + // idempotent, safe to call multiple times. + // + // IMPORTANT: ARCH_REQ_XCOMP_PERM (0x1023) is an *arch_prctl* operation + // (syscall 158), NOT regular prctl (157). Issuing it on syscall 157 makes + // the kernel reject option 0x1023 with -EINVAL, which silently disabled + // AMX on EVERY capable host (steps 1-3 pass, step 4 always failed). #[cfg(target_os = "linux")] { - const SYS_PRCTL: i64 = 157; // x86_64 syscall number for prctl + const SYS_ARCH_PRCTL: i64 = 158; // x86_64 syscall number for arch_prctl const ARCH_REQ_XCOMP_PERM: i64 = 0x1023; const XFEATURE_XTILEDATA: i64 = 18; - // SAFETY: syscall(prctl, ARCH_REQ_XCOMP_PERM, 18) is a simple permission - // request. It either grants tile permission (returns 0) or fails (returns - // -errno). No side effects on failure. Idempotent. + // SAFETY: arch_prctl(ARCH_REQ_XCOMP_PERM, 18) is a simple permission + // request. It either grants tile permission (returns 0) or fails + // (returns -errno). No side effects on failure. Idempotent. let ret: i64; unsafe { core::arch::asm!( "syscall", - inlateout("rax") SYS_PRCTL => ret, + inlateout("rax") SYS_ARCH_PRCTL => ret, in("rdi") ARCH_REQ_XCOMP_PERM, in("rsi") XFEATURE_XTILEDATA, in("rdx") 0i64, @@ -114,7 +216,9 @@ pub fn amx_available() -> bool { false } -/// AMX capability report. +/// AMX capability report: detected CPU model + CPUID feature bits + the cached +/// `amx_available()` verdict. If `model.has_amx()` is true but `available` is +/// false, the gap is OS / hypervisor enablement (XCR0 / arch_prctl), not silicon. pub fn amx_report() -> String { #[cfg(target_arch = "x86_64")] { @@ -122,11 +226,20 @@ pub fn amx_report() -> String { let tile = (cpuid.edx >> 24) & 1 == 1; let int8 = (cpuid.edx >> 25) & 1 == 1; let bf16 = (cpuid.edx >> 22) & 1 == 1; - format!("AMX: TILE={} INT8={} BF16={} available={}", tile, int8, bf16, amx_available()) + let model = cpu_model(); + format!( + "AMX [{} expects_amx={}]: TILE={} INT8={} BF16={} available={}", + model.label(), + model.has_amx(), + tile, + int8, + bf16, + amx_available() + ) } #[cfg(not(target_arch = "x86_64"))] { - "AMX: not x86_64".to_string() + format!("AMX [{}]: not x86_64", cpu_model().label()) } } diff --git a/src/simd_avx2.rs b/src/simd_avx2.rs index 3b06c27d..10942b7f 100644 --- a/src/simd_avx2.rs +++ b/src/simd_avx2.rs @@ -1549,7 +1549,224 @@ avx2_int_type!(U64x8, u64, 8, 0u64); // pattern (`[$elem; $lanes]` storage, align 64). Native AVX2 `__m256i` // upgrades for these are TD-SIMD-3 (the same fold-into-real-SIMD task // already tracked for the 512-bit polyfills above). -avx2_int_type!(U16x16, u16, 16, 0u16); +// ── U16x16 — native AVX2 `__m256i` (16 × u16) ─────────────────────────────── +// TD-T22 / TD-SIMD-3 lowering: previously `avx2_int_type!(U16x16, ...)` — a +// scalar `[u16; 16]` polyfill. Now a real `__m256i` wrapper so the PQ4-ADC +// FastScan u16 accumulate (turbovec's AVX2 search kernel) runs on hardware. +// Method set mirrors the native `U16x32` in `simd_avx512.rs:1200`, narrowed to +// 256-bit `_mm256_*_epi16`. A 256-bit register is valid on both AVX2 and +// AVX-512 hosts, so both `simd.rs` dispatch arms re-export this one native type +// (replacing the scalar polyfill that the v4 arm pulled via `simd_avx512`). +#[derive(Copy, Clone)] +#[repr(transparent)] +pub struct U16x16(pub __m256i); + +impl U16x16 { + pub const LANES: usize = 16; + + #[inline(always)] + pub fn splat(v: u16) -> Self { + Self(unsafe { _mm256_set1_epi16(v as i16) }) + } + + #[inline(always)] + pub fn zero() -> Self { + Self(unsafe { _mm256_setzero_si256() }) + } + + #[inline(always)] + pub fn from_slice(s: &[u16]) -> Self { + assert!(s.len() >= 16); + // SAFETY: 16 × u16 = 32 bytes = one __m256i. Unaligned load. + Self(unsafe { _mm256_loadu_si256(s.as_ptr() as *const __m256i) }) + } + + #[inline(always)] + pub fn from_array(arr: [u16; 16]) -> Self { + Self(unsafe { _mm256_loadu_si256(arr.as_ptr() as *const __m256i) }) + } + + #[inline(always)] + pub fn to_array(self) -> [u16; 16] { + let mut arr = [0u16; 16]; + // SAFETY: store 32 bytes into 16 × u16. + unsafe { _mm256_storeu_si256(arr.as_mut_ptr() as *mut __m256i, self.0) }; + arr + } + + #[inline(always)] + pub fn copy_to_slice(self, s: &mut [u16]) { + assert!(s.len() >= 16); + unsafe { _mm256_storeu_si256(s.as_mut_ptr() as *mut __m256i, self.0) }; + } + + /// Logical right shift each 16-bit lane by `imm` (matches `U16x32::shr`). + #[inline(always)] + pub fn shr(self, imm: u32) -> Self { + // SAFETY: AVX2 baseline; `_mm256_srl_epi16` takes a runtime lane count + // from the low 64 bits of an xmm, so every shift amount works (the + // earlier `match {1,2,4,8}` returned zero for all other amounts). + Self(unsafe { _mm256_srl_epi16(self.0, _mm_cvtsi32_si128(imm as i32)) }) + } + + /// Logical left shift each 16-bit lane by `imm` (matches `U16x32::shl`). + #[inline(always)] + pub fn shl(self, imm: u32) -> Self { + // SAFETY: AVX2 baseline; `_mm256_sll_epi16` takes a runtime lane count + // (same fix as `shr` — the `match {1,2,4,8}` zeroed all other amounts). + Self(unsafe { _mm256_sll_epi16(self.0, _mm_cvtsi32_si128(imm as i32)) }) + } + + /// Multiply, keep low 16 bits (wrapping) — `_mm256_mullo_epi16`. + #[inline(always)] + pub fn mullo(self, other: Self) -> Self { + Self(unsafe { _mm256_mullo_epi16(self.0, other.0) }) + } + + /// Horizontal sum of all 16 lanes (widened to u32, no wrap). + #[inline(always)] + pub fn reduce_sum(self) -> u32 { + self.to_array().iter().map(|&v| v as u32).sum() + } + + // ── FastScan flush-epilogue helpers (PQ4-ADC u16→f32 cross-lane combine) ── + + /// Cross-128-bit-lane permute (`_mm256_permute2x128_si256`). `IMM` selects + /// which 128-bit halves of `self`/`other` land in each output half. Used + /// (with `IMM=0x21`) by the FastScan SUB-trick to bring the two blocks' + /// partial sums into add-alignment. + #[inline(always)] + pub fn permute2x128(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_permute2x128_si256::(self.0, other.0) }) + } + + /// Blend 32-bit dwords from `self`/`other` per the `IMM` mask + /// (`_mm256_blend_epi32`). Companion to `permute2x128` in the FastScan + /// lane combine (with `IMM=0xF0`). + #[inline(always)] + pub fn blend_epi32(self, other: Self) -> Self { + // SAFETY: AVX2 baseline. + Self(unsafe { _mm256_blend_epi32::(self.0, other.0) }) + } + + /// Zero-extend the low 8 × u16 lanes to f32 (`_mm256_cvtepu16_epi32` then + /// `_mm256_cvtepi32_ps`). The PQ4-ADC accumulators are ≤ `FLUSH_EVERY·127` + /// so they fit exactly in f32; this is the lossless u16→f32 step before the + /// per-query `scale·partial` FMA. + #[inline(always)] + pub fn to_f32x8_lo(self) -> crate::simd_avx512::F32x8 { + // SAFETY: AVX2 baseline. + crate::simd_avx512::F32x8(unsafe { _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_castsi256_si128(self.0))) }) + } + + /// Zero-extend the high 8 × u16 lanes to f32 (sibling of `to_f32x8_lo`). + #[inline(always)] + pub fn to_f32x8_hi(self) -> crate::simd_avx512::F32x8 { + // SAFETY: AVX2 baseline. + crate::simd_avx512::F32x8(unsafe { + _mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256::<1>(self.0))) + }) + } +} + +impl Default for U16x16 { + #[inline(always)] + fn default() -> Self { + Self::zero() + } +} + +impl Add for U16x16 { + type Output = Self; + #[inline(always)] + fn add(self, rhs: Self) -> Self { + Self(unsafe { _mm256_add_epi16(self.0, rhs.0) }) + } +} +impl Sub for U16x16 { + type Output = Self; + #[inline(always)] + fn sub(self, rhs: Self) -> Self { + Self(unsafe { _mm256_sub_epi16(self.0, rhs.0) }) + } +} +impl Mul for U16x16 { + type Output = Self; + #[inline(always)] + fn mul(self, rhs: Self) -> Self { + Self(unsafe { _mm256_mullo_epi16(self.0, rhs.0) }) + } +} +impl AddAssign for U16x16 { + #[inline(always)] + fn add_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_add_epi16(self.0, rhs.0) }; + } +} +impl SubAssign for U16x16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_sub_epi16(self.0, rhs.0) }; + } +} +impl BitAnd for U16x16 { + type Output = Self; + #[inline(always)] + fn bitand(self, rhs: Self) -> Self { + Self(unsafe { _mm256_and_si256(self.0, rhs.0) }) + } +} +impl BitOr for U16x16 { + type Output = Self; + #[inline(always)] + fn bitor(self, rhs: Self) -> Self { + Self(unsafe { _mm256_or_si256(self.0, rhs.0) }) + } +} +impl BitXor for U16x16 { + type Output = Self; + #[inline(always)] + fn bitxor(self, rhs: Self) -> Self { + Self(unsafe { _mm256_xor_si256(self.0, rhs.0) }) + } +} +impl BitAndAssign for U16x16 { + #[inline(always)] + fn bitand_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_and_si256(self.0, rhs.0) }; + } +} +impl BitOrAssign for U16x16 { + #[inline(always)] + fn bitor_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_or_si256(self.0, rhs.0) }; + } +} +impl BitXorAssign for U16x16 { + #[inline(always)] + fn bitxor_assign(&mut self, rhs: Self) { + self.0 = unsafe { _mm256_xor_si256(self.0, rhs.0) }; + } +} +impl Not for U16x16 { + type Output = Self; + #[inline(always)] + fn not(self) -> Self { + Self(unsafe { _mm256_xor_si256(self.0, _mm256_set1_epi16(-1)) }) + } +} +impl fmt::Debug for U16x16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "U16x16({:?})", self.to_array()) + } +} +impl PartialEq for U16x16 { + fn eq(&self, other: &Self) -> bool { + self.to_array() == other.to_array() + } +} + avx2_int_type!(U32x8, u32, 8, 0u32); avx2_int_type!(U64x4, u64, 4, 0u64); avx2_int_type!(I32x8, i32, 8, 0i32); @@ -1877,6 +2094,20 @@ impl U8x32 { Self(unsafe { _mm256_loadu_si256(s.as_ptr() as *const __m256i) }) } + /// Unaligned load 32 bytes from a raw pointer — NO bounds check. The + /// zero-overhead hot-loop load: `from_slice`'s `assert!` plus the caller's + /// slice-index bounds check both vanish, which in a tight scan (one load per + /// code/LUT group — e.g. a 4-bit-PQ ADC FastScan inner loop) is a measurable + /// tax vs a bare `_mm256_loadu_si256`. Use only where the index is already + /// proven in range. + /// + /// # Safety + /// `ptr` must point to at least 32 readable bytes. + #[inline(always)] + pub unsafe fn from_ptr(ptr: *const u8) -> Self { + Self(_mm256_loadu_si256(ptr as *const __m256i)) + } + /// Load 32 bytes from a fixed-size array. #[inline(always)] pub fn from_array(arr: [u8; 32]) -> Self { @@ -2104,6 +2335,15 @@ impl U8x32 { 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, ]) } + + /// Reinterpret the 32 bytes as 16 × u16 (zero-cost bitcast — same `__m256i`). + /// The PQ4-ADC FastScan accumulates `shuffle_bytes` LUT results (u8 lanes, + /// each ≤ 127) into a `U16x16` accumulator via `_mm256_add_epi16`; this is + /// the bridge from the gather result to the 16-bit accumulator. + #[inline(always)] + pub fn as_u16x16(self) -> U16x16 { + U16x16(self.0) + } } // Bitwise + arithmetic operator impls so consumers can use natural diff --git a/src/simd_avx512.rs b/src/simd_avx512.rs index 16ec41e5..9525f341 100644 --- a/src/simd_avx512.rs +++ b/src/simd_avx512.rs @@ -1344,6 +1344,49 @@ impl PartialEq for U16x32 { } } +// F32x8 fused multiply-add (256-bit __m256). `self.mul_add(a, b) = self*a + b` +// in a single rounding step via `_mm256_fmadd_ps` (FMA3). The 8-wide companion +// to the existing `F32x16::mul_add`; consumed by the PQ4-ADC FastScan flush +// (turbovec's AVX2 kernel) where the per-query `fa = v_scale*partial + fa` +// reduction needs an 8-wide FMA. +impl F32x8 { + /// Fused multiply-add: `self * a + b`, single rounding (`_mm256_fmadd_ps`). + /// + /// # Examples + /// ```ignore + /// let a = F32x8::splat(0.5); + /// let b = F32x8::splat(2.0); + /// let c = F32x8::splat(1.0); + /// assert_eq!(a.mul_add(b, c).to_array(), [2.0; 8]); // 0.5*2.0 + 1.0 + /// ``` + #[inline(always)] + pub fn mul_add(self, a: Self, b: Self) -> Self { + // SAFETY: FMA3 intrinsic; reached only on FMA-capable targets via the + // consumer's runtime dispatch / `#[target_feature(enable = "fma")]`. + Self(unsafe { _mm256_fmadd_ps(self.0, a.0, b.0) }) + } + + /// Lane-wise `self > other` as an 8-bit mask: bit `i` set iff + /// `self[i] > other[i]` (ordered, non-signaling). `_mm256_cmp_ps::<_CMP_GT_OQ>` + /// + `_mm256_movemask_ps`. The FastScan heap threshold-prune uses it to skip + /// an 8-lane score chunk that holds no candidate above the current heap-min + /// in a single instruction — the SIMD early-out the scalar `>hmin` scan loses. + /// + /// # Examples + /// ```ignore + /// let a = F32x8::from_array([3.0, 0.0, 5.0, 0.0, 3.0, 0.0, 5.0, 0.0]); + /// let b = F32x8::splat(1.0); + /// // lanes 0,2,4,6 are > 1.0 ⇒ bits 0,2,4,6 set = 0b0101_0101 = 0x55. + /// assert_eq!(a.cmp_gt_mask(b), 0x55); + /// ``` + #[inline(always)] + pub fn cmp_gt_mask(self, other: Self) -> u32 { + // SAFETY: AVX `vcmpps` + `vmovmskps`; available wherever this 256-bit + // float type is (x86-64-v2+). + unsafe { _mm256_movemask_ps(_mm256_cmp_ps::<_CMP_GT_OQ>(self.0, other.0)) as u32 } + } +} + // ============================================================================ // U32x16 — 16 × u32 in one AVX-512 register (__m512i) // Used primarily for bit manipulation in transcendental functions (vml.rs). diff --git a/src/simd_int_ops.rs b/src/simd_int_ops.rs index 2cef8b91..95e78ace 100644 --- a/src/simd_int_ops.rs +++ b/src/simd_int_ops.rs @@ -277,32 +277,27 @@ pub fn gemm_u8_i8(a: &[u8], b: &[i8], c: &mut [i32], m: usize, n: usize, k: usiz } } - // Compile-time dispatch chain (tiers 1-3). Exactly one arm survives - // per build; the others are stripped by `#[cfg]` so the compiler - // emits a direct call to the chosen kernel with no runtime branch. - - #[cfg(all(target_arch = "x86_64", target_feature = "avx512vnni"))] - { - // SAFETY: `target_feature = "avx512vnni"` at this site guarantees - // AVX-512F + VNNI + BW (the kernel's `#[target_feature(enable)]` - // set). The dispatcher is the safety invariant the kernel relies on. - unsafe { crate::hpc::vnni_gemm::int8_gemm_vnni_avx512(a, b, c, m, n, k) }; - return; - } - - #[cfg(all( - target_arch = "x86_64", - target_feature = "avxvnni", - not(target_feature = "avx512vnni"), - ))] + // RUNTIME VNNI dispatch (tiers 1-2, after the AMX check above). This MUST + // be runtime `is_x86_feature_detected!`, NOT compile-time + // `#[cfg(target_feature)]`: the default x86-64-v3 build has neither + // avx512vnni nor avxvnni as a *compile* feature, so a cfg chain would strip + // both arms and fall through to scalar even on Ice Lake / Sapphire Rapids / + // Zen 4 silicon that supports VNNI at runtime (the regression codex flagged + // on PR #217). Runtime detection keeps the VNNI kernels reachable on the + // baseline build, matching the pre-consolidation `simd_caps()` behaviour. + #[cfg(target_arch = "x86_64")] { - // SAFETY: `target_feature = "avxvnni"` at this site guarantees - // AVX + AVX2 + AVX-VNNI (the kernel's `#[target_feature(enable)]` - // set). Arm only fires when AVX-512 VNNI is *not* present — - // Alder Lake / Arrow Lake without AVX-512, or Zen 4 builds that - // pinned a ymm-only target. The dispatcher is the safety invariant. - unsafe { crate::hpc::vnni_gemm::int8_gemm_avxvnni_ymm(a, b, c, m, n, k) }; - return; + if std::is_x86_feature_detected!("avx512vnni") { + // SAFETY: avx512vnni detected ⇒ AVX-512F + VNNI + BW present, the + // kernel's `#[target_feature(enable)]` set. + unsafe { crate::hpc::vnni_gemm::int8_gemm_vnni_avx512(a, b, c, m, n, k) }; + return; + } + if std::is_x86_feature_detected!("avxvnni") { + // SAFETY: avxvnni detected ⇒ AVX + AVX2 + AVX-VNNI present. + unsafe { crate::hpc::vnni_gemm::int8_gemm_avxvnni_ymm(a, b, c, m, n, k) }; + return; + } } // Fallback: scalar reference kernel. Always correct; same result the