Skip to content

feat(examples): L3 ring allreduce (chunked RS+AG, a2a3 verified)#975

Open
georgebisbas wants to merge 4 commits into
hw-native-sys:mainfrom
georgebisbas:feat/l3-ring-allreduce-skeleton
Open

feat(examples): L3 ring allreduce (chunked RS+AG, a2a3 verified)#975
georgebisbas wants to merge 4 commits into
hw-native-sys:mainfrom
georgebisbas:feat/l3-ring-allreduce-skeleton

Conversation

@georgebisbas
Copy link
Copy Markdown
Contributor

@georgebisbas georgebisbas commented Jun 2, 2026

Summary

Reopens the work from #972 (that PR cannot be reopened after branch rebase).

Adds examples/workers/l3/allreduce_ring_distributed/ — chunked ring allreduce
(RS + AG on a logical ring), separate from mesh allreduce_distributed/.

Closes / supersedes: #972

Ring uses +10624B HCCL window (chunked + per-round signals); mesh uses 4096B.

Algorithm

  • Stage-in: P chunk slots in HCCL window
  • Reduce-scatter: (P-1) ring steps, per-round TNOTIFY/TWAIT
  • Allgather: (P-1) ring steps
  • Recv from left neighbour chunks[] via CommRemotePtr after each barrier

Test plan

  • python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 5-6
  • python examples/workers/l3/allreduce_ring_distributed/main.py -p a2a3 -d 0-3
  • python examples/workers/l3/allreduce_distributed/main.py -p a2a3 -d 0-3
  • CI green

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: f2a5b71c-654e-4172-888f-4f4ceb18299d

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request adds a complete new L3 worker example implementing distributed ring AllReduce with chunked reduce-scatter and allgather phases. It includes AICORE kernel primitives, a full kernel implementation, orchestration wiring, Python runtime setup with golden-output validation, and integration tests.

Changes

Ring Allreduce Distributed Example

Layer / File(s) Summary
Documentation and Package Setup
examples/workers/l3/README.md, examples/workers/l3/allreduce_ring_distributed/__init__.py
Example entry added to the L3 README table, and package structure created with license header and relative import marker.
Ring Allreduce Primitives and Helpers
examples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_common.hpp
Compile-time constants (kAllReduceCount, kMaxSupportedRanks, kChunkMax), dynamic tensor/tile type aliases, and inline AICORE helpers for remote pointer computation, peer barrier synchronization, chunk copy with MTE2/MTE3 flag ordering, left-neighbor receipt, and scratch memory layout binding.
AIV Kernel Reduce-Scatter and Allgather Logic
examples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_kernel.cpp
Kernel entrypoint validates parameters, binds scratch layout, stages input into chunks, executes (nranks−1) reduce-scatter steps with remote exchange and tiled accumulation, then (nranks−1) allgather steps to disseminate reduced chunks, and stages the final result back to output with pipeline synchronization.
Orchestration and Task Submission
examples/workers/l3/allreduce_ring_distributed/kernels/orchestration/allreduce_ring_orch.cpp
Orchestration config specifies 5 expected arguments; orchestration entry extracts tensors and scalars, packages them into an Arg payload, and submits the AIV task.
Python Compilation and Helper Configuration
examples/workers/l3/allreduce_ring_distributed/main.py (lines 1–134)
Top-level constants define ALLREDUCE_COUNT and sizing calculations; scratch_float_elems() and parse_device_range() validate inputs with divisibility and rank-count bounds; build_chip_callable() compiles kernel and orchestration binaries and returns a ChipCallable hierarchy; expected_output() generates golden output.
Runtime Execution and Golden Validation
examples/workers/l3/allreduce_ring_distributed/main.py (lines 136–246)
run() allocates per-rank shared-memory tensors, initializes the worker, defines an orchestration function that allocates ring domains and scratch buffers, wires tensor/scalar arguments, submits the execution DAG, validates each rank's result against golden output (1e-3 tolerance), and ensures worker cleanup. main() provides CLI parsing for platform, device range, and optional PTO ISA commit.
Integration Tests
examples/workers/l3/allreduce_ring_distributed/test_allreduce.py
Two parameterized pytest tests: test_ring_allreduce_distributed() for 2-device and test_ring_allreduce_distributed_multi_rank() for 4-device configurations, both asserting run() returns exit code 0.

