diff --git a/crypto/math-cuda/build.rs b/crypto/math-cuda/build.rs index 25c16a634..bc8fdeb58 100644 --- a/crypto/math-cuda/build.rs +++ b/crypto/math-cuda/build.rs @@ -112,4 +112,7 @@ fn main() { compile_ptx("ntt.cu", "ntt.ptx", have_nvcc); compile_ptx("keccak.cu", "keccak.ptx", have_nvcc); compile_ptx("barycentric.cu", "barycentric.ptx", have_nvcc); + compile_ptx("inverse.cu", "inverse.ptx", have_nvcc); + compile_ptx("deep.cu", "deep.ptx", have_nvcc); + compile_ptx("fri.cu", "fri.ptx", have_nvcc); } diff --git a/crypto/math-cuda/kernels/deep.cu b/crypto/math-cuda/kernels/deep.cu new file mode 100644 index 000000000..4d67579c6 --- /dev/null +++ b/crypto/math-cuda/kernels/deep.cu @@ -0,0 +1,117 @@ +// R4 deep composition polynomial evaluations. +// +// For each trace-size row i in 0..domain_size, accumulate: +// result_i = sum over j of gamma_j * (H_j(x_i) - H_j(z^K)) * inv_h[i] (H terms) +// + sum over j,k of gamma'_{j,k} * (t_j(x_i) - t_j(z*w^k)) * inv_t[k,i] (trace) +// +// where x_i = LDE coset point at stride `blowup_factor` (so the kernel +// reads LDE column data at `i * blowup_factor`). `j` ranges over +// num_parts for H-terms and num_total_cols (= num_main + num_aux) for +// trace terms. `k` ranges over num_eval_points. +// +// Buffer layouts (ALL on device): +// main_lde base, row-major per column: main_lde[c * lde_stride + r] +// aux_lde ext3 de-interleaved: aux_lde[(c*3 + k) * lde_stride + r] +// h_lde ext3 de-interleaved: h_lde[(p*3 + k) * lde_stride + r] +// h_ood num_parts * 3 (ext3 interleaved) +// trace_ood num_total_cols * num_eval_points * 3 (ext3 interleaved, +// indexed as (col_idx * num_eval_points + k) * 3 + comp) +// gammas_h num_parts * 3 +// gammas_tr num_total_cols * num_eval_points * 3 +// inv_h domain_size * 3 +// inv_t num_eval_points * domain_size * 3 +// deep_out domain_size * 3 (ext3 interleaved; caller reinterprets) + +#include "goldilocks.cuh" +#include "ext3.cuh" + +extern "C" __global__ void deep_composition_ext3_row( + const uint64_t *main_lde, + const uint64_t *aux_lde, + const uint64_t *h_lde, + uint64_t lde_stride, + uint64_t num_main, + uint64_t num_aux, + uint64_t num_parts, + uint64_t num_eval_points, + uint64_t blowup_factor, + uint64_t domain_size, + const uint64_t *h_ood, + const uint64_t *trace_ood, + const uint64_t *gammas_h, + const uint64_t *gammas_tr, + const uint64_t *inv_h, + const uint64_t *inv_t, + uint64_t *deep_out) { + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i >= domain_size) return; + uint64_t row = i * blowup_factor; + + ext3::Fe3 result = ext3::zero(); + ext3::Fe3 inv_h_i = {inv_h[i * 3], inv_h[i * 3 + 1], inv_h[i * 3 + 2]}; + + // H-terms + for (uint64_t j = 0; j < num_parts; ++j) { + ext3::Fe3 h_val = { + h_lde[(j * 3 + 0) * lde_stride + row], + h_lde[(j * 3 + 1) * lde_stride + row], + h_lde[(j * 3 + 2) * lde_stride + row], + }; + ext3::Fe3 h_ood_j = {h_ood[j * 3], h_ood[j * 3 + 1], h_ood[j * 3 + 2]}; + ext3::Fe3 num = ext3::sub(h_val, h_ood_j); + ext3::Fe3 gamma = {gammas_h[j * 3], gammas_h[j * 3 + 1], gammas_h[j * 3 + 2]}; + ext3::Fe3 tmp = ext3::mul(gamma, num); + tmp = ext3::mul(tmp, inv_h_i); + result = ext3::add(result, tmp); + } + + uint64_t num_total_cols = num_main + num_aux; + + // Main trace terms (base column - ext3 OOD) + for (uint64_t j = 0; j < num_main; ++j) { + uint64_t t_val = main_lde[j * lde_stride + row]; + for (uint64_t k = 0; k < num_eval_points; ++k) { + uint64_t idx = (j * num_eval_points + k) * 3; + ext3::Fe3 t_ood = {trace_ood[idx], trace_ood[idx + 1], trace_ood[idx + 2]}; + ext3::Fe3 num = { + goldilocks::sub(t_val, t_ood.a), + goldilocks::neg(t_ood.b), + goldilocks::neg(t_ood.c), + }; + ext3::Fe3 gamma = {gammas_tr[idx], gammas_tr[idx + 1], gammas_tr[idx + 2]}; + uint64_t inv_t_idx = (k * domain_size + i) * 3; + ext3::Fe3 inv_t_ki = {inv_t[inv_t_idx], inv_t[inv_t_idx + 1], inv_t[inv_t_idx + 2]}; + ext3::Fe3 tmp = ext3::mul(gamma, num); + tmp = ext3::mul(tmp, inv_t_ki); + result = ext3::add(result, tmp); + } + } + + // Aux trace terms (ext3 column - ext3 OOD) + for (uint64_t j = 0; j < num_aux; ++j) { + ext3::Fe3 t_val = { + aux_lde[(j * 3 + 0) * lde_stride + row], + aux_lde[(j * 3 + 1) * lde_stride + row], + aux_lde[(j * 3 + 2) * lde_stride + row], + }; + uint64_t trace_j = num_main + j; + for (uint64_t k = 0; k < num_eval_points; ++k) { + uint64_t idx = (trace_j * num_eval_points + k) * 3; + ext3::Fe3 t_ood = {trace_ood[idx], trace_ood[idx + 1], trace_ood[idx + 2]}; + ext3::Fe3 num = ext3::sub(t_val, t_ood); + ext3::Fe3 gamma = {gammas_tr[idx], gammas_tr[idx + 1], gammas_tr[idx + 2]}; + uint64_t inv_t_idx = (k * domain_size + i) * 3; + ext3::Fe3 inv_t_ki = {inv_t[inv_t_idx], inv_t[inv_t_idx + 1], inv_t[inv_t_idx + 2]}; + ext3::Fe3 tmp = ext3::mul(gamma, num); + tmp = ext3::mul(tmp, inv_t_ki); + result = ext3::add(result, tmp); + } + } + + uint64_t out_idx = i * 3; + deep_out[out_idx + 0] = result.a; + deep_out[out_idx + 1] = result.b; + deep_out[out_idx + 2] = result.c; + // Suppress unused param warning when num_total_cols not referenced. + (void)num_total_cols; +} diff --git a/crypto/math-cuda/kernels/fri.cu b/crypto/math-cuda/kernels/fri.cu new file mode 100644 index 000000000..da4e7cc4d --- /dev/null +++ b/crypto/math-cuda/kernels/fri.cu @@ -0,0 +1,59 @@ +// R4 FRI fold + twiddle-update kernels on device. The host orchestrator +// loops log2(N) times: sample zeta on host, fold on device, keccak leaves +// + tree on device, D2H the root, transcript-append on host, update +// twiddles on device. +// +// Layout: ext3 evaluations are stored INTERLEAVED as +// `[a0,b0,c0, a1,b1,c1, ...]`, same layout the deep-poly LDE output +// already produces. Twiddles are base-field, one u64 per entry. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +// fold_evaluations_in_place: +// out[j] = (lo + hi) + inv_tw[j] * zeta * (lo - hi) +// where lo = evals[2j], hi = evals[2j+1]. Both lo/hi and zeta are ext3. +// inv_tw[j] is a base-field twiddle (F * E -> E). +// +// Writes N/2 ext3 outputs (3 * n_out u64 total) into `out`. `in` is the +// previous layer of 2 * n_out ext3 values (6 * n_out u64 total). +extern "C" __global__ void fri_fold_ext3( + const uint64_t *in, // 3 * 2*n_out u64 (ext3 interleaved) + uint64_t n_out, // number of output ext3 elements (= N/2) + const uint64_t *inv_tw, // n_out base-field twiddles + const uint64_t *zeta, // 3 u64 (ext3) + uint64_t *out) { // 3 * n_out u64 (ext3 interleaved) + uint64_t j = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (j >= n_out) return; + + const uint64_t *lo_p = in + 2 * j * 3; + const uint64_t *hi_p = lo_p + 3; + + ext3::Fe3 lo = ext3::make(lo_p[0], lo_p[1], lo_p[2]); + ext3::Fe3 hi = ext3::make(hi_p[0], hi_p[1], hi_p[2]); + ext3::Fe3 sum = ext3::add(lo, hi); + ext3::Fe3 diff = ext3::sub(lo, hi); + + ext3::Fe3 z = ext3::make(zeta[0], zeta[1], zeta[2]); + ext3::Fe3 zd = ext3::mul(z, diff); // ext3 * ext3 = ext3 + uint64_t tw = inv_tw[j]; + ext3::Fe3 tzd = ext3::mul_base(zd, tw); // base * ext3 = ext3 (componentwise) + ext3::Fe3 res = ext3::add(sum, tzd); + + uint64_t *out_p = out + j * 3; + out_p[0] = res.a; + out_p[1] = res.b; + out_p[2] = res.c; +} + +// update_twiddles_in_place: new[j] = old[2j]^2. Writes in-place. Caller +// must ensure the kernel is not reading the same index concurrently. Since +// we read `old[2j]` and write `new[j]` with j < 2j, there's no aliasing. +extern "C" __global__ void fri_update_twiddles( + uint64_t *tw, + uint64_t n_out) { + uint64_t j = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (j >= n_out) return; + uint64_t old = tw[2 * j]; + tw[j] = goldilocks::mul(old, old); +} diff --git a/crypto/math-cuda/kernels/inverse.cu b/crypto/math-cuda/kernels/inverse.cu new file mode 100644 index 000000000..67805ef6d --- /dev/null +++ b/crypto/math-cuda/kernels/inverse.cu @@ -0,0 +1,296 @@ +// Parallel Montgomery batch inverse over ext3, plus a compute-denoms +// helper for R3 OOD / R4 DEEP preludes. +// +// Batch inverse strategy (chunk-based parallel scan): +// +// 1. Chunk-local forward scan: each thread serially computes the +// prefix product of its chunk of `C = ceil(N / K)` ext3 values; +// writes the chunk output in place and posts its chunk total to +// `chunk_totals[thread_id]`. +// 2. Single-block scan of `chunk_totals` (K <= 1024 for our shapes, +// fits one block). +// 3. Chunk-local apply: each thread multiplies its chunk's local +// prefix by the exclusive-scan offset from step 2, producing the +// global forward prefix. +// 4. Mirror (1-3) in reverse for the suffix. +// 5. Single-thread kernel inverts total = prefix[N-1]. +// 6. Pointwise combine: `inv[i] = prefix[i-1] * suffix[i+1] * inv_total` +// (with prefix[-1] = suffix[N] = 1). One thread per element. +// +// Ext3 multiply is commutative in the field (it's a field, not just a +// ring), so prefix-product scans are well-defined. Layout is ext3 +// INTERLEAVED: one u64 triple per element, 3*N u64s total. + +#include "goldilocks.cuh" +#include "ext3.cuh" + +#define INV_BLOCK 256 + +// --------------------------------------------------------------------------- +// B.1: compute denoms for R4 DEEP and R3 OOD. +// +// denoms[k*n + i] = x[i * stride] - z[k] +// where `x` is a base-field coset (read at stride `stride`), `z` is an +// ext3 array of `k_scalars` entries (z^K and/or z * omega^k), and `n` is the +// trace-size count. Output is flat ext3 interleaved. +// --------------------------------------------------------------------------- +extern "C" __global__ void compute_denoms_ext3( + const uint64_t *x_base, // base-field LDE coset points + uint64_t stride, // read stride (blowup_factor for R4) + const uint64_t *z_scalars, // k_scalars * 3 u64 (ext3 interleaved) + uint64_t k_scalars, + uint64_t n, + uint64_t *denoms_out) { // k_scalars * n * 3 u64 + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t total = k_scalars * n; + if (tid >= total) return; + + uint64_t k = tid / n; + uint64_t i = tid - k * n; + + uint64_t x_i = x_base[i * stride]; + uint64_t z_a = z_scalars[k * 3 + 0]; + uint64_t z_b = z_scalars[k * 3 + 1]; + uint64_t z_c = z_scalars[k * 3 + 2]; + + // base - ext3 = ext3 ( (x_i - z_a), -z_b, -z_c ) + uint64_t out_a = goldilocks::sub(x_i, z_a); + uint64_t out_b = goldilocks::neg(z_b); + uint64_t out_c = goldilocks::neg(z_c); + + uint64_t out_idx = tid * 3; + denoms_out[out_idx + 0] = out_a; + denoms_out[out_idx + 1] = out_b; + denoms_out[out_idx + 2] = out_c; +} + +// --------------------------------------------------------------------------- +// B.2 chunk-scan primitives for batch inverse. +// +// `a_in` is the input array of N ext3 elements (3*N u64, interleaved). +// `prefix_out` receives prefix[i] = prod(a[0..=i]) for all i. +// `chunk_totals` receives the per-chunk total (one ext3 per chunk). +// +// Each thread owns a contiguous chunk of C elements. With K=256 threads +// per block and a single block, we can handle up to 256*C elements. +// For N up to ~1M, C is around 4096, so one thread does ~4k ext3 multiplies +// serially in shmem-free fashion. Depth = O(C) + O(K) + O(C); with +// K=256 threads running in parallel, the `O(C)` phases parallelise +// perfectly across threads. +// +// For cleanliness, we launch as grid=1, block=K=256. For N up to 2^20 +// that's fine; if we ever need N > 256 * C_max, we'd recurse. +// --------------------------------------------------------------------------- + +// Phase 1 & 3 fused into one kernel would require shmem across phases. +// Splitting makes each kernel simpler. + +// Phase 1: chunk-local forward scan. Also emits chunk_totals. +extern "C" __global__ void chunk_prefix_scan_ext3( + const uint64_t *a_in, // 3 * n u64 (ext3 interleaved) + uint64_t n, + uint64_t c_per_thread, // C = ceil(n / K) + uint64_t *prefix_out, // 3 * n u64 + uint64_t *chunk_totals) { // 3 * K u64 + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + uint64_t end = min(start + c_per_thread, n); + + ext3::Fe3 acc = ext3::one(); + for (uint64_t i = start; i < end; ++i) { + ext3::Fe3 e = {a_in[i * 3 + 0], a_in[i * 3 + 1], a_in[i * 3 + 2]}; + acc = ext3::mul(acc, e); + prefix_out[i * 3 + 0] = acc.a; + prefix_out[i * 3 + 1] = acc.b; + prefix_out[i * 3 + 2] = acc.c; + } + chunk_totals[tid * 3 + 0] = acc.a; + chunk_totals[tid * 3 + 1] = acc.b; + chunk_totals[tid * 3 + 2] = acc.c; +} + +// Phase 2: exclusive prefix scan of chunk_totals, single-threaded. +// scan_out[0] = 1, scan_out[i] = prod(chunk_totals[0..i]). +extern "C" __global__ void exclusive_scan_of_totals_ext3( + const uint64_t *chunk_totals, // 3 * K u64 + uint64_t k, + uint64_t *scan_out) { // 3 * K u64 + if (threadIdx.x != 0 || blockIdx.x != 0) return; + ext3::Fe3 acc = ext3::one(); + scan_out[0] = acc.a; + scan_out[1] = acc.b; + scan_out[2] = acc.c; + for (uint64_t i = 1; i < k; ++i) { + ext3::Fe3 ct = { + chunk_totals[(i - 1) * 3 + 0], + chunk_totals[(i - 1) * 3 + 1], + chunk_totals[(i - 1) * 3 + 2], + }; + acc = ext3::mul(acc, ct); + scan_out[i * 3 + 0] = acc.a; + scan_out[i * 3 + 1] = acc.b; + scan_out[i * 3 + 2] = acc.c; + } +} + +// Phase 3: apply per-chunk offset to local scan result. +// global_prefix[i] = offsets[thread] * local_prefix[i] +extern "C" __global__ void apply_scan_offsets_ext3( + uint64_t *prefix_inout, // 3 * n u64 (written in phase 1, rewritten here) + uint64_t n, + uint64_t c_per_thread, + const uint64_t *offsets) { // 3 * K u64 + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + uint64_t end = min(start + c_per_thread, n); + + ext3::Fe3 off = { + offsets[tid * 3 + 0], + offsets[tid * 3 + 1], + offsets[tid * 3 + 2], + }; + for (uint64_t i = start; i < end; ++i) { + ext3::Fe3 local = { + prefix_inout[i * 3 + 0], + prefix_inout[i * 3 + 1], + prefix_inout[i * 3 + 2], + }; + ext3::Fe3 g = ext3::mul(off, local); + prefix_inout[i * 3 + 0] = g.a; + prefix_inout[i * 3 + 1] = g.b; + prefix_inout[i * 3 + 2] = g.c; + } +} + +// Reverse-scan phase 1: chunk-local reverse prefix. +// suffix_out[i] = prod(a[i..chunk_end]) (within chunk only) +// chunk_totals[tid] = suffix_out[chunk_start] (= full chunk product) +extern "C" __global__ void chunk_suffix_scan_ext3( + const uint64_t *a_in, + uint64_t n, + uint64_t c_per_thread, + uint64_t *suffix_out, + uint64_t *chunk_totals) { + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + // Walk backward; acc starts at 1 and accumulates a[end-1], a[end-2], ... + // Empty chunks (start >= n) fall through with acc = 1 so that + // chunk_totals receives the identity, matching the prefix-scan kernel. + ext3::Fe3 acc = ext3::one(); + if (start < n) { + uint64_t end = min(start + c_per_thread, n); + for (uint64_t ri = end; ri > start; --ri) { + uint64_t i = ri - 1; + ext3::Fe3 e = {a_in[i * 3 + 0], a_in[i * 3 + 1], a_in[i * 3 + 2]}; + acc = ext3::mul(acc, e); + suffix_out[i * 3 + 0] = acc.a; + suffix_out[i * 3 + 1] = acc.b; + suffix_out[i * 3 + 2] = acc.c; + } + } + chunk_totals[tid * 3 + 0] = acc.a; + chunk_totals[tid * 3 + 1] = acc.b; + chunk_totals[tid * 3 + 2] = acc.c; +} + +// Exclusive reverse scan of chunk totals. +// scan_out[K-1] = 1 +// scan_out[k] = prod(chunk_totals[k+1..K]) +extern "C" __global__ void exclusive_reverse_scan_of_totals_ext3( + const uint64_t *chunk_totals, + uint64_t k, + uint64_t *scan_out) { + if (threadIdx.x != 0 || blockIdx.x != 0) return; + ext3::Fe3 acc = ext3::one(); + if (k == 0) return; + scan_out[(k - 1) * 3 + 0] = acc.a; + scan_out[(k - 1) * 3 + 1] = acc.b; + scan_out[(k - 1) * 3 + 2] = acc.c; + for (int64_t i = (int64_t)k - 2; i >= 0; --i) { + ext3::Fe3 ct = { + chunk_totals[(i + 1) * 3 + 0], + chunk_totals[(i + 1) * 3 + 1], + chunk_totals[(i + 1) * 3 + 2], + }; + acc = ext3::mul(acc, ct); + scan_out[i * 3 + 0] = acc.a; + scan_out[i * 3 + 1] = acc.b; + scan_out[i * 3 + 2] = acc.c; + } +} + +// Apply reverse offsets. +extern "C" __global__ void apply_reverse_scan_offsets_ext3( + uint64_t *suffix_inout, + uint64_t n, + uint64_t c_per_thread, + const uint64_t *offsets) { + uint32_t tid = threadIdx.x; + uint64_t start = (uint64_t)tid * c_per_thread; + if (start >= n) return; + uint64_t end = min(start + c_per_thread, n); + + ext3::Fe3 off = { + offsets[tid * 3 + 0], + offsets[tid * 3 + 1], + offsets[tid * 3 + 2], + }; + for (uint64_t i = start; i < end; ++i) { + ext3::Fe3 local = { + suffix_inout[i * 3 + 0], + suffix_inout[i * 3 + 1], + suffix_inout[i * 3 + 2], + }; + ext3::Fe3 g = ext3::mul(off, local); + suffix_inout[i * 3 + 0] = g.a; + suffix_inout[i * 3 + 1] = g.b; + suffix_inout[i * 3 + 2] = g.c; + } +} + +// Same fix for the forward apply_scan_offsets: threads whose chunks are +// empty must not write past end-of-array. (chunk_prefix_scan already +// behaves correctly because the start..end range is empty; apply just +// needs to handle start >= n gracefully. It already does by the same +// empty-range logic. No change needed there, just documenting.) + +// Final combine: inv[i] = pre_excl[i] * suf_excl[i] * inv_total +// where pre_excl[i] = prefix[i-1] (with prefix[-1] = 1) and +// suf_excl[i] = suffix[i+1] (with suffix[N] = 1). +// +// Instead of creating separate pre_excl / suf_excl arrays, we pass the +// inclusive prefix / suffix arrays and shift the index here. +extern "C" __global__ void batch_inverse_combine_ext3( + const uint64_t *prefix_incl, // 3 * n u64; prefix_incl[i] = prod(a[0..=i]) + const uint64_t *suffix_incl, // 3 * n u64; suffix_incl[i] = prod(a[i..n-1]) + const uint64_t *inv_total_ptr, // 3 u64 + uint64_t n, + uint64_t *inv_out) { // 3 * n u64 + uint64_t i = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n) return; + + ext3::Fe3 pre; + if (i == 0) { + pre = ext3::one(); + } else { + pre.a = prefix_incl[(i - 1) * 3 + 0]; + pre.b = prefix_incl[(i - 1) * 3 + 1]; + pre.c = prefix_incl[(i - 1) * 3 + 2]; + } + ext3::Fe3 suf; + if (i + 1 >= n) { + suf = ext3::one(); + } else { + suf.a = suffix_incl[(i + 1) * 3 + 0]; + suf.b = suffix_incl[(i + 1) * 3 + 1]; + suf.c = suffix_incl[(i + 1) * 3 + 2]; + } + ext3::Fe3 inv_tot = {inv_total_ptr[0], inv_total_ptr[1], inv_total_ptr[2]}; + + ext3::Fe3 r = ext3::mul(pre, suf); + r = ext3::mul(r, inv_tot); + + inv_out[i * 3 + 0] = r.a; + inv_out[i * 3 + 1] = r.b; + inv_out[i * 3 + 2] = r.c; +} diff --git a/crypto/math-cuda/src/deep.rs b/crypto/math-cuda/src/deep.rs new file mode 100644 index 000000000..0a9d28c8d --- /dev/null +++ b/crypto/math-cuda/src/deep.rs @@ -0,0 +1,217 @@ +//! R4 deep-composition polynomial evaluations on GPU. +//! +//! Mirrors `Self::compute_deep_composition_poly_evaluations` in +//! `crypto/stark/src/prover.rs`. Accepts the main/aux LDEs as device +//! handles (populated by the R1 fused path in `LDETraceTable`) and +//! takes every other tensor (composition parts LDE, OOD evals, +//! gammas, inv-denoms) from host. Returns a `Vec` of +//! `domain_size * 3` u64s, ext3 interleaved (ready to `transmute` to +//! `FieldElement` when the caller promises layout compatibility). + +use cudarc::driver::{LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; +use crate::lde::{GpuLdeBase, GpuLdeExt3}; + +/// Compute deep-composition evaluations on device. +/// +/// `num_eval_points = trace_terms_gammas_interleaved.len() / ((num_main + +/// num_aux) * 3)`. The caller is responsible for packing each Vec +/// into interleaved u64 slices (`[a0, a1, a2, b0, b1, b2, ...]`). +#[allow(clippy::too_many_arguments)] +pub fn deep_composition_ext3( + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + // Host-side inputs (H2D'd internally) + h_parts_deinterleaved: &[u64], // num_parts * 3 * lde_stride u64 + h_ood: &[u64], // num_parts * 3 + trace_ood: &[u64], // num_total_cols * num_eval_points * 3 + gammas_h: &[u64], // num_parts * 3 + gammas_tr: &[u64], // num_total_cols * num_eval_points * 3 + inv_h: &[u64], // domain_size * 3 + inv_t: &[u64], // num_eval_points * domain_size * 3 + // Shape params + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Result> { + deep_composition_ext3_impl( + main_lde, + aux_lde, + None, + h_parts_deinterleaved, + h_ood, + trace_ood, + gammas_h, + gammas_tr, + inv_h, + inv_t, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) +} + +/// Same as [`deep_composition_ext3`] but reads the composition-parts LDE +/// from a device handle (`GpuLdeExt3`) populated by the R2 fused path, +/// skipping the `num_parts * 3 * lde_size * 8` byte H2D of +/// `h_parts_deinterleaved`. +#[allow(clippy::too_many_arguments)] +pub fn deep_composition_ext3_with_dev_parts( + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + h_parts_dev: &GpuLdeExt3, + h_ood: &[u64], + trace_ood: &[u64], + gammas_h: &[u64], + gammas_tr: &[u64], + inv_h: &[u64], + inv_t: &[u64], + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Result> { + deep_composition_ext3_impl( + main_lde, + aux_lde, + Some(h_parts_dev), + &[], + h_ood, + trace_ood, + gammas_h, + gammas_tr, + inv_h, + inv_t, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) +} + +#[allow(clippy::too_many_arguments)] +fn deep_composition_ext3_impl( + main_lde: &GpuLdeBase, + aux_lde: Option<&GpuLdeExt3>, + h_parts_dev: Option<&GpuLdeExt3>, + h_parts_host: &[u64], + h_ood: &[u64], + trace_ood: &[u64], + gammas_h: &[u64], + gammas_tr: &[u64], + inv_h: &[u64], + inv_t: &[u64], + num_parts: usize, + num_main: usize, + num_aux: usize, + num_eval_points: usize, + blowup_factor: usize, + domain_size: usize, +) -> Result> { + assert_eq!(main_lde.m, num_main); + if let Some(a) = aux_lde { + assert_eq!(a.m, num_aux); + assert_eq!(a.lde_size, main_lde.lde_size); + } else { + assert_eq!(num_aux, 0); + } + if let Some(h) = h_parts_dev { + assert_eq!(h.m, num_parts); + assert_eq!(h.lde_size, main_lde.lde_size); + } else { + assert_eq!(h_parts_host.len(), num_parts * 3 * main_lde.lde_size); + } + assert_eq!(h_ood.len(), num_parts * 3); + let num_total_cols = num_main + num_aux; + assert_eq!(trace_ood.len(), num_total_cols * num_eval_points * 3); + assert_eq!(gammas_h.len(), num_parts * 3); + assert_eq!(gammas_tr.len(), num_total_cols * num_eval_points * 3); + assert_eq!(inv_h.len(), domain_size * 3); + assert_eq!(inv_t.len(), num_eval_points * domain_size * 3); + + let be = backend()?; + let stream = be.next_stream(); + + // H2D only the scalar arrays. h_parts comes from a device handle + // when available. + let h_ood_dev = stream.clone_htod(h_ood)?; + let trace_ood_dev = stream.clone_htod(trace_ood)?; + let gammas_h_dev = stream.clone_htod(gammas_h)?; + let gammas_tr_dev = stream.clone_htod(gammas_tr)?; + let inv_h_dev = stream.clone_htod(inv_h)?; + let inv_t_dev = stream.clone_htod(inv_t)?; + + // Keep the owned H2D of h_lde alive until kernel completes. Only + // populated in the host-parts path. + let h_lde_host_dev; + + let mut deep_out = stream.alloc_zeros::(domain_size * 3)?; + + let dummy_aux; + let aux_slice = if let Some(a) = aux_lde { + a.buf.as_ref() + } else { + dummy_aux = stream.alloc_zeros::(1)?; + &dummy_aux + }; + + let h_lde_slice = if let Some(h) = h_parts_dev { + h.buf.as_ref() + } else { + h_lde_host_dev = stream.clone_htod(h_parts_host)?; + &h_lde_host_dev + }; + + let lde_stride = main_lde.lde_size as u64; + let num_main_u = num_main as u64; + let num_aux_u = num_aux as u64; + let num_parts_u = num_parts as u64; + let num_eval_points_u = num_eval_points as u64; + let blowup_u = blowup_factor as u64; + let domain_size_u = domain_size as u64; + + let grid = (domain_size as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.deep_composition_ext3_row) + .arg(main_lde.buf.as_ref()) + .arg(aux_slice) + .arg(h_lde_slice) + .arg(&lde_stride) + .arg(&num_main_u) + .arg(&num_aux_u) + .arg(&num_parts_u) + .arg(&num_eval_points_u) + .arg(&blowup_u) + .arg(&domain_size_u) + .arg(&h_ood_dev) + .arg(&trace_ood_dev) + .arg(&gammas_h_dev) + .arg(&gammas_tr_dev) + .arg(&inv_h_dev) + .arg(&inv_t_dev) + .arg(&mut deep_out) + .launch(cfg)?; + } + + let out = stream.clone_dtoh(&deep_out)?; + stream.synchronize()?; + Ok(out) +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 353932ba6..41bf973ae 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -94,6 +94,10 @@ const ARITH_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/arith.ptx")); const NTT_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/ntt.ptx")); const KECCAK_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/keccak.ptx")); const BARY_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/barycentric.ptx")); +const INVERSE_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/inverse.ptx")); +const DEEP_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/deep.ptx")); +const FRI_PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/fri.ptx")); + /// Number of CUDA streams in the pool. Larger pools let many rayon-parallel /// callers overlap on the GPU without serializing on stream ownership. The /// default stream is deliberately excluded because it synchronises with all @@ -150,6 +154,23 @@ pub struct Backend { pub barycentric_base_batched_strided: CudaFunction, pub barycentric_ext3_batched_strided: CudaFunction, + // inverse.ptx + pub compute_denoms_ext3: CudaFunction, + pub chunk_prefix_scan_ext3: CudaFunction, + pub exclusive_scan_of_totals_ext3: CudaFunction, + pub apply_scan_offsets_ext3: CudaFunction, + pub chunk_suffix_scan_ext3: CudaFunction, + pub exclusive_reverse_scan_of_totals_ext3: CudaFunction, + pub apply_reverse_scan_offsets_ext3: CudaFunction, + pub batch_inverse_combine_ext3: CudaFunction, + + // deep.ptx + pub deep_composition_ext3_row: CudaFunction, + + // fri.ptx + pub fri_fold_ext3: CudaFunction, + pub fri_update_twiddles: CudaFunction, + // Twiddle caches keyed by log_n. fwd_twiddles: Mutex>>>>, inv_twiddles: Mutex>>>>, @@ -168,6 +189,9 @@ impl Backend { let ntt = ctx.load_module(Ptx::from_src(NTT_PTX))?; let keccak = ctx.load_module(Ptx::from_src(KECCAK_PTX))?; let bary = ctx.load_module(Ptx::from_src(BARY_PTX))?; + let inverse = ctx.load_module(Ptx::from_src(INVERSE_PTX))?; + let deep = ctx.load_module(Ptx::from_src(DEEP_PTX))?; + let fri = ctx.load_module(Ptx::from_src(FRI_PTX))?; let mut streams = Vec::with_capacity(STREAM_POOL_SIZE); for _ in 0..STREAM_POOL_SIZE { @@ -226,6 +250,20 @@ impl Backend { .load_function("barycentric_base_batched_strided")?, barycentric_ext3_batched_strided: bary .load_function("barycentric_ext3_batched_strided")?, + compute_denoms_ext3: inverse.load_function("compute_denoms_ext3")?, + chunk_prefix_scan_ext3: inverse.load_function("chunk_prefix_scan_ext3")?, + exclusive_scan_of_totals_ext3: inverse + .load_function("exclusive_scan_of_totals_ext3")?, + apply_scan_offsets_ext3: inverse.load_function("apply_scan_offsets_ext3")?, + chunk_suffix_scan_ext3: inverse.load_function("chunk_suffix_scan_ext3")?, + exclusive_reverse_scan_of_totals_ext3: inverse + .load_function("exclusive_reverse_scan_of_totals_ext3")?, + apply_reverse_scan_offsets_ext3: inverse + .load_function("apply_reverse_scan_offsets_ext3")?, + batch_inverse_combine_ext3: inverse.load_function("batch_inverse_combine_ext3")?, + deep_composition_ext3_row: deep.load_function("deep_composition_ext3_row")?, + fri_fold_ext3: fri.load_function("fri_fold_ext3")?, + fri_update_twiddles: fri.load_function("fri_update_twiddles")?, fwd_twiddles: Mutex::new(vec![None; max_log]), inv_twiddles: Mutex::new(vec![None; max_log]), ctx, diff --git a/crypto/math-cuda/src/fri.rs b/crypto/math-cuda/src/fri.rs new file mode 100644 index 000000000..4dc88868a --- /dev/null +++ b/crypto/math-cuda/src/fri.rs @@ -0,0 +1,277 @@ +//! Fully-device-resident FRI commit phase orchestration. +//! +//! The host loop (in the stark crate) samples each layer's `zeta` from the +//! transcript and feeds it in; this module keeps the folded evaluations, +//! twiddles, and per-layer Merkle trees on device, only D2H'ing each +//! layer's root (to append to the transcript), plus its full evals and +//! tree nodes (to plug into `FriLayer` for the query phase). +//! +//! Mirrors `commit_phase_from_evaluations` at +//! `crypto/stark/src/fri/mod.rs`. + +use cudarc::driver::{CudaSlice, CudaStream, LaunchConfig, PushKernelArg}; +use std::sync::Arc; + +use crate::Result; +use crate::device::backend; + +/// Device-side state across FRI commit iterations. Owns two ext3 eval +/// buffers (flip-flopped as layer input / output) and the inv_twiddles +/// buffer. Freed when dropped. +pub struct FriCommitState { + pub stream: Arc, + // Ping-pong evaluation buffers. Both sized `3 * n0` u64 at init; each + // successive fold uses half the space. Cheap to pre-allocate vs. per- + // layer alloc. + evals_a: CudaSlice, + evals_b: CudaSlice, + /// Base-field inv_twiddles; `n0 / 2` u64 at init, halved each layer. + inv_tw: CudaSlice, + /// Number of ext3 elements currently in the "input" buffer. + pub current_n: usize, + /// Which buffer holds the current layer's input. Toggles each fold. + a_is_input: bool, +} + +impl FriCommitState { + /// H2D the starting evals (ext3 interleaved, 3 * n0 u64) and the + /// initial inv_twiddles (base field, n0/2 u64). `n0` must be a power of + /// two and >= 2. + pub fn new(evals_host: &[u64], inv_tw_host: &[u64], n0: usize) -> Result { + assert!(n0 >= 2 && n0.is_power_of_two()); + assert_eq!(evals_host.len(), 3 * n0); + assert_eq!(inv_tw_host.len(), n0 / 2); + + let be = backend()?; + let stream = be.next_stream(); + + // SAFETY: every byte of evals_a is overwritten by the H2D below. + // evals_b is written by the first fold before it is read. + let mut evals_a = unsafe { stream.alloc::(3 * n0) }?; + let evals_b = unsafe { stream.alloc::(3 * n0) }?; + stream.memcpy_htod(evals_host, &mut evals_a)?; + let inv_tw = stream.clone_htod(inv_tw_host)?; + + Ok(Self { + stream, + evals_a, + evals_b, + inv_tw, + current_n: n0, + a_is_input: true, + }) + } + + /// Fold the current layer using `zeta`, run the row-pair Keccak leaves + /// + pair-hash Merkle tree kernels on the result, and D2H: + /// - the new root (32 bytes) + /// - the new layer's evals (3 * (current_n / 2) u64s) + /// - the new layer's Merkle tree nodes (standard layout, byte-packed) + /// + /// Also updates `inv_twiddles` in place to shrink for the next layer. + pub fn fold_and_commit_layer( + &mut self, + zeta_raw: [u64; 3], + ) -> Result<(Vec, Vec, Vec)> { + let be = backend()?; + let n_in = self.current_n; + let n_out = n_in / 2; + // fold_final handles the n_out == 1 last layer (no Merkle commit). + assert!( + n_out >= 2, + "fold_and_commit_layer requires n_out >= 2; use fold_final" + ); + + // Row-pair leaves: each leaf hashes two consecutive ext3 evals. + let num_leaves = n_out / 2; + let tight_total_nodes = 2 * num_leaves - 1; + + // H2D zeta. + let zeta_dev = self.stream.clone_htod(&zeta_raw)?; + + let cfg = LaunchConfig { + grid_dim: ((n_out as u32).div_ceil(128), 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + let n_out_u64 = n_out as u64; + + if self.a_is_input { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_a) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_b) + .launch(cfg)?; + } + } else { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_b) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_a) + .launch(cfg)?; + } + } + + // Keccak leaves + pair-hash tree into fresh device buffer. + let mut nodes_dev = unsafe { self.stream.alloc::(tight_total_nodes * 32) }?; + let leaves_offset_bytes = (num_leaves - 1) * 32; + { + let mut leaves_view = + nodes_dev.slice_mut(leaves_offset_bytes..leaves_offset_bytes + num_leaves * 32); + let num_leaves_u64 = num_leaves as u64; + let grid = (num_leaves as u32).div_ceil(128); + let kcfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + // Leaves read from the layer's OUTPUT eval buffer. + if self.a_is_input { + unsafe { + self.stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&self.evals_b) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(kcfg)?; + } + } else { + unsafe { + self.stream + .launch_builder(&be.keccak_fri_leaves_ext3) + .arg(&self.evals_a) + .arg(&num_leaves_u64) + .arg(&mut leaves_view) + .launch(kcfg)?; + } + } + } + { + let mut level_begin: u64 = (num_leaves - 1) as u64; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + let grid = (n_pairs as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + self.stream + .launch_builder(&be.keccak_merkle_level) + .arg(&mut nodes_dev) + .arg(&new_begin) + .arg(&n_pairs) + .launch(cfg)?; + } + level_begin = new_begin; + } + } + + // Update inv_twiddles for the next layer: `new[j] = old[2j]^2` for + // j in 0..n_out/2. (If n_out == 1, skip; no next fold.) + let tw_next = n_out / 2; + if tw_next > 0 { + let grid = (tw_next as u32).div_ceil(128); + let cfg = LaunchConfig { + grid_dim: (grid, 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + let tw_next_u64 = tw_next as u64; + unsafe { + self.stream + .launch_builder(&be.fri_update_twiddles) + .arg(&mut self.inv_tw) + .arg(&tw_next_u64) + .launch(cfg)?; + } + } + + // Sync and D2H. + self.stream.synchronize()?; + + // Layer evals: 3 * n_out u64 from the output buffer. + let layer_evals: Vec = if self.a_is_input { + let view = self.evals_b.slice(0..3 * n_out); + self.stream.clone_dtoh(&view)? + } else { + let view = self.evals_a.slice(0..3 * n_out); + self.stream.clone_dtoh(&view)? + }; + + // Tree nodes. + let nodes_bytes: Vec = self.stream.clone_dtoh(&nodes_dev)?; + debug_assert_eq!(nodes_bytes.len(), tight_total_nodes * 32); + + let mut root = vec![0u8; 32]; + root.copy_from_slice(&nodes_bytes[0..32]); + + self.a_is_input = !self.a_is_input; + self.current_n = n_out; + + Ok((root, layer_evals, nodes_bytes)) + } + + /// Final fold, no Merkle commit. Returns the single ext3 output + /// element (the FRI last_value). + pub fn fold_final(&mut self, zeta_raw: [u64; 3]) -> Result<[u64; 3]> { + let be = backend()?; + let n_in = self.current_n; + let n_out = n_in / 2; + assert!(n_out >= 1); + + let zeta_dev = self.stream.clone_htod(&zeta_raw)?; + let cfg = LaunchConfig { + grid_dim: ((n_out as u32).div_ceil(128), 1, 1), + block_dim: (128, 1, 1), + shared_mem_bytes: 0, + }; + let n_out_u64 = n_out as u64; + + if self.a_is_input { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_a) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_b) + .launch(cfg)?; + } + } else { + unsafe { + self.stream + .launch_builder(&be.fri_fold_ext3) + .arg(&self.evals_b) + .arg(&n_out_u64) + .arg(&self.inv_tw) + .arg(&zeta_dev) + .arg(&mut self.evals_a) + .launch(cfg)?; + } + } + + self.stream.synchronize()?; + let out_first: Vec = if self.a_is_input { + let view = self.evals_b.slice(0..3); + self.stream.clone_dtoh(&view)? + } else { + let view = self.evals_a.slice(0..3); + self.stream.clone_dtoh(&view)? + }; + self.a_is_input = !self.a_is_input; + self.current_n = n_out; + Ok([out_first[0], out_first[1], out_first[2]]) + } +} diff --git a/crypto/math-cuda/src/inverse.rs b/crypto/math-cuda/src/inverse.rs new file mode 100644 index 000000000..bec38f940 --- /dev/null +++ b/crypto/math-cuda/src/inverse.rs @@ -0,0 +1,421 @@ +//! Parallel Montgomery batch inverse on the GPU for ext3 elements, plus +//! the R3 OOD / R4 DEEP `compute-denoms + invert` convenience fn. + +use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg}; + +use crate::Result; +use crate::device::backend; + +const SCAN_THREADS: u32 = 256; +const COMBINE_BLOCK: u32 = 256; + +/// Parallel batch inverse over ext3 elements. `a` is 3 * n u64s +/// (interleaved). Returns a fresh Vec with 3 * n inverses. +/// +/// Mirrors `FieldElement::inplace_batch_inverse` semantically; parity +/// is gated by the prove+verify round-trip in the stark test suite. +pub fn batch_inverse_ext3(a: &[u64]) -> Result> { + assert!(a.len().is_multiple_of(3)); + let n = a.len() / 3; + if n == 0 { + return Ok(Vec::new()); + } + if n == 1 { + // Below GPU break-even (one element). Invert on host via Fermat. + let inv = invert_ext3_host([a[0], a[1], a[2]]); + return Ok(inv.to_vec()); + } + + let be = backend()?; + let stream = be.next_stream(); + + // H2D input. + let a_dev = stream.clone_htod(a)?; + + // Scratch buffers. + let mut prefix_dev = stream.alloc_zeros::(n * 3)?; + let mut suffix_dev = stream.alloc_zeros::(n * 3)?; + + // Chunk sizing: SCAN_THREADS threads, one chunk per thread. + let k: u32 = SCAN_THREADS; + let c_per_thread: u64 = (n as u64).div_ceil(k as u64); + let mut chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + let n_u64 = n as u64; + let k_u64 = k as u64; + + // Phase 1: chunk prefix scan. + let cfg_scan = LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (k, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.chunk_prefix_scan_ext3) + .arg(&a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut prefix_dev) + .arg(&mut chunk_totals) + .launch(cfg_scan)?; + } + + // Phase 2: exclusive scan of chunk totals (single thread). + unsafe { + stream + .launch_builder(&be.exclusive_scan_of_totals_ext3) + .arg(&chunk_totals) + .arg(&k_u64) + .arg(&mut chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + + // Phase 3: apply offsets. + unsafe { + stream + .launch_builder(&be.apply_scan_offsets_ext3) + .arg(&mut prefix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&chunk_offsets) + .launch(cfg_scan)?; + } + + // Mirror for suffix. + let mut suffix_chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut suffix_chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + unsafe { + stream + .launch_builder(&be.chunk_suffix_scan_ext3) + .arg(&a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut suffix_dev) + .arg(&mut suffix_chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_reverse_scan_of_totals_ext3) + .arg(&suffix_chunk_totals) + .arg(&k_u64) + .arg(&mut suffix_chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_reverse_scan_offsets_ext3) + .arg(&mut suffix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&suffix_chunk_offsets) + .launch(cfg_scan)?; + } + + // Compute total = prefix[n-1], invert on host. + let total = { + let last_view = prefix_dev.slice((n - 1) * 3..n * 3); + let last_host: Vec = stream.clone_dtoh(&last_view)?; + stream.synchronize()?; + invert_ext3_host([last_host[0], last_host[1], last_host[2]]) + }; + let mut inv_total_dev = stream.alloc_zeros::(3)?; + stream.memcpy_htod(&total, &mut inv_total_dev)?; + + // Combine. + let mut out_dev = stream.alloc_zeros::(n * 3)?; + let cfg_combine = LaunchConfig { + grid_dim: ((n as u32).div_ceil(COMBINE_BLOCK), 1, 1), + block_dim: (COMBINE_BLOCK, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.batch_inverse_combine_ext3) + .arg(&prefix_dev) + .arg(&suffix_dev) + .arg(&inv_total_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg_combine)?; + } + + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Same as [`batch_inverse_ext3`] but the input is already on device +/// (typically from `compute_denoms_ext3`). Avoids one H2D round-trip. +pub fn batch_inverse_ext3_dev(a_dev: &CudaSlice, n: usize) -> Result> { + if n == 0 { + return Ok(Vec::new()); + } + let be = backend()?; + let stream = be.next_stream(); + + let mut prefix_dev = stream.alloc_zeros::(n * 3)?; + let mut suffix_dev = stream.alloc_zeros::(n * 3)?; + + let k: u32 = SCAN_THREADS; + let c_per_thread: u64 = (n as u64).div_ceil(k as u64); + let mut chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + let n_u64 = n as u64; + let k_u64 = k as u64; + + let cfg_scan = LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (k, 1, 1), + shared_mem_bytes: 0, + }; + + unsafe { + stream + .launch_builder(&be.chunk_prefix_scan_ext3) + .arg(a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut prefix_dev) + .arg(&mut chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_scan_of_totals_ext3) + .arg(&chunk_totals) + .arg(&k_u64) + .arg(&mut chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_scan_offsets_ext3) + .arg(&mut prefix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&chunk_offsets) + .launch(cfg_scan)?; + } + + let mut suffix_chunk_totals = stream.alloc_zeros::((k as usize) * 3)?; + let mut suffix_chunk_offsets = stream.alloc_zeros::((k as usize) * 3)?; + unsafe { + stream + .launch_builder(&be.chunk_suffix_scan_ext3) + .arg(a_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&mut suffix_dev) + .arg(&mut suffix_chunk_totals) + .launch(cfg_scan)?; + } + unsafe { + stream + .launch_builder(&be.exclusive_reverse_scan_of_totals_ext3) + .arg(&suffix_chunk_totals) + .arg(&k_u64) + .arg(&mut suffix_chunk_offsets) + .launch(LaunchConfig { + grid_dim: (1, 1, 1), + block_dim: (1, 1, 1), + shared_mem_bytes: 0, + })?; + } + unsafe { + stream + .launch_builder(&be.apply_reverse_scan_offsets_ext3) + .arg(&mut suffix_dev) + .arg(&n_u64) + .arg(&c_per_thread) + .arg(&suffix_chunk_offsets) + .launch(cfg_scan)?; + } + + let total = { + let last_view = prefix_dev.slice((n - 1) * 3..n * 3); + let last_host: Vec = stream.clone_dtoh(&last_view)?; + stream.synchronize()?; + invert_ext3_host([last_host[0], last_host[1], last_host[2]]) + }; + let mut inv_total_dev = stream.alloc_zeros::(3)?; + stream.memcpy_htod(&total, &mut inv_total_dev)?; + + let mut out_dev = stream.alloc_zeros::(n * 3)?; + let cfg_combine = LaunchConfig { + grid_dim: ((n as u32).div_ceil(COMBINE_BLOCK), 1, 1), + block_dim: (COMBINE_BLOCK, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.batch_inverse_combine_ext3) + .arg(&prefix_dev) + .arg(&suffix_dev) + .arg(&inv_total_dev) + .arg(&n_u64) + .arg(&mut out_dev) + .launch(cfg_combine)?; + } + + let out = stream.clone_dtoh(&out_dev)?; + stream.synchronize()?; + Ok(out) +} + +/// Compute `denoms[k*n + i] = x[i * stride] - z_scalars[k]` for all i, k, +/// then batch-invert in place. Fuses B.1 + B.2 to avoid an intermediate +/// D2H + H2D of the denominator array. +/// +/// `x_base` is the LDE coset (base-field, at least `n * stride` u64s). +/// `z_scalars` is `k * 3` u64s (ext3 interleaved). Returns `k * n * 3` +/// u64s (the inverted denoms), flat in k-major then i-major order. +pub fn compute_and_invert_denoms_ext3( + x_base: &[u64], + stride: usize, + z_scalars: &[u64], + k_scalars: usize, + n: usize, +) -> Result> { + assert!(x_base.len() >= n * stride); + assert_eq!(z_scalars.len(), k_scalars * 3); + let total = k_scalars * n; + + let be = backend()?; + let stream = be.next_stream(); + + let x_dev = stream.clone_htod(&x_base[..n * stride])?; + let z_dev = stream.clone_htod(z_scalars)?; + let mut denoms_dev = stream.alloc_zeros::(total * 3)?; + + let stride_u64 = stride as u64; + let n_u64 = n as u64; + let k_u64 = k_scalars as u64; + + // Compute denoms. + let cfg = LaunchConfig { + grid_dim: ((total as u32).div_ceil(256), 1, 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.compute_denoms_ext3) + .arg(&x_dev) + .arg(&stride_u64) + .arg(&z_dev) + .arg(&k_u64) + .arg(&n_u64) + .arg(&mut denoms_dev) + .launch(cfg)?; + } + stream.synchronize()?; + + // Batch-invert in place (reuses the device buffer). + batch_inverse_ext3_dev(&denoms_dev, total) +} + +// ============================================================================= +// Host-side ext3 inverse (used once, for the total of the GPU prefix product). +// ============================================================================= + +const GOLDILOCKS_P: u128 = (1u128 << 64) - (1u128 << 32) + 1; + +fn gl_mul(a: u64, b: u64) -> u64 { + let prod = (a as u128) * (b as u128); + (prod % GOLDILOCKS_P) as u64 +} + +fn gl_add(a: u64, b: u64) -> u64 { + let s = (a as u128) + (b as u128); + (s % GOLDILOCKS_P) as u64 +} + +fn gl_sub(a: u64, b: u64) -> u64 { + let a128 = a as u128; + let b128 = b as u128; + if a128 >= b128 { + ((a128 - b128) % GOLDILOCKS_P) as u64 + } else { + (((GOLDILOCKS_P - b128) + a128) % GOLDILOCKS_P) as u64 + } +} + +fn gl_pow(mut base: u64, mut exp: u64) -> u64 { + let mut acc: u64 = 1; + while exp != 0 { + if exp & 1 != 0 { + acc = gl_mul(acc, base); + } + base = gl_mul(base, base); + exp >>= 1; + } + acc +} + +fn gl_inv(a: u64) -> u64 { + // Fermat: a^{p-2} + gl_pow(a, GOLDILOCKS_P as u64 - 2) +} + +/// Invert one ext3 element on the host. Used once per batch inverse to +/// invert the total product; the main batch inverse work stays on GPU. +fn invert_ext3_host(x: [u64; 3]) -> [u64; 3] { + // x = a + b*w + c*w^2 where w^3 = 2. + // Compute x^{-1} using the extension field's norm: + // norm(x) = x * x_conj1 * x_conj2 (where conjugates are Frobenius images) + // For Fp[w]/(w^3-2) over Fp, the norm lives in Fp. + // + // Simpler: do the full ext3 multiplication inverse via + // classical adjugate over Fp[w]. + // + // Use the closed-form adjugate for degree-3 extension: + // Let x = (a, b, c) representing a + b*w + c*w^2 + // Then x^{-1} = (d, e, f) / N + // where (Newton's identities / cofactor method): + // d = a^2 - 2*b*c + // e = 2*c^2 - a*b + // f = b^2 - a*c + // N = a*d + 2*b*f + 2*c*e + // + // (This matches the cpu `Degree3GoldilocksExtensionField::inv`.) + let a = x[0]; + let b = x[1]; + let c = x[2]; + + let bc = gl_mul(b, c); + let d = gl_sub(gl_mul(a, a), gl_add(bc, bc)); // a^2 - 2bc + let cc = gl_mul(c, c); + let ab = gl_mul(a, b); + let e = gl_sub(gl_add(cc, cc), ab); // 2c^2 - ab + let bb = gl_mul(b, b); + let ac = gl_mul(a, c); + let f = gl_sub(bb, ac); // b^2 - ac + + let ad = gl_mul(a, d); + let bf = gl_mul(b, f); + let ce = gl_mul(c, e); + let two_bf = gl_add(bf, bf); + let two_ce = gl_add(ce, ce); + let norm = gl_add(ad, gl_add(two_bf, two_ce)); + + let inv_norm = gl_inv(norm); + [ + gl_mul(d, inv_norm), + gl_mul(e, inv_norm), + gl_mul(f, inv_norm), + ] +} diff --git a/crypto/math-cuda/src/lib.rs b/crypto/math-cuda/src/lib.rs index a5bb8defb..a7a6ca583 100644 --- a/crypto/math-cuda/src/lib.rs +++ b/crypto/math-cuda/src/lib.rs @@ -5,7 +5,10 @@ //! parity test suite. pub mod barycentric; +pub mod deep; pub mod device; +pub mod fri; +pub mod inverse; pub mod lde; pub mod merkle; pub mod ntt; diff --git a/crypto/math-cuda/tests/batch_inverse.rs b/crypto/math-cuda/tests/batch_inverse.rs new file mode 100644 index 000000000..5edaea876 --- /dev/null +++ b/crypto/math-cuda/tests/batch_inverse.rs @@ -0,0 +1,93 @@ +//! Parity: GPU parallel batch inverse matches CPU +//! `FieldElement::inplace_batch_inverse` on ext3 elements. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math_cuda::inverse::batch_inverse_ext3; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + loop { + let v = rng.r#gen::(); + // Avoid zero. Batch inverse requires all non-zero. + if v != 0 { + return Fp::from_raw(v); + } + } +} +fn rand_fp3_nonzero(rng: &mut ChaCha8Rng) -> Fp3 { + // Random non-zero ext3: at least one component non-zero, all in [1, p). + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn canon3(a: &[u64]) -> Vec { + a.iter() + .enumerate() + .map(|(i, v)| { + // Each u64 is canonicalised independently (ext3 = 3 base coords). + let _ = i; + GoldilocksField::canonical(v) + }) + .collect() +} + +fn run(n: usize, seed: u64) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let xs: Vec = (0..n).map(|_| rand_fp3_nonzero(&mut rng)).collect(); + + // CPU reference: inplace_batch_inverse. + let mut cpu = xs.clone(); + FieldElement::inplace_batch_inverse(&mut cpu).expect("batch inverse non-zero"); + + // GPU. + let input_u64 = ext3_to_u64s(&xs); + let gpu_u64 = batch_inverse_ext3(&input_u64).unwrap(); + + let cpu_u64 = ext3_to_u64s(&cpu); + let gpu_canon = canon3(&gpu_u64); + let cpu_canon = canon3(&cpu_u64); + + for i in 0..n { + let g = &gpu_canon[i * 3..(i + 1) * 3]; + let c = &cpu_canon[i * 3..(i + 1) * 3]; + assert_eq!(g, c, "mismatch at i={i} n={n}"); + } +} + +#[test] +fn batch_inverse_small() { + for n in [2usize, 3, 5, 16, 63, 255, 256, 257] { + run(n, 100 + n as u64); + } +} + +#[test] +fn batch_inverse_medium() { + for n in [1024usize, 4096, 8192] { + run(n, 500 + n as u64); + } +} + +#[test] +fn batch_inverse_large() { + // Matches R3 OOD / R4 DEEP sizes for fib_1M (domain_size = 2^18, + // num_denoms_max = 2^18 * 4). + run(1 << 18, 999); + run(1 << 20, 12345); +} diff --git a/crypto/math-cuda/tests/deep.rs b/crypto/math-cuda/tests/deep.rs new file mode 100644 index 000000000..1c5609261 --- /dev/null +++ b/crypto/math-cuda/tests/deep.rs @@ -0,0 +1,303 @@ +//! Parity: GPU deep_composition_ext3 vs a direct CPU port of the same +//! row-wise summation. Uses random inputs, not the full stark LDE path. + +use std::sync::Arc; + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::IsPrimeField; +use math_cuda::deep::deep_composition_ext3; +use math_cuda::device::backend; +use math_cuda::lde::{GpuLdeBase, GpuLdeExt3}; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} +fn rand_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +fn ext3_to_raw(e: &Fp3) -> [u64; 3] { + [ + *e.value()[0].value(), + *e.value()[1].value(), + *e.value()[2].value(), + ] +} + +fn canon3(e: &Fp3) -> [u64; 3] { + [ + GoldilocksField::canonical(e.value()[0].value()), + GoldilocksField::canonical(e.value()[1].value()), + GoldilocksField::canonical(e.value()[2].value()), + ] +} + +/// CPU reference: exact port of `compute_deep_composition_poly_evaluations`. +#[allow(clippy::too_many_arguments)] +fn cpu_deep( + main_lde: &[Vec], // num_main cols * lde_size + aux_lde: &[Vec], // num_aux cols * lde_size + h_lde: &[Vec], // num_parts * lde_size + h_ood: &[Fp3], // num_parts + trace_ood: &[Vec], // num_total_cols * num_eval_points + gammas_h: &[Fp3], // num_parts + gammas_tr: &[Vec], // num_total_cols * num_eval_points + inv_h: &[Fp3], // domain_size + inv_t: &[Vec], // num_eval_points * domain_size + blowup_factor: usize, + domain_size: usize, +) -> Vec { + let num_parts = h_lde.len(); + let num_main = main_lde.len(); + let num_aux = aux_lde.len(); + let num_eval_points = if trace_ood.is_empty() { + 0 + } else { + trace_ood[0].len() + }; + + (0..domain_size) + .map(|i| { + let row = i * blowup_factor; + let mut result = Fp3::zero(); + // H-terms + for j in 0..num_parts { + let num = &h_lde[j][row] - &h_ood[j]; + result += &gammas_h[j] * &num * &inv_h[i]; + } + // Main + for j in 0..num_main { + for k in 0..num_eval_points { + let t_val = &main_lde[j][row]; + let t_ood = &trace_ood[j][k]; + let num = t_val - t_ood; // base - ext3 = ext3 + result += &gammas_tr[j][k] * &num * &inv_t[k][i]; + } + } + // Aux + for (j, aux_col) in aux_lde.iter().enumerate().take(num_aux) { + let trace_j = num_main + j; + for k in 0..num_eval_points { + let t_val = &aux_col[row]; + let t_ood = &trace_ood[trace_j][k]; + let num = t_val - t_ood; + result += &gammas_tr[trace_j][k] * &num * &inv_t[k][i]; + } + } + result + }) + .collect() +} + +fn run_parity( + log_domain_size: u32, + blowup_factor: usize, + num_main: usize, + num_aux: usize, + num_parts: usize, + num_eval_points: usize, + seed: u64, +) { + let domain_size = 1usize << log_domain_size; + let lde_size = domain_size * blowup_factor; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + + let main_lde: Vec> = (0..num_main) + .map(|_| (0..lde_size).map(|_| rand_fp(&mut rng)).collect()) + .collect(); + let aux_lde: Vec> = (0..num_aux) + .map(|_| (0..lde_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let h_lde: Vec> = (0..num_parts) + .map(|_| (0..lde_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let h_ood: Vec = (0..num_parts).map(|_| rand_fp3(&mut rng)).collect(); + let num_total_cols = num_main + num_aux; + let trace_ood: Vec> = (0..num_total_cols) + .map(|_| (0..num_eval_points).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let gammas_h: Vec = (0..num_parts).map(|_| rand_fp3(&mut rng)).collect(); + let gammas_tr: Vec> = (0..num_total_cols) + .map(|_| (0..num_eval_points).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + let inv_h: Vec = (0..domain_size).map(|_| rand_fp3(&mut rng)).collect(); + let inv_t: Vec> = (0..num_eval_points) + .map(|_| (0..domain_size).map(|_| rand_fp3(&mut rng)).collect()) + .collect(); + + // CPU reference. + let cpu_out = cpu_deep( + &main_lde, + &aux_lde, + &h_lde, + &h_ood, + &trace_ood, + &gammas_h, + &gammas_tr, + &inv_h, + &inv_t, + blowup_factor, + domain_size, + ); + + // GPU: upload main & aux LDEs into device buffers and wrap in handles. + let be = backend().unwrap(); + let stream = be.next_stream(); + + // main_lde to col-major u64: m * lde_size + let mut main_flat = vec![0u64; num_main * lde_size]; + for (c, col) in main_lde.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + main_flat[c * lde_size + r] = *v.value(); + } + } + let main_dev = stream.clone_htod(&main_flat).unwrap(); + + // aux_lde to de-interleaved: (m*3) * lde_size + let mut aux_flat = vec![0u64; num_aux * 3 * lde_size]; + for (c, col) in aux_lde.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + let [a, b, c0] = ext3_to_raw(v); + aux_flat[(c * 3) * lde_size + r] = a; + aux_flat[(c * 3 + 1) * lde_size + r] = b; + aux_flat[(c * 3 + 2) * lde_size + r] = c0; + } + } + let aux_dev = stream.clone_htod(&aux_flat).unwrap(); + stream.synchronize().unwrap(); + + let main_handle = GpuLdeBase { + buf: Arc::new(main_dev), + m: num_main, + lde_size, + }; + let aux_handle = if num_aux > 0 { + Some(GpuLdeExt3 { + buf: Arc::new(aux_dev), + m: num_aux, + lde_size, + }) + } else { + drop(aux_dev); + None + }; + + // h_parts to de-interleaved: num_parts*3 * lde_size + let mut h_flat = vec![0u64; num_parts * 3 * lde_size]; + for (p, col) in h_lde.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + let [a, b, c0] = ext3_to_raw(v); + h_flat[(p * 3) * lde_size + r] = a; + h_flat[(p * 3 + 1) * lde_size + r] = b; + h_flat[(p * 3 + 2) * lde_size + r] = c0; + } + } + + let mut h_ood_flat = vec![0u64; num_parts * 3]; + for (j, e) in h_ood.iter().enumerate() { + let [a, b, c] = ext3_to_raw(e); + h_ood_flat[j * 3] = a; + h_ood_flat[j * 3 + 1] = b; + h_ood_flat[j * 3 + 2] = c; + } + let mut trace_ood_flat = vec![0u64; num_total_cols * num_eval_points * 3]; + for (j, col) in trace_ood.iter().enumerate() { + for (k, e) in col.iter().enumerate() { + let idx = (j * num_eval_points + k) * 3; + let [a, b, c] = ext3_to_raw(e); + trace_ood_flat[idx] = a; + trace_ood_flat[idx + 1] = b; + trace_ood_flat[idx + 2] = c; + } + } + let mut gammas_h_flat = vec![0u64; num_parts * 3]; + for (j, e) in gammas_h.iter().enumerate() { + let [a, b, c] = ext3_to_raw(e); + gammas_h_flat[j * 3] = a; + gammas_h_flat[j * 3 + 1] = b; + gammas_h_flat[j * 3 + 2] = c; + } + let mut gammas_tr_flat = vec![0u64; num_total_cols * num_eval_points * 3]; + for (j, col) in gammas_tr.iter().enumerate() { + for (k, e) in col.iter().enumerate() { + let idx = (j * num_eval_points + k) * 3; + let [a, b, c] = ext3_to_raw(e); + gammas_tr_flat[idx] = a; + gammas_tr_flat[idx + 1] = b; + gammas_tr_flat[idx + 2] = c; + } + } + let mut inv_h_flat = vec![0u64; domain_size * 3]; + for (i, e) in inv_h.iter().enumerate() { + let [a, b, c] = ext3_to_raw(e); + inv_h_flat[i * 3] = a; + inv_h_flat[i * 3 + 1] = b; + inv_h_flat[i * 3 + 2] = c; + } + let mut inv_t_flat = vec![0u64; num_eval_points * domain_size * 3]; + for (k, layer) in inv_t.iter().enumerate() { + for (i, e) in layer.iter().enumerate() { + let idx = (k * domain_size + i) * 3; + let [a, b, c] = ext3_to_raw(e); + inv_t_flat[idx] = a; + inv_t_flat[idx + 1] = b; + inv_t_flat[idx + 2] = c; + } + } + + let gpu_raw = deep_composition_ext3( + &main_handle, + aux_handle.as_ref(), + &h_flat, + &h_ood_flat, + &trace_ood_flat, + &gammas_h_flat, + &gammas_tr_flat, + &inv_h_flat, + &inv_t_flat, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_factor, + domain_size, + ) + .unwrap(); + + for i in 0..domain_size { + let gpu = [gpu_raw[i * 3], gpu_raw[i * 3 + 1], gpu_raw[i * 3 + 2]]; + let gpu_canon = [ + GoldilocksField::canonical(&gpu[0]), + GoldilocksField::canonical(&gpu[1]), + GoldilocksField::canonical(&gpu[2]), + ]; + let cpu_canon = canon3(&cpu_out[i]); + assert_eq!( + gpu_canon, cpu_canon, + "row {i} mismatch at log_ds={log_domain_size} main={num_main} aux={num_aux} parts={num_parts}" + ); + } +} + +#[test] +fn deep_parity_small() { + run_parity(4, 2, 3, 2, 2, 1, 100); + run_parity(6, 4, 5, 3, 2, 2, 200); +} + +#[test] +fn deep_parity_medium() { + run_parity(10, 2, 10, 5, 4, 3, 1000); +} + +#[test] +fn deep_parity_no_aux() { + run_parity(8, 2, 5, 0, 2, 2, 5000); +} diff --git a/crypto/math-cuda/tests/fri_layer_tree.rs b/crypto/math-cuda/tests/fri_layer_tree.rs new file mode 100644 index 000000000..229ea2bb6 --- /dev/null +++ b/crypto/math-cuda/tests/fri_layer_tree.rs @@ -0,0 +1,111 @@ +//! Parity: GPU `build_fri_layer_tree_from_evals_ext3` vs CPU +//! `FriLayerMerkleTree::build` (PairKeccak256 backend over ext3 pairs). + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::traits::ByteConversion; +use math_cuda::merkle::build_fri_layer_tree_from_evals_ext3; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use sha3::{Digest, Keccak256}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + Fp::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +fn cpu_hash_pair_bytes(a: &Fp3, b: &Fp3) -> [u8; 32] { + let mut buf = [0u8; 48]; + a.write_bytes_be(&mut buf[0..24]); + b.write_bytes_be(&mut buf[24..48]); + let mut h = Keccak256::new(); + h.update(buf); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +fn cpu_hash_pair_nodes(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] { + let mut h = Keccak256::new(); + h.update(left); + h.update(right); + let mut out = [0u8; 32]; + out.copy_from_slice(&h.finalize()); + out +} + +fn cpu_fri_layer_nodes(evals: &[Fp3]) -> Vec<[u8; 32]> { + let num_leaves = evals.len() / 2; + assert!(num_leaves.is_power_of_two() && num_leaves >= 1); + let total = 2 * num_leaves - 1; + let mut nodes: Vec<[u8; 32]> = vec![[0u8; 32]; total]; + for j in 0..num_leaves { + nodes[num_leaves - 1 + j] = cpu_hash_pair_bytes(&evals[2 * j], &evals[2 * j + 1]); + } + let mut level_begin = num_leaves - 1; + while level_begin != 0 { + let new_begin = level_begin / 2; + let n_pairs = level_begin - new_begin; + for k in 0..n_pairs { + let l = nodes[level_begin + 2 * k]; + let r = nodes[level_begin + 2 * k + 1]; + nodes[new_begin + k] = cpu_hash_pair_nodes(&l, &r); + } + level_begin = new_begin; + } + nodes +} + +fn run_parity(log_num_leaves: u32, seed: u64) { + let num_leaves = 1usize << log_num_leaves; + let num_evals = num_leaves * 2; + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let evals: Vec = (0..num_evals).map(|_| rand_ext3(&mut rng)).collect(); + let evals_u64 = ext3_to_u64s(&evals); + + let cpu_nodes = cpu_fri_layer_nodes(&evals); + let gpu_bytes = build_fri_layer_tree_from_evals_ext3(&evals_u64).unwrap(); + + assert_eq!(cpu_nodes.len() * 32, gpu_bytes.len()); + for i in 0..cpu_nodes.len() { + let g = &gpu_bytes[i * 32..(i + 1) * 32]; + let c = &cpu_nodes[i]; + assert_eq!(g, c, "node {i} mismatch at log_num_leaves={log_num_leaves}"); + } +} + +#[test] +fn fri_layer_tree_small() { + for log in 1u32..=6 { + run_parity(log, 100 + log as u64); + } +} + +#[test] +fn fri_layer_tree_medium() { + for log in [10u32, 12, 14] { + run_parity(log, 500 + log as u64); + } +} + +#[test] +fn fri_layer_tree_large() { + run_parity(18, 9999); +} diff --git a/crypto/stark/src/fri/mod.rs b/crypto/stark/src/fri/mod.rs index bbb988bd1..9325913f1 100644 --- a/crypto/stark/src/fri/mod.rs +++ b/crypto/stark/src/fri/mod.rs @@ -18,7 +18,10 @@ use self::fri_functions::{ /// FRI commit phase from pre-computed bit-reversed evaluations, skipping the /// initial FFT. Use this when the caller already has the evaluation vector /// (e.g. from a fused LDE pipeline). -pub fn commit_phase_from_evaluations, E: IsField>( +pub fn commit_phase_from_evaluations< + F: IsFFTField + IsSubFieldOf + 'static, + E: IsField + 'static, +>( number_layers: usize, mut evals: Vec>, transcript: &mut impl IsStarkTranscript, @@ -32,6 +35,24 @@ where FieldElement: AsBytes + Sync + Send, FieldElement: AsBytes + Sync + Send, { + // GPU fast path: drives the entire commit phase device-side (per-layer + // fold + Keccak leaves + pair-hash tree, only D2H'ing each layer's root + // + evals + nodes for FriLayer construction). Falls back to the CPU + // loop below on any precondition miss; a cudarc failure mid-loop panics + // since the transcript is already advanced (see `try_fri_commit_gpu`). + #[cfg(feature = "cuda")] + { + if let Some(result) = crate::gpu_lde::try_fri_commit_gpu::( + number_layers, + &evals, + transcript, + coset_offset, + domain_size, + ) { + return result; + } + } + // Inverse twiddle factors for evaluation-form folding. let mut inv_twiddles = compute_coset_twiddles_inv(coset_offset, domain_size); diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 530e2f6b9..522e8d21f 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -11,14 +11,19 @@ use std::slice::{from_raw_parts, from_raw_parts_mut}; use std::sync::OnceLock; use std::sync::atomic::{AtomicU64, Ordering}; +use crypto::fiat_shamir::is_transcript::IsStarkTranscript; use crypto::merkle_tree::merkle::MerkleTree; use crypto::merkle_tree::traits::IsMerkleTreeBackend; use math::field::element::FieldElement; use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; use math::field::goldilocks::GoldilocksField; use math::field::traits::{IsFFTField, IsField, IsSubFieldOf}; +use math::traits::AsBytes; +use crate::config::FriLayerMerkleTreeBackend; use crate::domain::Domain; +use crate::fri::fri_commitment::FriLayer; +use crate::fri::fri_functions::compute_coset_twiddles_inv; use crate::trace::LDETraceTable; /// Break-even LDE size. For LDE sizes smaller than this, the CPU @@ -63,6 +68,8 @@ pub fn reset_all_gpu_call_counters() { GPU_PARTS_LDE_CALLS.store(0, Ordering::Relaxed); GPU_BARY_CALLS.store(0, Ordering::Relaxed); GPU_COMP_POLY_TREE_CALLS.store(0, Ordering::Relaxed); + GPU_DEEP_CALLS.store(0, Ordering::Relaxed); + GPU_FRI_CALLS.store(0, Ordering::Relaxed); } pub(crate) static GPU_EXTEND_HALVES_CALLS: AtomicU64 = AtomicU64::new(0); @@ -614,7 +621,7 @@ pub fn gpu_merkle_tree_calls() -> u64 { // ============================================================================ /// R2 dispatch counter: incremented once per -/// [`try_evaluate_parts_on_lde_gpu`] call that actually routed to the GPU. +/// [`try_evaluate_parts_on_lde_gpu_keep`] call that actually routed to the GPU. pub(crate) static GPU_PARTS_LDE_CALLS: AtomicU64 = AtomicU64::new(0); pub fn gpu_parts_lde_calls() -> u64 { GPU_PARTS_LDE_CALLS.load(Ordering::Relaxed) @@ -718,102 +725,10 @@ where out } -/// R2 GPU dispatch: batched ext3 LDE over `parts_coefs` (composition-poly -/// coefficient parts). Returns the LDE evaluations as `Vec>>` -/// of length `lde_size` per part on success, `None` to fall through to the CPU -/// path. Used by `round_2_compute_composition_polynomial` in the -/// `number_of_parts > 2` branch. -/// -/// Inputs are immutable (`&[&[FieldElement]]`) and outputs are fresh, so -/// there is no `restore_columns_on_err` needed. `Err` just returns `None` -/// and the caller's coefficient slices are left untouched. -pub(crate) fn try_evaluate_parts_on_lde_gpu( - parts_coefs: &[&[FieldElement]], - blowup_factor: usize, - domain_size: usize, - offset: &FieldElement, -) -> Option>>> -where - F: IsFFTField + IsField + IsSubFieldOf + 'static, - E: IsField + 'static, -{ - if parts_coefs.is_empty() { - return Some(Vec::new()); - } - if !domain_size.is_power_of_two() || !blowup_factor.is_power_of_two() { - return None; - } - let lde_size = domain_size.checked_mul(blowup_factor)?; - if lde_size < gpu_lde_threshold() { - return None; - } - if TypeId::of::() != TypeId::of::() { - return None; - } - if TypeId::of::() != TypeId::of::() { - return None; - } - let m = parts_coefs.len(); - - // Weights: `offset^k` for k in 0..domain_size. F is Goldilocks by check above. - let mut weights_u64 = Vec::with_capacity(domain_size); - let mut w = FieldElement::::one(); - for _ in 0..domain_size { - // SAFETY: F == Goldilocks per TypeId check. FieldElement is - // #[repr(transparent)] over u64. - let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; - weights_u64.push(v); - w *= offset; - } - - // Pack parts into per-part `3 * domain_size` u64 buffers (zero-padded). - let mut part_bufs: Vec> = Vec::with_capacity(m); - for part in parts_coefs.iter() { - let mut buf = vec![0u64; 3 * domain_size]; - let len = part.len().min(domain_size); - // SAFETY: E == Ext3; backing is `[FieldElement; 3]` = `[u64; 3]`. - let src_ptr = part.as_ptr() as *const u64; - let src_len = len.checked_mul(3).expect("part src len overflow"); - let src = unsafe { from_raw_parts(src_ptr, src_len) }; - buf[..src_len].copy_from_slice(src); - part_bufs.push(buf); - } - let input_slices: Vec<&[u64]> = part_bufs.iter().map(|v| v.as_slice()).collect(); - - let mut outputs: Vec>> = (0..m) - .map(|_| vec![FieldElement::::zero(); lde_size]) - .collect(); - let gpu_result = { - let mut out_slices: Vec<&mut [u64]> = outputs - .iter_mut() - .map(|o| { - let ptr = o.as_mut_ptr() as *mut u64; - let byte_len = lde_size.checked_mul(3).expect("ext3 out len overflow"); - // SAFETY: E == Ext3 per TypeId check; Vec> of - // length `lde_size` is layout-equivalent to `[u64; 3 * lde_size]`. - unsafe { from_raw_parts_mut(ptr, byte_len) } - }) - .collect(); - math_cuda::lde::evaluate_poly_coset_batch_ext3_into( - &input_slices, - domain_size, - blowup_factor, - &weights_u64, - &mut out_slices, - ) - }; - if gpu_result.is_err() { - // Outputs are local and dropped on return; caller's inputs are - // immutable, so no restore is needed. - return None; - } - GPU_PARTS_LDE_CALLS.fetch_add(1, Ordering::Relaxed); - Some(outputs) -} - /// R2 GPU dispatch: build the composition-polynomial Merkle tree from the -/// host-side ext3 LDE eval Vecs produced by [`try_evaluate_parts_on_lde_gpu`] -/// (or the CPU path). Uses the same row-pair leaf pattern as the CPU +/// host-side ext3 LDE eval Vecs produced by +/// [`try_evaluate_parts_on_lde_gpu_keep`] (or the CPU path). Uses the same +/// row-pair leaf pattern as the CPU /// `commit_composition_polynomial`: each leaf hashes 2 consecutive /// bit-reversed rows. /// @@ -1003,3 +918,427 @@ where let scalar = ood_ext3_scalar::(coset_offset_pow_n, n_inv, g_n_inv, z_pow_n); Some(apply_ext3_scalar::(&sums_raw, scalar, num_cols)) } + +// ============================================================================ +// R2 keep-handle variant, R4 DEEP composition, FRI commit dispatches +// ============================================================================ + +/// R4 DEEP-composition dispatch counter. +pub(crate) static GPU_DEEP_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_deep_calls() -> u64 { + GPU_DEEP_CALLS.load(Ordering::Relaxed) +} + +/// FRI commit-phase dispatch counter (one per `try_fri_commit_gpu` call, +/// not per layer). +pub(crate) static GPU_FRI_CALLS: AtomicU64 = AtomicU64::new(0); +pub fn gpu_fri_calls() -> u64 { + GPU_FRI_CALLS.load(Ordering::Relaxed) +} + +/// R2 GPU dispatch: batched ext3 LDE over `parts_coefs` (composition-poly +/// coefficient parts). Returns both the host LDE eval Vecs (needed for the +/// R2 Merkle commit and R3 OOD path) and a device-resident `GpuLdeExt3` +/// handle to the same de-interleaved buffer, so R4 DEEP can skip the +/// `num_parts * 3 * lde_size * 8` byte H2D. +pub(crate) fn try_evaluate_parts_on_lde_gpu_keep( + parts_coefs: &[&[FieldElement]], + blowup_factor: usize, + domain_size: usize, + offset: &FieldElement, +) -> Option<(Vec>>, math_cuda::lde::GpuLdeExt3)> +where + F: IsFFTField + IsField + IsSubFieldOf + 'static, + E: IsField + 'static, +{ + if parts_coefs.is_empty() { + return None; + } + if !domain_size.is_power_of_two() || !blowup_factor.is_power_of_two() { + return None; + } + let lde_size = domain_size.checked_mul(blowup_factor)?; + if lde_size < gpu_lde_threshold() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + let m = parts_coefs.len(); + + let mut weights_u64 = Vec::with_capacity(domain_size); + let mut w = FieldElement::::one(); + for _ in 0..domain_size { + // SAFETY: F == Goldilocks per TypeId check. + let v: u64 = unsafe { *(w.value() as *const _ as *const u64) }; + weights_u64.push(v); + w *= offset; + } + + let mut part_bufs: Vec> = Vec::with_capacity(m); + for part in parts_coefs.iter() { + let mut buf = vec![0u64; 3 * domain_size]; + let len = part.len().min(domain_size); + // SAFETY: E == Ext3 per TypeId check; backing is `[u64; 3]`. + let src_ptr = part.as_ptr() as *const u64; + let src_len = len.checked_mul(3).expect("part src len overflow"); + let src = unsafe { from_raw_parts(src_ptr, src_len) }; + buf[..src_len].copy_from_slice(src); + part_bufs.push(buf); + } + let input_slices: Vec<&[u64]> = part_bufs.iter().map(|v| v.as_slice()).collect(); + + let mut outputs: Vec>> = (0..m) + .map(|_| vec![FieldElement::::zero(); lde_size]) + .collect(); + let handle_result = { + let mut out_slices: Vec<&mut [u64]> = outputs + .iter_mut() + .map(|o| { + let ptr = o.as_mut_ptr() as *mut u64; + let byte_len = lde_size.checked_mul(3).expect("ext3 out len overflow"); + // SAFETY: E == Ext3; Vec of `lde_size` ext3 is `3*lde_size` u64s. + unsafe { from_raw_parts_mut(ptr, byte_len) } + }) + .collect(); + math_cuda::lde::evaluate_poly_coset_batch_ext3_into_keep( + &input_slices, + domain_size, + blowup_factor, + &weights_u64, + &mut out_slices, + ) + }; + let handle = match handle_result { + Ok(h) => h, + Err(_) => return None, + }; + GPU_PARTS_LDE_CALLS.fetch_add(1, Ordering::Relaxed); + Some((outputs, handle)) +} + +/// Reinterpret a slice of ext3 `FieldElement`s as a raw `&[u64]` of length +/// `3 * col.len()`. Caller must have established `E == Ext3` (TypeId check). +/// +/// SAFETY: `E == Degree3GoldilocksExtensionField` so each element is +/// `[FieldElement; 3]` = `[u64; 3]`. +unsafe fn ext3_slice_to_u64(col: &[FieldElement]) -> &[u64] { + let len = col.len().checked_mul(3).expect("ext3 u64 len overflow"); + let ptr = col.as_ptr() as *const u64; + unsafe { from_raw_parts(ptr, len) } +} + +/// Convert ext3 evals (3*n u64s, interleaved) into a freshly allocated +/// `Vec>` of length `n`. Caller must have established +/// `E == Ext3`. +fn u64_to_ext3_vec(raw: &[u64]) -> Vec> +where + E: IsField + 'static, +{ + type Gl = GoldilocksField; + type Ext3 = Degree3GoldilocksExtensionField; + assert_eq!(TypeId::of::(), TypeId::of::()); + assert!(raw.len().is_multiple_of(3)); + let n = raw.len() / 3; + let mut out: Vec> = Vec::with_capacity(n); + for i in 0..n { + let v: FieldElement = FieldElement::::new([ + FieldElement::::from_raw(raw[i * 3]), + FieldElement::::from_raw(raw[i * 3 + 1]), + FieldElement::::from_raw(raw[i * 3 + 2]), + ]); + // SAFETY: TypeId-checked above. E == Ext3, identical layout. + out.push(unsafe { transmute_copy::, FieldElement>(&v) }); + } + out +} + +/// R4 GPU dispatch: per-row DEEP composition over the full LDE domain. +/// Reuses the device-resident main + (optional) aux LDE handles from R1 +/// and, when supplied, the device-resident composition-parts LDE handle +/// from the R2 `_keep` path. +/// +/// Returns the `lde_size` ext3 evaluations of the DEEP polynomial on +/// success, or `None` to let the caller run its existing CPU loop. The +/// caller's `inv_denoms` must be `inv_denoms[0..lde_size]` for the H-term +/// and `inv_denoms[(1+k)*lde_size..(2+k)*lde_size]` for trace term k +/// (matching `compute_deep_composition_poly_evaluations`). +#[allow(clippy::too_many_arguments)] +pub(crate) fn try_deep_composition_gpu( + lde_trace: &LDETraceTable, + parts_dev: Option<&math_cuda::lde::GpuLdeExt3>, + parts_host: &[Vec>], + h_ood: &[FieldElement], + trace_ood_columns: &[Vec>], + composition_poly_gammas: &[FieldElement], + trace_terms_gammas: &[Vec>], + inv_denoms: &[FieldElement], + num_eval_points: usize, +) -> Option>> +where + F: IsField + IsSubFieldOf + 'static, + E: IsField + 'static, +{ + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + let main = lde_trace.gpu_main()?; + let lde_size = main.lde_size; + if lde_size < gpu_lde_threshold() { + return None; + } + if !lde_size.is_power_of_two() { + return None; + } + let num_main = main.m; + let aux_handle = lde_trace.gpu_aux(); + let num_aux = aux_handle.map(|a| a.m).unwrap_or(0); + let num_total_cols = num_main + num_aux; + let num_parts = composition_poly_gammas.len(); + if h_ood.len() != num_parts { + return None; + } + if trace_ood_columns.len() != num_total_cols + || trace_ood_columns.iter().any(|c| c.len() != num_eval_points) + { + return None; + } + if trace_terms_gammas.len() != num_total_cols + || trace_terms_gammas + .iter() + .any(|c| c.len() != num_eval_points) + { + return None; + } + let expected_inv_denoms = lde_size.checked_mul(1 + num_eval_points)?; + if inv_denoms.len() != expected_inv_denoms { + return None; + } + + // Validate the host parts when we don't have a device handle, since + // the math-cuda call will assert these. + if parts_dev.is_none() { + if parts_host.len() != num_parts { + return None; + } + if parts_host.iter().any(|p| p.len() != lde_size) { + return None; + } + } else if let Some(p) = parts_dev + && (p.m != num_parts || p.lde_size != lde_size) + { + return None; + } + + // Pack host buffers. SAFETY for ext3 transmutes: E == Ext3 by TypeId check. + let h_ood_raw: &[u64] = unsafe { ext3_slice_to_u64::(h_ood) }; + + // trace_ood: num_total_cols * num_eval_points * 3 (ext3 interleaved, + // (col * num_eval_points + k) layout). + let mut trace_ood_raw: Vec = Vec::with_capacity(num_total_cols * num_eval_points * 3); + for col in trace_ood_columns { + let slice = unsafe { ext3_slice_to_u64::(col) }; + trace_ood_raw.extend_from_slice(slice); + } + + let gammas_h_raw: &[u64] = unsafe { ext3_slice_to_u64::(composition_poly_gammas) }; + + let mut gammas_tr_raw: Vec = Vec::with_capacity(num_total_cols * num_eval_points * 3); + for col in trace_terms_gammas { + let slice = unsafe { ext3_slice_to_u64::(col) }; + gammas_tr_raw.extend_from_slice(slice); + } + + // inv_denoms is laid out as (1 + num_eval_points) blocks of lde_size + // each. Split the H-term block and the trace blocks (concatenated). + let inv_h_raw: &[u64] = unsafe { ext3_slice_to_u64::(&inv_denoms[0..lde_size]) }; + let inv_t_raw: &[u64] = + unsafe { ext3_slice_to_u64::(&inv_denoms[lde_size..lde_size * (1 + num_eval_points)]) }; + + // domain_size == lde_size here: R4 DEEP evaluates at every LDE point + // (Plonky3-style direct LDE). Calling the kernel with blowup_factor = 1 + // makes its `row = i * blowup_factor` index every row. + let domain_size_kernel = lde_size; + let blowup_kernel = 1usize; + + // Pack parts host path if no device handle. + let parts_host_packed: Vec; + let result = if let Some(parts) = parts_dev { + math_cuda::deep::deep_composition_ext3_with_dev_parts( + main, + aux_handle, + parts, + h_ood_raw, + &trace_ood_raw, + gammas_h_raw, + &gammas_tr_raw, + inv_h_raw, + inv_t_raw, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_kernel, + domain_size_kernel, + ) + } else { + // De-interleave each ext3 part column into 3 contiguous base-field + // slabs of length `lde_size` (the math-cuda kernel reads the parts + // buffer with layout `h_lde[(p*3 + k) * lde_stride + r]`). + let mut packed = vec![0u64; num_parts * 3 * lde_size]; + for (p, col) in parts_host.iter().enumerate() { + let slice = unsafe { ext3_slice_to_u64::(col) }; + for (r, chunk) in slice.chunks_exact(3).enumerate() { + packed[(p * 3) * lde_size + r] = chunk[0]; + packed[(p * 3 + 1) * lde_size + r] = chunk[1]; + packed[(p * 3 + 2) * lde_size + r] = chunk[2]; + } + } + parts_host_packed = packed; + math_cuda::deep::deep_composition_ext3( + main, + aux_handle, + &parts_host_packed, + h_ood_raw, + &trace_ood_raw, + gammas_h_raw, + &gammas_tr_raw, + inv_h_raw, + inv_t_raw, + num_parts, + num_main, + num_aux, + num_eval_points, + blowup_kernel, + domain_size_kernel, + ) + }; + + let deep_raw = match result { + Ok(v) => v, + Err(_) => return None, + }; + GPU_DEEP_CALLS.fetch_add(1, Ordering::Relaxed); + debug_assert_eq!(deep_raw.len(), lde_size * 3); + Some(u64_to_ext3_vec::(&deep_raw)) +} + +/// R4 FRI dispatch: drive the full FRI commit phase device-side. Mirrors +/// [`crate::fri::commit_phase_from_evaluations`]: per-layer transcript +/// ping-pong (sample zeta, fold, build Merkle tree, append root). +/// +/// Falls back via `None` only at preconditions; once the loop starts the +/// transcript is mutated layer-by-layer, so a mid-flight cudarc failure +/// panics (CPU restart from the same evals would re-sample zeta_0 against +/// an already-advanced transcript and produce a different proof). +#[allow(clippy::type_complexity)] +pub(crate) fn try_fri_commit_gpu( + number_layers: usize, + evals: &[FieldElement], + transcript: &mut impl IsStarkTranscript, + coset_offset: &FieldElement, + domain_size: usize, +) -> Option<( + FieldElement, + Vec>>, +)> +where + F: IsFFTField + IsField + IsSubFieldOf + 'static, + E: IsField + 'static, + FieldElement: AsBytes, + FieldElement: AsBytes, +{ + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + let n0 = evals.len(); + if n0 != domain_size || !n0.is_power_of_two() || n0 < 2 { + return None; + } + if n0 < gpu_lde_threshold() { + return None; + } + + // Pre-compute inv_twiddles on CPU (matches commit_phase_from_evaluations) + // and pack to u64 before any transcript mutation, so on H2D / state + // construction failure the caller's transcript is untouched. + let inv_twiddles = compute_coset_twiddles_inv::(coset_offset, domain_size); + let mut inv_tw_u64: Vec = Vec::with_capacity(inv_twiddles.len()); + for t in &inv_twiddles { + // SAFETY: F == Goldilocks per TypeId check; FieldElement is + // #[repr(transparent)] over u64. + let v: u64 = unsafe { *(t.value() as *const _ as *const u64) }; + inv_tw_u64.push(v); + } + + // SAFETY: E == Ext3; FieldElement backing is [u64; 3]. + let evals_u64: &[u64] = unsafe { ext3_slice_to_u64::(evals) }; + + let mut state = match math_cuda::fri::FriCommitState::new(evals_u64, &inv_tw_u64, n0) { + Ok(s) => s, + Err(_) => return None, + }; + + let num_committed_layers = number_layers.saturating_sub(1); + let mut fri_layer_list: Vec>> = + Vec::with_capacity(num_committed_layers); + + for _ in 0..num_committed_layers { + // <<<< Receive challenge zeta_k + let zeta: FieldElement = transcript.sample_field_element(); + // SAFETY: E == Ext3. + let zeta_ptr = &zeta as *const FieldElement as *const u64; + let zeta_raw: [u64; 3] = unsafe { [*zeta_ptr, *zeta_ptr.add(1), *zeta_ptr.add(2)] }; + + let (root, layer_evals_u64, nodes_bytes) = state + .fold_and_commit_layer(zeta_raw) + .expect("FRI commit: GPU fold+tree must not fail mid-phase (transcript advanced)"); + + // Build the FriLayer: ext3 evals + Merkle tree from precomputed nodes. + let evaluation = u64_to_ext3_vec::(&layer_evals_u64); + + debug_assert!(nodes_bytes.len().is_multiple_of(32)); + let nodes: Vec<[u8; 32]> = nodes_bytes + .chunks_exact(32) + .map(|c| c.try_into().expect("chunks_exact(32) yields 32 bytes")) + .collect(); + let merkle_tree = MerkleTree::>::from_precomputed_nodes(nodes) + .expect("FRI commit: precomputed nodes form a valid tree"); + + fri_layer_list.push(FriLayer::new(&evaluation, merkle_tree)); + + // >>>> Send commitment: [p_k] + let mut root_arr = [0u8; 32]; + root_arr.copy_from_slice(&root); + transcript.append_bytes(&root_arr); + } + + // <<<< Receive challenge zeta_{n-1} + let zeta_last: FieldElement = transcript.sample_field_element(); + let zeta_ptr = &zeta_last as *const FieldElement as *const u64; + let zeta_raw: [u64; 3] = unsafe { [*zeta_ptr, *zeta_ptr.add(1), *zeta_ptr.add(2)] }; + + let last_raw = state + .fold_final(zeta_raw) + .expect("FRI commit: GPU final fold must not fail mid-phase (transcript advanced)"); + let last_vec = u64_to_ext3_vec::(&last_raw); + let last_value = last_vec + .into_iter() + .next() + .expect("fold_final returns 1 elt"); + + // >>>> Send value: p_n + transcript.append_field_element(&last_value); + + GPU_FRI_CALLS.fetch_add(1, Ordering::Relaxed); + Some((last_value, fri_layer_list)) +} diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index 2c02eacf5..e50bf2279 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -328,6 +328,13 @@ where pub(crate) composition_poly_merkle_tree: BatchedMerkleTree, /// The commitment to the composition polynomial parts. pub(crate) composition_poly_root: Commitment, + /// Device-resident de-interleaved LDE handle from the R2 fused GPU path + /// (`try_evaluate_parts_on_lde_gpu_keep`). When present, R4 DEEP skips + /// the `num_parts * 3 * lde_size * 8` byte H2D and reads parts on + /// device. `None` when the GPU R2 path didn't run (number_of_parts <= 2, + /// below threshold, or any CPU fallback). + #[cfg(feature = "cuda")] + pub(crate) gpu_composition_parts: Option, } /// A container for the results of the third round of the STARK Prove protocol. @@ -1004,6 +1011,8 @@ pub trait IsStarkProver< #[cfg(feature = "instruments")] let t_sub = Instant::now(); + #[cfg(feature = "cuda")] + let mut gpu_composition_parts: Option = None; let lde_composition_poly_parts_evaluations = if number_of_parts == 2 { // Direct quotient decomposition: avoid full-size iFFT by algebraically // splitting H(x) = H₀(x²) + x·H₁(x²) using: @@ -1037,23 +1046,27 @@ pub trait IsStarkProver< }; // GPU fast path: batched ext3 LDE for all parts in one call. - // Non-`_keep` variant. The device buffer is freed at end of - // call; downstream R3 reads the host-side evals. PR-4 will - // upgrade to the `_keep` variant to retain the handle for R4 - // DEEP composition. + // `_keep` variant retains the de-interleaved device buffer as a + // `GpuLdeExt3` handle stored on Round2 so R4 DEEP can skip the + // `num_parts * 3 * lde_size * 8` byte H2D. #[cfg(feature = "cuda")] { let parts_slices: Vec<&[FieldElement]> = composition_poly_parts .iter() .map(|p| p.coefficients.as_slice()) .collect(); - crate::gpu_lde::try_evaluate_parts_on_lde_gpu::( + match crate::gpu_lde::try_evaluate_parts_on_lde_gpu_keep::( &parts_slices, domain.blowup_factor, domain.interpolation_domain_size, &domain.coset_offset, - ) - .unwrap_or_else(cpu_eval) + ) { + Some((evals, handle)) => { + gpu_composition_parts = Some(handle); + evals + } + None => cpu_eval(), + } } #[cfg(not(feature = "cuda"))] cpu_eval() @@ -1091,6 +1104,8 @@ pub trait IsStarkProver< lde_composition_poly_evaluations: lde_composition_poly_parts_evaluations, composition_poly_merkle_tree, composition_poly_root, + #[cfg(feature = "cuda")] + gpu_composition_parts, }) } @@ -1359,6 +1374,30 @@ pub trait IsStarkProver< let trace_ood_columns = round_3_result.trace_ood_evaluations.columns(); let num_total_cols = num_main_cols + num_aux_cols; + // GPU fast path: device-resident DEEP composition. Reuses the R1 + // main/aux LDE handles on `lde_trace` and (when the R2 fused path + // ran) the parts handle on `round_2_result.gpu_composition_parts`. + // Falls back to the CPU rayon loop below on any precondition miss + // or kernel failure. + #[cfg(feature = "cuda")] + { + if let Some(deep_evals) = + crate::gpu_lde::try_deep_composition_gpu::( + lde_trace, + round_2_result.gpu_composition_parts.as_ref(), + &round_2_result.lde_composition_poly_evaluations, + h_ood, + &trace_ood_columns, + composition_poly_gammas, + trace_terms_gammas, + &denoms, + num_eval_points, + ) + { + return deep_evals; + } + } + // OOD column compression (Plonky3-style): precompute one value per eval point, // ood_compressed_k = Σ_j gamma[j][k] * ood[j][k]. // The per-LDE-point trace column sums are NOT precomputed — they are fused diff --git a/prover/tests/cuda_path_integration.rs b/prover/tests/cuda_path_integration.rs index e54becb85..e6f109ab3 100644 --- a/prover/tests/cuda_path_integration.rs +++ b/prover/tests/cuda_path_integration.rs @@ -11,8 +11,8 @@ use lambda_vm_prover::prove; use lambda_vm_prover::test_utils::asm_elf_bytes; use stark::gpu_lde::{ - gpu_bary_calls, gpu_comp_poly_tree_calls, gpu_lde_calls, gpu_parts_lde_calls, - reset_all_gpu_call_counters, + gpu_bary_calls, gpu_comp_poly_tree_calls, gpu_deep_calls, gpu_fri_calls, gpu_lde_calls, + gpu_parts_lde_calls, reset_all_gpu_call_counters, }; #[test] @@ -46,4 +46,14 @@ fn gpu_path_fires_end_to_end() { gpu_comp_poly_tree_calls() > 0, "R2 GPU comp-poly tree did not fire" ); + + // R4 DEEP composition. Reuses the R1 main/aux handles and (when R2 + // parts LDE took the keep path) the parts handle on Round2. Fires + // once per table whose lde_size clears the threshold. + assert!(gpu_deep_calls() > 0, "R4 GPU DEEP composition did not fire"); + + // R4 FRI commit phase. One call per table (per + // `commit_phase_from_evaluations`); each call drives all FRI layers + // device-side internally. + assert!(gpu_fri_calls() > 0, "R4 GPU FRI commit did not fire"); }