Skip to content

[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776

Open
XiaomingFun233 wants to merge 8 commits into
NVIDIA:mainfrom
XiaomingFun233:pr/fused-router-topk-opt
Open

[fused_router][pytorch] Optimize naive topk path and add perf benchmark#2776
XiaomingFun233 wants to merge 8 commits into
NVIDIA:mainfrom
XiaomingFun233:pr/fused-router-topk-opt

Conversation

@XiaomingFun233
Copy link
Copy Markdown

Summary

This PR ports and keeps a focused set of CUDA fused-router optimizations that showed consistent gains on the tested workload, while avoiding heavier variants that regressed performance.

1. Add fused-router performance benchmark test

  • Add tests/pytorch/test_fused_router_perf.py.
  • Benchmark coverage:
    • fused_topk_with_score_function
    • fused_compute_score_for_moe_aux_loss
    • fused_moe_aux_loss

2. Keep low-risk fused-router CUDA optimizations

  • transformer_engine/common/fused_router/utils.h
    • Add warp-level sum helper used in backward normalization path.
  • transformer_engine/common/fused_router/fused_topk_with_score_function.cu
    • Use warp-level sum reduction in backward normalization.
    • Add safe expert_bias.has_data() handling in forward to avoid invalid dtype switch when bias is absent.
  • transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
    • Use warp-level sum reduction in backward normalization.

3. Optimize naive_topk_and_mask for small-k

  • transformer_engine/common/fused_router/utils.h
    • Add lightweight specialization for topk <= 8.
    • Keep generic fallback for compatibility.

Performance (A/B)

Measured with:

  • TE_RUN_PERF_TESTS=1 pytest -q tests/pytorch/test_fused_router_perf.py -s

Before

  • topk_router[softmax]: fused 0.029562 ms, speedup 8.3067x
  • topk_router[sigmoid]: fused 0.030138 ms, speedup 7.2715x
  • scores_for_aux_loss[softmax]: fused 0.026183 ms, speedup 3.8721x
  • scores_for_aux_loss[sigmoid]: fused 0.025872 ms, speedup 3.8892x
  • moe_aux_loss: fused 0.015680 ms, speedup 1.8884x

After

  • topk_router[softmax]: fused 0.022384 ms, speedup 11.1324x
  • topk_router[sigmoid]: fused 0.022840 ms, speedup 9.7714x
  • scores_for_aux_loss[softmax]: fused 0.017230 ms, speedup 5.9707x
  • scores_for_aux_loss[sigmoid]: fused 0.017049 ms, speedup 6.0205x
  • moe_aux_loss: fused 0.015412 ms, speedup 1.8424x

Notes

  • This PR intentionally avoids the larger full-port variant that previously regressed topk_router/scores_for_aux_loss performance on this setup.

@XiaomingFun233
Copy link
Copy Markdown
Author

Test on H200 ,CUDA version 13.0

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 18, 2026

Greptile Summary

This PR adds a performance benchmark test suite for fused-router kernels and applies three focused CUDA optimizations: a warp_reduce_sum_float helper replacing XOR-butterfly reductions in two backward kernels, a safe expert_bias.has_data() guard in the forward dispatch to avoid accessing uninitialized dtype when no bias is present, and a template-specialized naive_topk_and_mask_smallk<K> path for topk ≤ 8 that enables compile-time loop unrolling of the masking inner loop.

  • Benchmark test (test_fused_router_perf.py): Covers all three fused-router entry points with warmup, timing, and numerical correctness checks; gated behind TE_RUN_PERF_TESTS=1 so it is skipped in normal CI.
  • expert_bias null safety (fused_topk_with_score_function.cu): Wraps the inner BiasType switch in a has_data() check, passing nullptr to the kernel template when no bias tensor is provided — fixing a latent crash when the bias path was absent.
  • Small-K topk specialization (utils.h): Adds naive_topk_and_mask_smallk<K> with compile-time-unrolled masking using a register selected[] array, dispatched via a switch(topk) in naive_topk_and_mask; the generic fallback is preserved for topk > 8.

Confidence Score: 5/5

Safe to merge; all three CUDA changes are logically correct and the benchmark gating ensures it does not affect normal CI.

The expert_bias.has_data() guard is a genuine correctness fix for a previously latent crash path. The warp_reduce_sum_float refactor preserves the all-lanes broadcast semantics of the original XOR butterfly, and the naive_topk_and_mask_smallk dispatch produces identical outputs to the generic path. No shared-memory visibility, race, or masking correctness issues were found.

No files require special attention; minor clarity gaps in utils.h were already flagged in earlier review threads.

Important Files Changed

Filename Overview
transformer_engine/common/fused_router/utils.h Adds warp_reduce_sum_float helper and a new naive_topk_and_mask_smallk template specialization dispatched from the wrapper naive_topk_and_mask; logic is correct but has minor clarity gaps already noted in previous review threads.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Forward dispatch now correctly guards the inner BiasType dtype switch behind expert_bias.has_data(), and backward normalization adopts the new warp_reduce_sum_float helper - both changes are correct and safe.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Backward normalization reduction replaced with warp_reduce_sum_float; the broadcast-to-all-lanes semantics are preserved and all downstream uses of sum_Output_x_Grad remain correct.
tests/pytorch/test_fused_router_perf.py New benchmark suite with correct warmup/timing pattern using CUDA events; gated behind an environment variable. Numerical correctness is verified against PyTorch reference implementations.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["naive_topk_and_mask()"] --> B{topk}
    B -->|1 to 8| C["naive_topk_and_mask_smallk K\nCompile-time unrolled k-loop\nRegister-based selected array\nXOR butterfly max reduction"]
    B -->|greater than 8| D["naive_topk_and_mask_generic\nRuntime loop over k\nShared-memory topk_indices mask"]
    E["fused_topk_with_score_function_forward"] --> F{"expert_bias.has_data()"}
    F -->|Yes| G["Type-switch on bias dtype\nLaunch kernel with BiasType ptr"]
    F -->|No| H["Launch kernel with bias=nullptr"]
    I["Backward normalization"] --> J["warp_reduce_sum_float\n__shfl_down_sync to lane 0\n__shfl_sync broadcast to all lanes"]
Loading

Reviews (6): Last reviewed commit: "Merge branch 'main' into pr/fused-router..." | Re-trigger Greptile

Comment thread tests/pytorch/test_fused_router_perf.py
Comment on lines +41 to +46
__device__ inline float warp_reduce_sum_float(float val) {
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Broadcast from lane 0 should be documented

__shfl_down_sync accumulates the sum only in lane 0 after all steps — unlike the __shfl_xor_sync butterfly approach (used elsewhere in this file) which gives the correct sum to every lane simultaneously. The subsequent __shfl_sync(…, 0) is therefore load-bearing for correctness, not just an optimisation.

Adding a short comment here prevents a future reader from accidentally removing it thinking it's redundant:

Suggested change
__device__ inline float warp_reduce_sum_float(float val) {
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}
__device__ inline float warp_reduce_sum_float(float val) {
// __shfl_down_sync accumulates the total only in lane 0;
// the broadcast below is required for all lanes to see the result.
for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
return __shfl_sync(0xffffffff, val, 0);
}

Comment on lines +224 to +248
#pragma unroll
for (int k = 0; k < K; ++k) {
CompType val = -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
bool masked = false;
#pragma unroll
for (int j = 0; j < k; ++j) {
masked |= (selected[j] == i);
}
if (masked) continue;
CompType cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 OOB thread index initialised to 0 could shadow earlier selections

For threads where lane_id >= data_size, the inner loop body never executes, so val stays at -inf and index is set to 0:

int index = (lane_id < data_size) ? lane_id : 0;

0 is a valid data index that may already have been placed in selected by a previous k iteration. During the XOR-reduction phase, these OOB threads participate (they shuffle -inf values, which can never win the shuffled_val > val comparison), so the final chosen_index remains correct. However, after the broadcast:

selected[k] = chosen_index;

every thread — including OOB ones — writes chosen_index to their register copy of selected, keeping all threads in sync. The algorithm is therefore correct, but initialising the fallback index to a sentinel value (e.g., -1 or data_size - 1) would make the intent clearer and avoid confusion with a real element:

Suggested change
#pragma unroll
for (int k = 0; k < K; ++k) {
CompType val = -std::numeric_limits<CompType>::infinity();
int index = (lane_id < data_size) ? lane_id : 0;
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
bool masked = false;
#pragma unroll
for (int j = 0; j < k; ++j) {
masked |= (selected[j] == i);
}
if (masked) continue;
CompType cur_val = scores[i];
if (cur_val > val) {
val = cur_val;
index = i;
}
}
for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) {
auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s);
auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s);
if (shuffled_val > val) {
val = shuffled_val;
index = shuffled_index;
}
}
int index = (lane_id < data_size) ? lane_id : -1; // -1: sentinel for out-of-range lane

