From 879bb78e3a41b9750708605919514a8150e8b911 Mon Sep 17 00:00:00 2001 From: Xinhao Wei Date: Tue, 17 Mar 2026 09:30:35 +0000 Subject: [PATCH 1/5] fused_router: keep low-risk CUDA optimizations - restore forward hot paths to baseline behavior for topk/scores kernels\n- keep warp-level reduction helper for backward normalization\n- handle empty expert_bias safely in fused topk forward Signed-off-by: Xinhao Wei --- .../fused_score_for_moe_aux_loss.cu | 6 +--- .../fused_topk_with_score_function.cu | 33 +++++++++++-------- .../common/fused_router/utils.h | 7 ++++ 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 4eb4240d7c..675f071aba 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -270,11 +270,7 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_sum_Output_x_Grad += local_grad[i] * act_output[i]; } - // Warp reduce the sum - for (int s = 16; s > 0; s /= 2) { - local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); - } - CompType sum_Output_x_Grad = local_sum_Output_x_Grad; + CompType sum_Output_x_Grad = warp_reduce_sum_float(local_sum_Output_x_Grad); // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_grad[i] = local_grad[i] / (sum_fwd_input + epsilon) - diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 9f7a830546..9b8d8b9299 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -284,15 +284,24 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, Tensor intermediate_output, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( logits.data.dtype, DataType, - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - expert_bias.data.dtype, BiasType, - fused_topk_with_score_function_forward_kernel_launcher( - reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, - use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, - reinterpret_cast(expert_bias.data.dptr), - reinterpret_cast(probs.data.dptr), - reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), stream););); + if (expert_bias.has_data()) { + TE_ROUTER_PROBS_TYPE_SWITCH_ALL( + expert_bias.data.dtype, BiasType, + fused_topk_with_score_function_forward_kernel_launcher( + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, + reinterpret_cast(expert_bias.data.dptr), + reinterpret_cast(probs.data.dptr), + reinterpret_cast(routing_map.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), stream);); + } else { + fused_topk_with_score_function_forward_kernel_launcher( + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, nullptr, + reinterpret_cast(probs.data.dptr), + reinterpret_cast(routing_map.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), stream); + }); } template @@ -399,11 +408,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( local_sum_Output_x_Grad += local_grad[i] * act_output[i]; } } - // Warp reduce the sum - for (int s = 16; s > 0; s /= 2) { - local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); - } - CompType sum_Output_x_Grad = local_sum_Output_x_Grad; + CompType sum_Output_x_Grad = warp_reduce_sum_float(local_sum_Output_x_Grad); // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 08ad3d16a6..48c33b8040 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -53,6 +53,13 @@ enum ReduceFuncType { MAX, }; +__device__ inline float warp_reduce_sum_float(float val) { + for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + return __shfl_sync(0xffffffff, val, 0); +} + template __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, int lane_id) { From 24cdc4192f12b41b932455ea1871ed85c223dfcd Mon Sep 17 00:00:00 2001 From: Xinhao Wei Date: Tue, 17 Mar 2026 11:17:38 +0000 Subject: [PATCH 2/5] fused_router: specialize naive_topk_and_mask for topk<=8 Add a lightweight register-based small-k path and keep the generic fallback for compatibility. Signed-off-by: Xinhao Wei --- .../common/fused_router/utils.h | 85 ++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 48c33b8040..948dbc7fa4 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -436,8 +436,56 @@ __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int __syncwarp(); } -__device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, - int *topk_indices, CompType *topk_scores, int lane_id) { +template +__device__ inline void naive_topk_and_mask_smallk(CompType *scores, int data_size, int *topk_indices, + CompType *topk_scores, int lane_id) { + static_assert(K > 0 && K <= 8, "K must be in [1, 8]"); + int selected[K]; +#pragma unroll + for (int i = 0; i < K; ++i) { + selected[i] = -1; + } + +#pragma unroll + for (int k = 0; k < K; ++k) { + CompType val = -std::numeric_limits::infinity(); + int index = (lane_id < data_size) ? lane_id : 0; + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + bool masked = false; +#pragma unroll + for (int j = 0; j < k; ++j) { + masked |= (selected[j] == i); + } + if (masked) continue; + CompType cur_val = scores[i]; + if (cur_val > val) { + val = cur_val; + index = i; + } + } + for (int s = kThreadsPerWarp / 2; s > 0; s /= 2) { + auto shuffled_val = __shfl_xor_sync(0xffffffff, val, s); + auto shuffled_index = __shfl_xor_sync(0xffffffff, index, s); + if (shuffled_val > val) { + val = shuffled_val; + index = shuffled_index; + } + } + + CompType chosen_val = __shfl_sync(0xffffffff, val, 0); + int chosen_index = __shfl_sync(0xffffffff, index, 0); + if (lane_id == 0) { + topk_indices[k] = chosen_index; + topk_scores[k] = chosen_val; + } + selected[k] = chosen_index; + __syncwarp(); + } +} + +__device__ inline void naive_topk_and_mask_generic(CompType *scores, int data_size, int topk, + int *topk_indices, CompType *topk_scores, + int lane_id) { // Check if the index is masked by the later iteration auto is_masked = [&topk_indices](int k, int index) { if (k == 0) return false; @@ -482,6 +530,39 @@ __device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int } } +__device__ inline void naive_topk_and_mask(CompType *scores, int data_size, int topk, + int *topk_indices, CompType *topk_scores, int lane_id) { + switch (topk) { + case 1: + naive_topk_and_mask_smallk<1>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 2: + naive_topk_and_mask_smallk<2>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 3: + naive_topk_and_mask_smallk<3>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 4: + naive_topk_and_mask_smallk<4>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 5: + naive_topk_and_mask_smallk<5>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 6: + naive_topk_and_mask_smallk<6>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 7: + naive_topk_and_mask_smallk<7>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + case 8: + naive_topk_and_mask_smallk<8>(scores, data_size, topk_indices, topk_scores, lane_id); + break; + default: + naive_topk_and_mask_generic(scores, data_size, topk, topk_indices, topk_scores, lane_id); + break; + } +} + template __device__ __forceinline__ void topk_and_mask(CompType *scores, int data_size, int topk, int *topk_indices, CompType *topk_scores, From 3ad3ad216ed72d2dfc38d109247765353e292536 Mon Sep 17 00:00:00 2001 From: Xinhao Wei Date: Wed, 18 Mar 2026 06:47:02 +0000 Subject: [PATCH 3/5] tests: add fused router performance benchmark Add CUDA perf benchmark for fused topk router, aux-loss score, and moe aux-loss kernels. Signed-off-by: Xinhao Wei --- tests/pytorch/test_fused_router_perf.py | 359 ++++++++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 tests/pytorch/test_fused_router_perf.py diff --git a/tests/pytorch/test_fused_router_perf.py b/tests/pytorch/test_fused_router_perf.py new file mode 100644 index 0000000000..e7c9bb7f2b --- /dev/null +++ b/tests/pytorch/test_fused_router_perf.py @@ -0,0 +1,359 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +from typing import Callable, Optional, Tuple + +import pytest +import torch + +from transformer_engine.pytorch.router import ( + fused_compute_score_for_moe_aux_loss, + fused_moe_aux_loss, + fused_topk_with_score_function, +) + + +seed = 42 +torch.manual_seed(seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available() or os.getenv("TE_RUN_PERF_TESTS", "0") != "1", + reason="Benchmark test - run with: TE_RUN_PERF_TESTS=1 pytest tests/pytorch/test_fused_router_perf.py", +) + + +def _benchmark_cuda_kernel(fn: Callable[[], object], warmup: int = 20, iters: int = 100) -> float: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start_event.record() + for _ in range(iters): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / iters + + +def group_limited_topk( + scores: torch.Tensor, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + group_scores = ( + scores.view(num_tokens, num_groups, -1).topk(topk // group_topk, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_tokens, num_groups, num_experts // num_groups) + .reshape(num_tokens, -1) + ) + masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) + probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1) + return probs, top_indices + + +def topk_softmax_sigmoid_pytorch( + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + score_function: str = "softmax", + expert_bias: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk_value, num_groups_value=None, group_topk_value=None): + if group_topk_value: + assert num_groups_value is not None + return group_limited_topk( + scores=scores, + topk=topk_value, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups_value, + group_topk=group_topk_value, + ) + return torch.topk(scores, k=topk_value, dim=1) + + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits.float()).type_as(logits) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) + topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + return topk_masked_gates, topk_map + + +def compute_scores_for_aux_loss_pytorch( + logits: torch.Tensor, topk: int, score_function: str +) -> Tuple[torch.Tensor, torch.Tensor]: + if score_function == "softmax": + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits) + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + _, top_indices = torch.topk(scores, k=topk, dim=1) + routing_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + return routing_map, scores + + +def aux_loss_pytorch( + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + topk: int, + num_experts: int, + moe_aux_loss_coeff: float, +) -> torch.Tensor: + aggregated_probs_per_expert = probs.sum(dim=0) + return torch.sum(aggregated_probs_per_expert * tokens_per_expert) * ( + num_experts * moe_aux_loss_coeff / (topk * total_num_tokens * total_num_tokens) + ) + + +def _make_router_logits( + dtype: torch.dtype, num_tokens: int, num_experts: int, score_function: str +) -> torch.Tensor: + if score_function == "sigmoid": + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + return logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + + logits = ( + torch.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + device="cuda", + dtype=dtype, + ) + * 1e-4 + ) + return logits.view(num_tokens, num_experts) + + +def _make_router_bias(num_experts: int) -> torch.Tensor: + bias = torch.arange(num_experts, device="cuda", dtype=torch.float32) * 0.1 + return torch.flip(bias, dims=[0]) + + +def _print_perf_result(case_name: str, torch_ms: float, fused_ms: float) -> None: + speedup = torch_ms / fused_ms + print( + f"{case_name}: torch={torch_ms:.6f} ms, fused={fused_ms:.6f} ms, speedup={speedup:.4f}x" + ) + + +@pytest.mark.parametrize( + "score_function,use_pre_softmax,enable_bias", + [("softmax", False, False), ("sigmoid", False, True)], + ids=["softmax", "sigmoid_with_bias"], +) +def test_fused_topk_router_perf_against_torch( + score_function, use_pre_softmax, enable_bias, record_property +): + dtype = torch.float32 + num_tokens = 4096 + num_experts = 192 + topk = 8 + num_groups = 8 + group_topk = 4 + scaling_factor = 1.2 + + logits = _make_router_logits(dtype, num_tokens, num_experts, score_function) + expert_bias = _make_router_bias(num_experts) if enable_bias else None + + torch_probs, torch_map = topk_softmax_sigmoid_pytorch( + logits=logits, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + fused_probs, fused_map = fused_topk_with_score_function( + logits=logits, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + + torch_ms = _benchmark_cuda_kernel( + lambda: topk_softmax_sigmoid_pytorch( + logits=logits, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + ) + fused_ms = _benchmark_cuda_kernel( + lambda: fused_topk_with_score_function( + logits=logits, + topk=topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + ) + ) + + record_property("torch_ms", round(torch_ms, 6)) + record_property("fused_ms", round(fused_ms, 6)) + record_property("speedup", round(torch_ms / fused_ms, 6)) + _print_perf_result(f"topk_router[{score_function}]", torch_ms, fused_ms) + + torch.testing.assert_close(torch_probs, fused_probs) + torch.testing.assert_close(torch_map, fused_map) + + +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) +def test_fused_scores_for_aux_loss_perf_against_torch(score_function, record_property): + dtype = torch.float32 + num_tokens = 8192 + num_experts = 128 + topk = 8 + logits = _make_router_logits(dtype, num_tokens, num_experts, score_function) + + torch_map, torch_scores = compute_scores_for_aux_loss_pytorch( + logits=logits, + topk=topk, + score_function=score_function, + ) + fused_map, fused_scores = fused_compute_score_for_moe_aux_loss( + logits=logits, + topk=topk, + score_function=score_function, + ) + + torch_ms = _benchmark_cuda_kernel( + lambda: compute_scores_for_aux_loss_pytorch( + logits=logits, + topk=topk, + score_function=score_function, + ) + ) + fused_ms = _benchmark_cuda_kernel( + lambda: fused_compute_score_for_moe_aux_loss( + logits=logits, + topk=topk, + score_function=score_function, + ) + ) + + record_property("torch_ms", round(torch_ms, 6)) + record_property("fused_ms", round(fused_ms, 6)) + record_property("speedup", round(torch_ms / fused_ms, 6)) + _print_perf_result(f"scores_for_aux_loss[{score_function}]", torch_ms, fused_ms) + + torch.testing.assert_close(torch_scores, fused_scores) + torch.testing.assert_close(torch_map, fused_map) + + +def test_fused_moe_aux_loss_perf_against_torch(record_property): + dtype = torch.float32 + num_tokens = 8192 + num_experts = 128 + topk = 4 + coeff = 0.01 + + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) + probs = probs.view(num_tokens, num_experts) + tokens_per_expert = torch.randint(1, 1000, (num_experts,), device="cuda", dtype=torch.int32) + + torch_loss = aux_loss_pytorch( + probs=probs, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + ) + fused_loss = fused_moe_aux_loss( + probs=probs, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + + torch_ms = _benchmark_cuda_kernel( + lambda: aux_loss_pytorch( + probs=probs, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + topk=topk, + num_experts=num_experts, + moe_aux_loss_coeff=coeff, + ) + ) + fused_ms = _benchmark_cuda_kernel( + lambda: fused_moe_aux_loss( + probs=probs, + tokens_per_expert=tokens_per_expert, + total_num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + coeff=coeff, + ) + ) + + record_property("torch_ms", round(torch_ms, 6)) + record_property("fused_ms", round(fused_ms, 6)) + record_property("speedup", round(torch_ms / fused_ms, 6)) + _print_perf_result("moe_aux_loss", torch_ms, fused_ms) + + torch.testing.assert_close(torch_loss, fused_loss) From f228dd41e8ae70a2c47b75423757c5ee0b610470 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 06:21:46 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fused_router_perf.py | 13 ++++++++----- transformer_engine/common/fused_router/utils.h | 5 +++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/pytorch/test_fused_router_perf.py b/tests/pytorch/test_fused_router_perf.py index e7c9bb7f2b..a853b87449 100644 --- a/tests/pytorch/test_fused_router_perf.py +++ b/tests/pytorch/test_fused_router_perf.py @@ -23,7 +23,10 @@ pytestmark = pytest.mark.skipif( not torch.cuda.is_available() or os.getenv("TE_RUN_PERF_TESTS", "0") != "1", - reason="Benchmark test - run with: TE_RUN_PERF_TESTS=1 pytest tests/pytorch/test_fused_router_perf.py", + reason=( + "Benchmark test - run with: TE_RUN_PERF_TESTS=1 pytest" + " tests/pytorch/test_fused_router_perf.py" + ), ) @@ -156,7 +159,9 @@ def _make_router_logits( ) -> torch.Tensor: if score_function == "sigmoid": offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 - logits = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + logits = ( + torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + ) return logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) logits = ( @@ -178,9 +183,7 @@ def _make_router_bias(num_experts: int) -> torch.Tensor: def _print_perf_result(case_name: str, torch_ms: float, fused_ms: float) -> None: speedup = torch_ms / fused_ms - print( - f"{case_name}: torch={torch_ms:.6f} ms, fused={fused_ms:.6f} ms, speedup={speedup:.4f}x" - ) + print(f"{case_name}: torch={torch_ms:.6f} ms, fused={fused_ms:.6f} ms, speedup={speedup:.4f}x") @pytest.mark.parametrize( diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 948dbc7fa4..270301d028 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -437,8 +437,9 @@ __device__ inline void radix_topk_and_mask(CompType *scores, int data_size, int } template -__device__ inline void naive_topk_and_mask_smallk(CompType *scores, int data_size, int *topk_indices, - CompType *topk_scores, int lane_id) { +__device__ inline void naive_topk_and_mask_smallk(CompType *scores, int data_size, + int *topk_indices, CompType *topk_scores, + int lane_id) { static_assert(K > 0 && K <= 8, "K must be in [1, 8]"); int selected[K]; #pragma unroll From 4ff7cc9cc294269d1027cc84ee605886a871752d Mon Sep 17 00:00:00 2001 From: Xinhao Wei Date: Thu, 28 May 2026 09:47:37 +0000 Subject: [PATCH 5/5] fused_router: address review feedback Signed-off-by: Xinhao Wei --- tests/pytorch/test_fused_router_perf.py | 17 +++++++++++++---- transformer_engine/common/fused_router/utils.h | 4 +++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_fused_router_perf.py b/tests/pytorch/test_fused_router_perf.py index a853b87449..122d19dd76 100644 --- a/tests/pytorch/test_fused_router_perf.py +++ b/tests/pytorch/test_fused_router_perf.py @@ -15,10 +15,13 @@ ) -seed = 42 -torch.manual_seed(seed) -if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) +SEED = 42 + + +def _set_seed() -> None: + torch.manual_seed(SEED) + if torch.cuda.is_available(): + torch.cuda.manual_seed(SEED) pytestmark = pytest.mark.skipif( @@ -194,6 +197,8 @@ def _print_perf_result(case_name: str, torch_ms: float, fused_ms: float) -> None def test_fused_topk_router_perf_against_torch( score_function, use_pre_softmax, enable_bias, record_property ): + _set_seed() + dtype = torch.float32 num_tokens = 4096 num_experts = 192 @@ -262,6 +267,8 @@ def test_fused_topk_router_perf_against_torch( @pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) def test_fused_scores_for_aux_loss_perf_against_torch(score_function, record_property): + _set_seed() + dtype = torch.float32 num_tokens = 8192 num_experts = 128 @@ -304,6 +311,8 @@ def test_fused_scores_for_aux_loss_perf_against_torch(score_function, record_pro def test_fused_moe_aux_loss_perf_against_torch(record_property): + _set_seed() + dtype = torch.float32 num_tokens = 8192 num_experts = 128 diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 270301d028..09087aff98 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -54,6 +54,8 @@ enum ReduceFuncType { }; __device__ inline float warp_reduce_sum_float(float val) { + // __shfl_down_sync accumulates the total only in lane 0; + // the broadcast below is required so every lane sees the final sum. for (int offset = kThreadsPerWarp / 2; offset > 0; offset /= 2) { val += __shfl_down_sync(0xffffffff, val, offset); } @@ -450,7 +452,7 @@ __device__ inline void naive_topk_and_mask_smallk(CompType *scores, int data_siz #pragma unroll for (int k = 0; k < K; ++k) { CompType val = -std::numeric_limits::infinity(); - int index = (lane_id < data_size) ? lane_id : 0; + int index = (lane_id < data_size) ? lane_id : -1; for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { bool masked = false; #pragma unroll