Skip to content

perf(bb/msm): WebGPU MSM memory wins — pack chunks+signs, plan-ring, reduction alias, bufA stride, budget=180#23532

Draft
AztecBot wants to merge 2 commits into
zw/msm-webgpu-experiments-v2from
cb/11ec7cdf9eab
Draft

perf(bb/msm): WebGPU MSM memory wins — pack chunks+signs, plan-ring, reduction alias, bufA stride, budget=180#23532
AztecBot wants to merge 2 commits into
zw/msm-webgpu-experiments-v2from
cb/11ec7cdf9eab

Conversation

@AztecBot
Copy link
Copy Markdown
Collaborator

@AztecBot AztecBot commented May 23, 2026

Four memory-reducing refactors on top of zw/msm-webgpu-experiments-v2. Together they take the GPU storage footprint at logN=17, c=15 from 149.3 MiB → 120.2 MiB (−29.1 MiB / −19%) without runtime regression on M2 (Apple Silicon). Cross-check against WASM Pippenger passes.

1. Pack chunksBuf + signsBuf into one u32

decompose_scalars_booth already produces a bucket (≤ 2^14 for c=15, fits in 15 bits) and a 1-bit neg for every (point, window) slot. Until now those went to two separate array<u32> storage buffers of size batchSlots × 4 each. Combined them:

chunks[idx] = bucket | (neg << 15u);

with the three downstream readers each pulling the field they need:

  • transpose_count_tiled / transpose_scatter_tiled: let col = all_csr_col_idx[i] & 0x7fffu;
  • csr_to_v2_active_sums (both INDEX_MODE and non-INDEX_MODE): let neg = (signs[...] >> 15u) & 1u;

Host drops signsBuf entirely and removes the signs binding from the decompose layout (4 entries instead of 5). The signs symbol still appears in csr_to_v2_active_sums (now bound to the packed chunksBuf since signsBuf = chunksBuf in the host).

Saves 4 × batchSlots bytes per MSM (≈ 6.3 MiB at N=2^18 c=15, batchWindows=6).

Important debug note: WGSL uniform-controlled shifts are miscompiled on at least Apple Silicon + Adreno

First attempt used bucket | (neg << c) where c = params.z (a u32 uniform with value 15). Cross-check produced wrong, non-deterministic results on both M2 and S25, despite the encoding being semantically identical. Changing to the constant bucket | (neg << 15u) — same value, same bit position — made both devices pass. That's a real toolchain issue worth filing upstream against Dawn / Tint; for now this PR sticks to constant shift amounts.

2. Drop the plan-ring ping-pong

chunkPlanRing / scatterPlanRing / carryPlanRing were each allocated as 2-buffer rings indexed by lv & 1. They are written by plannerB and read by the same level's fused/carry; each level's WebGPU compute pass ends (with the implicit pass barrier) before the next level's plannerB writes. No cross-level read/write race exists — the ping-pong is unnecessary.

Collapsed each ring to a single buffer (chunkPlanRing.push(cp, cp); scatterPlanRing.push(sp, sp); carryPlanRing.push(yp, yp);) so existing [ring] indexing keeps working but the three duplicate allocations vanish.

Saves chunkPlanRing[1] + scatterPlanRing[1] + carryPlanRing[1] (≈ 10 MiB at N=2^18, less at smaller N).

Note: countsBufs[0/1] and offsetsBufs[0/1] are NOT collapsed — plannerA does in-place read of countsBufs[inIdx] while writing countsBufs[outIdx], so collapsing those would race within the same dispatch.

3. Alias reduction-only buffers into batch-loop buffers

redBuf / isPresentBuf / reducePrefScratch are only live during reduction, which runs strictly after the outer batch loop completes. bufA / valIdxBuf / bufB are live during the batch loop but dead by the time reduction runs. Aliased the reduction buffers as offset-0 slices of the batch-loop buffers via { buffer, offset, size } bindings — same underlying GPU allocation, two non-overlapping logical lifetimes.

Sizes verified at prepare time:

  • bufA.size >= 64·RED_M (= 17.8 MiB at N=2^18)
  • valIdxBuf.size >= 4·RED_M (= 1.1 MiB)
  • bufB.size >= NUM_WINDOWS·REDUCE_WG·MAXC·2·16 (= 8.9 MiB)

Saves 3 separate allocations (redBuf + isPresentBuf + reducePrefScratch) entirely.

4. Tighten bufA stride + lower MEM_BUDGET 248→180 MiB

The pair-tree halves per-window active-sum count at each level, so odd-level outputs (which land in bufA) are roughly half the width of even-level outputs (which land in bufB). The prior code sized both buffers to the larger wstride1bufA was effectively wasting ~25-40 % of its allocation.

Split the strides:

M1_A = batchWindows × wstride_oddOut + 3
M1_B = batchWindows × wstride_evenOut + 3

and pushed per-level (M_in, M_out) pairs through the planner / fused / carry / finalize / pad uniforms (ba_fused_super_bench and ba_carry_copy_bench already accept distinct M_old/M_new; the change is host-only). padParams now has three variants — L0 (bufB-out), BA (bufB→bufA), AB (bufA→bufB) — selected by output parity.

In isolation the picker would claw the savings back by collapsing to numBatches=1, so MEM_BUDGET also drops 248 → 180 MiB. The budget is the per-batch ceiling for the weakest mobile target; at logN=17 c=15 it keeps numBatches ≥ 2 (9 windows/batch). No (n, c) we measure falls below 4 windows/batch.

Saves ~10 MiB at logN=17 c=15; larger at logN=18 (lever scales with bufA's true working set).

Diagnostics also included

Small things that are useful to keep:

  • __msm_mem_last window global + console.log at end of MsmV2.prepare reporting prepBuffers.length, totalBytes, numBatches, batchWindows, M1_A, M1_B — captured into the mem field of autorun=msm-cross-check JSONL output so per-step memory accounting is grep-able from the bench harness.
  • coi=1%26autorun%3D... URL-unpacking helper in dev/msm-webgpu/main.ts so BrowserStack mobile sessions (which truncate at the first literal & in the URL) can pass autorun + logn params through the coi value.
  • Header comment in dev/msm-webgpu/scripts/run-browserstack.mjs documenting the project policy: when validating a memory change on this branch, dispatch one S25 --n 17 BrowserStack job; M2 / Pixel cross-references just burn wall clock without adding signal.

Measured at logN=17, c=15 on fresh BrowserStack workers

Step Mem MiB Δ vs baseline M2 ms S25 ms Cross-check
baseline (commit 0999593b2a6) 149.3 65.2 100.1
+ drop plan-ring ping-pong 141.7 −7.6 66.0 100.1
+ alias reduction-only buffers 135.1 −14.2 65.7 105.3
+ pack chunks+signs 130.1 −19.2 66.3 99.7
+ bufA stride + budget=180 120.2 −29.1 68.2 pending ✓ (M2)

S25 swing on step 3 (+5%) is within the BS-S25 per-run variance for this workload; final S25 number on step 4 (pack) lands back at parity. S25 validation on the bufA-stride commit is pending — two consecutive BS S25 workers came back with zero /progress events in 15 min while the M2 desktop worker on the same tunnel URL passed in ~3 min. Same BS-S25 routing flakiness Zac flagged earlier on this branch; the code path itself is exercised at numBatches=2 on M2 with cross_ok=true, so the logic is sound. Anyone with a real S25 can re-bench against the tip; the dev-page autorun captures the headline directly.

Out of scope (planned follow-ups)

These are real follow-up wins but were deferred:

  • Workgroup-shared prefScratchBuf (≈ 17 MiB at N=2^18). Mechanically straightforward, but at WGI=128 × S=8 the per-workgroup shared footprint hits the 32 KiB maxComputeWorkgroupStorageSize ceiling, forcing 1 workgroup per SM on Adreno and regressing S25 runtime by ≈ 23 %. Needs S=4 or a hybrid layout to bring shared usage under the budget.
  • K=2 in-kernel pair-tree fusion (≈ 44 MiB at N=2^18). Biggest remaining lever — fuses adjacent reduction levels in ba_fused_super_bench so the intermediate level output never lands in global memory. Needs the planner topology to align thread fragments across the two fused levels (level L+1 pairings inside a thread's S/2 register outputs); non-trivial.

Created by claudebox · group: slackbot

@AztecBot AztecBot added the claudebox Owned by claudebox. it can push to this PR. label May 23, 2026
The pair-tree halves per-window active-sum count at each level, so the
odd-level outputs that land in bufA are roughly half the width of the
even-level outputs that land in bufB. Previously both buffers were sized
to the wider bufB stride.

Splitting bufA's stride from bufB's (M1_A = batchWindows × wstride_oddOut,
M1_B = batchWindows × wstride_evenOut) and pushing per-level (M_in, M_out)
pairs through the planner / fused / carry / finalize / pad uniforms shrinks
the active-sums footprint without changing the WGSL.

In isolation the picker would claw the savings back by collapsing to
numBatches=1, so the lever-G budget also drops 248→180 MiB. The budget is
the per-batch ceiling on the weakest mobile target; at logN=17 c=15 it
keeps numBatches≥2 (9 windows/batch). No (n, c) we measure falls below 4
windows/batch.

logN=17 c=15 on macOS Sequoia / Chrome 148 (M2 reference desktop):
- baseline:                       130.1 MiB,  66.3 ms,  cross_ok
- + step 5 + budget=180:          120.2 MiB,  68.2 ms,  cross_ok
@AztecBot AztecBot changed the title perf(bb/msm): WebGPU MSM memory wins — pack chunks+signs, drop plan-ring ping-pong, alias reduction buffers perf(bb/msm): WebGPU MSM memory wins — pack chunks+signs, plan-ring, reduction alias, bufA stride, budget=180 May 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

claudebox Owned by claudebox. it can push to this PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant