From c3fe89901f8841671968773638b37731a2492067 Mon Sep 17 00:00:00 2001 From: Gabriel Barreto Date: Mon, 18 May 2026 20:51:49 -0300 Subject: [PATCH 1/2] Stats now prints cache hits --- Ix/Aiur/Semantics/BytecodeFfi.lean | 12 ++++---- Ix/Aiur/Statistics.lean | 36 ++++++++++++++++-------- src/ffi/aiur/protocol.rs | 44 ++++++++++++++++++++---------- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index ecfcd555..183d5d59 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -56,16 +56,18 @@ namespace Bytecode.Toplevel private opaque execute' : @& Bytecode.Toplevel → @& Bytecode.FunIdx → @& Array G → (ioData : @& Array G) → (ioMap : @& Array (Array G × IOKeyInfo)) → - Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array Nat) + Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat)) /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, returning the raw output of the function, the updated `IOBuffer`, and an array -of query counts (one per function circuit, then one per memory size). Returns -`Except.error msg` when execution fails (e.g. `assert_eq!` mismatch from a -typechecker rejecting a constant), so callers can recover instead of crashing. -/ +of per-circuit `(uniqueRows, totalHits)` pairs (one per function circuit, then +one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum +of query multiplicities. Returns `Except.error msg` when execution fails +(e.g. `assert_eq!` mismatch from a typechecker rejecting a constant), so +callers can recover instead of crashing. -/ def execute (toplevel : @& Bytecode.Toplevel) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : - Except String (Array G × IOBuffer × Array Nat) := + Except String (Array G × IOBuffer × Array (Nat × Nat)) := let ioData := ioBuffer.data let ioMap := ioBuffer.map match execute' toplevel funIdx args ioData ioMap.toArray with diff --git a/Ix/Aiur/Statistics.lean b/Ix/Aiur/Statistics.lean index f98d85d7..7e25c283 100644 --- a/Ix/Aiur/Statistics.lean +++ b/Ix/Aiur/Statistics.lean @@ -4,8 +4,9 @@ public import Ix.Aiur.Compiler /-! Circuit statistics for Aiur executions. -Given a `CompiledToplevel` and the query counts returned by `execute`, computes -per-circuit width, height (the query count), and the FFT cost +Given a `CompiledToplevel` and the per-circuit `(uniqueRows, totalHits)` pairs +returned by `execute`, computes per-circuit width, height (the unique-row +count), cache hits (sum of multiplicities), and the FFT cost (width × height × log2(height)) for every constrained function and memory circuit. The FFT cost is a Float to capture small changes continuously. Results are sorted by FFT cost in decreasing order and printed with cumulative @@ -20,11 +21,14 @@ structure CircuitStats where name : String width : Nat height : Nat + cacheHits : Nat fftCost : Float structure ExecutionStats where circuits : Array CircuitStats totalFftCost : Float + totalUncachedFftCost : Float + totalCacheHits : Nat -- Clamp to at least 2 so that log2 is at least 1, avoiding zero cost for h = 1 def fftCost (w h : Nat) : Float := @@ -34,7 +38,7 @@ def fftCost (w h : Nat) : Float := let hf := h.toFloat wf * hf * (max hf 2.0).log2 -def computeStats (compiled : CompiledToplevel) (queryCounts : Array Nat) : +def computeStats (compiled : CompiledToplevel) (queryCounts : Array (Nat × Nat)) : ExecutionStats := let t := compiled.bytecode -- Invert nameMap to get FunIdx → String @@ -46,18 +50,22 @@ def computeStats (compiled : CompiledToplevel) (queryCounts : Array Nat) : for i in [:nAllFuns] do if t.functions[i]!.constrained then let w := t.functions[i]!.layout.totalWidth - let h := queryCounts[i]! + let (h, totalMults) := queryCounts[i]! + let hits := totalMults - h let name := reverseMap[i]?.getD s!"" - acc := acc.push { name, width := w, height := h, fftCost := fftCost w h : CircuitStats } + acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } acc let memoryCircuits := t.memorySizes.mapIdx fun i size => let w := size + 11 - let h := queryCounts[nAllFuns + i]! + let (h, totalMults) := queryCounts[nAllFuns + i]! + let hits := totalMults - h { name := s!"memory[{size}]", - width := w, height := h, fftCost := fftCost w h : CircuitStats } + width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost) let totalFftCost := circuits.foldl (· + ·.fftCost) 0.0 - { circuits, totalFftCost } + let totalUncachedFftCost := circuits.foldl (fun acc cs => acc + fftCost cs.width (cs.height + cs.cacheHits)) 0.0 + let totalCacheHits := circuits.foldl (· + ·.cacheHits) 0 + { circuits, totalFftCost, totalUncachedFftCost, totalCacheHits } private def padLeft (s : String) (n : Nat) : String := let pad := n - s.length @@ -81,28 +89,34 @@ def printStats (stats : ExecutionStats) : IO Unit := do let wName := stats.circuits.foldl (fun m cs => Nat.max m cs.name.length) 4 let wWidth := stats.circuits.foldl (fun m cs => Nat.max m (toString cs.width).length) 5 let wHeight := stats.circuits.foldl (fun m cs => Nat.max m (toString cs.height).length) 6 + let wHits := stats.circuits.foldl (fun m cs => Nat.max m (toString cs.cacheHits).length) 5 let formatCost (f : Float) : String := let n := f.round.toUInt64.toNat toString n let wFftCost := stats.circuits.foldl (fun m cs => Nat.max m (formatCost cs.fftCost).length) 7 let wPct := 7 let wCum := 7 - let totalW := wName + 1 + wWidth + 1 + wHeight + 1 + wFftCost + 1 + wPct + 1 + wCum + let totalW := wName + 1 + wWidth + 1 + wHeight + 1 + wHits + 1 + wFftCost + 1 + wPct + 1 + wCum let totalWidth := stats.circuits.foldl (· + ·.width) 0 + let savedPct := + if stats.totalUncachedFftCost == 0.0 then "0.00%" + else formatPercent (stats.totalUncachedFftCost - stats.totalFftCost) stats.totalUncachedFftCost let sep := String.ofList (List.replicate totalW '-') IO.println "=== Circuit Statistics ===" IO.println s!"Circuits: {stats.circuits.size}" IO.println s!"Total width: {totalWidth}" IO.println s!"Total FFT cost: {formatCost stats.totalFftCost}" + IO.println s!"Total cache hits: {stats.totalCacheHits}" + IO.println s!"Total saved cost: {savedPct}" IO.println sep - IO.println s!"{padRight "Name" wName} {padLeft "Width" wWidth} {padLeft "Height" wHeight} {padLeft "FFT cost" wFftCost} {padLeft "%" wPct} {padLeft "%++" wCum}" + IO.println s!"{padRight "Name" wName} {padLeft "Width" wWidth} {padLeft "Height" wHeight} {padLeft "Hits" wHits} {padLeft "FFT cost" wFftCost} {padLeft "%" wPct} {padLeft "%++" wCum}" IO.println sep let mut cumFftCost : Float := 0.0 for cs in stats.circuits do cumFftCost := cumFftCost + cs.fftCost let pct := formatPercent cs.fftCost stats.totalFftCost let cum := formatPercent cumFftCost stats.totalFftCost - IO.println s!"{padRight cs.name wName} {padLeft (toString cs.width) wWidth} {padLeft (toString cs.height) wHeight} {padLeft (formatCost cs.fftCost) wFftCost} {padLeft pct wPct} {padLeft cum wCum}" + IO.println s!"{padRight cs.name wName} {padLeft (toString cs.width) wWidth} {padLeft (toString cs.height) wHeight} {padLeft (toString cs.cacheHits) wHits} {padLeft (formatCost cs.fftCost) wFftCost} {padLeft pct wPct} {padLeft cum wCum}" end Aiur diff --git a/src/ffi/aiur/protocol.rs b/src/ffi/aiur/protocol.rs index c6a1f9fe..772b18ef 100644 --- a/src/ffi/aiur/protocol.rs +++ b/src/ffi/aiur/protocol.rs @@ -1,5 +1,5 @@ use multi_stark::{ - p3_field::{Field, PrimeField64}, + p3_field::PrimeField64, prover::Proof, types::{CommitmentParameters, FriParameters}, }; @@ -87,7 +87,9 @@ extern "C" fn rs_aiur_system_verify( } /// `Bytecode.Toplevel.execute`: runs execution only (no proof) and returns -/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array Nat)`. +/// `Except String (Array G × (Array G × Array (Array G × IOKeyInfo)) × Array (Nat × Nat))`. +/// The trailing `Array (Nat × Nat)` is one `(uniqueRows, totalHits)` pair per +/// function circuit followed by one per memory size. /// On execution failure (e.g. assertion mismatch from a typechecker /// rejecting a constant), returns `Except.error msg` instead of panicking /// — letting Lean test runners (`KernelArena.lean`) classify failures. @@ -112,31 +114,45 @@ extern "C" fn rs_aiur_toplevel_execute( Err(err) => return LeanExcept::error_string(&err.to_string()), }; - // Build query counts: one per function, then one per memory size - let mut query_counts: Vec = Vec::with_capacity( + // Build per-circuit (unique_rows, total_hits) pairs: + // one per function, then one per memory size. `unique_rows` is the trace + // height (number of distinct queries); `total_hits` is the sum of + // multiplicities (how often those rows were hit). + let mut query_counts: Vec<(usize, usize)> = Vec::with_capacity( query_record.function_queries.len() + toplevel.memory_sizes.len(), ); + let summarize = |q: &crate::aiur::execute::QueryMap| -> (usize, usize) { + let mut rows = 0usize; + let mut hits = 0usize; + for (_, res) in q.iter() { + let m = usize::try_from(res.multiplicity.as_canonical_u64()) + .expect("multiplicity exceeds usize"); + if m != 0 { + rows += 1; + hits += m; + } + } + (rows, hits) + }; for queries in &query_record.function_queries { - let count = - queries.iter().filter(|(_, res)| !res.multiplicity.is_zero()).count(); - query_counts.push(count); + query_counts.push(summarize(queries)); } for size in &toplevel.memory_sizes { - let count = query_record.memory_queries.get(size).map_or(0, |q| { - q.iter().filter(|(_, res)| !res.multiplicity.is_zero()).count() - }); - query_counts.push(count); + let pair = query_record.memory_queries.get(size).map_or((0, 0), summarize); + query_counts.push(pair); } let lean_query_counts = { let arr = LeanArray::alloc(query_counts.len()); - for (i, &count) in query_counts.iter().enumerate() { - arr.set(i, LeanOwned::box_usize(count)); + for (i, &(rows, hits)) in query_counts.iter().enumerate() { + let pair = + LeanProd::new(LeanOwned::box_usize(rows), LeanOwned::box_usize(hits)); + arr.set(i, pair); } arr }; let lean_io = build_lean_io_buffer(&io_buffer); - // (Array G, (Array G × Array (Array G × IOKeyInfo), Array Nat)) + // (Array G, (Array G × Array (Array G × IOKeyInfo), Array (Nat × Nat))) let io_counts = LeanProd::new(lean_io, lean_query_counts); let result = LeanProd::new(build_g_array(&output), io_counts); LeanExcept::ok(result) From cdc2185dea54e05d7c5584e51c7fcff3a6824844 Mon Sep 17 00:00:00 2001 From: Arthur Paulino Date: Tue, 19 May 2026 04:38:13 -0700 Subject: [PATCH 2/2] =?UTF-8?q?Use=20QueryCount=20struct=20instead=20of=20?= =?UTF-8?q?(Nat=20=C3=97=20Nat)=20pair?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit execute now returns Array QueryCount with named uniqueRows/totalHits fields instead of opaque (Nat × Nat) pairs. computeStats consumes the struct directly. --- Ix/Aiur/Semantics/BytecodeFfi.lean | 18 +++++++++++++----- Ix/Aiur/Statistics.lean | 15 +++++++++------ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/Ix/Aiur/Semantics/BytecodeFfi.lean b/Ix/Aiur/Semantics/BytecodeFfi.lean index 183d5d59..08a99894 100644 --- a/Ix/Aiur/Semantics/BytecodeFfi.lean +++ b/Ix/Aiur/Semantics/BytecodeFfi.lean @@ -50,6 +50,15 @@ instance : BEq IOBuffer where -- via `Std.HashMap.beq_iff_equiv` + `Std.HashMap.Equiv.{refl,symm,trans}`, -- bypassing the need for `LawfulBEq` on the outer `IOBuffer`. +/-- Per-circuit query counts for one circuit (one per function circuit, then +one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum +of query multiplicities. The difference `totalHits - uniqueRows` is the number +of cache hits. -/ +structure QueryCount where + uniqueRows : Nat + totalHits : Nat + deriving Inhabited + namespace Bytecode.Toplevel @[extern "rs_aiur_toplevel_execute"] @@ -60,20 +69,19 @@ private opaque execute' : @& Bytecode.Toplevel → /-- Executes the bytecode function `funIdx` with the given `args` and `ioBuffer`, returning the raw output of the function, the updated `IOBuffer`, and an array -of per-circuit `(uniqueRows, totalHits)` pairs (one per function circuit, then -one per memory size). `uniqueRows` is the trace height; `totalHits` is the sum -of query multiplicities. Returns `Except.error msg` when execution fails -(e.g. `assert_eq!` mismatch from a typechecker rejecting a constant), so +of per-circuit `QueryCount`s. Returns `Except.error msg` when execution +fails (e.g. `assert_eq!` mismatch from a typechecker rejecting a constant), so callers can recover instead of crashing. -/ def execute (toplevel : @& Bytecode.Toplevel) (funIdx : @& Bytecode.FunIdx) (args : @& Array G) (ioBuffer : IOBuffer) : - Except String (Array G × IOBuffer × Array (Nat × Nat)) := + Except String (Array G × IOBuffer × Array QueryCount) := let ioData := ioBuffer.data let ioMap := ioBuffer.map match execute' toplevel funIdx args ioData ioMap.toArray with | .error e => .error e | .ok (output, (ioData, ioMap), queryCounts) => let ioMap := ioMap.foldl (fun acc (k, v) => acc.insert k v) ∅ + let queryCounts := queryCounts.map fun (uniqueRows, totalHits) => { uniqueRows, totalHits } .ok (output, ⟨ioData, ioMap⟩, queryCounts) end Bytecode.Toplevel diff --git a/Ix/Aiur/Statistics.lean b/Ix/Aiur/Statistics.lean index 7e25c283..ddcf082b 100644 --- a/Ix/Aiur/Statistics.lean +++ b/Ix/Aiur/Statistics.lean @@ -1,10 +1,11 @@ module public import Ix.Aiur.Compiler +public import Ix.Aiur.Semantics.BytecodeFfi /-! Circuit statistics for Aiur executions. -Given a `CompiledToplevel` and the per-circuit `(uniqueRows, totalHits)` pairs +Given a `CompiledToplevel` and the per-circuit `QueryCount`s returned by `execute`, computes per-circuit width, height (the unique-row count), cache hits (sum of multiplicities), and the FFT cost (width × height × log2(height)) for every constrained function and memory @@ -38,7 +39,7 @@ def fftCost (w h : Nat) : Float := let hf := h.toFloat wf * hf * (max hf 2.0).log2 -def computeStats (compiled : CompiledToplevel) (queryCounts : Array (Nat × Nat)) : +def computeStats (compiled : CompiledToplevel) (queryCounts : Array QueryCount) : ExecutionStats := let t := compiled.bytecode -- Invert nameMap to get FunIdx → String @@ -50,15 +51,17 @@ def computeStats (compiled : CompiledToplevel) (queryCounts : Array (Nat × Nat) for i in [:nAllFuns] do if t.functions[i]!.constrained then let w := t.functions[i]!.layout.totalWidth - let (h, totalMults) := queryCounts[i]! - let hits := totalMults - h + let qc := queryCounts[i]! + let h := qc.uniqueRows + let hits := qc.totalHits - qc.uniqueRows let name := reverseMap[i]?.getD s!"" acc := acc.push { name, width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } acc let memoryCircuits := t.memorySizes.mapIdx fun i size => let w := size + 11 - let (h, totalMults) := queryCounts[nAllFuns + i]! - let hits := totalMults - h + let qc := queryCounts[nAllFuns + i]! + let h := qc.uniqueRows + let hits := qc.totalHits - qc.uniqueRows { name := s!"memory[{size}]", width := w, height := h, cacheHits := hits, fftCost := fftCost w h : CircuitStats } let circuits := (functionCircuits ++ memoryCircuits).qsort (·.fftCost > ·.fftCost)