feat(examples): L3 ring allreduce (chunked RS+AG, a2a3 verified)#975
feat(examples): L3 ring allreduce (chunked RS+AG, a2a3 verified)#975georgebisbas wants to merge 4 commits into
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis pull request adds a complete new L3 worker example implementing distributed ring AllReduce with chunked reduce-scatter and allgather phases. It includes AICORE kernel primitives, a full kernel implementation, orchestration wiring, Python runtime setup with golden-output validation, and integration tests. ChangesRing Allreduce Distributed Example
Sequence DiagramsequenceDiagram
participant Python as Python Runtime
participant Worker
participant Orch as Orchestration Layer
participant AIVKernel as AIV Kernel
participant Rank0
participant Rank1
Python->>Worker: Initialize worker
Python->>Python: Allocate per-rank input/output tensors
Python->>Python: Allocate ring domain window and scratch buffer
Python->>Orch: Submit orchestration DAG with tensor/scalar args
Orch->>AIVKernel: rt_submit_aiv_task (3 tensors + 2 scalars)
AIVKernel->>Rank0: Validate nranks, bind scratch layout
Rank0->>Rank0: Stage input into chunk slots
par Reduce-Scatter Phase
Rank0->>Rank1: Publish chunk, barrier signal
Rank1-->>Rank0: Send left-neighbor chunk
Rank0->>Rank0: Load/accumulate tile with MTE flags
end
par Allgather Phase
Rank0->>Rank1: Publish reduced chunk for dissemination
Rank1-->>Rank0: Send chunk from previous round
Rank0->>Rank0: Store chunk to output slot
end
AIVKernel->>AIVKernel: Stage concatenated chunks to output tensor
AIVKernel->>Worker: Return
Python->>Python: Compute golden expected output
Python->>Python: Validate each rank output vs golden (1e-3 tolerance)
Python->>Worker: Close worker
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a distributed ring AllReduce implementation, featuring chunked reduce-scatter and allgather algorithms. The code review feedback highlights several critical optimization and correctness improvements. First, the exchange buffer is completely unused by remote ranks and should be removed along with its redundant memory copies across the kernel, helper functions, and Python host code to improve performance and reduce scratch memory usage. Second, the kernel must explicitly zero-initialize the local signal slots to prevent undefined behavior, as device memory is not guaranteed to be zero-initialized. Finally, the unnecessary from __future__ import annotations import in main.py should be removed.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| __gm__ int32_t *signal_base = nullptr; | ||
| RingBindScratch(scratch, nranks, chunk_elems, chunks, exchange, signal_base); | ||
|
|
||
| // Signal rows rely on fresh-window zero init (per-round rows, used once). |
There was a problem hiding this comment.
Device memory (such as the HCCL window scratch space) is not guaranteed to be zero-initialized. Relying on 'fresh-window zero init' can lead to undefined behavior or state leakage between runs. Explicitly zero-initialize the local signal slots at the start of the kernel to ensure correctness.
// Explicitly zero-initialize the local signal slots to prevent garbage values
// from causing undefined behavior, as device memory is not guaranteed to be zero-initialized.
const int total_rounds = 2 * (nranks - 1);
for (int r = 0; r < total_rounds; ++r) {
for (int peer = 0; peer < nranks; ++peer) {
signal_base[r * kMaxSupportedRanks + peer] = 0;
}
}
pipe_barrier(PIPE_ALL);References
- Do not assume that allocated shared memory or device memory is zero-initialized. Always explicitly initialize all fields (such as thread/core counts and mapping arrays) to prevent garbage values from causing segmentation faults or undefined behavior.
| __gm__ float *chunks = nullptr; | ||
| __gm__ float *exchange = nullptr; | ||
| __gm__ int32_t *signal_base = nullptr; | ||
| RingBindScratch(scratch, nranks, chunk_elems, chunks, exchange, signal_base); |
There was a problem hiding this comment.
The exchange buffer is completely unused by any remote rank, and copying chunks to it is redundant. Removing the exchange buffer and its associated copies will save 2 * (P - 1) redundant global memory copies (TLOAD + TSTORE) per rank, significantly improving performance and reducing scratch memory usage.
__gm__ float *chunks = nullptr;
__gm__ int32_t *signal_base = nullptr;
RingBindScratch(scratch, nranks, chunk_elems, chunks, signal_base);| RingCopyChunkGm(exchange, chunks + static_cast<size_t>(send_idx * chunk_elems), chunk_elems, chunkTile); | ||
| pipe_barrier(PIPE_ALL); |
| RingCopyChunkGm(exchange, chunks + static_cast<size_t>(send_idx * chunk_elems), chunk_elems, chunkTile); | ||
| pipe_barrier(PIPE_ALL); |
| AICORE inline void RingBindScratch( | ||
| __gm__ float *scratch, int nranks, int chunk_elems, __gm__ float *&chunks, __gm__ float *&exchange, | ||
| __gm__ int32_t *&signal_base | ||
| ) { | ||
| chunks = scratch; | ||
| exchange = scratch + static_cast<size_t>(nranks * chunk_elems); | ||
| signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + static_cast<size_t>((nranks + 1) * chunk_elems)); | ||
| } |
There was a problem hiding this comment.
Simplify RingBindScratch by removing the unused exchange buffer parameter and calculation.
AICORE inline void RingBindScratch(
__gm__ float *scratch, int nranks, int chunk_elems, __gm__ float *&chunks,
__gm__ int32_t *&signal_base
) {
chunks = scratch;
signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + static_cast<size_t>(nranks * chunk_elems));
}| SCRATCH_FLOAT_ELEMS = (K_MAX_SUPPORTED_RANKS + 1) * CHUNK_MAX | ||
| SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS | ||
| SCRATCH_NBYTES = SCRATCH_FLOAT_ELEMS * 4 + SIGNAL_SLOTS * 4 |
There was a problem hiding this comment.
Since the unused exchange buffer has been removed, the scratch float elements can be simplified to just ALLREDUCE_COUNT floats, reducing memory usage.
| SCRATCH_FLOAT_ELEMS = (K_MAX_SUPPORTED_RANKS + 1) * CHUNK_MAX | |
| SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS | |
| SCRATCH_NBYTES = SCRATCH_FLOAT_ELEMS * 4 + SIGNAL_SLOTS * 4 | |
| SCRATCH_FLOAT_ELEMS = ALLREDUCE_COUNT | |
| SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS | |
| SCRATCH_NBYTES = SCRATCH_FLOAT_ELEMS * 4 + SIGNAL_SLOTS * 4 |
| def scratch_float_elems(nranks: int) -> int: | ||
| """Float slots in the HCCL window for this rank count: P chunk slots + 1 exchange.""" | ||
| if ALLREDUCE_COUNT % nranks != 0: | ||
| raise ValueError(f"ALLREDUCE_COUNT={ALLREDUCE_COUNT} must divide nranks={nranks}") | ||
| chunk = ALLREDUCE_COUNT // nranks | ||
| return (nranks + 1) * chunk |
There was a problem hiding this comment.
Simplify scratch_float_elems to return ALLREDUCE_COUNT directly since the exchange buffer is removed.
| def scratch_float_elems(nranks: int) -> int: | |
| """Float slots in the HCCL window for this rank count: P chunk slots + 1 exchange.""" | |
| if ALLREDUCE_COUNT % nranks != 0: | |
| raise ValueError(f"ALLREDUCE_COUNT={ALLREDUCE_COUNT} must divide nranks={nranks}") | |
| chunk = ALLREDUCE_COUNT // nranks | |
| return (nranks + 1) * chunk | |
| def scratch_float_elems(nranks: int) -> int: | |
| """Float slots in the HCCL window for this rank count: P chunk slots.""" | |
| if ALLREDUCE_COUNT % nranks != 0: | |
| raise ValueError(f"ALLREDUCE_COUNT={ALLREDUCE_COUNT} must divide nranks={nranks}") | |
| return ALLREDUCE_COUNT |
|
|
||
| """ | ||
|
|
||
| from __future__ import annotations |
There was a problem hiding this comment.
For Python projects targeting Python 3.9 or higher, PEP 585 generic collections (like list[str]) are fully supported at runtime, so adding from __future__ import annotations is unnecessary.
References
- For Python projects targeting Python 3.9 or higher, PEP 585 generic collections (like list[str]) are fully supported at runtime, so adding from future import annotations is unnecessary to prevent runtime errors during module load.
New L3 example separate from mesh allreduce_distributed: stage-in, (P-1) reduce-scatter and (P-1) allgather ring rounds over HCCL window chunks with per-round TNOTIFY/TWAIT barriers. Same golden as mesh. P=2/P=4 pytest; default CLI devices 0-3.
75c3152 to
497ae58
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/workers/l3/allreduce_ring_distributed/main.py`:
- Around line 136-145: Validate the device_ids input at the top of run(): check
that device_ids is non-empty and that nranks = len(device_ids) is within the
supported range (e.g., between 2 and 16 as the example expects); if not, raise a
ValueError with a clear message so downstream calls (like scratch_float_elems)
don't hit ZeroDivisionError or unsupported configurations. Add this check
directly in run() before calling scratch_float_elems() or computing window_size,
referencing run() and scratch_float_elems() in the message so the caller can see
which entrypoint enforces the constraint.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: a173f9be-0a2b-4f97-bf62-6560ebde9a86
📒 Files selected for processing (7)
examples/workers/l3/README.mdexamples/workers/l3/allreduce_ring_distributed/__init__.pyexamples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_common.hppexamples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_kernel.cppexamples/workers/l3/allreduce_ring_distributed/kernels/orchestration/allreduce_ring_orch.cppexamples/workers/l3/allreduce_ring_distributed/main.pyexamples/workers/l3/allreduce_ring_distributed/test_allreduce.py
| def run( | ||
| device_ids: list[int], | ||
| platform: str = "a2a3", | ||
| pto_isa_commit: str | None = None, | ||
| ) -> int: | ||
| """Core logic — callable from both CLI and pytest.""" | ||
| nranks = len(device_ids) | ||
| float_elems = scratch_float_elems(nranks) | ||
| window_size = max(SCRATCH_NBYTES, 4 * 1024) | ||
|
|
There was a problem hiding this comment.
Validate device_ids inside run(), not just in the CLI parser.
run() is the public entrypoint used by pytest as well as the CLI, so invalid rank counts bypass parse_device_range(). device_ids=[] hits a ZeroDivisionError in scratch_float_elems(), and 1 or >16 devices fall into configs the example explicitly does not support.
Suggested fix
def run(
device_ids: list[int],
platform: str = "a2a3",
pto_isa_commit: str | None = None,
) -> int:
"""Core logic — callable from both CLI and pytest."""
nranks = len(device_ids)
+ if not (2 <= nranks <= K_MAX_SUPPORTED_RANKS):
+ raise ValueError(
+ f"allreduce_ring_distributed needs between 2 and {K_MAX_SUPPORTED_RANKS} devices, "
+ f"got {nranks} ({device_ids})"
+ )
float_elems = scratch_float_elems(nranks)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@examples/workers/l3/allreduce_ring_distributed/main.py` around lines 136 -
145, Validate the device_ids input at the top of run(): check that device_ids is
non-empty and that nranks = len(device_ids) is within the supported range (e.g.,
between 2 and 16 as the example expects); if not, raise a ValueError with a
clear message so downstream calls (like scratch_float_elems) don't hit
ZeroDivisionError or unsupported configurations. Add this check directly in
run() before calling scratch_float_elems() or computing window_size, referencing
run() and scratch_float_elems() in the message so the caller can see which
entrypoint enforces the constraint.
Drop RingZeroSignals (per-round barrier rows used once; zeroing raced peer notify and caused AICPU 507018 timeout). Recv via left neighbour chunks[] after barrier, not local exchange mirror (max golden diff 99 on second chunk). Size scratch CommBufferSpec to (P+1)*chunk elements. Align ring example with mesh L3 style: single allreduce_ring_kernel.cpp (no common header), phase banners, and matching orch/main.py comments.
497ae58 to
690efbc
Compare
Mirror parse_device_range() so pytest/CLI callers cannot pass an empty list or unsupported rank count into scratch_float_elems().
Summary
Reopens the work from #972 (that PR cannot be reopened after branch rebase).
Adds
examples/workers/l3/allreduce_ring_distributed/— chunked ring allreduce(RS + AG on a logical ring), separate from mesh
allreduce_distributed/.Closes / supersedes: #972
Ring uses +10624B HCCL window (chunked + per-round signals); mesh uses 4096B.
Algorithm
TNOTIFY/TWAITchunks[]viaCommRemotePtrafter each barrierTest plan
python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 5-6python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 0-3python examples/workers/l3/allreduce_distributed/main.py -p a2a3 -d 0-3