This is purely a readability / defensive-programming concern given that invalid index values can never propagate to the output.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +17 to +21

seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Module-level seed is fragile for tokens_per_expert randomness

The random seed is set once at import time. If pytest collects or runs other tests before this module's tests execute, the global random state will have advanced and test_fused_moe_aux_loss_perf_against_torch (which calls torch.randint for tokens_per_expert) will use an unknown seed. While the numerical correctness check in that test (torch.testing.assert_close(torch_loss, fused_loss)) passes regardless of the specific random values, reproducible benchmarks are easier to debug.

Consider moving the seed setup into each individual test function, or using a pytest fixture to ensure a consistent state per test.

@denera denera self-requested a review April 14, 2026 19:41
Copy link
Copy Markdown
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XiaomingFun233 I went through the Greptile comments and I agree with the suggested changes except for the speedup assertion — we should keep the benchmark test in the suite for manual verification but our CI pipeline is for functional testing, not benchmarking, so the assertion does not make sense. Could you address the remaining issues and rebase the branch on latest TE/main? We can launch the CI on our end for testing.

Please also check out the contributing guidelines, particularly regarding the sign-off for your commits and the license information that needs to be added to the source files.

Thanks!

@XiaomingFun233
Copy link
Copy Markdown
Author

@XiaomingFun233 I went through the Greptile comments and I agree with the suggested changes except for the speedup assertion — we should keep the benchmark test in the suite for manual verification but our CI pipeline is for functional testing, not benchmarking, so the assertion does not make sense. Could you address the remaining issues and rebase the branch on latest TE/main? We can launch the CI on our end for testing.

Please also check out the contributing guidelines, particularly regarding the sign-off for your commits and the license information that needs to be added to the source files.

Thanks!

ok I will complete this work

- restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Add a lightweight register-based small-k path and keep the generic fallback for compatibility.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels.

Signed-off-by: Xinhao Wei <xiaomingchinafun@outlook.com>
@XiaomingFun233 XiaomingFun233 force-pushed the pr/fused-router-topk-opt branch from 066b9ea to 3ad3ad2 Compare May 15, 2026 06:17
@XiaomingFun233
Copy link
Copy Markdown
Author

Addressed the remaining issues from the review.

  • Kept the benchmark test for manual verification and removed perf-style assertions.
  • Rebased the branch onto the latest TE/main.
  • Added sign-off to the commits.

Please let me know if there is anything else you would like me to adjust

@XiaomingFun233
Copy link
Copy Markdown
Author

please checkout this new change @denera

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
@XiaomingFun233
Copy link
Copy Markdown
Author

XiaomingFun233 commented May 22, 2026

@hartsock @tabo please review all these new changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants