feat(examples): add L3 ring allreduce skeleton (chunked RS+AG)#972
feat(examples): add L3 ring allreduce skeleton (chunked RS+AG)#972georgebisbas wants to merge 2 commits into
Conversation
New allreduce_ring_distributed example: P-1 reduce-scatter and P-1 allgather ring rounds over HCCL window exchange slots, with per-round notify/wait barriers. Same golden as mesh allreduce_distributed. P=2 and P=4 pytest; default CLI device 0-3.
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a distributed ring allreduce implementation under examples/workers/l3/, consisting of a chunked reduce-scatter and allgather kernel, orchestration code, a Python driver, and tests. Feedback on the changes highlights critical issues in the kernel: signal slots in the scratch buffer must be explicitly zero-initialized to prevent barrier failures, and additional barriers are required in both the reduce-scatter and allgather loops to resolve Write-After-Read (WAR) hazards. Consequently, the SIGNAL_SLOTS allocation in the Python driver needs to be doubled to prevent out-of-bounds indexing. Additionally, the all-to-all barrier can be optimized to a ring-specific neighbor synchronization to reduce overhead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| __gm__ int32_t *signal_base = | ||
| reinterpret_cast<__gm__ int32_t *>(scratch + static_cast<size_t>((nranks + 1) * chunk_elems)); |
There was a problem hiding this comment.
Following the general rule:
Do not assume that allocated shared memory or device memory is zero-initialized. Always explicitly initialize all fields (such as thread/core counts and mapping arrays) to prevent garbage values from causing segmentation faults or undefined behavior.
The signal slots in the scratch buffer are not guaranteed to be zero-initialized. If they contain positive garbage values, TWAIT with WaitCmp::GE to 1 will immediately succeed, completely breaking the barrier synchronization. We must explicitly initialize the local signal slots to 0 at the start of the kernel.
__gm__ int32_t *signal_base =
reinterpret_cast<__gm__ int32_t *>(scratch + static_cast<size_t>((nranks + 1) * chunk_elems));
// Explicitly initialize local signal slots to 0 to prevent garbage values from causing undefined behavior in barriers.
const int total_signal_slots = 4 * (nranks - 1) * kMaxSupportedRanks;
for (int i = 0; i < total_signal_slots; ++i) {
signal_base[i] = 0;
}
pipe_barrier(PIPE_ALL);References
- Do not assume that allocated shared memory or device memory is zero-initialized. Always explicitly initialize all fields (such as thread/core counts and mapping arrays) to prevent garbage values from causing segmentation faults or undefined behavior.
| set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); | ||
| wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); | ||
| pipe_barrier(PIPE_ALL); |
There was a problem hiding this comment.
There is a critical Write-After-Read (WAR) hazard (anti-dependency) across ranks in the reduce-scatter loop.
In step exchange slot. However, if Rank exchange slot before Rank RoundBarrier at the start of the step only ensures that the write from the previous step is complete, but does not prevent a fast writer from overwriting the buffer before a slow reader finishes.
To resolve this, we must add a second barrier at the end of each step to ensure all reads are complete before any rank proceeds to overwrite its exchange slot in the next step.
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
pipe_barrier(PIPE_ALL);
RoundBarrier(commCtx, signal_base, my_rank, nranks, round++);| set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); | ||
| wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); | ||
| pipe_barrier(PIPE_ALL); |
There was a problem hiding this comment.
Similarly to the reduce-scatter loop, there is a Write-After-Read (WAR) hazard in the allgather loop. A fast rank can proceed to the next step and overwrite its exchange slot before its right neighbor has finished reading it.
Adding a second barrier at the end of each step ensures that all ranks have finished reading the current step's exchange slots before any rank proceeds to overwrite them.
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);
pipe_barrier(PIPE_ALL);
RoundBarrier(commCtx, signal_base, my_rank, nranks, round++);| K_MAX_SUPPORTED_RANKS = 16 | ||
| CHUNK_MAX = ALLREDUCE_COUNT // 2 # largest chunk (P=2) | ||
| SCRATCH_FLOAT_ELEMS = (K_MAX_SUPPORTED_RANKS + 1) * CHUNK_MAX | ||
| SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS |
There was a problem hiding this comment.
Since we need to introduce a second barrier per step in both the reduce-scatter and allgather loops to prevent WAR hazards, the total number of barrier rounds per execution doubles from
We must double the allocated SIGNAL_SLOTS accordingly to prevent out-of-bounds signal indexing in the scratch buffer.
| SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS | |
| SIGNAL_SLOTS = 4 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS |
| AICORE inline void RoundBarrier(__gm__ CommContext *ctx, __gm__ int32_t *signal_base, int my_rank, int nranks, | ||
| int round) { | ||
| __gm__ int32_t *round_signals = signal_base + round * kMaxSupportedRanks; | ||
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) { | ||
| continue; | ||
| } | ||
| __gm__ int32_t *remote_signal = CommRemotePtr(ctx, round_signals + my_rank, peer); | ||
| pto::comm::Signal sig(remote_signal); | ||
| pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); | ||
| } | ||
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) { | ||
| continue; | ||
| } | ||
| pto::comm::Signal sig(round_signals + peer); | ||
| pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); | ||
| } | ||
| pipe_barrier(PIPE_ALL); | ||
| } |
There was a problem hiding this comment.
In a ring topology, each rank only interacts with its immediate left and right neighbors (reading from left, writing to right).
Using an all-to-all barrier (RoundBarrier with all peers) introduces unnecessary synchronization overhead of
AICORE inline void RoundBarrier(__gm__ CommContext *ctx, __gm__ int32_t *signal_base, int my_rank, int nranks,
int round) {
__gm__ int32_t *round_signals = signal_base + round * kMaxSupportedRanks;
int right = (my_rank + 1) % nranks;
int left = (my_rank - 1 + nranks) % nranks;
// Notify right neighbor
__gm__ int32_t *remote_signal = CommRemotePtr(ctx, round_signals + my_rank, right);
pto::comm::Signal sig_remote(remote_signal);
pto::comm::TNOTIFY(sig_remote, (int32_t)1, pto::comm::NotifyOp::AtomicAdd);
// Wait for left neighbor
pto::comm::Signal sig_local(round_signals + left);
pto::comm::TWAIT(sig_local, (int32_t)1, pto::comm::WaitCmp::GE);
pipe_barrier(PIPE_ALL);
}50aff1b to
991c9db
Compare
- Use one signal row with cumulative TWAIT generations instead of per-round rows or mid-barrier slot resets that race with TNOTIFY. - Add stage-in device barrier and zero-init signals once per run. - Match mesh local-write, barrier, then remote-read for exchange chunks. - Allocate a single kMaxSupportedRanks int32 tail in main.py.
991c9db to
cb7ab81
Compare
Adds a new L3 example
allreduce_ring_distributed/implementing chunked ring allreduce (P−1 reduce-scatter + P−1 allgather rounds) over the HCCL communication window.Verification status
8fb6316e(a2a3, devices 5–6; temporary mesh-binary fallback)allreduce_ring_common.hppwired;main.pycompiles ring sources; NPU verify pendingWhat changed
New:
examples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_kernel.cppkernels/aiv/allreduce_ring_common.hppkernels/orchestration/allreduce_ring_orch.cppmain.pytest_allreduce.pyEdit:
examples/workers/l3/README.md— documents the new example.Scratch layout (per rank, in window):
Pchunk slots (partitioned input / working buffers)2(P−1) × kMaxSupportedRanksint32 slots (one row per round)Golden: identical to mesh —
output[i] = sum_r (i + r×100)for all ranks (256-element float32 vectors;ALLREDUCE_COUNTmust dividenranks).NPU re-verify (after ring kernel wired)
python examples/workers/l3/allreduce_distributed/main.py -p a2a3 -d 5-6 # baseline python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 5-6 python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3sim -d 0-3