Sequence Diagram

sequenceDiagram
  participant Python as Python Runtime
  participant Worker
  participant Orch as Orchestration Layer
  participant AIVKernel as AIV Kernel
  participant Rank0
  participant Rank1
  Python->>Worker: Initialize worker
  Python->>Python: Allocate per-rank input/output tensors
  Python->>Python: Allocate ring domain window and scratch buffer
  Python->>Orch: Submit orchestration DAG with tensor/scalar args
  Orch->>AIVKernel: rt_submit_aiv_task (3 tensors + 2 scalars)
  AIVKernel->>Rank0: Validate nranks, bind scratch layout
  Rank0->>Rank0: Stage input into chunk slots
  par Reduce-Scatter Phase
    Rank0->>Rank1: Publish chunk, barrier signal
    Rank1-->>Rank0: Send left-neighbor chunk
    Rank0->>Rank0: Load/accumulate tile with MTE flags
  end
  par Allgather Phase
    Rank0->>Rank1: Publish reduced chunk for dissemination
    Rank1-->>Rank0: Send chunk from previous round
    Rank0->>Rank0: Store chunk to output slot
  end
  AIVKernel->>AIVKernel: Stage concatenated chunks to output tensor
  AIVKernel->>Worker: Return
  Python->>Python: Compute golden expected output
  Python->>Python: Validate each rank output vs golden (1e-3 tolerance)
  Python->>Worker: Close worker
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

In rings we gather, chunk by chunk,
Each rank reduces with a hunch,
Scatter down, then gather round,
AllReduce magic, distributed sound! 🐰✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.29% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main change: adding a new L3 ring allreduce example implementation with chunked reduce-scatter and allgather, including platform verification. It accurately reflects the primary purpose of the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description clearly describes the addition of a new ring allreduce example implementation distinct from the mesh version, including algorithm details, window size comparisons, and test verification.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a distributed ring AllReduce implementation, featuring chunked reduce-scatter and allgather algorithms. The code review feedback highlights several critical optimization and correctness improvements. First, the exchange buffer is completely unused by remote ranks and should be removed along with its redundant memory copies across the kernel, helper functions, and Python host code to improve performance and reduce scratch memory usage. Second, the kernel must explicitly zero-initialize the local signal slots to prevent undefined behavior, as device memory is not guaranteed to be zero-initialized. Finally, the unnecessary from __future__ import annotations import in main.py should be removed.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

__gm__ int32_t *signal_base = nullptr;
RingBindScratch(scratch, nranks, chunk_elems, chunks, exchange, signal_base);

// Signal rows rely on fresh-window zero init (per-round rows, used once).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Device memory (such as the HCCL window scratch space) is not guaranteed to be zero-initialized. Relying on 'fresh-window zero init' can lead to undefined behavior or state leakage between runs. Explicitly zero-initialize the local signal slots at the start of the kernel to ensure correctness.

    // Explicitly zero-initialize the local signal slots to prevent garbage values
    // from causing undefined behavior, as device memory is not guaranteed to be zero-initialized.
    const int total_rounds = 2 * (nranks - 1);
    for (int r = 0; r < total_rounds; ++r) {
        for (int peer = 0; peer < nranks; ++peer) {
            signal_base[r * kMaxSupportedRanks + peer] = 0;
        }
    }
    pipe_barrier(PIPE_ALL);
References
  1. Do not assume that allocated shared memory or device memory is zero-initialized. Always explicitly initialize all fields (such as thread/core counts and mapping arrays) to prevent garbage values from causing segmentation faults or undefined behavior.

Comment on lines +49 to +52
__gm__ float *chunks = nullptr;
__gm__ float *exchange = nullptr;
__gm__ int32_t *signal_base = nullptr;
RingBindScratch(scratch, nranks, chunk_elems, chunks, exchange, signal_base);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The exchange buffer is completely unused by any remote rank, and copying chunks to it is redundant. Removing the exchange buffer and its associated copies will save 2 * (P - 1) redundant global memory copies (TLOAD + TSTORE) per rank, significantly improving performance and reducing scratch memory usage.

    __gm__ float *chunks = nullptr;
    __gm__ int32_t *signal_base = nullptr;
    RingBindScratch(scratch, nranks, chunk_elems, chunks, signal_base);

