diff --git a/crypto/crypto/src/merkle_tree/merkle.rs b/crypto/crypto/src/merkle_tree/merkle.rs index d31e71311..4d0865c0c 100644 --- a/crypto/crypto/src/merkle_tree/merkle.rs +++ b/crypto/crypto/src/merkle_tree/merkle.rs @@ -55,6 +55,11 @@ pub(crate) struct MmapNodeBacking { pub struct MerkleTree { pub root: B::Node, nodes: Vec, + /// Number of leaves in the tree (always a power of two). Stored explicitly + /// so the leaf count is still known after [`MerkleTree::drop_leaves`] frees + /// the leaf half (after which `nodes.len() == leaves_len - 1`, so the usual + /// `(node_count + 1) / 2` recovery no longer applies). + leaves_len: usize, #[cfg(feature = "disk-spill")] #[cfg_attr(feature = "serde", serde(skip))] mmap_backing: Option, @@ -79,13 +84,14 @@ where { fn serialize(&self, serializer: S) -> Result { use serde::ser::SerializeStruct; - let mut s = serializer.serialize_struct("MerkleTree", 2)?; + let mut s = serializer.serialize_struct("MerkleTree", 3)?; s.serialize_field("root", &self.root)?; if self.mmap_backing.is_some() { s.serialize_field("nodes", &MmapNodesSeq(self))?; } else { s.serialize_field("nodes", &self.nodes)?; } + s.serialize_field("leaves_len", &self.leaves_len)?; s.end() } } @@ -154,6 +160,7 @@ where Some(MerkleTree { root: nodes[ROOT].clone(), nodes, + leaves_len, #[cfg(feature = "disk-spill")] mmap_backing: None, }) @@ -211,6 +218,93 @@ where self.create_proof(merkle_path) } + /// Number of leaves in the tree (always a power of two). Valid both before + /// and after [`drop_leaves`](Self::drop_leaves). + pub fn leaves_len(&self) -> usize { + self.leaves_len + } + + /// Whether the leaf half of the node buffer has been freed by + /// [`drop_leaves`](Self::drop_leaves). When `true`, the leaf-level sibling + /// must be supplied by the caller to build an opening (see + /// [`get_proof_by_pos_with_leaf_sibling`](Self::get_proof_by_pos_with_leaf_sibling)). + pub fn leaves_dropped(&self) -> bool { + self.node_count() == self.leaves_len - 1 + } + + /// Free the leaf half of the node buffer, keeping only the inner nodes + /// (`nodes[0..leaves_len - 1]`, root at index 0). This roughly halves the + /// tree's memory footprint. The root and every inner node are retained, so + /// the only path node that must be regenerated at open time is the + /// leaf-level sibling — see + /// [`get_proof_by_pos_with_leaf_sibling`](Self::get_proof_by_pos_with_leaf_sibling). + /// + /// Idempotent: a no-op if the leaves were already dropped. A single-leaf + /// tree (`leaves_len == 1`) has no leaf half to drop and is left unchanged. + pub fn drop_leaves(&mut self) { + if self.leaves_len <= 1 { + return; + } + let inner_count = self.leaves_len - 1; + // `disk-spill` mmap backing is read-only and never populated together + // with leaf-dropping in the prover, so only the heap path is handled. + #[cfg(feature = "disk-spill")] + if self.mmap_backing.is_some() { + return; + } + if self.nodes.len() > inner_count { + self.nodes.truncate(inner_count); + self.nodes.shrink_to_fit(); + } + } + + /// Builds the same opening that [`get_proof_by_pos`](Self::get_proof_by_pos) + /// would, but sources the leaf-level sibling from `leaf_sibling` instead of + /// the (possibly dropped) leaf buffer. All higher levels read the retained + /// inner nodes. The resulting [`Proof`] is byte-identical to the full-tree + /// `get_proof_by_pos(pos)` provided `leaf_sibling` equals the hash the + /// builder stored for the leaf at `sibling_leaf_position(pos)`. + /// + /// Works whether or not the leaves have been dropped. + pub fn get_proof_by_pos_with_leaf_sibling( + &self, + pos: usize, + leaf_sibling: B::Node, + ) -> Option> { + let leaf_node = pos + (self.leaves_len - 1); + // Single-leaf tree: the leaf is the root, the path is empty. + if leaf_node == ROOT { + return Some(Proof { + merkle_path: Vec::new(), + }); + } + + let tree_depth = self.leaves_len.ilog2() as usize; + let mut merkle_path = Vec::with_capacity(tree_depth); + + // Bottom level: the leaf-level sibling, supplied by the caller (its + // node lives in the dropped leaf half). + merkle_path.push(leaf_sibling); + + // Higher levels: every sibling here is an inner node, still resident. + let mut node = parent_index(leaf_node); + while node != ROOT { + let sibling = self.node_get(sibling_index(node))?; + merkle_path.push(sibling.clone()); + node = parent_index(node); + } + + Some(Proof { merkle_path }) + } + + /// 0-based leaf position of the leaf-level sibling needed to open `pos`. + /// The caller regenerates that leaf's hash (e.g. by re-hashing the LDE row) + /// to pass into [`get_proof_by_pos_with_leaf_sibling`](Self::get_proof_by_pos_with_leaf_sibling). + pub fn sibling_leaf_position(&self, pos: usize) -> usize { + let leaf_node = pos + (self.leaves_len - 1); + sibling_index(leaf_node) - (self.leaves_len - 1) + } + /// Creates a proof from a Merkle pasth fn create_proof(&self, merkle_path: Vec) -> Option> { Some(Proof { merkle_path }) diff --git a/crypto/crypto/src/tests/merkle_proof_tests.rs b/crypto/crypto/src/tests/merkle_proof_tests.rs index 458d33c0c..e072a4f6a 100644 --- a/crypto/crypto/src/tests/merkle_proof_tests.rs +++ b/crypto/crypto/src/tests/merkle_proof_tests.rs @@ -333,3 +333,42 @@ fn batch_proof_verify_sparse_leaves_across_tree() { 16 )); } + +use crate::merkle_tree::traits::IsMerkleTreeBackend; +use crate::merkle_tree::utils::complete_until_power_of_two; + +/// Leaf-drop opener must produce byte-identical proofs to the full-tree opener. +/// This is the core correctness invariant for the streaming "leaf-drop" mode. +#[test] +fn leaf_dropped_opening_is_byte_identical_to_full_tree_opening() { + type B = TestBackend; + let values: Vec = (1..1000).map(Ecgfp5FE::new).collect(); + let leaves = complete_until_power_of_two(::hash_leaves(&values)); + + let full = TestMerkleTreeEcgfp::build(&values).unwrap(); + let mut dropped = TestMerkleTreeEcgfp::build(&values).unwrap(); + dropped.drop_leaves(); + assert!(dropped.leaves_dropped()); + assert_eq!(dropped.leaves_len(), leaves.len()); + + for pos in [0usize, 1, 2, 7, 9349 % leaves.len(), leaves.len() - 1] { + let full_proof = full.get_proof_by_pos(pos).unwrap(); + let sib_leaf = leaves[dropped.sibling_leaf_position(pos)]; + let dropped_proof = dropped + .get_proof_by_pos_with_leaf_sibling(pos, sib_leaf) + .unwrap(); + assert_eq!( + full_proof.merkle_path, dropped_proof.merkle_path, + "leaf-dropped proof differs from full-tree proof at pos {pos}" + ); + } +} + +#[test] +fn drop_leaves_is_idempotent() { + let values: Vec = (1..100).map(Ecgfp5FE::new).collect(); + let mut tree = TestMerkleTreeEcgfp::build(&values).unwrap(); + tree.drop_leaves(); + tree.drop_leaves(); + assert!(tree.leaves_dropped()); +} diff --git a/crypto/stark/src/instruments.rs b/crypto/stark/src/instruments.rs index 16ff95082..02fd5996c 100644 --- a/crypto/stark/src/instruments.rs +++ b/crypto/stark/src/instruments.rs @@ -86,6 +86,8 @@ thread_local! { static R2_SUB: RefCell> = const { RefCell::new(None) }; /// Round 4 sub-timings: (fft, merkle, deep_comp, queries) static R4_SUB: RefCell> = const { RefCell::new(None) }; + /// Round 3 OOD evaluation timing. + static R3_OOD: RefCell> = const { RefCell::new(None) }; /// Assembled sub-ops from prove_rounds_2_to_4 (without reconstruct_round1 LDE time). static ROUND_SUB_OPS: RefCell> = const { RefCell::new(None) }; } @@ -141,6 +143,9 @@ pub fn reset_all() { R4_SUB.with(|cell| { cell.borrow_mut().take(); }); + R3_OOD.with(|cell| { + cell.borrow_mut().take(); + }); ROUND_SUB_OPS.with(|cell| { cell.borrow_mut().take(); }); @@ -154,6 +159,14 @@ pub fn take_r2_sub() -> Option<(Duration, Duration, Duration)> { R2_SUB.with(|cell| cell.borrow_mut().take()) } +pub fn store_r3_ood(d: Duration) { + R3_OOD.with(|cell| *cell.borrow_mut() = Some(d)); +} + +pub fn take_r3_ood() -> Option { + R3_OOD.with(|cell| cell.borrow_mut().take()) +} + pub fn store_r4_sub(fft: Duration, merkle: Duration, deep_comp: Duration, queries: Duration) { R4_SUB.with(|cell| *cell.borrow_mut() = Some((fft, merkle, deep_comp, queries))); } diff --git a/crypto/stark/src/proof/stark.rs b/crypto/stark/src/proof/stark.rs index 1751d60fe..40569deea 100644 --- a/crypto/stark/src/proof/stark.rs +++ b/crypto/stark/src/proof/stark.rs @@ -30,6 +30,40 @@ pub struct DeepPolynomialOpening, E: IsField> { pub type DeepPolynomialOpenings = Vec>; +/// Per-(chunk, lde_size) batched FRI instance (Approach 1, batched FRI within a +/// chunk). +/// +/// One per height bucket inside a chunk: every bucket-mate's individual DEEP +/// composition polynomial is linearly combined with successive powers of the +/// bucket's `delta_fri` challenge (sampled from the chunk-shared `bucket_seed`), +/// and a single FRI commit + grinding + query is run on the combined +/// polynomial. The `members` list pins the canonical bucket-local order used to +/// derive `delta_fri^i` on the verifier side; reordering the list rejects the +/// proof. +/// +/// `decommitments` length equals `air.options().fri_number_of_queries` (one +/// decommitment per shared iota). `nonce` is `Some` when the AIR's grinding +/// factor > 0 (`None` otherwise). +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(bound = "")] +pub struct ChunkBucketFri { + /// LDE size shared by every bucket-mate. Equal to + /// `trace_length * blowup_factor` for each member. + pub lde_size: u32, + /// Chunk-local indices of the bucket-mates, in canonical (chunk-local + /// index ascending) order. Index `i` here corresponds to `delta_fri^i` + /// in the linear combination. + pub members: Vec, + /// `[pₖ]` for the committed FRI layers. + pub layer_roots: Vec, + /// `pₙ` — the final folded constant. + pub last_value: FieldElement, + /// One FRI decommitment per shared iota. + pub decommitments: Vec>, + /// Grinding nonce, when `grinding_factor > 0`. + pub nonce: Option, +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(bound = "PI: serde::Serialize + serde::de::DeserializeOwned")] pub struct StarkProof, E: IsField, PI> { @@ -50,17 +84,14 @@ pub struct StarkProof, E: IsField, PI> { pub composition_poly_root: Commitment, // Hᵢ(z^N) pub composition_poly_parts_ood_evaluation: Vec>, - // [pₖ] - pub fri_layers_merkle_roots: Vec, - // pₙ - pub fri_last_value: FieldElement, - // Open(pₖ(Dₖ), −𝜐ₛ^(2ᵏ)) - pub query_list: Vec>, // Open(H₁(D_LDE, 𝜐ᵢ), Open(H₂(D_LDE, 𝜐ᵢ), Open(tⱼ(D_LDE), 𝜐ᵢ) // Open(H₁(D_LDE, -𝜐ᵢ), Open(H₂(D_LDE, -𝜐ᵢ), Open(tⱼ(D_LDE), -𝜐ᵢ) + // + // FRI for this table is no longer per-table: it is run once per + // (chunk, lde_size) bucket and lives in + // [`MultiProof::fri_chunk_buckets`]. These DEEP openings are evaluated + // at the bucket-shared query indices (iotas). pub deep_poly_openings: DeepPolynomialOpenings, - // nonce obtained from grinding - pub nonce: Option, // Bus interaction public inputs for the accumulated column. // Contains the table contribution (L), used for: // 1. Circular constraint offset: L/N per row @@ -77,4 +108,12 @@ pub struct StarkProof, E: IsField, PI> { #[serde(bound = "PI: serde::Serialize + serde::de::DeserializeOwned")] pub struct MultiProof, E: IsField, PI> { pub proofs: Vec>, + /// Per-(chunk, lde_size-bucket) batched FRI instances. Outer Vec is indexed + /// by chunk (chunks of `chunk_size` tables in proof order); inner Vec lists + /// buckets in canonical first-encounter (chunk-local-index ascending) order. + pub fri_chunk_buckets: Vec>>, + /// Pinned chunk size (= the prover's `table_parallelism()` at proving time). + /// The verifier uses this to chunk the proof slice into the same per-chunk + /// grouping the prover used. + pub chunk_size: u32, } diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 53af372ec..912fab3bf 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -36,7 +36,6 @@ use crate::trace::LDETraceTable; use super::config::{BatchedMerkleTree, BatchedMerkleTreeBackend, Commitment}; use super::constraints::evaluator::ConstraintEvaluator; use super::domain::{Domain, DomainConstants}; -use super::fri::fri_decommit::FriDecommitment; use super::grinding; use super::lookup::BusPublicInputs; use super::proof::stark::{DeepPolynomialOpening, MultiProof, StarkProof}; @@ -106,7 +105,15 @@ where FieldElement: AsBytes, { /// Build a `TableCommit` for a plain (non-preprocessed) table. - fn plain(tree: BatchedMerkleTree, root: Commitment) -> Self { + /// + /// In streaming mode (`leaf_drop = true`) the tree's leaf half is freed + /// before wrapping in `Arc` (the only point we still have `&mut` access). + /// The leaf siblings needed at open time are regenerated from the + /// recomputed LDE — see [`IsStarkProver::open_polys_with`]. + fn plain(mut tree: BatchedMerkleTree, root: Commitment, leaf_drop: bool) -> Self { + if leaf_drop { + tree.drop_leaves(); + } Self { tree: Arc::new(tree), root, @@ -118,12 +125,17 @@ where /// Build a `TableCommit` for a preprocessed table. fn preprocessed( - tree: BatchedMerkleTree, + mut tree: BatchedMerkleTree, root: Commitment, - precomputed_tree: BatchedMerkleTree, + mut precomputed_tree: BatchedMerkleTree, precomputed_root: Commitment, num_precomputed_cols: usize, + leaf_drop: bool, ) -> Self { + if leaf_drop { + tree.drop_leaves(); + precomputed_tree.drop_leaves(); + } Self { tree: Arc::new(tree), root, @@ -263,6 +275,77 @@ impl LdeTwiddles { } } +/// Streaming "retire-LDE" mode (Approach 1, Milestone 1). +/// +/// When `true`, the main/aux LDE columns are dropped immediately after committing +/// (Phase A / Phase C) instead of being cached, and each table's LDE is recomputed +/// on demand in Rounds 2-4 via [`reconstruct_round1`]. This trades extra FFT work +/// (one LDE expansion per table) for a large drop in peak working memory: the +/// `O(N × cols × lde_size)` cache of all tables' LDE columns is no longer held +/// simultaneously between Phase A/C and Rounds 2-4 (only the resident traces and +/// the Merkle trees remain). No re-execution of the VM is involved — the LDE is +/// rebuilt from the still-resident trace. +/// +/// # Leaf-drop (T1) +/// +/// In this same mode the main and aux Merkle trees are additionally +/// *leaf-dropped*: after committing, [`MerkleTree::drop_leaves`] frees the leaf +/// half of each tree (about half its memory), keeping only the inner nodes and +/// root. At open time the single leaf-level sibling each query needs is +/// regenerated by re-hashing the recomputed LDE row (see +/// [`IsStarkProver::open_polys_with`] / [`keccak_leaf_from_row`]); all higher +/// path nodes are read from the retained inner nodes. The resulting `Proof`s are +/// byte-identical to the full-tree ones, so the verifier is unchanged. +/// +/// The per-table *composition* Merkle tree (built transiently in Round 2 and +/// dropped at the end of that table's Rounds 2-4) is **left full** — it does not +/// contribute to the cross-table peak, so leaf-dropping it would add complexity +/// (row-pair leaf regeneration) for no peak-memory win. Its +/// `open_composition_poly` path therefore keeps using the full-tree opener. +/// +/// Opt-in via `LAMBDA_STREAM_LDE=1` (or `true`). Default: `false` (cache for speed). +pub fn streaming_retire_lde() -> bool { + std::env::var("LAMBDA_STREAM_LDE") + .map(|v| v == "1" || v.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +/// On-demand trace source for streaming "retire-traces" mode (C.2b). +/// +/// In the default (non-streaming) path every table's main+aux trace is resident +/// in `air_trace_pairs` for the whole of [`IsStarkProver::multi_prove`], and no +/// provider is used (`None`). In the streaming path (`LAMBDA_STREAM_LDE=1`) the +/// caller may instead pass a provider that builds the *log-derived* tables' +/// traces on demand from a compact routed intermediate: the prover asks the +/// provider for table `idx`'s freshly-built **main-only** trace at each point it +/// is needed (Phase A main commit, Phase C aux build, Rounds 2-4 reconstruct), +/// uses it transiently, and drops it. The auxiliary trace is rebuilt by the +/// prover on top of the freshly-built main via `AIR::build_auxiliary_trace`. +/// +/// Because the underlying build is deterministic (C.2a), the trace produced for +/// `idx` is byte-identical across all three phases (and to the pre-built trace +/// the non-streaming path uses), so the resulting proof is byte-identical. +/// +/// Tables for which [`TraceProvider::is_retired`] returns `false` (preprocessed +/// tables, PAGE, etc.) are *not* built on demand: the prover uses the resident +/// trace borrowed in `air_trace_pairs` exactly as in the non-streaming path. +pub trait TraceProvider: Sync +where + Field: IsSubFieldOf + IsField, + FieldExtension: IsField, +{ + /// Whether table `idx` is retired (built on demand) rather than resident. + fn is_retired(&self, idx: usize) -> bool; + + /// Number of rows of the main trace for table `idx`. Cheap; used in the + /// pre-pass to size the LDE domain without materializing the trace. + fn num_rows(&self, idx: usize) -> usize; + + /// Build the **main-only** trace (no auxiliary columns) for retired table + /// `idx`. Must be deterministic: byte-identical across repeated calls. + fn build_main(&self, idx: usize) -> TraceTable; +} + /// Number of tables to process concurrently in `multi_prove`. /// Default: num_cores / 3 (benchmarked optimal on both M3 Pro and EPYC 9454P). /// Override with `TABLE_PARALLELISM` env var. @@ -307,20 +390,13 @@ pub(crate) struct Round3 { composition_poly_parts_ood_evaluation: Vec>, } -/// A container for the results of the fourth round of the STARK Prove protocol. -pub(crate) struct Round4, E: IsField> { - /// The final value resulting from folding the Deep composition polynomial all the way down to a constant value. - fri_last_value: FieldElement, - /// The commitments to the fold polynomials of the inner layers of FRI. - fri_layers_merkle_roots: Vec, - /// The values and proofs of validity of the evaluations of the trace polynomials and the composition polynomials - /// parts at the domain values corresponding to the FRI query challenges and their symmetric counterparts. - deep_poly_openings: DeepPolynomialOpenings, - /// The values and proofs of validity of the evaluations of the fold polynomials of the inner - /// layers of FRI at the values corresponding to the symmetrics of the FRI query challenges. - query_list: Vec>, - /// The proof of work nonce. - nonce: Option, +/// DEEP composition coefficients derived from the challenge 𝛾, sampled at the end +/// of Rounds 2-3 so Round 4 only builds the DEEP LDE and runs FRI. +pub(crate) struct DeepCoeffs { + /// Coefficients for the composition-polynomial-part terms. + gammas: Vec>, + /// Per-trace-column coefficient chunks for the trace terms. + trace_term_coeffs: Vec>>, } /// Returns the evaluations of the polynomial `p` over the lde domain defined by the given @@ -402,6 +478,27 @@ where result } +/// Hash a single already-gathered LDE row into its leaf [`Commitment`], using +/// the exact byte layout and hasher of [`keccak_leaves_bit_reversed`]: BE +/// concatenation of the row's columns, then `hash_bytes`. +/// +/// Used by the streaming leaf-drop opener to regenerate the one leaf-level +/// sibling needed per opened position (the rest of the path is read from the +/// retained inner nodes). Producing a different digest here than the build-time +/// `keccak_leaves_bit_reversed` would yield a wrong (non-byte-identical) proof. +pub fn keccak_leaf_from_row(row: &[FieldElement]) -> Commitment +where + E: IsField, + FieldElement: AsBytes + ByteConversion, +{ + let byte_len = as ByteConversion>::BYTE_LEN; + let mut buf = vec![0u8; row.len() * byte_len]; + for (col_idx, value) in row.iter().enumerate() { + value.write_bytes_be(&mut buf[col_idx * byte_len..(col_idx + 1) * byte_len]); + } + BatchedMerkleTreeBackend::::hash_bytes(&buf) +} + /// Compute Keccak-256 leaf hashes for `commit_composition_polynomial`: one /// leaf per row-pair, where leaf `i` hashes the BE concatenation of /// `parts[..][br_0] ++ parts[..][br_1]` with @@ -615,6 +712,7 @@ pub trait IsStarkProver< domain: &Domain, twiddles: &LdeTwiddles, precomputed: Option<(Commitment, usize)>, + leaf_drop: bool, #[cfg(feature = "disk-spill")] storage_mode: StorageMode, ) -> Result<(TableCommit, Vec>>), ProvingError> where @@ -646,7 +744,7 @@ pub trait IsStarkProver< tree.spill_nodes_to_disk() .map_err(|e| ProvingError::DiskSpill(format!("main Merkle tree: {e}")))?; } - TableCommit::plain(tree, root) + TableCommit::plain(tree, root, leaf_drop) } Some((expected_precomputed_root, num_cols)) => { #[allow(unused_mut)] @@ -676,6 +774,7 @@ pub trait IsStarkProver< precomputed_tree, precomputed_root, num_cols, + leaf_drop, ) } }; @@ -688,9 +787,9 @@ pub trait IsStarkProver< /// Recompute Round1 from the trace, reusing the Merkle trees stored in commitments. /// - /// Only used by `run_debug_checks` — Phase D consumes the cached LDE - /// directly and does not go through this path. - #[cfg(feature = "debug-checks")] + /// Used by `run_debug_checks` and by streaming "retire-LDE" mode + /// ([`streaming_retire_lde`]), where the cached LDE was dropped after commit + /// and is rebuilt on demand from the still-resident trace. fn reconstruct_round1( air: &dyn AIR, trace: &TraceTable, @@ -720,25 +819,56 @@ pub trait IsStarkProver< /// Reconstruct Round1 for every table, print the bus balance report, and /// validate each trace. Called once after Phase C commits. #[cfg(feature = "debug-checks")] + #[allow(clippy::too_many_arguments)] fn run_debug_checks( air_trace_pairs: &[AirTracePair<'_, Field, FieldExtension, PI>], commitments: &[Round1Commitments], domains: &[Arc>], twiddle_caches: &[Arc>], + provider: Option<&dyn TraceProvider>, + stream_traces: bool, + lookup_challenges: &[FieldElement], ) where FieldElement: AsBytes, FieldElement: AsBytes, PI: Send + Sync + Clone, { + // For retired tables the resident trace is an empty placeholder; rebuild + // each retired table's main+aux trace here so debug checks see the real + // data (matching what the proving rounds reconstruct). + let rebuilt: Vec>> = air_trace_pairs + .iter() + .enumerate() + .map(|(idx, (air, _, _))| { + if stream_traces && provider.is_some_and(|p| p.is_retired(idx)) { + let mut t = provider.unwrap().build_main(idx); + if air.has_aux_trace() { + air.build_auxiliary_trace(&mut t, lookup_challenges); + } + Some(t) + } else { + None + } + }) + .collect(); + let trace_at = |idx: usize| -> &TraceTable { + match &rebuilt[idx] { + Some(t) => t, + None => air_trace_pairs[idx].1, + } + }; + let mut temp_results: Vec> = Vec::with_capacity(air_trace_pairs.len()); - for (((air, trace, _), commitment), (domain, twiddles)) in air_trace_pairs + for (idx, (((air, _, _), commitment), (domain, twiddles))) in air_trace_pairs .iter() .zip(commitments.iter()) .zip(domains.iter().zip(twiddle_caches.iter())) + .enumerate() { - let result = Self::reconstruct_round1(*air, *trace, domain, commitment, twiddles) - .expect("reconstruct_round1 failed in debug-checks"); + let result = + Self::reconstruct_round1(*air, trace_at(idx), domain, commitment, twiddles) + .expect("reconstruct_round1 failed in debug-checks"); temp_results.push(result); } @@ -748,15 +878,16 @@ pub trait IsStarkProver< .collect(); print_bus_balance_report(&all_bus_public_inputs); - for (((air, trace, pub_inputs), round_1_result), domain) in air_trace_pairs + for (idx, (((air, _, pub_inputs), round_1_result), domain)) in air_trace_pairs .iter() .zip(temp_results.iter()) .zip(domains.iter()) + .enumerate() { validate_trace( *air, *pub_inputs, - *trace, + trace_at(idx), domain, &round_1_result.rap_challenges, round_1_result.bus_public_inputs.as_ref(), @@ -1018,29 +1149,24 @@ pub trait IsStarkProver< } /// Returns the result of the fourth round of the STARK Prove protocol. - fn round_4_compute_and_run_fri_on_the_deep_composition_polynomial( + /// Sample the DEEP composition challenge 𝛾 and expand it into the per-trace-term + /// and per-composition-part coefficients. Done at the end of Rounds 2-3 (from the + /// table's fork) so Round 4 only builds the DEEP LDE and runs FRI. + fn sample_deep_coeffs( air: &dyn AIR, - domain: &Domain, - round_1_result: &Round1, round_2_result: &Round2, - round_3_result: &Round3, - z: &FieldElement, transcript: &mut impl IsStarkTranscript, - ) -> Round4 + ) -> DeepCoeffs where FieldElement: AsBytes, FieldElement: AsBytes, { - let coset_offset_u64 = air.context().proof_options.coset_offset; - let coset_offset = FieldElement::::from(coset_offset_u64); - let gamma = transcript.sample_field_element(); let n_terms_composition_poly = round_2_result.lde_composition_poly_evaluations.len(); let num_terms_trace = air.context().transition_offsets.len() * air.step_size() * air.context().trace_columns; - // <<<< Receive challenges: 𝛾, 𝛾' let mut deep_composition_coefficients: Vec<_> = core::iter::successors(Some(FieldElement::one()), |x| Some(x * &gamma)) .take(n_terms_composition_poly + num_terms_trace) @@ -1053,86 +1179,11 @@ pub trait IsStarkProver< .map(|chunk| chunk.to_vec()) .collect(); - // <<<< Receive challenges: 𝛾ⱼ, 𝛾ⱼ' let gammas = deep_composition_coefficients; - // Compute p₀ (deep composition polynomial) as N evaluations on trace-size coset - #[cfg(feature = "instruments")] - let t_sub = Instant::now(); - let deep_evals = Self::compute_deep_composition_poly_evaluations( - &round_1_result.lde_trace, - round_2_result, - round_3_result, - z, - domain, - &domain.trace_primitive_root, - &gammas, - &trace_term_coeffs, - ); - #[cfg(feature = "instruments")] - let other_dur_1 = t_sub.elapsed(); - - // DEEP evaluations are already at 2N LDE points — just bit-reverse for FRI. - // No iFFT+FFT extension needed (Plonky3-style direct LDE computation). - let domain_size = domain.lde_roots_of_unity_coset.len(); - #[cfg(feature = "instruments")] - let t_sub = Instant::now(); - let mut lde_evals = deep_evals; - in_place_bit_reverse_permute(&mut lde_evals); - #[cfg(feature = "instruments")] - let r4_fft_dur = t_sub.elapsed(); - - // FRI commit phase from pre-computed evaluations - #[cfg(feature = "instruments")] - let t_sub = Instant::now(); - let (fri_last_value, fri_layers) = - fri::commit_phase_from_evaluations::( - domain.root_order as usize, - lde_evals, - transcript, - &coset_offset, - domain_size, - ); - #[cfg(feature = "instruments")] - let r4_merkle_dur = t_sub.elapsed(); - - // grinding: generate nonce and append it to the transcript - #[cfg(feature = "instruments")] - let t_sub = Instant::now(); - let security_bits = air.context().proof_options.grinding_factor; - let mut nonce = None; - if security_bits > 0 { - let nonce_value = grinding::generate_nonce(&transcript.state(), security_bits) - .expect("nonce not found"); - transcript.append_bytes(&nonce_value.to_be_bytes()); - nonce = Some(nonce_value); - } - - let number_of_queries = air.options().fri_number_of_queries; - let iotas = Self::sample_query_indexes(number_of_queries, domain, transcript); - - let query_list = fri::query_phase(&fri_layers, &iotas); - - let fri_layers_merkle_roots: Vec<_> = fri_layers - .iter() - .map(|layer| layer.merkle_tree.root) - .collect(); - - let deep_poly_openings = - Self::open_deep_composition_poly(domain, round_1_result, round_2_result, &iotas); - - #[cfg(feature = "instruments")] - { - let queries_dur = t_sub.elapsed(); - crate::instruments::store_r4_sub(r4_fft_dur, r4_merkle_dur, other_dur_1, queries_dur); - } - - Round4 { - fri_last_value, - fri_layers_merkle_roots, - deep_poly_openings, - query_list, - nonce, + DeepCoeffs { + gammas, + trace_term_coeffs, } } @@ -1347,15 +1398,31 @@ pub trait IsStarkProver< ) -> PolynomialOpenings where C: IsField, - FieldElement: AsBytes + Sync + Send, + FieldElement: AsBytes + Sync + Send + ByteConversion, G: Fn(usize) -> Vec>, { let domain_size = domain.lde_roots_of_unity_coset.len() as u64; let index = challenge * 2; let index_sym = challenge * 2 + 1; + // `proof_for` produces the same `Proof` whether or not the tree's leaves + // were dropped: with a full tree we read the stored leaf sibling; with a + // leaf-dropped tree (streaming mode) we regenerate that one sibling leaf + // by re-hashing the LDE row at `sibling_leaf_position`, matching the + // build-time `keccak_leaves_bit_reversed` byte layout exactly. + let proof_for = |pos: usize| { + if tree.leaves_dropped() { + let sib_leaf_pos = tree.sibling_leaf_position(pos); + let sib_row = gather(reverse_index(sib_leaf_pos, domain_size)); + let sibling = keccak_leaf_from_row(&sib_row); + tree.get_proof_by_pos_with_leaf_sibling(pos, sibling) + .unwrap() + } else { + tree.get_proof_by_pos(pos).unwrap() + } + }; PolynomialOpenings { - proof: tree.get_proof_by_pos(index).unwrap(), - proof_sym: tree.get_proof_by_pos(index_sym).unwrap(), + proof: proof_for(index), + proof_sym: proof_for(index_sym), evaluations: gather(reverse_index(index, domain_size)), evaluations_sym: gather(reverse_index(index_sym, domain_size)), } @@ -1444,7 +1511,41 @@ pub trait IsStarkProver< /// /// The transcript must be safely initialized before passing it to this method. fn multi_prove( + air_trace_pairs: Vec>, + transcript: &mut (impl IsStarkTranscript + Clone + Send), + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, + ) -> Result, ProvingError> + where + FieldElement: AsBytes, + FieldElement: AsBytes, + PI: Send + Sync + Clone, + Field: Copy + 'static, + FieldExtension: Copy + 'static, + ::BaseType: SpillSafe, + ::BaseType: SpillSafe, + { + Self::multi_prove_with_provider( + air_trace_pairs, + None, + transcript, + #[cfg(feature = "disk-spill")] + storage_mode, + ) + } + + /// Like [`IsStarkProver::multi_prove`], but with an optional on-demand trace + /// [`TraceProvider`] for streaming "retire-traces" mode (C.2b). + /// + /// When `provider` is `None` this is exactly [`IsStarkProver::multi_prove`]. + /// When `provider` is `Some` *and* [`streaming_retire_lde`] is enabled, the + /// log-derived tables (those for which the provider reports `is_retired`) are + /// built on demand from the provider at each phase (main commit, aux build, + /// reconstruct) and dropped afterwards, so the resident traces in + /// `air_trace_pairs` for those indices may be empty placeholders. Tables the + /// provider does not retire (preprocessed, PAGE) use their resident traces. + fn multi_prove_with_provider( mut air_trace_pairs: Vec>, + provider: Option<&dyn TraceProvider>, transcript: &mut (impl IsStarkTranscript + Clone + Send), #[cfg(feature = "disk-spill")] storage_mode: StorageMode, ) -> Result, ProvingError> @@ -1488,8 +1589,22 @@ pub trait IsStarkProver< let mut domains = Vec::with_capacity(num_airs); let mut twiddle_caches: Vec>> = Vec::with_capacity(num_airs); - for (air, trace, _pub_inputs) in &*air_trace_pairs { - let trace_length = trace.num_rows(); + // Streaming "retire-traces" is active only when a provider is supplied + // AND the streaming env flag is on. Off-path (`provider == None`) every + // expression below collapses to today's resident-trace behaviour. + let stream_traces = provider.is_some() && streaming_retire_lde(); + // Returns the main-trace row count for table `idx`, sourced from the + // provider for retired tables (whose resident placeholder is empty) and + // from the resident trace otherwise. + let rows_of = |idx: usize| -> usize { + match provider { + Some(p) if stream_traces && p.is_retired(idx) => p.num_rows(idx), + _ => air_trace_pairs[idx].1.num_rows(), + } + }; + + for (idx, (air, _trace, _pub_inputs)) in air_trace_pairs.iter().enumerate() { + let trace_length = rows_of(idx); let blowup = air.options().blowup_factor as usize; let coset_offset = air.options().coset_offset; let key = (trace_length, blowup, coset_offset); @@ -1518,6 +1633,7 @@ pub trait IsStarkProver< drop(domain_cache); let k = table_parallelism().min(num_airs).max(1); + let stream_lde = streaming_retire_lde(); // Spill main traces to mmap before Round 1 LDE. #[cfg(feature = "disk-spill")] @@ -1568,14 +1684,28 @@ pub trait IsStarkProver< let domain = &domains[idx]; let twiddles = &twiddle_caches[idx]; + // Retire-traces: build this table's main trace on demand and + // commit it, then drop it at the end of this closure. Other + // tables (preprocessed/PAGE, or non-streaming) use the + // resident trace borrowed above. + let retired_main; + let main_trace: &TraceTable = + if stream_traces && provider.is_some_and(|p| p.is_retired(idx)) { + retired_main = provider.unwrap().build_main(idx); + &retired_main + } else { + trace + }; + let precomputed = air .is_preprocessed() .then(|| (air.precomputed_commitment(), air.num_precomputed_columns())); Self::commit_main_trace( - *trace, + main_trace, domain, twiddles, precomputed, + stream_lde, #[cfg(feature = "disk-spill")] storage_mode, ) @@ -1590,7 +1720,9 @@ pub trait IsStarkProver< } transcript.append_bytes(&commit.root); main_commits.push(commit); - main_ldes.push(cached_main); + // Streaming retire-LDE: drop the main LDE now; Rounds 2-4 recompute + // it on demand from the resident trace. + main_ldes.push(if stream_lde { Vec::new() } else { cached_main }); } } @@ -1630,13 +1762,18 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); + // For retired tables the aux trace is built on demand inside the Pass-2 + // loop (so it can be committed and dropped per table); their bus inputs + // are filled there. Non-retired tables build aux into their resident + // trace here as before. #[cfg(feature = "parallel")] - let aux_iter = air_trace_pairs.par_iter_mut(); + let aux_iter = air_trace_pairs.par_iter_mut().enumerate(); #[cfg(not(feature = "parallel"))] - let aux_iter = air_trace_pairs.iter_mut(); - let bus_inputs_vec: Vec>> = aux_iter - .map(|(air, trace, _)| { - if air.has_aux_trace() { + let aux_iter = air_trace_pairs.iter_mut().enumerate(); + let mut bus_inputs_vec: Vec>> = aux_iter + .map(|(idx, (air, trace, _))| { + let retired = stream_traces && provider.is_some_and(|p| p.is_retired(idx)); + if air.has_aux_trace() && !retired { air.build_auxiliary_trace(*trace, &lookup_challenges) } else { None @@ -1673,6 +1810,16 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let phase_start = Instant::now(); + // Capture the pre-fork shared transcript state. Phase D (batched FRI) + // clones this per chunk and replays chunk-local data + // (table_contributions, composition roots, all chunk-mate OOD + // evaluations) canonically to derive each bucket's `delta_fri` and + // query iotas. The verifier reconstructs an identical seed from proof + // data only. This is the shared state after Phase B (LogUp challenges + // sampled), before any per-table fork — it does NOT include per-table + // aux roots (those live only in the per-table forks below). + let pre_fork_transcript = transcript.clone(); + // Pre-fork all transcripts (cheap, sequential — must match verifier ordering) let mut table_transcripts: Vec<_> = (0..num_airs) .map(|idx| { @@ -1708,6 +1855,25 @@ pub trait IsStarkProver< let domain = &domains[idx]; let twiddles = &twiddle_caches[idx]; + // Retire-traces: rebuild this table's main trace, build its + // aux on top (the resident trace is an empty placeholder), and + // capture the bus inputs that Pass 1 skipped. The owned trace + // (main+aux) is dropped at the end of this closure. Other + // tables read aux columns from their resident trace. + let retired = stream_traces && provider.is_some_and(|p| p.is_retired(idx)); + let mut retired_bus: Option> = None; + let retired_trace; + let trace: &TraceTable = if retired { + let mut t = provider.unwrap().build_main(idx); + if air.has_aux_trace() { + retired_bus = air.build_auxiliary_trace(&mut t, &lookup_challenges); + } + retired_trace = t; + &retired_trace + } else { + trace + }; + if air.has_aux_trace() { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; let mut columns = trace.extract_columns_aux(lde_size); @@ -1738,20 +1904,24 @@ pub trait IsStarkProver< ProvingError::DiskSpill(format!("aux Merkle tree: {e}")) })?; } - Ok((Some(TableCommit::plain(tree, root)), columns)) + Ok((Some(TableCommit::plain(tree, root, stream_lde)), columns, retired_bus)) } else { - Ok((None, Vec::new())) + Ok((None, Vec::new(), retired_bus)) } }) .collect(); // Sequential: append aux roots to forked transcripts for (j, result) in chunk_aux.into_iter().enumerate() { - let (aux_commit, cached_aux) = result?; + let (aux_commit, cached_aux, retired_bus) = result?; if let Some(ref c) = aux_commit { table_transcripts[chunk_start + j].append_bytes(&c.root); } - aux_results.push((aux_commit, cached_aux)); + // Retired tables compute their bus inputs here (Pass 1 skipped them). + if retired_bus.is_some() { + bus_inputs_vec[chunk_start + j] = retired_bus; + } + aux_results.push((aux_commit, if stream_lde { Vec::new() } else { cached_aux })); } } @@ -1786,7 +1956,15 @@ pub trait IsStarkProver< } #[cfg(feature = "debug-checks")] - Self::run_debug_checks(&air_trace_pairs, &commitments, &domains, &twiddle_caches); + Self::run_debug_checks( + &air_trace_pairs, + &commitments, + &domains, + &twiddle_caches, + provider, + stream_traces, + &lookup_challenges, + ); // ===================================================================== // Rounds 2-4: Parallel per-table proving in chunks of K @@ -1806,6 +1984,9 @@ pub trait IsStarkProver< )> = Vec::with_capacity(num_airs); let mut proofs = Vec::with_capacity(num_airs); + // Per-(chunk, lde_size-bucket) batched FRI instances, outer index = chunk. + let mut fri_chunk_buckets: Vec>> = + Vec::with_capacity(num_airs.div_ceil(k)); let mut lde_drain = cached_ldes.into_iter(); for chunk_start in (0..num_airs).step_by(k) { let chunk_end = (chunk_start + k).min(num_airs); @@ -1814,40 +1995,73 @@ pub trait IsStarkProver< let chunk_ldes: Vec> = lde_drain.by_ref().take(chunk_size).collect(); let chunk_commitments = &commitments[chunk_start..chunk_end]; - let chunk_transcripts = &mut table_transcripts[chunk_start..chunk_end]; - #[cfg(feature = "parallel")] - let iter = chunk_ldes - .into_par_iter() - .zip(chunk_commitments.par_iter()) - .zip(chunk_transcripts.par_iter_mut()) - .enumerate(); - #[cfg(not(feature = "parallel"))] - let iter = chunk_ldes - .into_iter() - .zip(chunk_commitments.iter()) - .zip(chunk_transcripts.iter_mut()) - .enumerate(); - - let chunk_results: Vec> = iter - .map(|(j, ((lde, commitment), table_transcript))| { + // ---- Pass 1: per-table Rounds 2-3 ---------------------------------- + // Advance each table's forked transcript through the OOD evaluations. + // No FRI yet: the intermediate results are collected so Round 4 can run + // as a separate per-chunk phase (which a later step batches across the + // chunk's tables into a single FRI per lde_size bucket). + let pass1: Vec> = { + let chunk_transcripts = &mut table_transcripts[chunk_start..chunk_end]; + + #[cfg(feature = "parallel")] + let iter = chunk_ldes + .into_par_iter() + .zip(chunk_commitments.par_iter()) + .zip(chunk_transcripts.par_iter_mut()) + .enumerate(); + #[cfg(not(feature = "parallel"))] + let iter = chunk_ldes + .into_iter() + .zip(chunk_commitments.iter()) + .zip(chunk_transcripts.iter_mut()) + .enumerate(); + + iter.map(|(j, ((lde, commitment), table_transcript))| { let idx = chunk_start + j; let (air, trace, pub_inputs) = &air_trace_pairs[idx]; - let _ = trace; // used by instruments let domain = &domains[idx]; #[cfg(feature = "instruments")] let table_start = Instant::now(); - // Build Round1 from cached LDE (consumed by value, no recomputation). - let round_1_result = - commitment.build_round1(lde, air.step_size(), domain.blowup_factor); + // Build Round1 from the cached LDE (consumed by value), or, in + // streaming retire-LDE mode, recompute it from the trace. The + // trace is the resident one unless this table is retired, in + // which case it is rebuilt on demand (main + aux) and dropped + // after the reconstruct. Determinism (C.2a) makes this trace + // byte-identical to the Phase A/C builds, so the LDE matches. + let round_1_result = if stream_lde { + let _ = lde; // empty placeholder when streaming + let retired = + stream_traces && provider.is_some_and(|p| p.is_retired(idx)); + let retired_trace; + let recon_trace: &TraceTable = if retired { + let mut t = provider.unwrap().build_main(idx); + if air.has_aux_trace() { + air.build_auxiliary_trace(&mut t, &lookup_challenges); + } + retired_trace = t; + &retired_trace + } else { + trace + }; + Self::reconstruct_round1( + *air, + recon_trace, + domain, + commitment, + &twiddle_caches[idx], + )? + } else { + commitment.build_round1(lde, air.step_size(), domain.blowup_factor) + }; if let Some(ref bpi) = round_1_result.bus_public_inputs { table_transcript.append_field_element(&bpi.table_contribution); } - let proof = Self::prove_rounds_2_to_4( + let (round_2_result, round_3_result, z, deep_coeffs) = Self::prove_rounds_2_to_3( *air, *pub_inputs, &round_1_result, @@ -1856,32 +2070,304 @@ pub trait IsStarkProver< )?; #[cfg(feature = "instruments")] - let table_timing = { - let sub_ops = crate::instruments::take_round_sub_ops().unwrap_or_default(); + let instr1 = { + let zero = std::time::Duration::ZERO; ( air.name().to_string(), trace.num_rows(), table_start.elapsed(), - sub_ops, + crate::instruments::take_r2_sub().unwrap_or((zero, zero, zero)), + crate::instruments::take_r3_ood().unwrap_or(zero), ) }; #[cfg(feature = "instruments")] - return Ok((proof, table_timing)); + return Ok((round_1_result, round_2_result, round_3_result, z, deep_coeffs, instr1)); #[cfg(not(feature = "instruments"))] - Ok(proof) + Ok((round_1_result, round_2_result, round_3_result, z, deep_coeffs)) }) - .collect(); + .collect() + }; + let intermediates = pass1.into_iter().collect::, ProvingError>>()?; + + // ---- Pass 2: per-(chunk, lde_size) batched FRI -------------------- + // Group the chunk's tables into lde_size buckets. For each bucket, + // derive a single `delta_fri` from a shared `bucket_seed`, fold each + // member's DEEP-composition LDE with successive powers of `delta_fri` + // into one polynomial, run ONE FRI commit + grinding + query, and + // produce per-table DEEP openings at the bucket-shared iotas. + // + // The per-table forks (`table_transcripts`) are NOT advanced past the + // OOD evaluations here: FRI challenges come exclusively from the + // chunk-shared `bucket_seed` below, so prover and verifier derive + // byte-identical `delta_fri`/iotas. + + // Unpack `intermediates` (chunk-local order) into parallel vectors. + #[cfg(feature = "instruments")] + let mut chunk_instr1: Vec<_> = Vec::with_capacity(chunk_size); + let mut chunk_round1: Vec> = + Vec::with_capacity(chunk_size); + let mut chunk_round2: Vec> = Vec::with_capacity(chunk_size); + let mut chunk_round3: Vec> = Vec::with_capacity(chunk_size); + let mut chunk_z: Vec> = Vec::with_capacity(chunk_size); + let mut chunk_deep_coeffs: Vec> = + Vec::with_capacity(chunk_size); + for intermediate in intermediates { + #[cfg(feature = "instruments")] + let (round_1_result, round_2_result, round_3_result, z, deep_coeffs, instr1) = + intermediate; + #[cfg(not(feature = "instruments"))] + let (round_1_result, round_2_result, round_3_result, z, deep_coeffs) = intermediate; + chunk_round1.push(round_1_result); + chunk_round2.push(round_2_result); + chunk_round3.push(round_3_result); + chunk_z.push(z); + chunk_deep_coeffs.push(deep_coeffs); + #[cfg(feature = "instruments")] + chunk_instr1.push(instr1); + } - for result in chunk_results { + // ---- Build the chunk-shared bucket seed ---------------------------- + // bucket_seed byte order (must match the verifier exactly): + // 1. pre-fork shared transcript state (after Phase B) + // 2. for each chunk-local j (ascending): table_contribution (if any) + // 3. for each chunk-local j (ascending): composition_poly_root + // 4. for each chunk-local j (ascending): all trace_ood_evaluations + // columns (column-major), then composition_poly_parts_ood + let mut bucket_seed = pre_fork_transcript.clone(); + for r1 in chunk_round1.iter() { + if let Some(ref bpi) = r1.bus_public_inputs { + bucket_seed.append_field_element(&bpi.table_contribution); + } + } + for r2 in chunk_round2.iter() { + bucket_seed.append_bytes(&r2.composition_poly_root); + } + for r3 in chunk_round3.iter() { + for col in r3.trace_ood_evaluations.columns().iter() { + for elem in col.iter() { + bucket_seed.append_field_element(elem); + } + } + for elem in r3.composition_poly_parts_ood_evaluation.iter() { + bucket_seed.append_field_element(elem); + } + } + + // ---- Bucket by lde_size (first-encounter order) -------------------- + let mut bucket_members: Vec> = Vec::new(); + let mut bucket_lde_sizes: Vec = Vec::new(); + for j in 0..chunk_size { + let lde_size = domains[chunk_start + j].lde_roots_of_unity_coset.len(); + match bucket_lde_sizes.iter().position(|&s| s == lde_size) { + Some(b) => bucket_members[b].push(j), + None => { + bucket_lde_sizes.push(lde_size); + bucket_members.push(vec![j]); + } + } + } + + let mut chunk_buckets: Vec> = + Vec::with_capacity(bucket_members.len()); + // Per chunk-local index: the bucket-shared iotas used for openings. + let mut iotas_for: Vec>> = (0..chunk_size).map(|_| None).collect(); + #[cfg(feature = "instruments")] + let mut fri_sub_for: Vec< + Option<(std::time::Duration, std::time::Duration, std::time::Duration)>, + > = (0..chunk_size).map(|_| None).collect(); + + for (members, &lde_size) in bucket_members.iter().zip(bucket_lde_sizes.iter()) { + let mut bt = bucket_seed.clone(); + bt.append_bytes(&(lde_size as u64).to_le_bytes()); + let delta_fri: FieldElement = bt.sample_field_element(); + + let leader_idx = chunk_start + members[0]; + let (leader_air, _, _) = &air_trace_pairs[leader_idx]; + let leader_domain = &domains[leader_idx]; + let coset_offset = FieldElement::::from( + leader_air.context().proof_options.coset_offset, + ); + + // Streaming bucket combine: build each member's DEEP LDE one at a + // time, fold into the accumulator with delta_fri^i, then drop. + // Peak DEEP memory inside this loop: 2 × |LDE|. + #[cfg(feature = "instruments")] + let mut deep_comp_dur = std::time::Duration::ZERO; + #[cfg(feature = "instruments")] + let mut deep_extend_dur = std::time::Duration::ZERO; + let mut combined: Vec> = Vec::new(); + let mut delta_power = FieldElement::::one(); + for (i_local, &j) in members.iter().enumerate() { + let idx = chunk_start + j; + let domain_j = &domains[idx]; + + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + let deep_evals = Self::compute_deep_composition_poly_evaluations( + &chunk_round1[j].lde_trace, + &chunk_round2[j], + &chunk_round3[j], + &chunk_z[j], + domain_j, + &domain_j.trace_primitive_root, + &chunk_deep_coeffs[j].gammas, + &chunk_deep_coeffs[j].trace_term_coeffs, + ); + #[cfg(feature = "instruments")] + { + deep_comp_dur += t_sub.elapsed(); + } + + // DEEP evaluations are already at the LDE points; bit-reverse + // for FRI (no FFT extension needed). + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + let mut deep_lde = deep_evals; + in_place_bit_reverse_permute(&mut deep_lde); + #[cfg(feature = "instruments")] + { + deep_extend_dur += t_sub.elapsed(); + } + + debug_assert_eq!(deep_lde.len(), lde_size); + if i_local == 0 { + // First member: assign directly to avoid mul-by-one. + combined = deep_lde; + } else { + for (acc, src) in combined.iter_mut().zip(deep_lde.iter()) { + *acc = &*acc + &delta_power * src; + } + } + delta_power = &delta_power * &delta_fri; + } + + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + let (last_value, fri_layers) = + fri::commit_phase_from_evaluations::( + leader_domain.root_order as usize, + combined, + &mut bt, + &coset_offset, + lde_size, + ); + #[cfg(feature = "instruments")] + let fri_commit_dur = t_sub.elapsed(); + + // grinding: generate nonce and append it to the bucket transcript. + let security_bits = leader_air.context().proof_options.grinding_factor; + let nonce = if security_bits > 0 { + let nonce_value = grinding::generate_nonce(&bt.state(), security_bits) + .expect("bucket-FRI grinding nonce not found"); + bt.append_bytes(&nonce_value.to_be_bytes()); + Some(nonce_value) + } else { + None + }; + + let number_of_queries = leader_air.options().fri_number_of_queries; + #[cfg(feature = "instruments")] + let t_sub = Instant::now(); + let iotas = + Self::sample_query_indexes(number_of_queries, leader_domain, &mut bt); + let decommitments = fri::query_phase(&fri_layers, &iotas); + #[cfg(feature = "instruments")] + let queries_dur = t_sub.elapsed(); + let layer_roots: Vec = fri_layers + .iter() + .map(|layer| layer.merkle_tree.root) + .collect(); + + chunk_buckets.push(crate::proof::stark::ChunkBucketFri { + lde_size: lde_size as u32, + members: members.clone(), + layer_roots, + last_value, + decommitments, + nonce, + }); + + for &j in members.iter() { + iotas_for[j] = Some(iotas.clone()); + } + // Attribute the bucket's FRI sub-timings to the leader member. #[cfg(feature = "instruments")] { - let (proof, timing) = result?; - proofs.push(proof); - table_timings.push(timing); + fri_sub_for[members[0]] = Some(( + deep_comp_dur, + deep_extend_dur, + fri_commit_dur + queries_dur, + )); + for &j in members.iter().skip(1) { + fri_sub_for[j] = Some(( + std::time::Duration::ZERO, + std::time::Duration::ZERO, + std::time::Duration::ZERO, + )); + } + } + } + fri_chunk_buckets.push(chunk_buckets); + + // ---- Per chunk-mate: open DEEP at bucket-shared iotas + assemble --- + for j in 0..chunk_size { + let idx = chunk_start + j; + let (_air, _trace, pub_inputs) = &air_trace_pairs[idx]; + let domain = &domains[idx]; + let round_1_result = &chunk_round1[j]; + let round_2_result = &chunk_round2[j]; + let round_3_result = &chunk_round3[j]; + let iotas = iotas_for[j] + .as_ref() + .expect("every chunk-mate belongs to a bucket"); + + let deep_poly_openings = Self::open_deep_composition_poly( + domain, + round_1_result, + round_2_result, + iotas, + ); + + let proof = StarkProof { + lde_trace_main_merkle_root: round_1_result.main.root, + lde_trace_aux_merkle_root: round_1_result.aux.as_ref().map(|x| x.root), + lde_trace_precomputed_merkle_root: round_1_result.main.precomputed_root, + trace_ood_evaluations: round_3_result.trace_ood_evaluations.clone(), + composition_poly_root: round_2_result.composition_poly_root, + composition_poly_parts_ood_evaluation: round_3_result + .composition_poly_parts_ood_evaluation + .clone(), + deep_poly_openings, + bus_public_inputs: round_1_result.bus_public_inputs.clone(), + public_inputs: (*pub_inputs).clone(), + trace_length: domain.interpolation_domain_size, + }; + proofs.push(proof); + + #[cfg(feature = "instruments")] + { + let (name, rows, pass1_dur, (r2_constraints, r2_fft, r2_merkle), r3_ood) = + chunk_instr1[j].clone(); + let (deep_comp, deep_extend, fri_dur) = fri_sub_for[j] + .clone() + .unwrap_or(( + std::time::Duration::ZERO, + std::time::Duration::ZERO, + std::time::Duration::ZERO, + )); + let sub_ops = crate::instruments::TableSubOps { + constraints: r2_constraints, + comp_decompose: r2_fft, + comp_commit: r2_merkle, + ood: r3_ood, + deep_comp, + deep_extend, + fri_commit: fri_dur, + queries: std::time::Duration::ZERO, + }; + table_timings.push((name, rows, pass1_dur, sub_ops)); } - #[cfg(not(feature = "instruments"))] - proofs.push(result?); } } @@ -1901,17 +2387,23 @@ pub trait IsStarkProver< }); } - Ok(MultiProof { proofs }) + Ok(MultiProof { + proofs, + fri_chunk_buckets, + chunk_size: k as u32, + }) } - /// Generate a STARK proof for a single AIR/trace. - /// This is equivalent to calling `multi_prove` with a single-element slice. + /// Generate a STARK proof for a single AIR/trace, returned as a one-element + /// [`MultiProof`]. The batched-FRI bucket data lives at the multi-proof level + /// (`MultiProof::fri_chunk_buckets`), so single-table callers consume the + /// wrapper directly (chunk of size 1 = one bucket = one FRI instance). fn prove( air: &dyn AIR, trace: &mut TraceTable, pub_inputs: &PI, transcript: &mut (impl IsStarkTranscript + Clone + Send), - ) -> Result, ProvingError> + ) -> Result, ProvingError> where FieldElement: AsBytes, FieldElement: AsBytes, @@ -1928,26 +2420,35 @@ pub trait IsStarkProver< #[cfg(feature = "disk-spill")] StorageMode::Ram, ) - .map(|mut multi_proof| multi_proof.proofs.remove(0)) } // TODO: propagate errors instead of unwrap() in open_deep_composition_poly and FRI operations - /// Executes rounds 2-4 and generates a STARK proof for the trace `main_trace` with public inputs `pub_inputs`. - /// Warning: the transcript must be safely initializated before passing it to this method. - fn prove_rounds_2_to_4( + /// Rounds 2-3 (per-table): build the composition polynomial (Round 2) and + /// evaluate it plus the trace at the out-of-domain point `z` (Round 3), + /// appending the composition root and OOD evaluations to `transcript`. + /// No FRI is run here — Round 4 (per-table today, per-chunk batched FRI in + /// streaming) consumes the returned `(Round2, Round3, z)`. + #[allow(clippy::type_complexity)] + fn prove_rounds_2_to_3( air: &dyn AIR, pub_inputs: &PI, round_1_result: &Round1, transcript: &mut impl IsStarkTranscript, domain: &Domain, - ) -> Result, ProvingError> + ) -> Result< + ( + Round2, + Round3, + FieldElement, + DeepCoeffs, + ), + ProvingError, + > where FieldElement: AsBytes, FieldElement: AsBytes, PI: Send + Sync + Clone, { - info!("Started proof generation..."); - // =================================== // ==========| Round 2 |========== // =================================== @@ -2008,7 +2509,7 @@ pub trait IsStarkProver< &z, ); #[cfg(feature = "instruments")] - let round_3_dur = t_r3.elapsed(); + crate::instruments::store_r3_ood(t_r3.elapsed()); // >>>> Send values: tⱼ(zgᵏ) let trace_ood_evaluations_columns = round_3_result.trace_ood_evaluations.columns(); @@ -2023,75 +2524,11 @@ pub trait IsStarkProver< transcript.append_field_element(element); } - // =================================== - // ==========| Round 4 |========== - // =================================== + // <<<< Receive challenge: 𝛾 (DEEP) — sampled here so Round 4 only builds the + // DEEP LDE and runs FRI. Same transcript position as before the split. + let deep_coeffs = Self::sample_deep_coeffs(air, &round_2_result, transcript); - // Part of this round is running FRI, which is an interactive - // protocol on its own. Therefore we pass it the transcript - // to simulate the interactions with the verifier. - let round_4_result = Self::round_4_compute_and_run_fri_on_the_deep_composition_polynomial( - air, - domain, - round_1_result, - &round_2_result, - &round_3_result, - &z, - transcript, - ); - - #[cfg(feature = "instruments")] - { - let zero = std::time::Duration::ZERO; - let (r2_constraints, r2_fft, r2_merkle) = - crate::instruments::take_r2_sub().unwrap_or((zero, zero, zero)); - let (r4_fft, r4_merkle, r4_deep_comp, r4_queries) = - crate::instruments::take_r4_sub().unwrap_or((zero, zero, zero, zero)); - crate::instruments::store_round_sub_ops(crate::instruments::TableSubOps { - constraints: r2_constraints, - comp_decompose: r2_fft, - comp_commit: r2_merkle, - ood: round_3_dur, - deep_comp: r4_deep_comp, - deep_extend: r4_fft, - fri_commit: r4_merkle, - queries: r4_queries, - }); - } - - info!("End proof generation"); - - Ok(StarkProof { - // [t] - lde_trace_main_merkle_root: round_1_result.main.root, - // [t] - lde_trace_aux_merkle_root: round_1_result.aux.as_ref().map(|x| x.root), - // For preprocessed tables: commitment to precomputed columns only - lde_trace_precomputed_merkle_root: round_1_result.main.precomputed_root, - // tⱼ(zgᵏ) - trace_ood_evaluations: round_3_result.trace_ood_evaluations, - // [H₁] and [H₂] - composition_poly_root: round_2_result.composition_poly_root, - // Hᵢ(z^N) - composition_poly_parts_ood_evaluation: round_3_result - .composition_poly_parts_ood_evaluation, - // [pₖ] - fri_layers_merkle_roots: round_4_result.fri_layers_merkle_roots, - // pₙ - fri_last_value: round_4_result.fri_last_value, - // Open(p₀(D₀), 𝜐ₛ), Open(pₖ(Dₖ), −𝜐ₛ^(2ᵏ)) - query_list: round_4_result.query_list, - // Open(H₁(D_LDE, 𝜐₀), Open(H₂(D_LDE, 𝜐₀), Open(tⱼ(D_LDE), 𝜐₀) - // Open(H₁(D_LDE, -𝜐ᵢ), Open(H₂(D_LDE, -𝜐ᵢ), Open(tⱼ(D_LDE), -𝜐ᵢ) - deep_poly_openings: round_4_result.deep_poly_openings, - // nonce obtained from grinding - nonce: round_4_result.nonce, - // Bus interaction public inputs (for boundary constraints and bus balance check) - bus_public_inputs: round_1_result.bus_public_inputs.clone(), - // Public inputs for boundary constraints - public_inputs: pub_inputs.clone(), - trace_length: domain.interpolation_domain_size, - }) + Ok((round_2_result, round_3_result, z, deep_coeffs)) } } diff --git a/crypto/stark/src/tests/bus_tests/soundness_tests.rs b/crypto/stark/src/tests/bus_tests/soundness_tests.rs index fc718bf7c..bcedc7211 100644 --- a/crypto/stark/src/tests/bus_tests/soundness_tests.rs +++ b/crypto/stark/src/tests/bus_tests/soundness_tests.rs @@ -875,7 +875,7 @@ fn test_injected_bus_public_inputs_on_non_logup_air_rejected() { // Inject fake bus_public_inputs into a non-LogUp proof. // DummyAIR has has_trace_interaction() = false, so this must be rejected. - proof.bus_public_inputs = Some(BusPublicInputs { + proof.proofs[0].bus_public_inputs = Some(BusPublicInputs { table_contribution: FieldElement::::from(42u64), #[cfg(feature = "debug-checks")] per_bus_sums: Default::default(), diff --git a/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs b/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs index 4059ed481..da2f82106 100644 --- a/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs +++ b/crypto/stark/src/tests/prove_verify_roundtrip_tests.rs @@ -248,3 +248,94 @@ fn create_mul_air( transition_constraints, ) } + +/// THE leaf-drop correctness invariant: a proof produced with streaming +/// leaf-drop ON must be byte-identical to one produced with full trees +/// (`LAMBDA_STREAM_LDE=0`). Leaf-drop only changes *how* the prover produces +/// the Merkle paths (regenerating the one leaf-level sibling on demand), never +/// the path contents — so the serialized `MultiProof` bytes must match exactly. +/// +/// The env var is process-global, so this test serializes its two proving runs +/// under a mutex and restores the prior value. Other tests default to the unset +/// (full-tree) behavior; a transient leaf-drop window in a concurrent test would +/// still yield valid, byte-identical proofs (that is exactly what this asserts). +#[test] +fn leaf_drop_proof_is_byte_identical_to_full_tree() { + use std::sync::Mutex; + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + fn prove_once() -> Vec { + let cpu_main_columns = { + let add_column = vec![ + FE::one(), FE::zero(), FE::one(), FE::zero(), + FE::one(), FE::one(), FE::zero(), FE::zero(), + ]; + let mul_column = vec![ + FE::zero(), FE::one(), FE::zero(), FE::one(), + FE::zero(), FE::zero(), FE::one(), FE::one(), + ]; + let a_column = (1..=8u64).map(FE::from).collect::>(); + let b_column = (1..=8u64).map(|i| FE::from(i * 10)).collect::>(); + let c_column = vec![ + FE::from(11), FE::from(40), FE::from(33), FE::from(160), + FE::from(55), FE::from(66), FE::from(490), FE::from(640), + ]; + vec![add_column, mul_column, a_column, b_column, c_column] + }; + let mut cpu_trace = crate::trace::TraceTable::from_columns_main(cpu_main_columns, 1); + + let add_a = vec![FE::from(1), FE::from(3), FE::from(5), FE::from(6)]; + let add_b = vec![FE::from(10), FE::from(30), FE::from(50), FE::from(60)]; + let add_c = vec![FE::from(11), FE::from(33), FE::from(55), FE::from(66)]; + let add_m = vec![FE::one(), FE::one(), FE::one(), FE::one()]; + let mut add_trace = + crate::trace::TraceTable::from_columns_main(vec![add_a, add_b, add_c, add_m], 1); + + let mul_a = vec![FE::from(2), FE::from(4), FE::from(7), FE::from(8)]; + let mul_b = vec![FE::from(20), FE::from(40), FE::from(70), FE::from(80)]; + let mul_c = vec![FE::from(40), FE::from(160), FE::from(490), FE::from(640)]; + let mul_m = vec![FE::one(), FE::one(), FE::one(), FE::one()]; + let mut mul_trace = + crate::trace::TraceTable::from_columns_main(vec![mul_a, mul_b, mul_c, mul_m], 1); + + let proof_options = ProofOptions::default_test_options(); + let cpu_air = create_cpu_air(&proof_options); + let add_air = create_add_air(&proof_options); + let mul_air = create_mul_air(&proof_options); + + #[allow(clippy::type_complexity)] + let air_trace_pairs: Vec<( + &dyn AIR, + &mut crate::trace::TraceTable, + &(), + )> = vec![ + (&cpu_air, &mut cpu_trace, &()), + (&add_air, &mut add_trace, &()), + (&mul_air, &mut mul_trace, &()), + ]; + + let proofs = + multi_prove_ram(air_trace_pairs, &mut DefaultTranscript::::new(&[])).unwrap(); + serde_cbor::to_vec(&proofs).expect("serialize proofs") + } + + let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner()); + let prev = std::env::var("LAMBDA_STREAM_LDE").ok(); + + // SAFETY: single-threaded section guarded by ENV_LOCK; restored below. + unsafe { std::env::set_var("LAMBDA_STREAM_LDE", "0") }; + let full_tree_bytes = prove_once(); + + unsafe { std::env::set_var("LAMBDA_STREAM_LDE", "1") }; + let leaf_drop_bytes = prove_once(); + + match prev { + Some(v) => unsafe { std::env::set_var("LAMBDA_STREAM_LDE", v) }, + None => unsafe { std::env::remove_var("LAMBDA_STREAM_LDE") }, + } + + assert_eq!( + full_tree_bytes, leaf_drop_bytes, + "streaming leaf-drop proof must be byte-identical to the full-tree proof" + ); +} diff --git a/crypto/stark/src/tests/small_trace_tests.rs b/crypto/stark/src/tests/small_trace_tests.rs index 8373ae9d6..8495ce96c 100644 --- a/crypto/stark/src/tests/small_trace_tests.rs +++ b/crypto/stark/src/tests/small_trace_tests.rs @@ -19,7 +19,7 @@ type Felt = FieldElement; fn make_valid_simple_proof() -> ( SimpleAdditionAIR, - crate::proof::stark::StarkProof< + crate::proof::stark::MultiProof< GoldilocksField, GoldilocksField, SimpleAdditionPublicInputs, @@ -99,7 +99,7 @@ fn test_verify_fails_with_wrong_inputs() { let (air, mut proof) = make_valid_simple_proof(); // Tamper with the proof's public inputs - proof.public_inputs = SimpleAdditionPublicInputs { + proof.proofs[0].public_inputs = SimpleAdditionPublicInputs { a: Felt::from(99u64), // Wrong value - doesn't match trace b: Felt::from(2u64), }; @@ -124,11 +124,13 @@ fn test_verify_rejects_truncated_composition_poly_parts_ood() { let (air, mut proof) = make_valid_simple_proof(); assert!( - !proof.composition_poly_parts_ood_evaluation.is_empty(), + !proof.proofs[0] + .composition_poly_parts_ood_evaluation + .is_empty(), "test precondition: a valid proof has at least one composition poly part", ); // Drop one entry so the per-query opening has more parts than the header. - proof.composition_poly_parts_ood_evaluation.pop(); + proof.proofs[0].composition_poly_parts_ood_evaluation.pop(); assert!( !Verifier::verify( @@ -150,7 +152,7 @@ fn test_verify_rejects_opening_column_count_mismatch() { // Append a phantom extra evaluation column to the first query's // main-trace opening so the (base + aux) count exceeds `ood_evaluations_table_width`. - if let Some(opening) = proof.deep_poly_openings.first_mut() { + if let Some(opening) = proof.proofs[0].deep_poly_openings.first_mut() { let extra = opening .main_trace_polys .evaluations diff --git a/crypto/stark/src/verifier.rs b/crypto/stark/src/verifier.rs index 8091c8b32..ddc972d88 100644 --- a/crypto/stark/src/verifier.rs +++ b/crypto/stark/src/verifier.rs @@ -27,8 +27,6 @@ use math::{ }; use std::collections::HashMap; use std::marker::PhantomData; -#[cfg(feature = "instruments")] -use std::time::Instant; /// A default STARK verifier implementing `IsStarkVerifier`. pub struct Verifier< @@ -49,6 +47,7 @@ impl< /// A container holding the complete list of challenges sent to the prover along with the seed used /// to validate the proof-of-work nonce. +#[derive(Clone)] pub struct Challenges where FieldExtension: Send + Sync + IsField, @@ -238,55 +237,6 @@ pub trait IsStarkVerifier< composition_poly_claimed_ood_evaluation == composition_poly_ood_evaluation } - /// Reconstructs the Deep composition polynomial evaluations at the challenge indices values using the provided - /// openings of the trace polynomials and the composition polynomial parts. It then uses these to verify that the - /// FRI decommitments are valid and correspond to the Deep composition polynomial. - fn step_3_verify_fri( - proof: &StarkProof, - domain: &VerifierDomain, - challenges: &Challenges, - ) -> bool - where - FieldElement: AsBytes + Sync + Send, - FieldElement: AsBytes + Sync + Send, - { - let (deep_poly_evaluations, deep_poly_evaluations_sym) = - match Self::reconstruct_deep_composition_poly_evaluations_for_all_queries( - challenges, domain, proof, - ) { - Some(pair) => pair, - None => return false, - }; - - // verify FRI - let mut evaluation_point_inverse = challenges - .iotas - .iter() - .map(|iota| Self::query_challenge_to_evaluation_point(*iota, false, domain)) - .collect::>>(); - // Any zero evaluation point means a malformed query index, reject. - if FieldElement::inplace_batch_inverse(&mut evaluation_point_inverse).is_err() { - return false; - } - - proof - .query_list - .iter() - .zip(&challenges.iotas) - .zip(evaluation_point_inverse) - .enumerate() - .all(|(i, ((proof_s, iota_s), eval))| { - Self::verify_query_and_sym_openings( - proof, - &challenges.zetas, - *iota_s, - proof_s, - eval, - &deep_poly_evaluations[i], - &deep_poly_evaluations_sym[i], - ) - }) - } /// Returns the field element element of the domain `domain` corresponding to the given FRI query index challenge `iota`. /// Returns the LDE-coset element for FRI query challenge `iota`. The @@ -458,94 +408,6 @@ pub trait IsStarkVerifier< ) } - /// Verify a single FRI query - /// `zetas`: the vector of all challenges sent by the verifier to the prover at the commit - /// phase to fold polynomials. - /// `iota`: the index challenge of this FRI query. This index uniquely determines two elements 𝜐 and -𝜐 - /// of the evaluation domain of FRI layer 0. - /// `evaluation_point_inv`: precomputed value of 𝜐⁻¹. - /// `deep_composition_evaluation`: precomputed value of p₀(𝜐), where p₀ is the deep composition polynomial. - /// `deep_composition_evaluation_sym`: precomputed value of p₀(-𝜐), where p₀ is the deep composition polynomial. - fn verify_query_and_sym_openings( - proof: &StarkProof, - zetas: &[FieldElement], - iota: usize, - fri_decommitment: &FriDecommitment, - evaluation_point_inv: FieldElement, - deep_composition_evaluation: &FieldElement, - deep_composition_evaluation_sym: &FieldElement, - ) -> bool - where - FieldElement: AsBytes + Sync + Send, - FieldElement: AsBytes + Sync + Send, - { - let fri_layers_merkle_roots = &proof.fri_layers_merkle_roots; - let evaluation_point_vec: Vec> = - core::iter::successors(Some(evaluation_point_inv.square()), |evaluation_point| { - Some(evaluation_point.square()) - }) - .take(fri_layers_merkle_roots.len()) - .collect(); - - let p0_eval = deep_composition_evaluation; - let p0_eval_sym = deep_composition_evaluation_sym; - - // Reconstruct p₁(𝜐²) - let mut v = - (p0_eval + p0_eval_sym) + evaluation_point_inv * &zetas[0] * (p0_eval - p0_eval_sym); - let mut index = iota; - - // Handle case with 0 FRI layers (trace_length <= 2) - // In this case, the fold loop below doesn't iterate, so we need to verify - // the final value directly here. - if fri_layers_merkle_roots.is_empty() { - return v == proof.fri_last_value; - } - - // For each FRI layer, starting from the layer 1: use the proof to verify the validity of values pᵢ(−𝜐^(2ⁱ)) (given by the prover) and - // pᵢ(𝜐^(2ⁱ)) (computed on the previous iteration by the verifier). Then use them to obtain pᵢ₊₁(𝜐^(2ⁱ⁺¹)). - // Finally, check that the final value coincides with the given by the prover. - fri_layers_merkle_roots - .iter() - .enumerate() - .zip(&fri_decommitment.layers_auth_paths) - .zip(&fri_decommitment.layers_evaluations_sym) - .zip(evaluation_point_vec) - .fold( - true, - |result, - ( - (((i, merkle_root), auth_path_sym), evaluation_sym), - evaluation_point_inv, - )| { - // Verify opening Open(pᵢ(Dₖ), −𝜐^(2ⁱ)) and Open(pᵢ(Dₖ), 𝜐^(2ⁱ)). - // `v` is pᵢ(𝜐^(2ⁱ)). - // `evaluation_sym` is pᵢ(−𝜐^(2ⁱ)). - let openings_ok = Self::verify_fri_layer_openings( - merkle_root, - auth_path_sym, - &v, - evaluation_sym, - index, - ); - - // Update `v` with next value pᵢ₊₁(𝜐^(2ⁱ⁺¹)). - v = (&v + evaluation_sym) + evaluation_point_inv * &zetas[i + 1] * (&v - evaluation_sym); - - // Update index for next iteration. The index of the squares in the next layer - // is obtained by halving the current index. This is due to the bit-reverse - // ordering of the elements in the Merkle tree. - index >>= 1; - - if i < fri_decommitment.layers_evaluations_sym.len() - 1 { - result & openings_ok - } else { - // Check that final value is the given by the prover - result & (v == proof.fri_last_value) & openings_ok - } - }, - ) - } fn reconstruct_deep_composition_poly_evaluations_for_all_queries( challenges: &Challenges, @@ -810,47 +672,14 @@ pub trait IsStarkVerifier< } } - // ===================================================================== - // Phase C + Rounds 2-4: Forked per table - // ===================================================================== - // Each table gets an independent transcript fork (cloned from the shared - // state after Phase B, domain-separated by table index). This matches - // the prover's forking and makes per-table verification independent. - - for (idx, (air, proof)) in airs.iter().zip(&multi_proof.proofs).enumerate() { - // Must match prover: fork with domain separator for multi-table, - // use original transcript directly for single-table. - let num_tables = airs.len(); - let mut table_transcript = transcript.clone(); - if num_tables > 1 { - table_transcript.append_bytes(&(idx as u64).to_le_bytes()); - } - - // Phase C: replay aux commitment - if let Some(root) = proof.lde_trace_aux_merkle_root { - table_transcript.append_bytes(&root); - } - - // Bind table_contribution (L) to transcript, matching prover. - if let Some(ref bpi) = proof.bus_public_inputs { - table_transcript.append_field_element(&bpi.table_contribution); - } - - // Rounds 2-4: verify - if !Self::verify_rounds_2_to_4( - *air, - proof, - &mut table_transcript, - lookup_challenges.clone(), - ) { - error!( - "Table {} failed verify_rounds_2_to_4 (num_constraints={}, trace_cols={})", - idx, - air.context().num_transition_constraints, - air.context().trace_columns - ); - return false; - } + // Phase C + Rounds 2-3 (no FRI) per table, then Phase D batched FRI. + if !Self::verify_chunks_phase_c_d( + airs, + multi_proof, + transcript, + &lookup_challenges, + ) { + return false; } // ===================================================================== @@ -888,10 +717,12 @@ pub trait IsStarkVerifier< true } - /// Verify a single STARK proof. - /// This is equivalent to calling `multi_verify` with a single-element slice. + /// Verify a single-table proof, supplied as a one-element [`MultiProof`] + /// (the shape returned by `Prover::prove`). The batched-FRI bucket data + /// lives at the multi-proof level, so single-table verification consumes the + /// wrapper directly. fn verify( - proof: &StarkProof, + proof: &MultiProof, air: &dyn AIR, transcript: &mut (impl IsStarkTranscript + Clone), ) -> bool @@ -900,15 +731,16 @@ pub trait IsStarkVerifier< FieldElement: AsBytes + Sync + Send, PI: Clone, { - let multi_proof = MultiProof { - proofs: vec![proof.clone()], - }; - Self::multi_verify(&[air], &multi_proof, transcript, &FieldElement::zero()) + Self::multi_verify(&[air], proof, transcript, &FieldElement::zero()) } - /// Replays rounds 2, 3 and 4 of the protocol for a given proof, assuming round 1 has - /// already been replayed and the RAP challenges are known. - fn replay_rounds_after_round_1( + /// Replays rounds 2 and 3 of the protocol for a given proof, assuming round 1 + /// has already been replayed and the RAP challenges are known. Stops right + /// after sampling the DEEP gamma coefficients — FRI challenges (zetas, iotas, + /// grinding) are NOT derived here; they come from the chunk-shared bucket + /// seed in Phase D. The returned `Challenges` has empty `zetas`/`iotas` and a + /// zeroed `grinding_seed`, filled in per bucket by the caller. + fn replay_rounds_2_to_3( air: &dyn AIR, proof: &StarkProof, domain: &VerifierDomain, @@ -971,8 +803,11 @@ pub trait IsStarkVerifier< } // =================================== - // ==========| Round 4 |========== + // ==========| Round 4 (DEEP coeffs only) |========== // =================================== + // Sample the DEEP gamma coefficients. FRI commit/grinding/query sampling + // does NOT happen on this per-table fork — it is done per bucket from the + // chunk-shared seed in Phase D. let num_terms_composition_poly = proof.composition_poly_parts_ood_evaluation.len(); let num_terms_trace = @@ -995,158 +830,375 @@ pub trait IsStarkVerifier< // <<<< Receive challenges: 𝛾ⱼ, 𝛾ⱼ' let gammas = deep_composition_coefficients; - // FRI commit phase - let merkle_roots = &proof.fri_layers_merkle_roots; - let mut zetas = merkle_roots - .iter() - .map(|root| { - // >>>> Send challenge 𝜁ₖ - let element = transcript.sample_field_element(); - // <<<< Receive commitment: [pₖ] (the first one is [p₀]) - transcript.append_bytes(root); - element - }) - .collect::>>(); - - // >>>> Send challenge 𝜁ₙ₋₁ - zetas.push(transcript.sample_field_element()); - - // <<<< Receive value: pₙ - transcript.append_field_element(&proof.fri_last_value); - - // Receive grinding value - let security_bits = air.context().proof_options.grinding_factor; - let mut grinding_seed = [0u8; 32]; - if security_bits > 0 - && let Some(nonce_value) = proof.nonce - { - grinding_seed = transcript.state(); - transcript.append_bytes(&nonce_value.to_be_bytes()); - } - - // FRI query phase - // <<<< Send challenges 𝜄ₛ (iota_s) - let number_of_queries = air.options().fri_number_of_queries; - let iotas = Self::sample_query_indexes(number_of_queries, domain, transcript); - Challenges { z, boundary_coeffs, transition_coeffs, trace_term_coeffs, gammas, - zetas, - iotas, + zetas: Vec::new(), + iotas: Vec::new(), rap_challenges, - grinding_seed, + grinding_seed: [0u8; 32], } } - /// Verifies a single table after round 1 has been replayed. - fn verify_rounds_2_to_4( - air: &dyn AIR, - proof: &StarkProof, - transcript: &mut impl IsStarkTranscript, - rap_challenges: Vec>, + fn verify_chunks_phase_c_d( + airs: &[&dyn AIR], + multi_proof: &MultiProof, + transcript: &(impl IsStarkTranscript + Clone), + lookup_challenges: &[FieldElement], ) -> bool where FieldElement: AsBytes + Sync + Send, FieldElement: AsBytes + Sync + Send, { - let domain = new_verifier_domain(air, proof.trace_length); + let num_tables = airs.len(); + let pre_fork_transcript = transcript.clone(); - // Verify there are enough queries - if proof.query_list.len() < air.options().fri_number_of_queries { + // -- Phase C: per-table Rounds 2-3 (no FRI) -- + let mut domains: Vec> = Vec::with_capacity(num_tables); + let mut table_challenges: Vec> = Vec::with_capacity(num_tables); + for (idx, (air, proof)) in airs.iter().zip(&multi_proof.proofs).enumerate() { + let domain = new_verifier_domain(*air, proof.trace_length); + let mut table_transcript = transcript.clone(); + if num_tables > 1 { + table_transcript.append_bytes(&(idx as u64).to_le_bytes()); + } + if let Some(root) = proof.lde_trace_aux_merkle_root { + table_transcript.append_bytes(&root); + } + if let Some(ref bpi) = proof.bus_public_inputs { + table_transcript.append_field_element(&bpi.table_contribution); + } + let challenges = Self::replay_rounds_2_to_3( + *air, + proof, + &domain, + &mut table_transcript, + lookup_challenges.to_vec(), + ); + domains.push(domain); + table_challenges.push(challenges); + } + + // -- Phase D: per-(chunk, lde_size) batched FRI -- + let k = (multi_proof.chunk_size as usize).max(1); + let expected_num_chunks = num_tables.div_ceil(k); + if multi_proof.fri_chunk_buckets.len() != expected_num_chunks { + error!("fri_chunk_buckets chunk count mismatch"); return false; } + for (chunk_idx, chunk_start) in (0..num_tables).step_by(k).enumerate() { + let chunk_end = (chunk_start + k).min(num_tables); + let chunk_size = chunk_end - chunk_start; + let bucket_seed = + Self::build_bucket_seed(&pre_fork_transcript, multi_proof, chunk_start, chunk_size); + + let mut bucket_members: Vec> = Vec::new(); + let mut bucket_lde_sizes: Vec = Vec::new(); + for j in 0..chunk_size { + let lde_size = domains[chunk_start + j].lde_length; + match bucket_lde_sizes.iter().position(|&s| s == lde_size) { + Some(b) => bucket_members[b].push(j), + None => { + bucket_lde_sizes.push(lde_size); + bucket_members.push(vec![j]); + } + } + } - #[cfg(feature = "instruments")] - println!("- Started step 1: Recover challenges"); - #[cfg(feature = "instruments")] - let timer1 = Instant::now(); + let proof_buckets = &multi_proof.fri_chunk_buckets[chunk_idx]; + if proof_buckets.len() != bucket_members.len() { + error!("chunk {chunk_idx}: bucket count mismatch"); + return false; + } + for (b, (members, &lde_size)) in + bucket_members.iter().zip(bucket_lde_sizes.iter()).enumerate() + { + if !Self::verify_one_bucket( + airs, + multi_proof, + &domains, + &table_challenges, + &bucket_seed, + chunk_start, + chunk_idx, + b, + members, + lde_size, + ) { + return false; + } + } + } + true + } - let challenges = - Self::replay_rounds_after_round_1(air, proof, &domain, transcript, rap_challenges); + /// Build the chunk-shared `bucket_seed`. Byte order MUST match the prover: + /// pre-fork shared state, then for each chunk-local index (ascending): + /// table_contribution (if any), then composition_poly_root, then OOD evals + /// (all trace_ood columns column-major, then composition_poly_parts_ood). + fn build_bucket_seed( + pre_fork_transcript: &T, + multi_proof: &MultiProof, + chunk_start: usize, + chunk_size: usize, + ) -> T + where + T: IsStarkTranscript + Clone, + FieldElement: AsBytes, + FieldElement: AsBytes, + { + let mut bucket_seed = pre_fork_transcript.clone(); + for j in 0..chunk_size { + let proof = &multi_proof.proofs[chunk_start + j]; + if let Some(ref bpi) = proof.bus_public_inputs { + bucket_seed.append_field_element(&bpi.table_contribution); + } + } + for j in 0..chunk_size { + let proof = &multi_proof.proofs[chunk_start + j]; + bucket_seed.append_bytes(&proof.composition_poly_root); + } + for j in 0..chunk_size { + let proof = &multi_proof.proofs[chunk_start + j]; + for col in proof.trace_ood_evaluations.columns().iter() { + for elem in col.iter() { + bucket_seed.append_field_element(elem); + } + } + for elem in proof.composition_poly_parts_ood_evaluation.iter() { + bucket_seed.append_field_element(elem); + } + } + bucket_seed + } - // verify grinding - let security_bits = air.context().proof_options.grinding_factor; + /// Verify one (chunk, lde_size) bucket: derive its FRI challenges from the + /// chunk-shared `bucket_seed`, reconstruct each member's DEEP evaluations at + /// the bucket-shared iotas, combine them with `delta_fri` powers, verify the + /// batched FRI fold, and verify each member's per-table openings. + #[allow(clippy::too_many_arguments)] + fn verify_one_bucket( + airs: &[&dyn AIR], + multi_proof: &MultiProof, + domains: &[VerifierDomain], + table_challenges: &[Challenges], + bucket_seed: &(impl IsStarkTranscript + Clone), + chunk_start: usize, + chunk_idx: usize, + b: usize, + members: &[usize], + lde_size: usize, + ) -> bool + where + FieldElement: AsBytes + Sync + Send, + FieldElement: AsBytes + Sync + Send, + { + let bucket = &multi_proof.fri_chunk_buckets[chunk_idx][b]; + if bucket.members != *members || bucket.lde_size as usize != lde_size { + error!("chunk {chunk_idx} bucket {b}: members/lde_size mismatch"); + return false; + } + + let leader_idx = chunk_start + members[0]; + let leader_air = airs[leader_idx]; + let leader_domain = &domains[leader_idx]; + + let mut bt = bucket_seed.clone(); + bt.append_bytes(&(lde_size as u64).to_le_bytes()); + let delta_fri: FieldElement = bt.sample_field_element(); + + let mut zetas = bucket + .layer_roots + .iter() + .map(|root| { + let element = bt.sample_field_element(); + bt.append_bytes(root); + element + }) + .collect::>>(); + zetas.push(bt.sample_field_element()); + bt.append_field_element(&bucket.last_value); + + let security_bits = leader_air.context().proof_options.grinding_factor; if security_bits > 0 { - let nonce_is_valid = proof.nonce.is_some_and(|nonce_value| { - grinding::is_valid_nonce(&challenges.grinding_seed, nonce_value, security_bits) + let grinding_seed = bt.state(); + let nonce_is_valid = bucket.nonce.is_some_and(|nonce_value| { + grinding::is_valid_nonce(&grinding_seed, nonce_value, security_bits) }); - if !nonce_is_valid { #[cfg(not(feature = "test_fiat_shamir"))] - error!("Grinding factor not satisfied"); + error!("chunk {chunk_idx} bucket {b}: grinding factor not satisfied"); return false; } + if let Some(nonce_value) = bucket.nonce { + bt.append_bytes(&nonce_value.to_be_bytes()); + } } - #[cfg(feature = "instruments")] - let elapsed1 = timer1.elapsed(); - #[cfg(feature = "instruments")] - println!(" Time spent: {:?}", elapsed1); - - #[cfg(feature = "instruments")] - println!("- Started step 2: Verify claimed polynomial"); - #[cfg(feature = "instruments")] - let timer2 = Instant::now(); - - if !Self::step_2_verify_claimed_composition_polynomial(air, proof, &domain, &challenges) { - #[cfg(not(feature = "test_fiat_shamir"))] - error!("Composition Polynomial verification failed"); + let number_of_queries = leader_air.options().fri_number_of_queries; + let iotas = Self::sample_query_indexes(number_of_queries, leader_domain, &mut bt); + if bucket.decommitments.len() < number_of_queries { + error!("chunk {chunk_idx} bucket {b}: too few FRI decommitments"); return false; } - #[cfg(feature = "instruments")] - let elapsed2 = timer2.elapsed(); - #[cfg(feature = "instruments")] - println!(" Time spent: {:?}", elapsed2); - #[cfg(feature = "instruments")] - println!("- Started step 3: Verify FRI"); - #[cfg(feature = "instruments")] - let timer3 = Instant::now(); - - if !Self::step_3_verify_fri(proof, &domain, &challenges) { - #[cfg(not(feature = "test_fiat_shamir"))] - error!("FRI verification failed"); + let mut evaluation_point_inverse = iotas + .iter() + .map(|iota| Self::query_challenge_to_evaluation_point(*iota, false, leader_domain)) + .collect::>>(); + if FieldElement::inplace_batch_inverse(&mut evaluation_point_inverse).is_err() { + error!("chunk {chunk_idx} bucket {b}: zero FRI evaluation point"); return false; } - #[cfg(feature = "instruments")] - let elapsed3 = timer3.elapsed(); - #[cfg(feature = "instruments")] - println!(" Time spent: {:?}", elapsed3); + let num_queries = iotas.len(); + let mut combined_eval: Vec> = + vec![FieldElement::zero(); num_queries]; + let mut combined_eval_sym: Vec> = + vec![FieldElement::zero(); num_queries]; + let mut delta_power = FieldElement::::one(); + for (i_local, &j) in members.iter().enumerate() { + let idx = chunk_start + j; + let air = airs[idx]; + let proof = &multi_proof.proofs[idx]; + let domain = &domains[idx]; + + let mut challenges = table_challenges[idx].clone(); + challenges.iotas = iotas.clone(); + challenges.zetas = zetas.clone(); + + if !Self::step_2_verify_claimed_composition_polynomial(air, proof, domain, &challenges) { + #[cfg(not(feature = "test_fiat_shamir"))] + error!("chunk {chunk_idx} bucket {b}: table {idx} composition poly failed"); + return false; + } - #[cfg(feature = "instruments")] - println!("- Started step 4: Verify deep composition polynomial"); - #[cfg(feature = "instruments")] - let timer4 = Instant::now(); + let (member_eval, member_eval_sym) = + match Self::reconstruct_deep_composition_poly_evaluations_for_all_queries( + &challenges, + domain, + proof, + ) { + Some(pair) => pair, + None => { + error!("chunk {chunk_idx} bucket {b}: table {idx} DEEP reconstruct failed"); + return false; + } + }; - #[allow(clippy::let_and_return)] - if !Self::step_4_verify_trace_and_composition_openings(proof, &challenges) { - #[cfg(not(feature = "test_fiat_shamir"))] - error!("DEEP Composition Polynomial verification failed"); - return false; - } + if !Self::step_4_verify_trace_and_composition_openings(proof, &challenges) { + #[cfg(not(feature = "test_fiat_shamir"))] + error!("chunk {chunk_idx} bucket {b}: table {idx} openings failed"); + return false; + } - #[cfg(feature = "instruments")] - let elapsed4 = timer4.elapsed(); - #[cfg(feature = "instruments")] - println!(" Time spent: {:?}", elapsed4); + if i_local == 0 { + combined_eval = member_eval; + combined_eval_sym = member_eval_sym; + } else { + for q in 0..num_queries { + combined_eval[q] = &combined_eval[q] + &delta_power * &member_eval[q]; + combined_eval_sym[q] = + &combined_eval_sym[q] + &delta_power * &member_eval_sym[q]; + } + } + delta_power = &delta_power * &delta_fri; + } - #[cfg(feature = "instruments")] + for (q, ((iota, decommitment), eval_point_inv)) in iotas + .iter() + .zip(&bucket.decommitments) + .zip(evaluation_point_inverse) + .enumerate() { - let total_time = elapsed1 + elapsed2 + elapsed3 + elapsed4; - println!( - " Fraction of verifying time per step: {:.4} {:.4} {:.4} {:.4}", - elapsed1.as_nanos() as f64 / total_time.as_nanos() as f64, - elapsed2.as_nanos() as f64 / total_time.as_nanos() as f64, - elapsed3.as_nanos() as f64 / total_time.as_nanos() as f64, - elapsed4.as_nanos() as f64 / total_time.as_nanos() as f64 - ); + if !Self::verify_bucket_fri_query( + &bucket.layer_roots, + &bucket.last_value, + &zetas, + *iota, + decommitment, + eval_point_inv, + &combined_eval[q], + &combined_eval_sym[q], + ) { + #[cfg(not(feature = "test_fiat_shamir"))] + error!("chunk {chunk_idx} bucket {b}: FRI query {q} failed"); + return false; + } } - true } + + /// Verify a single batched-FRI query for a bucket. The combined DEEP + /// evaluations `D(𝜐)` / `D(-𝜐)` (already linearly combined across + /// bucket-mates with `delta_fri` powers) are folded against the bucket's + /// `layer_roots` / `last_value`. MMCS-free port of the per-table FRI query + /// verification. + #[allow(clippy::too_many_arguments)] + fn verify_bucket_fri_query( + layer_roots: &[Commitment], + last_value: &FieldElement, + zetas: &[FieldElement], + iota: usize, + fri_decommitment: &FriDecommitment, + evaluation_point_inv: FieldElement, + deep_composition_evaluation: &FieldElement, + deep_composition_evaluation_sym: &FieldElement, + ) -> bool + where + FieldElement: AsBytes + Sync + Send, + FieldElement: AsBytes + Sync + Send, + { + let evaluation_point_vec: Vec> = + core::iter::successors(Some(evaluation_point_inv.square()), |evaluation_point| { + Some(evaluation_point.square()) + }) + .take(layer_roots.len()) + .collect(); + + let p0_eval = deep_composition_evaluation; + let p0_eval_sym = deep_composition_evaluation_sym; + + let mut v = + (p0_eval + p0_eval_sym) + evaluation_point_inv * &zetas[0] * (p0_eval - p0_eval_sym); + let mut index = iota; + + if layer_roots.is_empty() { + return v == *last_value; + } + + layer_roots + .iter() + .enumerate() + .zip(&fri_decommitment.layers_auth_paths) + .zip(&fri_decommitment.layers_evaluations_sym) + .zip(evaluation_point_vec) + .fold( + true, + |result, + ( + (((i, merkle_root), auth_path_sym), evaluation_sym), + evaluation_point_inv, + )| { + let openings_ok = Self::verify_fri_layer_openings( + merkle_root, + auth_path_sym, + &v, + evaluation_sym, + index, + ); + v = (&v + evaluation_sym) + + evaluation_point_inv * &zetas[i + 1] * (&v - evaluation_sym); + index >>= 1; + if i < fri_decommitment.layers_evaluations_sym.len() - 1 { + result & openings_ok + } else { + result & (v == *last_value) & openings_ok + } + }, + ) + } } diff --git a/executor/src/tests/checkpoint_tests.rs b/executor/src/tests/checkpoint_tests.rs new file mode 100644 index 000000000..f9931dfbb --- /dev/null +++ b/executor/src/tests/checkpoint_tests.rs @@ -0,0 +1,80 @@ +//! Tests for executor checkpoints: snapshot the VM state mid-execution, +//! recreate an `Executor` from it, and resume — the concatenated logs must be +//! byte-identical to a straight run. This is the determinism property that +//! makes memory-eviction re-execution (Approach 1) sound. +//! +//! The program is built by hand (no ELF fixture) so the test is hermetic: a run +//! of 100_005 `ADDI x5, x5, 1` instructions followed by `JALR x0, 0(x0)` (jump +//! to address 0, which halts the VM). No syscalls are used, and the >100_000 +//! instruction count guarantees the snapshot is taken mid-execution (across a +//! `resume()` 100_000-instruction chunk boundary). + +use crate::elf::{Elf, Segment}; +use crate::vm::execution::Executor; + +const ADDI_X5_X5_1: u32 = 0x0012_8293; // addi x5, x5, 1 +const JALR_X0_0_X0: u32 = 0x0000_0067; // jalr x0, 0(x0) -> pc = 0 -> halt +const N_ADDI: usize = 100_005; +const BASE: u64 = 0x1000; + +fn long_program() -> Elf { + let mut values = vec![ADDI_X5_X5_1; N_ADDI]; + values.push(JALR_X0_0_X0); + Elf { + entry_point: BASE, + data: vec![Segment { + base_addr: BASE, + values, + is_executable: true, + }], + } +} + +#[test] +fn snapshot_resume_produces_identical_logs() { + let elf = long_program(); + + // Straight run. + let full = Executor::new(&elf, vec![]).unwrap().run().unwrap().logs; + assert_eq!(full.len(), N_ADDI + 1, "all instructions should log once"); + assert!(full.len() > 100_000, "must span multiple resume() chunks"); + + // Run one chunk (100_000 instructions), snapshot mid-execution, recreate, finish. + let mut exec = Executor::new(&elf, vec![]).unwrap(); + let mut logs = Vec::new(); + { + let chunk0 = exec.resume().unwrap().expect("at least one chunk"); + logs.extend_from_slice(chunk0); + } + assert!( + logs.len() < full.len(), + "snapshot must be taken before the program finishes (got {} of {})", + logs.len(), + full.len() + ); + + let snapshot = exec.snapshot(); + let mut resumed = Executor::from_snapshot(&elf, snapshot).expect("recreate from snapshot"); + while let Some(chunk) = resumed.resume().unwrap() { + logs.extend_from_slice(chunk); + } + + assert_eq!(logs.len(), full.len(), "log count mismatch after snapshot+resume"); + assert_eq!(logs, full, "snapshot+resume logs must equal the straight run"); +} + +#[test] +fn from_snapshot_at_start_equals_fresh_run() { + let elf = long_program(); + let full = Executor::new(&elf, vec![]).unwrap().run().unwrap().logs; + + // Snapshot before running anything, recreate, run: identical to fresh. + let snapshot = Executor::new(&elf, vec![]).unwrap().snapshot(); + let resumed_logs = Executor::from_snapshot(&elf, snapshot) + .unwrap() + .run() + .unwrap() + .logs; + + assert_eq!(resumed_logs, full); +} diff --git a/executor/src/tests/mod.rs b/executor/src/tests/mod.rs index 448a05dee..66c0f85ff 100644 --- a/executor/src/tests/mod.rs +++ b/executor/src/tests/mod.rs @@ -1,3 +1,4 @@ +pub mod checkpoint_tests; pub mod flamegraph_tests; pub mod keccak_tests; pub mod memory_tests; diff --git a/executor/src/vm/execution.rs b/executor/src/vm/execution.rs index 37dcf0198..8b2cbc507 100644 --- a/executor/src/vm/execution.rs +++ b/executor/src/vm/execution.rs @@ -30,6 +30,18 @@ pub struct ExecutionResult { /// Size of each log chunk - balances memory usage vs callback overhead const CHUNK_SIZE: usize = 100_000; +/// A snapshot of the mutable VM state at a cycle boundary: enough to recreate an +/// [`Executor`] (together with the program ELF) that resumes execution +/// byte-identically. The immutable instruction cache is rebuilt from the ELF, not +/// stored. Determinism holds because all nondeterministic input (private inputs) +/// is pre-loaded into `memory`, so replay reproduces identical logs. +#[derive(Clone)] +pub struct VmSnapshot { + memory: Memory, + registers: Registers, + pc: u64, +} + /// Executor state for chunked execution pub struct Executor { memory: Memory, @@ -55,6 +67,31 @@ impl Executor { }) } + /// Capture the current VM state (registers, pc, memory) as a [`VmSnapshot`]. + /// Cheap apart from the memory clone (a clone of the touched-cells map). + pub fn snapshot(&self) -> VmSnapshot { + VmSnapshot { + memory: self.memory.clone(), + registers: self.registers.clone(), + pc: self.pc, + } + } + + /// Recreate an `Executor` positioned at a previously captured [`VmSnapshot`]. + /// The instruction cache is rebuilt from `program` (the same ELF the snapshot + /// was taken under); the snapshot's memory already holds the loaded program + /// and execution state, so the program is NOT reloaded. + pub fn from_snapshot(program: &Elf, snapshot: VmSnapshot) -> Result { + let instructions = InstructionCache::new(&program.data)?; + Ok(Self { + memory: snapshot.memory, + registers: snapshot.registers, + pc: snapshot.pc, + instructions, + logs: Vec::with_capacity(CHUNK_SIZE), + }) + } + /// Resume execution and return next logs. Returns None when program is finished. pub fn resume(&mut self) -> Result, ExecutorError> { if self.pc == 0 { diff --git a/executor/src/vm/logs.rs b/executor/src/vm/logs.rs index de6b73d0b..a5aa426b0 100644 --- a/executor/src/vm/logs.rs +++ b/executor/src/vm/logs.rs @@ -11,7 +11,7 @@ /// - `src1_val` = syscall number (from x17): 64=Commit, 93=Halt, etc. /// - `src2_val` = buf_addr (x11) for Commit, 0 otherwise /// - `dst_val` = count (x12) for Commit, 0 otherwise -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Log { /// PC before instruction execution (use this to look up the instruction) pub current_pc: u64, diff --git a/executor/src/vm/memory.rs b/executor/src/vm/memory.rs index c94376e76..ebb10e5f0 100644 --- a/executor/src/vm/memory.rs +++ b/executor/src/vm/memory.rs @@ -50,7 +50,7 @@ pub const MAX_PRIVATE_INPUT_SIZE: u64 = 6700000; /// Must match `PRIVATE_INPUT_START` in `syscalls/src/syscalls.rs`. pub const PRIVATE_INPUT_START_INDEX: u64 = 0xFF000000; -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] pub struct Memory { cells: U64HashMap<[u8; 4]>, /// Bytes committed to public output via `commit_public_output`. The diff --git a/executor/src/vm/registers.rs b/executor/src/vm/registers.rs index 61945b732..a82ef44f1 100644 --- a/executor/src/vm/registers.rs +++ b/executor/src/vm/registers.rs @@ -2,7 +2,7 @@ use std::fmt::Display; pub const STACK_TOP: u64 = 0xFFFFFFFFFFFFFFF0; // 64-bit max (Multiple of 16 for RV64 ABI) -#[derive(Debug)] +#[derive(Debug, Clone)] /// Holds the current value of all 32 registers /// Register zero is implicit as it cannot hold any value other than zero pub struct Registers([u64; 31]); diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 14f35cdf8..751a50a30 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -272,6 +272,29 @@ impl VmAirs { pairs } + /// Build `(air, trace, public_inputs)` triples by pairing the AIRs (in + /// [`Self::air_refs`] order) with an air-ordered trace slice. + /// + /// Used by streaming "retire-traces" mode, where the trace slice holds the + /// resident preprocessed/PAGE traces and empty placeholders for the retired + /// log-derived chunks (the prover sources those from a `TraceProvider`). The + /// AIR order MUST match the slice order produced by `route_for_streaming`. + pub fn air_trace_pairs_from_slice<'a>( + &'a self, + traces: &'a mut [stark::trace::TraceTable], + ) -> Vec> { + let airs = self.air_refs(); + debug_assert_eq!( + airs.len(), + traces.len(), + "AIR count must match the air-ordered trace slice length" + ); + airs.into_iter() + .zip(traces.iter_mut()) + .map(|(air, trace)| (air, trace, &())) + .collect() + } + /// Collect AIR references for [`Verifier::multi_verify`]. pub fn air_refs(&self) -> Vec<&dyn AIR> { let mut refs: Vec<&dyn AIR> = vec![ @@ -616,93 +639,188 @@ pub fn prove_with_options_and_inputs( auto_storage::decide(&lengths, proof_options.blowup_factor) }; - let mut traces = Traces::from_elf_and_logs( - &program, - &result.logs, - max_rows, - private_inputs, - #[cfg(feature = "disk-spill")] - storage_mode, - )?; - debug_assert_eq!( - traces.public_output_bytes, result.return_values.memory_values, - "public output diverged between executor view and trace reconstruction" - ); - drop(result); - - #[cfg(feature = "instruments")] - let trace_build_elapsed = phase_start.elapsed(); - #[cfg(feature = "instruments")] - let heap_after_trace = stark::instruments::heap_bytes(); - - // Phase 3: AIR construction - #[cfg(feature = "instruments")] - let phase_start = std::time::Instant::now(); + let stream_traces = stark::prover::streaming_retire_lde(); + + // Streaming "retire-traces" mode: route into a compact intermediate and + // build each log-derived table chunk on demand during proving, instead of + // materializing a full resident `Traces`. The proof is byte-identical to the + // non-streaming path (deterministic trace build, C.2a). + let proof; + let runtime_page_ranges; + let table_counts; + let public_output_bytes; + let num_private_input_pages; + if stream_traces { + let (provider, mut resident) = crate::tables::trace_builder::route_for_streaming( + &program, + &result.logs, + max_rows, + private_inputs, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + debug_assert_eq!( + resident.public_output_bytes, result.return_values.memory_values, + "public output diverged between executor view and trace reconstruction" + ); + drop(result); + + #[cfg(feature = "instruments")] + let trace_build_elapsed = phase_start.elapsed(); + #[cfg(feature = "instruments")] + let heap_after_trace = stark::instruments::heap_bytes(); + + // Phase 3: AIR construction + #[cfg(feature = "instruments")] + let phase_start = std::time::Instant::now(); + + table_counts = resident.table_counts.clone(); + let airs = VmAirs::new( + &program, + proof_options, + false, + &resident.page_configs, + &table_counts, + ); - let table_counts = traces.table_counts(); - let airs = VmAirs::new( - &program, - proof_options, - false, - &traces.page_configs, - &table_counts, - ); + #[cfg(feature = "instruments")] + let air_elapsed = phase_start.elapsed(); + #[cfg(feature = "instruments")] + let heap_after_air = stark::instruments::heap_bytes(); - #[cfg(feature = "instruments")] - let air_elapsed = phase_start.elapsed(); - #[cfg(feature = "instruments")] - let heap_after_air = stark::instruments::heap_bytes(); + runtime_page_ranges = resident.runtime_page_ranges(); + num_private_input_pages = resident + .page_configs + .iter() + .filter(|c| c.is_private_input) + .count(); + public_output_bytes = resident.public_output_bytes.clone(); + + let mut transcript = DefaultTranscript::::new(&[]); + absorb_statement( + &mut transcript, + elf_bytes, + &public_output_bytes, + &table_counts, + num_private_input_pages, + &runtime_page_ranges, + ); - let runtime_page_ranges = traces.runtime_page_ranges(); + // Phase 4: Prove (multi_prove with on-demand trace provider) + let pairs = airs.air_trace_pairs_from_slice(&mut resident.traces); + proof = Prover::multi_prove_with_provider( + pairs, + Some(&provider), + &mut transcript, + #[cfg(feature = "disk-spill")] + storage_mode, + ) + .map_err(|e| Error::Prover(format!("{e:?}")))?; + + #[cfg(feature = "instruments")] + { + instruments::print_report( + execute_elapsed, + trace_build_elapsed, + air_elapsed, + total_start.elapsed(), + &stark::instruments::ProveHeapProfile { + before: heap_before, + after_execute: heap_after_execute, + after_trace_build: heap_after_trace, + after_air: heap_after_air, + }, + ); + } + } else { + let mut traces = Traces::from_elf_and_logs( + &program, + &result.logs, + max_rows, + private_inputs, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + debug_assert_eq!( + traces.public_output_bytes, result.return_values.memory_values, + "public output diverged between executor view and trace reconstruction" + ); + drop(result); + + #[cfg(feature = "instruments")] + let trace_build_elapsed = phase_start.elapsed(); + #[cfg(feature = "instruments")] + let heap_after_trace = stark::instruments::heap_bytes(); + + // Phase 3: AIR construction + #[cfg(feature = "instruments")] + let phase_start = std::time::Instant::now(); + + table_counts = traces.table_counts(); + let airs = VmAirs::new( + &program, + proof_options, + false, + &traces.page_configs, + &table_counts, + ); - let num_private_input_pages = traces - .page_configs - .iter() - .filter(|c| c.is_private_input) - .count(); + #[cfg(feature = "instruments")] + let air_elapsed = phase_start.elapsed(); + #[cfg(feature = "instruments")] + let heap_after_air = stark::instruments::heap_bytes(); - // Bind the full statement (program, public output, table layout) into the - // Fiat-Shamir transcript so every challenge depends on it. - let mut transcript = DefaultTranscript::::new(&[]); - absorb_statement( - &mut transcript, - elf_bytes, - &traces.public_output_bytes, - &table_counts, - num_private_input_pages, - &runtime_page_ranges, - ); + runtime_page_ranges = traces.runtime_page_ranges(); + num_private_input_pages = traces + .page_configs + .iter() + .filter(|c| c.is_private_input) + .count(); + public_output_bytes = traces.public_output_bytes.clone(); + + // Bind the full statement (program, public output, table layout) into the + // Fiat-Shamir transcript so every challenge depends on it. + let mut transcript = DefaultTranscript::::new(&[]); + absorb_statement( + &mut transcript, + elf_bytes, + &public_output_bytes, + &table_counts, + num_private_input_pages, + &runtime_page_ranges, + ); - // Phase 4: Prove (multi_prove) - let proof = Prover::multi_prove( - airs.air_trace_pairs(&mut traces), - &mut transcript, - #[cfg(feature = "disk-spill")] - storage_mode, - ) - .map_err(|e| Error::Prover(format!("{e:?}")))?; + // Phase 4: Prove (multi_prove) + proof = Prover::multi_prove( + airs.air_trace_pairs(&mut traces), + &mut transcript, + #[cfg(feature = "disk-spill")] + storage_mode, + ) + .map_err(|e| Error::Prover(format!("{e:?}")))?; - #[cfg(feature = "instruments")] - { - instruments::print_report( - execute_elapsed, - trace_build_elapsed, - air_elapsed, - total_start.elapsed(), - &stark::instruments::ProveHeapProfile { - before: heap_before, - after_execute: heap_after_execute, - after_trace_build: heap_after_trace, - after_air: heap_after_air, - }, - ); + #[cfg(feature = "instruments")] + { + instruments::print_report( + execute_elapsed, + trace_build_elapsed, + air_elapsed, + total_start.elapsed(), + &stark::instruments::ProveHeapProfile { + before: heap_before, + after_execute: heap_after_execute, + after_trace_build: heap_after_trace, + after_air: heap_after_air, + }, + ); + } } Ok(VmProof { proof, runtime_page_ranges, table_counts, - public_output: traces.public_output_bytes.clone(), + public_output: public_output_bytes, num_private_input_pages, }) } diff --git a/prover/src/tables/branch.rs b/prover/src/tables/branch.rs index 1a4cff20c..c02d9b4a4 100644 --- a/prover/src/tables/branch.rs +++ b/prover/src/tables/branch.rs @@ -107,7 +107,7 @@ const MASK_254: u64 = 254; /// A single BRANCH operation to be added to the trace. /// /// Derives Hash and Eq so it can be used as a HashMap key for deduplication. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct BranchOperation { /// Current program counter (64-bit) pub pc: u64, @@ -166,7 +166,12 @@ pub fn generate_branch_trace( *op_map.entry(op.clone()).or_insert(0) += 1; } - let unique_ops: Vec<_> = op_map.into_iter().collect(); + let mut unique_ops: Vec<_> = op_map.into_iter().collect(); + // Deterministic row order: HashMap iteration order is randomized per + // instance, so sort by the canonical operation key. This makes repeated + // builds (e.g. streaming on-demand trace rebuild) produce byte-identical + // traces. + unique_ops.sort_unstable_by(|a, b| a.0.cmp(&b.0)); let num_rows = unique_ops.len().next_power_of_two().max(4); let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; diff --git a/prover/src/tables/dvrm.rs b/prover/src/tables/dvrm.rs index 30352e125..b0da0b719 100644 --- a/prover/src/tables/dvrm.rs +++ b/prover/src/tables/dvrm.rs @@ -156,7 +156,7 @@ const SIGN_FILL: u64 = 0xFFFF; /// A single DVRM operation to be added to the trace. /// /// Derives Hash and Eq for HashMap-based deduplication. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct DvrmOperation { /// Numerator (64-bit) pub n: u64, @@ -299,7 +299,12 @@ pub fn generate_dvrm_trace( } } - let unique_ops: Vec<_> = op_map.into_iter().collect(); + let mut unique_ops: Vec<_> = op_map.into_iter().collect(); + // Deterministic row order: HashMap iteration order is randomized per + // instance, so sort by the canonical operation key. This makes repeated + // builds (e.g. streaming on-demand trace rebuild) produce byte-identical + // traces. + unique_ops.sort_unstable_by(|a, b| a.0.cmp(&b.0)); let num_rows = unique_ops.len().next_power_of_two().max(4); let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; diff --git a/prover/src/tables/lt.rs b/prover/src/tables/lt.rs index da1bc948e..23ae54713 100644 --- a/prover/src/tables/lt.rs +++ b/prover/src/tables/lt.rs @@ -95,7 +95,7 @@ pub mod cols { /// A single LT operation to be added to the trace. /// /// Derives Hash and Eq so it can be used as a HashMap key for deduplication. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct LtOperation { /// Left operand (64-bit value) pub lhs: u64, @@ -136,7 +136,12 @@ pub fn generate_lt_trace( *op_map.entry(op.clone()).or_insert(0) += 1; } - let unique_ops: Vec<_> = op_map.into_iter().collect(); + let mut unique_ops: Vec<_> = op_map.into_iter().collect(); + // Deterministic row order: HashMap iteration order is randomized per + // instance, so sort by the canonical operation key. This makes repeated + // builds (e.g. streaming on-demand trace rebuild) produce byte-identical + // traces. + unique_ops.sort_unstable_by(|a, b| a.0.cmp(&b.0)); let num_rows = unique_ops.len().next_power_of_two().max(4); let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; diff --git a/prover/src/tables/mul.rs b/prover/src/tables/mul.rs index ecb72a4d1..89882f063 100644 --- a/prover/src/tables/mul.rs +++ b/prover/src/tables/mul.rs @@ -136,7 +136,7 @@ const SIGN_FILL: u64 = 0xFFFF; /// A single MUL operation to be added to the trace. /// /// Derives Hash and Eq for HashMap-based deduplication. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct MulOperation { /// Left operand (64-bit) pub lhs: u64, @@ -295,7 +295,12 @@ pub fn generate_mul_trace( } } - let unique_ops: Vec<_> = op_map.into_iter().collect(); + let mut unique_ops: Vec<_> = op_map.into_iter().collect(); + // Deterministic row order: HashMap iteration order is randomized per + // instance, so sort by the canonical operation key. This makes repeated + // builds (e.g. streaming on-demand trace rebuild) produce byte-identical + // traces. + unique_ops.sort_unstable_by(|a, b| a.0.cmp(&b.0)); let num_rows = unique_ops.len().next_power_of_two().max(4); let mut data = vec![FE::zero(); num_rows * cols::NUM_COLUMNS]; diff --git a/prover/src/tables/trace_builder.rs b/prover/src/tables/trace_builder.rs index 76535484b..400e5fd12 100644 --- a/prover/src/tables/trace_builder.rs +++ b/prover/src/tables/trace_builder.rs @@ -2052,6 +2052,22 @@ pub struct Traces { pub memw_registers: Vec>, } +/// TEST ONLY: per-table chunk Vecs produced by [`Traces::chunked_tables_via_route`], +/// in the same field order as the log-derived tables of [`Traces`]. +#[cfg(test)] +pub(crate) struct ChunkedTablesViaRoute { + pub cpus: Vec>, + pub lts: Vec>, + pub shifts: Vec>, + pub memws: Vec>, + pub memw_aligneds: Vec>, + pub loads: Vec>, + pub muls: Vec>, + pub dvrms: Vec>, + pub branches: Vec>, + pub memw_registers: Vec>, +} + /// Intermediate state from Phase 2: all ops collected from CPU, ready for /// Phases 3-5 (LT extension, bitwise, trace generation). struct CollectedOps { @@ -2070,9 +2086,381 @@ struct CollectedOps { keccak_ops: Vec, } +/// Fully-routed intermediate produced by [`route`] (PHASES 1-4). Holds every +/// piece of state that the PHASE-5 per-table fill consumes, so a single table's +/// trace can be built on demand via [`build_table`] without re-running routing. +/// +/// All cross-table coupling (op routing in PHASES 1-2, MEMW->LT extension in +/// PHASE 3, and the accumulated bitwise lookup multiplicities in PHASE 4) is +/// fully resolved before this struct is returned. In particular `lt_ops` and +/// `bitwise_ops` here are the FINAL lists: every table's contribution to the +/// shared bitwise multiplicities has already been folded in, so any per-table +/// fill that reads them (the BITWISE table itself) sees the complete counts. +/// +/// `elf`, `memory_state` and `private_input` are retained only so the PAGE +/// table fill can be performed per-table identically to the monolithic build. +struct RoutedTraceData<'a> { + // --- routed op-lists (PHASE 5 inputs, all extensions applied) --- + cpu_ops: Vec, + memw_ops: Vec, + memw_aligned_ops: Vec, + memw_register_ops: Vec, + load_ops: Vec, + lt_ops: Vec, + shift_ops: Vec, + /// Final accumulated bitwise lookups from every table (incl. padding). + bitwise_ops: Vec, + branch_ops: Vec, + mul_ops: Vec<(MulOperation, bool)>, + dvrm_ops: Vec<(DvrmOperation, bool)>, + commit_ops: Vec, + keccak_ops: Vec, + + // --- derived scalars / preprocessed inputs PHASE 5 consumes --- + /// Total CPU padding rows across all CPU chunks (drives padding lookups). + num_padding_rows: usize, + /// Timestamp of the final ECALL (HALT) instruction. + halt_timestamp: u64, + /// Committed public output bytes recovered during routing. + public_output_bytes: Vec, + /// DECODE trace (program-derived) and its pc->row index. + decode_trace: TraceTable, + decode_pc_to_row: HashMap, + /// Finalized register state map for the REGISTER table. + register_final_state: FinalRegisterStateMap, + entry_point: u64, + + // --- borrowed inputs needed only by the PAGE table fill --- + elf: Option<&'a Elf>, + memory_state: MemoryState, + private_input: &'a [u8], +} + +/// Selector for the log-derived, chunked execution tables built in PHASE 5. +/// +/// Ordering matches the table ordering used by the prover (see `prover/src/lib.rs`): +/// CPU, LT, SHIFT, MEMW, MEMW_A, LOAD, MUL, DVRM, BRANCH, MEMW_R. +/// +/// Program-independent / preprocessed tables (BITWISE 2^20, DECODE, PAGE, +/// REGISTER, HALT, COMMIT, KECCAK*) are intentionally NOT in this selector; +/// they are assembled directly in [`Traces::from_elf_and_logs`] as before. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TableKind { + Cpu, + Lt, + Shift, + Memw, + MemwAligned, + Load, + Mul, + Dvrm, + Branch, + MemwRegister, +} + +impl RoutedTraceData<'_> { + /// Build the chunked trace tables for a SINGLE log-derived execution table + /// (PHASE-5 fill for one `chunk_and_generate` call). The returned `Vec` + /// contains one `TraceTable` per `max_rows`-sized chunk, byte-identical to + /// what the monolithic [`build_traces`] produces for that table. + fn build_table( + &self, + which: TableKind, + max_rows: &super::MaxRowsConfig, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, + ) -> Result>, Error> { + match which { + TableKind::Cpu => chunk_and_generate( + &self.cpu_ops, + max_rows.cpu, + cpu::generate_cpu_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Lt => chunk_and_generate( + &self.lt_ops, + max_rows.lt, + lt::generate_lt_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Shift => chunk_and_generate( + &self.shift_ops, + max_rows.shift, + shift::generate_shift_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Memw => chunk_and_generate( + &self.memw_ops, + max_rows.memw, + memw::generate_memw_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::MemwAligned => chunk_and_generate( + &self.memw_aligned_ops, + max_rows.memw_aligned, + memw_aligned::generate_memw_aligned_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Load => chunk_and_generate( + &self.load_ops, + max_rows.load, + load::generate_load_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Mul => chunk_and_generate( + &self.mul_ops, + max_rows.mul, + mul::generate_mul_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Dvrm => chunk_and_generate( + &self.dvrm_ops, + max_rows.dvrm, + dvrm::generate_dvrm_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::Branch => chunk_and_generate( + &self.branch_ops, + max_rows.branch, + branch::generate_branch_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + TableKind::MemwRegister => chunk_and_generate( + &self.memw_register_ops, + max_rows.memw_register, + memw_register::generate_memw_register_trace, + #[cfg(feature = "disk-spill")] + storage_mode, + ), + } + } + + /// Per-`TableKind` `max_rows` chunk size. + fn max_rows_for(which: TableKind, max_rows: &super::MaxRowsConfig) -> usize { + match which { + TableKind::Cpu => max_rows.cpu, + TableKind::Lt => max_rows.lt, + TableKind::Shift => max_rows.shift, + TableKind::Memw => max_rows.memw, + TableKind::MemwAligned => max_rows.memw_aligned, + TableKind::Load => max_rows.load, + TableKind::Mul => max_rows.mul, + TableKind::Dvrm => max_rows.dvrm, + TableKind::Branch => max_rows.branch, + TableKind::MemwRegister => max_rows.memw_register, + } + } + + /// Number of ops routed to this `TableKind` (drives the chunk count). + fn ops_len(&self, which: TableKind) -> usize { + match which { + TableKind::Cpu => self.cpu_ops.len(), + TableKind::Lt => self.lt_ops.len(), + TableKind::Shift => self.shift_ops.len(), + TableKind::Memw => self.memw_ops.len(), + TableKind::MemwAligned => self.memw_aligned_ops.len(), + TableKind::Load => self.load_ops.len(), + TableKind::Mul => self.mul_ops.len(), + TableKind::Dvrm => self.dvrm_ops.len(), + TableKind::Branch => self.branch_ops.len(), + TableKind::MemwRegister => self.memw_register_ops.len(), + } + } + + /// Number of `max_rows`-sized chunks this `TableKind` splits into. Matches + /// [`build_table`]'s chunking: an empty op-list still yields one chunk. + fn num_chunks(&self, which: TableKind, max_rows: &super::MaxRowsConfig) -> usize { + let len = self.ops_len(which); + if len == 0 { + 1 + } else { + len.div_ceil(Self::max_rows_for(which, max_rows)) + } + } + + /// Build the trace table for a SINGLE chunk of one log-derived table. + /// Byte-identical to `build_table(which, ..)[chunk_idx]`, but rebuilds only + /// the requested chunk so retired-trace mode never materializes a whole + /// table at once. + fn build_chunk( + &self, + which: TableKind, + chunk_idx: usize, + max_rows: &super::MaxRowsConfig, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, + ) -> Result, Error> { + let mr = Self::max_rows_for(which, max_rows); + match which { + TableKind::Cpu => chunk_one( + &self.cpu_ops, mr, chunk_idx, cpu::generate_cpu_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Lt => chunk_one( + &self.lt_ops, mr, chunk_idx, lt::generate_lt_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Shift => chunk_one( + &self.shift_ops, mr, chunk_idx, shift::generate_shift_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Memw => chunk_one( + &self.memw_ops, mr, chunk_idx, memw::generate_memw_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::MemwAligned => chunk_one( + &self.memw_aligned_ops, mr, chunk_idx, + memw_aligned::generate_memw_aligned_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Load => chunk_one( + &self.load_ops, mr, chunk_idx, load::generate_load_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Mul => chunk_one( + &self.mul_ops, mr, chunk_idx, mul::generate_mul_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Dvrm => chunk_one( + &self.dvrm_ops, mr, chunk_idx, dvrm::generate_dvrm_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::Branch => chunk_one( + &self.branch_ops, mr, chunk_idx, branch::generate_branch_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + TableKind::MemwRegister => chunk_one( + &self.memw_register_ops, mr, chunk_idx, + memw_register::generate_memw_register_trace, + #[cfg(feature = "disk-spill")] storage_mode, + ), + } + } +} + +/// PHASES 3-4: from collected/routed ops, finish all cross-table coupling +/// (MEMW->LT extension, accumulated bitwise lookups, padding lookups) and the +/// remaining derived scalars, producing a [`RoutedTraceData`] from which any +/// single table can be filled on demand. +/// +/// This is the routing half of the old [`build_traces`]; the PHASE-5 fill is +/// done per-table by [`RoutedTraceData::build_table`] (and the fixed/ +/// preprocessed tables are assembled by the caller). +#[allow(clippy::too_many_arguments)] +fn route<'a>( + ops: CollectedOps, + elf: Option<&'a Elf>, + memory_state: MemoryState, + entry_point: u64, + decode_trace: TraceTable, + decode_pc_to_row: HashMap, + register_state: RegisterState, + max_rows: &super::MaxRowsConfig, + private_input: &'a [u8], +) -> Result, Error> { + let CollectedOps { + cpu_ops, + memw_ops, + memw_aligned_ops, + memw_register_ops, + load_ops, + mut lt_ops, + shift_ops, + mut bitwise_ops, + branch_ops, + mul_ops, + dvrm_ops, + commit_ops, + keccak_ops, + } = ops; + + // ===================================================================== + // PHASE 3: MEMW -> LT (timestamp ordering and overflow checks) + // ===================================================================== + lt_ops.extend(collect_lt_from_memw(&memw_ops)); + lt_ops.extend(collect_lt_from_memw_aligned(&memw_aligned_ops)); + + // ===================================================================== + // PHASE 4: All -> Bitwise lookups + // ===================================================================== + bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); + bitwise_ops.extend(collect_bitwise_from_mul(&mul_ops)); + bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); + bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); + bitwise_ops.extend(shift::collect_bitwise_from_shift(&shift_ops)); + bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); + // MEMW_R sends IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] + bitwise_ops.extend(collect_bitwise_from_memw_register(&memw_register_ops)); + // PAGE tables do a batched ARE_BYTES[init, fini] lookup per row (C1+C2) + if let Some(elf) = elf { + bitwise_ops.extend(collect_bitwise_from_page(elf, &memory_state, private_input)); + } + + let public_output_bytes: Vec = commit_ops + .iter() + .filter(|op| !op.end) + .map(|op| op.value) + .collect(); + // COMMIT table sends AreBytes and IsHalfword lookups + bitwise_ops.extend(collect_bitwise_from_commit(&commit_ops)); + // KECCAK_RND sends XOR/AND/ARE_BYTES/HWSL; KECCAK core sends IS_HALF + bitwise_ops.extend(collect_bitwise_from_keccak(&keccak_ops)); + + // CPU padding rows send ARE_BYTES with all-zero values. + // Add corresponding ops so the bitwise table multiplicities balance. + let num_padding_rows: usize = cpu_ops + .chunks(max_rows.cpu) + .map(|chunk| chunk.len().next_power_of_two().max(4) - chunk.len()) + .sum(); + bitwise_ops.extend(collect_byte_check_ops_for_padding(num_padding_rows)); + + // Extract halt timestamp from the last ECALL instruction. + let halt_timestamp = cpu_ops + .iter() + .rev() + .find(|op| op.decode.op_ecall) + .ok_or(Error::MissingHaltEcall)? + .timestamp; + + let register_final_state = register_state.to_final_state_map(); + + Ok(RoutedTraceData { + cpu_ops, + memw_ops, + memw_aligned_ops, + memw_register_ops, + load_ops, + lt_ops, + shift_ops, + bitwise_ops, + branch_ops, + mul_ops, + dvrm_ops, + commit_ops, + keccak_ops, + num_padding_rows, + halt_timestamp, + public_output_bytes, + decode_trace, + decode_pc_to_row, + register_final_state, + entry_point, + elf, + memory_state, + private_input, + }) +} + /// Chunk raw ops and generate one trace table per chunk. When `storage_mode` /// is `Disk`, each chunk's main table is spilled to mmap before the next chunk -/// is built so peak heap usage stays bounded. +/// is built so peak heap usage stays bounded fn chunk_and_generate( ops: &[T], max_rows: usize, @@ -2099,6 +2487,35 @@ fn chunk_and_generate( Ok(tables) } +/// Build the trace table for a single `chunk_idx` of an op-list, matching the +/// chunking and padding `chunk_and_generate` performs. Used by retired-trace +/// mode to rebuild exactly one chunk on demand. +fn chunk_one( + ops: &[T], + max_rows: usize, + chunk_idx: usize, + generate: impl Fn(&[T]) -> TraceTable, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, +) -> Result, Error> { + let chunk: &[T] = if ops.is_empty() { + debug_assert_eq!(chunk_idx, 0, "empty table has exactly one chunk"); + &[] + } else { + let start = chunk_idx * max_rows; + let end = (start + max_rows).min(ops.len()); + &ops[start..end] + }; + #[allow(unused_mut)] + let mut t = generate(chunk); + #[cfg(feature = "disk-spill")] + if storage_mode == StorageMode::Disk { + t.main_table + .spill_to_disk() + .map_err(|e| Error::Prover(format!("disk-spill trace: {e}")))?; + } + Ok(t) +} + /// Phase 2: Collect and route all operations from CPU ops. /// /// Takes the raw output of `collect_ops_from_cpu` plus `register_state` @@ -2210,11 +2627,18 @@ fn collect_all_ops( /// /// `elf` controls PAGE table generation: `Some(elf)` generates real PAGE tables /// and PAGE bitwise lookups; `None` produces empty page tables. +/// +/// This is now a thin orchestrator over [`route`] (PHASES 3-4) and +/// [`RoutedTraceData::build_table`] (the PHASE-5 fill of each log-derived, +/// chunked execution table). The fixed-size / preprocessed tables (BITWISE, +/// DECODE, REGISTER, HALT, COMMIT, KECCAK*, PAGE) are assembled here directly +/// from the routed data, exactly as before. The produced `Traces` is +/// byte-identical to the old monolithic build. #[allow(clippy::too_many_arguments)] fn build_traces( ops: CollectedOps, elf: Option<&Elf>, - memory_state: &MemoryState, + memory_state: MemoryState, entry_point: u64, decode_trace: TraceTable, decode_pc_to_row: HashMap, @@ -2223,167 +2647,162 @@ fn build_traces( #[cfg(feature = "disk-spill")] storage_mode: StorageMode, private_input: &[u8], ) -> Result { - let CollectedOps { - cpu_ops, - memw_ops, - memw_aligned_ops, - memw_register_ops, - load_ops, - mut lt_ops, - shift_ops, - mut bitwise_ops, - branch_ops, - mul_ops, - dvrm_ops, - commit_ops, - keccak_ops, - } = ops; - - // ===================================================================== - // PHASE 3: MEMW → LT (timestamp ordering and overflow checks) - // ===================================================================== - lt_ops.extend(collect_lt_from_memw(&memw_ops)); - lt_ops.extend(collect_lt_from_memw_aligned(&memw_aligned_ops)); - - // ===================================================================== - // PHASE 4: All → Bitwise lookups - // ===================================================================== - bitwise_ops.extend(collect_bitwise_from_lt(<_ops)); - bitwise_ops.extend(collect_bitwise_from_mul(&mul_ops)); - bitwise_ops.extend(collect_bitwise_from_dvrm(&dvrm_ops)); - bitwise_ops.extend(collect_bitwise_from_branch(&branch_ops)); - bitwise_ops.extend(shift::collect_bitwise_from_shift(&shift_ops)); - bitwise_ops.extend(collect_bitwise_from_memw_aligned(&memw_aligned_ops)); - // MEMW_R sends IS_HALFWORD[timestamp_0 - old_timestamp_lo - 1] - bitwise_ops.extend(collect_bitwise_from_memw_register(&memw_register_ops)); - // PAGE tables do a batched ARE_BYTES[init, fini] lookup per row (C1+C2) - if let Some(elf) = elf { - bitwise_ops.extend(collect_bitwise_from_page(elf, memory_state, private_input)); - } - - let public_output_bytes: Vec = commit_ops - .iter() - .filter(|op| !op.end) - .map(|op| op.value) - .collect(); - // COMMIT table sends AreBytes and IsHalfword lookups - bitwise_ops.extend(collect_bitwise_from_commit(&commit_ops)); - // KECCAK_RND sends XOR/AND/ARE_BYTES/HWSL; KECCAK core sends IS_HALF - bitwise_ops.extend(collect_bitwise_from_keccak(&keccak_ops)); - - // CPU padding rows send ARE_BYTES with all-zero values. - // Add corresponding ops so the bitwise table multiplicities balance. - let num_padding_rows: usize = cpu_ops - .chunks(max_rows.cpu) - .map(|chunk| chunk.len().next_power_of_two().max(4) - chunk.len()) - .sum(); - bitwise_ops.extend(collect_byte_check_ops_for_padding(num_padding_rows)); + // PHASES 3-4: finish all cross-table coupling and derived scalars. + let routed = route( + ops, + elf, + memory_state, + entry_point, + decode_trace, + decode_pc_to_row, + register_state, + max_rows, + private_input, + )?; // ===================================================================== - // PHASE 5: Generate final traces (parallelized) + // PHASE 5: Generate final traces. // ===================================================================== - // Extract halt timestamp from the last ECALL instruction - let halt_op = cpu_ops - .iter() - .rev() - .find(|op| op.decode.op_ecall) - .ok_or(Error::MissingHaltEcall)?; - let halt_timestamp = halt_op.timestamp; - - let cpus = chunk_and_generate( - &cpu_ops, - max_rows.cpu, - cpu::generate_cpu_trace, + // Per-table fill of the log-derived, chunked execution tables. + let cpus = routed.build_table( + TableKind::Cpu, + max_rows, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + let memws = routed.build_table( + TableKind::Memw, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let memws = chunk_and_generate( - &memw_ops, - max_rows.memw, - memw::generate_memw_trace, + let memw_aligneds = routed.build_table( + TableKind::MemwAligned, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let memw_aligneds = chunk_and_generate( - &memw_aligned_ops, - max_rows.memw_aligned, - memw_aligned::generate_memw_aligned_trace, + let memw_registers = routed.build_table( + TableKind::MemwRegister, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let memw_registers = chunk_and_generate( - &memw_register_ops, - max_rows.memw_register, - memw_register::generate_memw_register_trace, + let loads = routed.build_table( + TableKind::Load, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let loads = chunk_and_generate( - &load_ops, - max_rows.load, - load::generate_load_trace, + let lts = routed.build_table( + TableKind::Lt, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let lts = chunk_and_generate( - <_ops, - max_rows.lt, - lt::generate_lt_trace, + let shifts = routed.build_table( + TableKind::Shift, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let shifts = chunk_and_generate( - &shift_ops, - max_rows.shift, - shift::generate_shift_trace, + let muls = routed.build_table( + TableKind::Mul, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let muls = chunk_and_generate( - &mul_ops, - max_rows.mul, - mul::generate_mul_trace, + let dvrms = routed.build_table( + TableKind::Dvrm, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let dvrms = chunk_and_generate( - &dvrm_ops, - max_rows.dvrm, - dvrm::generate_dvrm_trace, + let branches = routed.build_table( + TableKind::Branch, + max_rows, #[cfg(feature = "disk-spill")] storage_mode, )?; - let branches = chunk_and_generate( - &branch_ops, - max_rows.branch, - branch::generate_branch_trace, + + // Fixed-size / preprocessed tables assembled directly from routed data. + let pre = build_preprocessed_tables( + &routed, #[cfg(feature = "disk-spill")] storage_mode, )?; + Ok(Traces { + cpus, + bitwise: pre.bitwise, + lts, + shifts, + memws, + memw_aligneds, + loads, + decode: pre.decode, + muls, + dvrms, + pages: pre.pages, + page_configs: pre.page_configs, + register: pre.register, + public_output_bytes: routed.public_output_bytes.clone(), + branches, + halt: pre.halt, + commit: pre.commit, + keccak: pre.keccak, + keccak_rnd: pre.keccak_rnd, + keccak_rc: pre.keccak_rc, + memw_registers, + }) +} + +/// Preprocessed / fixed-size / per-page tables assembled from a +/// [`RoutedTraceData`]. These are program-independent or derived only from the +/// (compact) routed intermediate, so they stay resident even in retired-trace +/// mode (only the chunked log-derived tables are retired). +struct PreprocessedTables { + bitwise: TraceTable, + decode: TraceTable, + commit: TraceTable, + keccak: TraceTable, + keccak_rnd: TraceTable, + keccak_rc: TraceTable, + pages: Vec>, + page_configs: Vec, + register: TraceTable, + halt: TraceTable, +} + +/// Build all preprocessed / fixed-size / per-page tables from routed data. +/// Byte-identical to the inline assembly the monolithic [`build_traces`] used. +fn build_preprocessed_tables( + routed: &RoutedTraceData<'_>, + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, +) -> Result { let mut bitwise = bitwise::generate_bitwise_trace(); - bitwise::update_multiplicities(&mut bitwise, &bitwise_ops); + bitwise::update_multiplicities(&mut bitwise, &routed.bitwise_ops); // Update DECODE multiplicities // Each CPU operation looks up the DECODE table once // Padding rows also look up pc=1 (the CPU padding entry) // When CPU is split, each chunk pads independently - let mut decode = decode_trace; - let mut decode_lookups: Vec = cpu_ops.iter().map(|op| op.decode.pc).collect(); - decode_lookups.extend(std::iter::repeat_n(cpu::CPU_PADDING_PC, num_padding_rows)); - decode::update_multiplicities(&mut decode, &decode_pc_to_row, &decode_lookups); - - // Prepare register final state before scope (needs register_state ownership) - let register_final_state = register_state.to_final_state_map(); - - // Generate remaining traces in parallel (page, register, halt, commit). - // chunk_and_generate already handled cpu, lt, memw, load, mul, dvrm, branch above. + let mut decode = routed.decode_trace.clone(); + let mut decode_lookups: Vec = routed.cpu_ops.iter().map(|op| op.decode.pc).collect(); + decode_lookups.extend(std::iter::repeat_n( + cpu::CPU_PADDING_PC, + routed.num_padding_rows, + )); + decode::update_multiplicities(&mut decode, &routed.decode_pc_to_row, &decode_lookups); + + // Generate remaining traces (page, register, halt, commit). #[allow(unused_mut)] - let mut commit_trace = commit::generate_commit_trace(&commit_ops); + let mut commit_trace = commit::generate_commit_trace(&routed.commit_ops); // Generate keccak traces (core table + per-round table + preprocessed RC) - let keccak_rnd_ops: Vec = keccak_ops + let keccak_rnd_ops: Vec = routed + .keccak_ops .iter() .map(|op| KeccakRoundOperation { timestamp: op.timestamp, @@ -2391,10 +2810,10 @@ fn build_traces( output: op.output, }) .collect(); - let keccak_trace = keccak::generate_keccak_trace(&keccak_ops); + let keccak_trace = keccak::generate_keccak_trace(&routed.keccak_ops); let keccak_rnd_trace = keccak_rnd::generate_keccak_rnd_trace(&keccak_rnd_ops); let mut keccak_rc_trace = keccak_rc::generate_keccak_rc_trace(); - keccak_rc::update_multiplicities(&mut keccak_rc_trace, keccak_ops.len()); + keccak_rc::update_multiplicities(&mut keccak_rc_trace, routed.keccak_ops.len()); #[allow(unused_mut)] let (mut pages, page_configs, mut register_trace, mut halt_trace); @@ -2403,14 +2822,21 @@ fn build_traces( let ((pages_val, register_val), halt_val) = rayon::join( || { rayon::join( - || match elf { - Some(elf) => generate_page_tables(elf, memory_state, private_input), + || match routed.elf { + Some(elf) => { + generate_page_tables(elf, &routed.memory_state, routed.private_input) + } None => (Vec::new(), Vec::new()), }, - || register::generate_register_trace(®ister_final_state, entry_point), + || { + register::generate_register_trace( + &routed.register_final_state, + routed.entry_point, + ) + }, ) }, - || halt::generate_halt_trace(halt_timestamp), + || halt::generate_halt_trace(routed.halt_timestamp), ); let (pages_v, page_configs_v) = pages_val; pages = pages_v; @@ -2420,9 +2846,9 @@ fn build_traces( } #[cfg(not(feature = "parallel"))] { - match elf { + match routed.elf { Some(elf) => { - let (p, c) = generate_page_tables(elf, memory_state, private_input); + let (p, c) = generate_page_tables(elf, &routed.memory_state, routed.private_input); pages = p; page_configs = c; } @@ -2431,8 +2857,9 @@ fn build_traces( page_configs = Vec::new(); } } - register_trace = register::generate_register_trace(®ister_final_state, entry_point); - halt_trace = halt::generate_halt_trace(halt_timestamp); + register_trace = + register::generate_register_trace(&routed.register_final_state, routed.entry_point); + halt_trace = halt::generate_halt_trace(routed.halt_timestamp); } // Fixed-size and per-page tables aren't built through `chunk_and_generate`, @@ -2466,31 +2893,273 @@ fn build_traces( } } - Ok(Traces { - cpus, + Ok(PreprocessedTables { bitwise, - lts, - shifts, - memws, - memw_aligneds, - loads, decode, - muls, - dvrms, - pages, - page_configs, - register: register_trace, - public_output_bytes, - branches, - halt: halt_trace, commit: commit_trace, keccak: keccak_trace, keccak_rnd: keccak_rnd_trace, keccak_rc: keccak_rc_trace, - memw_registers, + pages, + page_configs, + register: register_trace, + halt: halt_trace, }) } +/// Run-length encode the non-ELF (zero-init) page bases of `configs` into +/// `(base, count)` runtime page ranges. Shared by [`Traces`] and +/// [`StreamingTraces`]. +fn runtime_page_ranges_from_configs(configs: &[PageConfig]) -> Vec { + let page_size = page::DEFAULT_PAGE_SIZE as u64; + + let runtime_bases: Vec = configs + .iter() + .filter(|config| config.init_values.is_none()) + .map(|config| config.page_base) + .collect(); + + let mut ranges = Vec::new(); + if runtime_bases.is_empty() { + return ranges; + } + + let mut start = runtime_bases[0]; + let mut count = 1u64; + + for &base in &runtime_bases[1..] { + if base == start + count * page_size { + count += 1; + } else { + ranges.push(crate::RuntimePageRange { base: start, count }); + start = base; + count = 1; + } + } + ranges.push(crate::RuntimePageRange { base: start, count }); + + ranges +} + +/// One slot in the prover's air ordering: a resident (preprocessed/PAGE) table +/// or a retired log-derived chunk built on demand from the routed intermediate. +#[derive(Clone, Copy)] +enum StreamingSlot { + /// Resident table — built once, lives in the air-ordered resident-trace Vec. + /// Never built on demand; the prover uses the resident trace borrowed in + /// `air_trace_pairs`. + Resident, + /// Retired log-derived table chunk `(kind, chunk_idx)`, built on demand. + Retired(TableKind, usize), +} + +/// The on-demand half of streaming "retire-traces" mode (C.2b): the routed +/// intermediate plus the air-order slot map. Implements +/// [`stark::prover::TraceProvider`] so the prover can (re)build each retired +/// log-derived chunk on demand and drop it. Owns `routed` (lifetime `'a`) but no +/// materialized log-derived traces. +pub struct StreamingProvider<'a> { + routed: RoutedTraceData<'a>, + max_rows: super::MaxRowsConfig, + /// Air-order slot map (length = number of AIRs). + slots: Vec, +} + +/// The resident half of streaming mode: the air-ordered trace Vec (real +/// preprocessed / PAGE traces at resident slots, empty placeholders at retired +/// slots) plus the derived metadata the proving entry point needs. The prover +/// borrows `traces` mutably to form `air_trace_pairs`. +pub struct StreamingResident { + /// Air-ordered traces: real at resident slots, empty placeholder otherwise. + pub traces: Vec>, + pub table_counts: crate::TableCounts, + pub page_configs: Vec, + pub public_output_bytes: Vec, +} + +impl StreamingResident { + pub fn runtime_page_ranges(&self) -> Vec { + runtime_page_ranges_from_configs(&self.page_configs) + } +} + +/// Run PHASES 0-4 (routing) and build the preprocessed/PAGE tables, then return +/// the split (`provider`, `resident`) for streaming "retire-traces" proving. +/// Mirrors [`Traces::from_elf_and_logs`] up to (but not including) the PHASE-5 +/// fill of the chunked log-derived tables. +/// +/// The air-order slot map MUST match `VmAirs::air_trace_pairs`: 8 preprocessed +/// tables, then cpus, lts, shifts, memws, memw_aligneds, loads, muls, dvrms, +/// branches, pages, memw_registers (PAGE is resident, not retired). +pub fn route_for_streaming<'a>( + elf: &'a Elf, + logs: &[Log], + max_rows: &super::MaxRowsConfig, + private_input: &'a [u8], + #[cfg(feature = "disk-spill")] storage_mode: StorageMode, +) -> Result<(StreamingProvider<'a>, StreamingResident), Error> { + let instructions = decode::instructions_from_elf(elf) + .map_err(|e| Error::Execution(format!("Failed to parse instructions: {e}")))?; + let (decode_trace, decode_pc_to_row) = decode::generate_decode_trace(&instructions); + + let cpu_ops = collect_cpu_ops(logs, &instructions)?; + + let mut memory_state = MemoryState::from_elf(elf); + memory_state.add_private_input(private_input); + let mut register_state = RegisterState::new(elf.entry_point); + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = + collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); + + let ops = collect_all_ops( + cpu_ops, + memw_ops, + load_ops, + lt_ops, + shift_ops, + bitwise_ops, + commit_ops, + keccak_ops, + &mut register_state, + ); + + let routed = route( + ops, + Some(elf), + memory_state, + elf.entry_point, + decode_trace, + decode_pc_to_row, + register_state, + max_rows, + private_input, + )?; + + let pre = build_preprocessed_tables( + &routed, + #[cfg(feature = "disk-spill")] + storage_mode, + )?; + + let public_output_bytes = routed.public_output_bytes.clone(); + + // Per-kind chunk counts (drive table_counts and the slot ranges). + let count = |k: TableKind| routed.num_chunks(k, max_rows); + let table_counts = crate::TableCounts { + cpu: count(TableKind::Cpu), + lt: count(TableKind::Lt), + shift: count(TableKind::Shift), + memw: count(TableKind::Memw), + memw_aligned: count(TableKind::MemwAligned), + load: count(TableKind::Load), + mul: count(TableKind::Mul), + dvrm: count(TableKind::Dvrm), + branch: count(TableKind::Branch), + memw_register: count(TableKind::MemwRegister), + }; + + // Destructure the preprocessed tables; they move into the resident Vec in + // air order, interleaved with empty placeholders for the retired chunks. + let PreprocessedTables { + bitwise, + decode, + commit, + keccak, + keccak_rnd, + keccak_rc, + pages, + page_configs, + register, + halt, + } = pre; + + let mut slots: Vec = Vec::new(); + let mut traces: Vec> = Vec::new(); + + // 8 preprocessed tables, in air order (bitwise, decode, halt, commit, + // keccak, keccak_rnd, keccak_rc, register) — all resident. + for t in [bitwise, decode, halt, commit, keccak, keccak_rnd, keccak_rc, register] { + slots.push(StreamingSlot::Resident); + traces.push(t); + } + + // Helper: append all retired chunks of a kind (empty placeholder traces). + let push_kind = |slots: &mut Vec, + traces: &mut Vec>, + kind: TableKind| { + for chunk in 0..count(kind) { + slots.push(StreamingSlot::Retired(kind, chunk)); + traces.push(TraceTable::new_main(Vec::new(), 1, 1)); + } + }; + push_kind(&mut slots, &mut traces, TableKind::Cpu); + push_kind(&mut slots, &mut traces, TableKind::Lt); + push_kind(&mut slots, &mut traces, TableKind::Shift); + push_kind(&mut slots, &mut traces, TableKind::Memw); + push_kind(&mut slots, &mut traces, TableKind::MemwAligned); + push_kind(&mut slots, &mut traces, TableKind::Load); + push_kind(&mut slots, &mut traces, TableKind::Mul); + push_kind(&mut slots, &mut traces, TableKind::Dvrm); + push_kind(&mut slots, &mut traces, TableKind::Branch); + // PAGE tables (resident, one slot per page). + for page in pages { + slots.push(StreamingSlot::Resident); + traces.push(page); + } + push_kind(&mut slots, &mut traces, TableKind::MemwRegister); + + let provider = StreamingProvider { + routed, + max_rows: max_rows.clone(), + slots, + }; + let resident = StreamingResident { + traces, + table_counts, + page_configs, + public_output_bytes, + }; + Ok((provider, resident)) +} + +impl stark::prover::TraceProvider + for StreamingProvider<'_> +{ + fn is_retired(&self, idx: usize) -> bool { + matches!(self.slots[idx], StreamingSlot::Retired(..)) + } + + fn num_rows(&self, idx: usize) -> usize { + match self.slots[idx] { + // Several log-derived tables (LT/MUL/DVRM/BRANCH) pad to the count of + // their *deduplicated* ops, which a raw op-count cannot predict, so + // build the chunk to read its true padded row count (then drop it). + // One extra transient chunk build; cheap relative to the LDE work. + StreamingSlot::Retired(..) => self.build_main(idx).num_rows(), + // For resident slots the prover reads num_rows from the resident + // trace directly; this path is unused but kept total. + StreamingSlot::Resident => 0, + } + } + + fn build_main(&self, idx: usize) -> TraceTable { + match self.slots[idx] { + StreamingSlot::Retired(kind, chunk) => self + .routed + .build_chunk( + kind, + chunk, + &self.max_rows, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, + ) + .expect("retired chunk build is infallible in RAM mode"), + StreamingSlot::Resident => { + unreachable!("build_main called on a resident (non-retired) slot {idx}") + } + } + } +} + /// Padded row count after chunking. #[cfg(feature = "disk-spill")] fn padded_chunked_rows(ops_count: usize, max_rows: usize) -> u64 { @@ -3014,37 +3683,7 @@ impl Traces { /// Runtime (non-ELF) pages are identified by `init_values == None` /// (zero-init), avoiding a redundant ELF segment scan. pub fn runtime_page_ranges(&self) -> Vec { - let page_size = page::DEFAULT_PAGE_SIZE as u64; - - // Collect sorted non-ELF page bases (zero-init pages are runtime pages) - let runtime_bases: Vec = self - .page_configs - .iter() - .filter(|config| config.init_values.is_none()) - .map(|config| config.page_base) - .collect(); - - // Run-length encode contiguous pages into (base, count) ranges - let mut ranges = Vec::new(); - if runtime_bases.is_empty() { - return ranges; - } - - let mut start = runtime_bases[0]; - let mut count = 1u64; - - for &base in &runtime_bases[1..] { - if base == start + count * page_size { - count += 1; - } else { - ranges.push(crate::RuntimePageRange { base: start, count }); - start = base; - count = 1; - } - } - ranges.push(crate::RuntimePageRange { base: start, count }); - - ranges + runtime_page_ranges_from_configs(&self.page_configs) } /// Generates all traces from ELF and execution logs using phased collection. @@ -3096,7 +3735,7 @@ impl Traces { build_traces( ops, Some(elf), - &memory_state, + memory_state, elf.entry_point, decode_trace, decode_pc_to_row, @@ -3148,7 +3787,7 @@ impl Traces { build_traces( ops, None, - &memory_state, + memory_state, entry_point, decode_trace, decode_pc_to_row, @@ -3160,6 +3799,75 @@ impl Traces { ) } + /// TEST ONLY: build every log-derived chunked execution table via the + /// `route` + per-table `build_table` split (PHASES 1-4 then per-table + /// PHASE-5 fill), returning the per-table chunk Vecs in the same field + /// order as [`Traces`]. Used to assert byte-identity against the monolithic + /// [`from_logs`]/[`build_traces`] path. + #[cfg(test)] + #[allow(clippy::type_complexity)] + pub(crate) fn chunked_tables_via_route( + logs: &[Log], + instructions: U64HashMap, + max_rows: &super::MaxRowsConfig, + ) -> Result { + // PHASES 1-2 (mirror of `from_logs`). + let cpu_ops = collect_cpu_ops(logs, &instructions)?; + let mut memory_state = MemoryState::new(); + let entry_point = cpu_ops.first().map_or(0, |op| op.decode.pc); + let mut register_state = RegisterState::new(entry_point); + let (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops) = + collect_ops_from_cpu(&cpu_ops, &mut memory_state, &mut register_state); + let ops = collect_all_ops( + cpu_ops, + memw_ops, + load_ops, + lt_ops, + shift_ops, + bitwise_ops, + commit_ops, + keccak_ops, + &mut register_state, + ); + let (decode_trace, decode_pc_to_row) = decode::generate_decode_trace(&instructions); + + // PHASES 3-4: route once. All cross-table multiplicity coupling is + // resolved here before any per-table fill reads it. + let routed = route( + ops, + None, + memory_state, + entry_point, + decode_trace, + decode_pc_to_row, + register_state, + max_rows, + &[], + )?; + + // PHASE 5: per-table fill, one `build_table` call per table. + let build = |which: TableKind| { + routed.build_table( + which, + max_rows, + #[cfg(feature = "disk-spill")] + StorageMode::Ram, + ) + }; + Ok(ChunkedTablesViaRoute { + cpus: build(TableKind::Cpu)?, + lts: build(TableKind::Lt)?, + shifts: build(TableKind::Shift)?, + memws: build(TableKind::Memw)?, + memw_aligneds: build(TableKind::MemwAligned)?, + loads: build(TableKind::Load)?, + muls: build(TableKind::Mul)?, + dvrms: build(TableKind::Dvrm)?, + branches: build(TableKind::Branch)?, + memw_registers: build(TableKind::MemwRegister)?, + }) + } + /// Generates all traces with a trimmed bitwise table (TEST ONLY). /// /// # WARNING: UNSOUND FOR PRODUCTION diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 4fcdba7f4..76cef638d 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -52,3 +52,5 @@ pub mod statement_tests; pub mod templates_tests; #[cfg(test)] pub mod trace_builder_tests; + +pub mod streaming_retire_tests; diff --git a/prover/src/tests/streaming_retire_tests.rs b/prover/src/tests/streaming_retire_tests.rs new file mode 100644 index 000000000..e95001742 --- /dev/null +++ b/prover/src/tests/streaming_retire_tests.rs @@ -0,0 +1,76 @@ +//! Byte-identical proof test for streaming "retire-traces" mode (C.2b). +//! +//! The decisive correctness invariant: a proof produced with the full streaming +//! mode ON (`LAMBDA_STREAM_LDE=1`, which retires the LDE, Merkle leaves AND the +//! log-derived traces, rebuilding each on demand from a compact routed +//! intermediate) must be byte-identical to one produced with it OFF. C.2a made +//! the trace build deterministic, so the on-demand-rebuilt trace equals the +//! pre-built one and the proof matches exactly. If they differ, the on-demand +//! rebuild has diverged from the pre-built trace. +//! +//! Marked `#[ignore]` because the only execution fixture that runs end-to-end on +//! this branch's executor (`fib_iterative_1200k`) is heavy (~1.2M cycles → many +//! chunks per table). Run explicitly with: +//! `cargo test -p lambda-vm-prover --lib -- --ignored streaming_retire` + +use std::sync::Mutex; + +use stark::proof::options::GoldilocksCubicProofOptions; + +use crate::MaxRowsConfig; +use crate::test_utils::asm_elf_bytes; + +/// `LAMBDA_STREAM_LDE` is process-global; serialize the two proving runs. +static ENV_LOCK: Mutex<()> = Mutex::new(()); + +/// Prove `elf_bytes` once with streaming `on`/off, returning the serialized +/// `VmProof` bytes. Restores the prior env-var value. Also asserts the proof +/// verifies regardless of which mode produced it. +fn prove_serialized(elf_bytes: &[u8], max_rows: &MaxRowsConfig, on: bool) -> Vec { + let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner()); + let prev = std::env::var("LAMBDA_STREAM_LDE").ok(); + + // SAFETY: single-threaded section guarded by ENV_LOCK; restored below. + unsafe { std::env::set_var("LAMBDA_STREAM_LDE", if on { "1" } else { "0" }) }; + + // Grinding is disabled (factor 0): the parallel grinding nonce search uses + // `find_any`, which is non-deterministic, so any grinding>0 makes the proof + // bytes vary run-to-run regardless of streaming mode. With grinding=0 the + // proof is fully deterministic, isolating the trace-rebuild invariant. + let options = GoldilocksCubicProofOptions::with_params(2, 128, 0) + .expect("grinding=0 options are valid"); + let proof = crate::prove_with_options(elf_bytes, &options, max_rows) + .expect("prove must succeed"); + + match prev { + Some(v) => unsafe { std::env::set_var("LAMBDA_STREAM_LDE", v) }, + None => unsafe { std::env::remove_var("LAMBDA_STREAM_LDE") }, + } + + assert!( + crate::verify_with_options(&proof, elf_bytes, &options).expect("verify must run"), + "proof failed to verify" + ); + + bincode::serialize(&proof).expect("serialize VmProof") +} + +/// Assert the streaming-ON and streaming-OFF proofs are byte-identical for a +/// given ELF / max_rows configuration. +fn assert_byte_identical(name: &str, max_rows: MaxRowsConfig) { + let elf_bytes = asm_elf_bytes(name); + let off = prove_serialized(&elf_bytes, &max_rows, false); + let on = prove_serialized(&elf_bytes, &max_rows, true); + assert_eq!( + off, on, + "streaming retire-traces proof for `{name}` must be byte-identical to the resident proof" + ); +} + +/// Multi-chunk case: many chunks per log-derived table (PAGE interleaved between +/// BRANCH and MEMW_R), exercising the air-order -> (TableKind, chunk) mapping. +#[test] +#[ignore = "heavy execution fixture (~1.2M cycles); run with --ignored"] +fn streaming_retire_proof_is_byte_identical_chunked() { + assert_byte_identical("fib_iterative_1200k", MaxRowsConfig::default()); +} diff --git a/prover/src/tests/trace_builder_tests.rs b/prover/src/tests/trace_builder_tests.rs index 199ce71db..e2d41b857 100644 --- a/prover/src/tests/trace_builder_tests.rs +++ b/prover/src/tests/trace_builder_tests.rs @@ -822,3 +822,207 @@ mod routing_tests { ); } } + + +// ============================================================================= +// route + per-table build_table byte-identity (C.1 trace-retire step) +// ============================================================================= + +/// THE invariant for C.1: building each log-derived execution table via +/// `route` + per-table `build_table` reproduces byte-identical traces to the +/// monolithic `from_logs`/`build_traces` path. +#[test] +fn test_route_then_build_table_matches_monolithic() { + // A mixed program touching multiple tables: Add (CPU/MEMW), SLT (LT), + // AND (bitwise), BLT (branch+LT), plus the halting ecall. + let mut logs = vec![ + make_add_log(0x1000, 100, 200, 300), + make_slt_log(0x1004, 5, 10, 1), + make_and_log(0x1008, 0xFF, 0xF0, 0xF0), + make_blt_log(0x100c, 1, 2, true), + make_add_log(0x1010, 7, 8, 15), + ]; + let mut instrs = vec![ + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::Add, + }, + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::SetLessThan, + }, + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::And, + }, + Instruction::Branch { + src1: 2, + src2: 3, + cond: Comparison::LessThan, + offset: 8, + }, + Instruction::Arith { + dst: 1, + src1: 2, + src2: 3, + op: ArithOp::Add, + }, + ]; + append_ecall(&mut logs, &mut instrs); + let instructions = make_instructions(&logs, &instrs); + let max_rows = Default::default(); + + // Monolithic build (reference). + let mono = Traces::from_logs(&logs, instructions.clone(), &max_rows).unwrap(); + + // route() once, then per-table build_table() for every log-derived table. + let split = Traces::chunked_tables_via_route(&logs, instructions, &max_rows).unwrap(); + + // Compare each per-table chunk against the monolithic output. + // + // `TraceTable`'s derived `PartialEq` is not usable here (the extension-field + // marker type param does not implement `PartialEq`), so we compare the + // underlying flat column data plus dimensions directly. + // + // NOTE: several `generate_*_trace` fns deduplicate rows via a + // `HashMap` and emit `op_map.into_iter()`, whose order is + // randomized per `HashMap` instance. That row ordering is therefore + // pre-existing nondeterminism in the monolithic build itself (two + // back-to-back `from_logs` calls already disagree on LT row order). Since + // the route/build split must reproduce the same ROWS (not a particular + // ordering), we compare each table as a multiset of rows: sort the rows of + // both sides, then assert byte-equality. This is decisive for "same trace + // content" while being immune to the pre-existing ordering nondeterminism. + type TT = stark::trace::TraceTable< + crate::tables::types::GoldilocksField, + crate::tables::types::GoldilocksExtension, + >; + fn sorted_rows(t: &TT) -> Vec> { + let w = t.main_table.width; + let mut rows: Vec> = t + .main_table + .data + .chunks(w.max(1)) + .map(|row| row.iter().map(|fe| *fe.value()).collect()) + .collect(); + rows.sort(); + rows + } + fn assert_tables_eq(split: &[TT], mono: &[TT], name: &str) { + assert_eq!(split.len(), mono.len(), "{name}: chunk count differs"); + for (i, (s, m)) in split.iter().zip(mono.iter()).enumerate() { + assert_eq!( + s.main_table.width, m.main_table.width, + "{name} chunk {i}: width differs" + ); + assert_eq!( + s.main_table.height, m.main_table.height, + "{name} chunk {i}: height differs" + ); + assert_eq!( + s.num_main_columns, m.num_main_columns, + "{name} chunk {i}: num_main_columns differs" + ); + assert_eq!( + s.num_aux_columns, m.num_aux_columns, + "{name} chunk {i}: num_aux_columns differs" + ); + assert_eq!( + s.step_size, m.step_size, + "{name} chunk {i}: step_size differs" + ); + assert_eq!( + sorted_rows(s), + sorted_rows(m), + "{name} chunk {i}: row multiset differs" + ); + } + } + assert_tables_eq(&split.cpus, &mono.cpus, "CPU"); + assert_tables_eq(&split.lts, &mono.lts, "LT"); + assert_tables_eq(&split.shifts, &mono.shifts, "SHIFT"); + assert_tables_eq(&split.memws, &mono.memws, "MEMW"); + assert_tables_eq(&split.memw_aligneds, &mono.memw_aligneds, "MEMW_A"); + assert_tables_eq(&split.loads, &mono.loads, "LOAD"); + assert_tables_eq(&split.muls, &mono.muls, "MUL"); + assert_tables_eq(&split.dvrms, &mono.dvrms, "DVRM"); + assert_tables_eq(&split.branches, &mono.branches, "BRANCH"); + assert_tables_eq(&split.memw_registers, &mono.memw_registers, "MEMW_R"); + + // Sanity: the program actually populated several of these tables. + assert!(!mono.cpus.is_empty()); + assert!(mono.cpus[0].main_table.height >= 6); + assert!(mono.lts[0].main_table.height >= 2, "expected LT rows"); + assert!(mono.branches[0].main_table.height >= 1, "expected BRANCH rows"); +} + +/// Two `from_logs` builds of the same input must produce byte-identical traces. +/// +/// Before the determinism fix, the LT/MUL/DVRM/BRANCH builders emitted +/// `HashMap::into_iter()` order (randomized per instance), so two back-to-back +/// builds disagreed on row order. They now sort the deduplicated ops by a +/// canonical key, which is a prerequisite for streaming on-demand trace rebuild +/// (Phase-A commit vs Rounds-2-4 rebuild must agree, or the committed root won't +/// match the reconstructed trace). +#[test] +fn trace_build_is_deterministic_across_builds() { + type TT = stark::trace::TraceTable< + crate::tables::types::GoldilocksField, + crate::tables::types::GoldilocksExtension, + >; + + // Several DISTINCT LT and branch ops so each dedup'd `unique_ops` has + // multiple elements whose order previously varied per HashMap instance. + let mut logs = vec![ + make_slt_log(0x1000, 5, 10, 1), + make_slt_log(0x1004, 200, 7, 0), + make_slt_log(0x1008, 42, 42, 0), + make_slt_log(0x100c, 1, 999, 1), + make_blt_log(0x1010, 3, 4, true), + make_blt_log(0x1014, 50, 9, false), + make_blt_log(0x1018, 77, 77, false), + ]; + let mut instrs = vec![ + Instruction::Arith { dst: 1, src1: 2, src2: 3, op: ArithOp::SetLessThan }, + Instruction::Arith { dst: 1, src1: 2, src2: 3, op: ArithOp::SetLessThan }, + Instruction::Arith { dst: 1, src1: 2, src2: 3, op: ArithOp::SetLessThan }, + Instruction::Arith { dst: 1, src1: 2, src2: 3, op: ArithOp::SetLessThan }, + Instruction::Branch { src1: 2, src2: 3, cond: Comparison::LessThan, offset: 8 }, + Instruction::Branch { src1: 2, src2: 3, cond: Comparison::LessThan, offset: 8 }, + Instruction::Branch { src1: 2, src2: 3, cond: Comparison::LessThan, offset: 8 }, + ]; + append_ecall(&mut logs, &mut instrs); + let instructions = make_instructions(&logs, &instrs); + let max_rows = Default::default(); + + let a = Traces::from_logs(&logs, instructions.clone(), &max_rows).unwrap(); + let b = Traces::from_logs(&logs, instructions, &max_rows).unwrap(); + + fn flat(t: &TT) -> Vec { + t.main_table.data.iter().map(|fe| *fe.value()).collect() + } + fn eq(x: &[TT], y: &[TT], name: &str) { + assert_eq!(x.len(), y.len(), "{name}: chunk count differs across builds"); + for (i, (s, m)) in x.iter().zip(y.iter()).enumerate() { + assert_eq!( + flat(s), + flat(m), + "{name} chunk {i}: trace data differs across builds (non-deterministic order)" + ); + } + } + eq(&a.lts, &b.lts, "LT"); + eq(&a.muls, &b.muls, "MUL"); + eq(&a.dvrms, &b.dvrms, "DVRM"); + eq(&a.branches, &b.branches, "BRANCH"); + eq(&a.cpus, &b.cpus, "CPU"); + eq(&a.memws, &b.memws, "MEMW"); + eq(&a.shifts, &b.shifts, "SHIFT"); + eq(&a.loads, &b.loads, "LOAD"); +}