Comment on lines +80 to +81
RingCopyChunkGm(exchange, chunks + static_cast<size_t>(send_idx * chunk_elems), chunk_elems, chunkTile);
pipe_barrier(PIPE_ALL);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Remove the redundant copy to the unused exchange buffer to optimize performance on the hot path.

Comment on lines +108 to +109
RingCopyChunkGm(exchange, chunks + static_cast<size_t>(send_idx * chunk_elems), chunk_elems, chunkTile);
pipe_barrier(PIPE_ALL);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Remove the redundant copy to the unused exchange buffer to optimize performance on the hot path.

Comment on lines +99 to +106
AICORE inline void RingBindScratch(
__gm__ float *scratch, int nranks, int chunk_elems, __gm__ float *&chunks, __gm__ float *&exchange,
__gm__ int32_t *&signal_base
) {
chunks = scratch;
exchange = scratch + static_cast<size_t>(nranks * chunk_elems);
signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + static_cast<size_t>((nranks + 1) * chunk_elems));
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Simplify RingBindScratch by removing the unused exchange buffer parameter and calculation.

AICORE inline void RingBindScratch(
    __gm__ float *scratch, int nranks, int chunk_elems, __gm__ float *&chunks,
    __gm__ int32_t *&signal_base
) {
    chunks = scratch;
    signal_base = reinterpret_cast<__gm__ int32_t *>(scratch + static_cast<size_t>(nranks * chunk_elems));
}

Comment on lines +67 to +69
SCRATCH_FLOAT_ELEMS = (K_MAX_SUPPORTED_RANKS + 1) * CHUNK_MAX
SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS
SCRATCH_NBYTES = SCRATCH_FLOAT_ELEMS * 4 + SIGNAL_SLOTS * 4
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since the unused exchange buffer has been removed, the scratch float elements can be simplified to just ALLREDUCE_COUNT floats, reducing memory usage.

Suggested change
SCRATCH_FLOAT_ELEMS = (K_MAX_SUPPORTED_RANKS + 1) * CHUNK_MAX
SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS
SCRATCH_NBYTES = SCRATCH_FLOAT_ELEMS * 4 + SIGNAL_SLOTS * 4
SCRATCH_FLOAT_ELEMS = ALLREDUCE_COUNT
SIGNAL_SLOTS = 2 * (K_MAX_SUPPORTED_RANKS - 1) * K_MAX_SUPPORTED_RANKS
SCRATCH_NBYTES = SCRATCH_FLOAT_ELEMS * 4 + SIGNAL_SLOTS * 4

Comment on lines +72 to +77
def scratch_float_elems(nranks: int) -> int:
"""Float slots in the HCCL window for this rank count: P chunk slots + 1 exchange."""
if ALLREDUCE_COUNT % nranks != 0:
raise ValueError(f"ALLREDUCE_COUNT={ALLREDUCE_COUNT} must divide nranks={nranks}")
chunk = ALLREDUCE_COUNT // nranks
return (nranks + 1) * chunk
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Simplify scratch_float_elems to return ALLREDUCE_COUNT directly since the exchange buffer is removed.

Suggested change
def scratch_float_elems(nranks: int) -> int:
"""Float slots in the HCCL window for this rank count: P chunk slots + 1 exchange."""
if ALLREDUCE_COUNT % nranks != 0:
raise ValueError(f"ALLREDUCE_COUNT={ALLREDUCE_COUNT} must divide nranks={nranks}")
chunk = ALLREDUCE_COUNT // nranks
return (nranks + 1) * chunk
def scratch_float_elems(nranks: int) -> int:
"""Float slots in the HCCL window for this rank count: P chunk slots."""
if ALLREDUCE_COUNT % nranks != 0:
raise ValueError(f"ALLREDUCE_COUNT={ALLREDUCE_COUNT} must divide nranks={nranks}")
return ALLREDUCE_COUNT


"""

from __future__ import annotations
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For Python projects targeting Python 3.9 or higher, PEP 585 generic collections (like list[str]) are fully supported at runtime, so adding from __future__ import annotations is unnecessary.

References
  1. For Python projects targeting Python 3.9 or higher, PEP 585 generic collections (like list[str]) are fully supported at runtime, so adding from future import annotations is unnecessary to prevent runtime errors during module load.

New L3 example separate from mesh allreduce_distributed: stage-in,
(P-1) reduce-scatter and (P-1) allgather ring rounds over HCCL window
chunks with per-round TNOTIFY/TWAIT barriers. Same golden as mesh.
P=2/P=4 pytest; default CLI devices 0-3.
@georgebisbas georgebisbas force-pushed the feat/l3-ring-allreduce-skeleton branch from 75c3152 to 497ae58 Compare June 2, 2026 14:41
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/workers/l3/allreduce_ring_distributed/main.py`:
- Around line 136-145: Validate the device_ids input at the top of run(): check
that device_ids is non-empty and that nranks = len(device_ids) is within the
supported range (e.g., between 2 and 16 as the example expects); if not, raise a
ValueError with a clear message so downstream calls (like scratch_float_elems)
don't hit ZeroDivisionError or unsupported configurations. Add this check
directly in run() before calling scratch_float_elems() or computing window_size,
referencing run() and scratch_float_elems() in the message so the caller can see
which entrypoint enforces the constraint.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: a173f9be-0a2b-4f97-bf62-6560ebde9a86

📥 Commits

Reviewing files that changed from the base of the PR and between d61dee4 and 75c3152.

📒 Files selected for processing (7)
  • examples/workers/l3/README.md
  • examples/workers/l3/allreduce_ring_distributed/__init__.py
  • examples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_common.hpp
  • examples/workers/l3/allreduce_ring_distributed/kernels/aiv/allreduce_ring_kernel.cpp
  • examples/workers/l3/allreduce_ring_distributed/kernels/orchestration/allreduce_ring_orch.cpp
  • examples/workers/l3/allreduce_ring_distributed/main.py
  • examples/workers/l3/allreduce_ring_distributed/test_allreduce.py

Comment on lines +136 to +145
def run(
device_ids: list[int],
platform: str = "a2a3",
pto_isa_commit: str | None = None,
) -> int:
"""Core logic — callable from both CLI and pytest."""
nranks = len(device_ids)
float_elems = scratch_float_elems(nranks)
window_size = max(SCRATCH_NBYTES, 4 * 1024)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate device_ids inside run(), not just in the CLI parser.

run() is the public entrypoint used by pytest as well as the CLI, so invalid rank counts bypass parse_device_range(). device_ids=[] hits a ZeroDivisionError in scratch_float_elems(), and 1 or >16 devices fall into configs the example explicitly does not support.

Suggested fix
 def run(
     device_ids: list[int],
     platform: str = "a2a3",
     pto_isa_commit: str | None = None,
 ) -> int:
     """Core logic — callable from both CLI and pytest."""
     nranks = len(device_ids)
+    if not (2 <= nranks <= K_MAX_SUPPORTED_RANKS):
+        raise ValueError(
+            f"allreduce_ring_distributed needs between 2 and {K_MAX_SUPPORTED_RANKS} devices, "
+            f"got {nranks} ({device_ids})"
+        )
     float_elems = scratch_float_elems(nranks)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/workers/l3/allreduce_ring_distributed/main.py` around lines 136 -
145, Validate the device_ids input at the top of run(): check that device_ids is
non-empty and that nranks = len(device_ids) is within the supported range (e.g.,
between 2 and 16 as the example expects); if not, raise a ValueError with a
clear message so downstream calls (like scratch_float_elems) don't hit
ZeroDivisionError or unsupported configurations. Add this check directly in
run() before calling scratch_float_elems() or computing window_size, referencing
run() and scratch_float_elems() in the message so the caller can see which
entrypoint enforces the constraint.

Drop RingZeroSignals (per-round barrier rows used once; zeroing raced
peer notify and caused AICPU 507018 timeout). Recv via left neighbour
chunks[] after barrier, not local exchange mirror (max golden diff 99
on second chunk). Size scratch CommBufferSpec to (P+1)*chunk elements.

Align ring example with mesh L3 style: single allreduce_ring_kernel.cpp
(no common header), phase banners, and matching orch/main.py comments.
@georgebisbas georgebisbas force-pushed the feat/l3-ring-allreduce-skeleton branch from 497ae58 to 690efbc Compare June 2, 2026 15:12
Mirror parse_device_range() so pytest/CLI callers cannot pass an empty
list or unsupported rank count into scratch_float_elems().
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant