Skip to content

[PyTorch] Make modules.GroupedLinear graph-safe#3038

Open
yaox12 wants to merge 2 commits into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt
Open

[PyTorch] Make modules.GroupedLinear graph-safe#3038
yaox12 wants to merge 2 commits into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented May 22, 2026

Description

  • Enable grouped quantization and cuBLASLt grouped gemm for modules.GroupedLinear to benefit cases where cuteDSL fused grouped gemm is not available.

    1. Reduce CPU overhead by reducing number of kernels.
    2. Be CUDA-Graph-safe.
    3. Improve kernel performance.
  • Move grouped gemm and grouped linear related tests to a standalone file.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR introduces a new GroupedTensor-backed code path in modules.GroupedLinear that uses cuBLASLt grouped GEMM on SM100+ hardware, making the forward and backward graph-safe when m_splits is provided as a CUDA tensor. The new path is off by default (NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=0) and activates when MXFP8 or BF16/FP16 conditions are met on Blackwell GPUs.

  • Adds _forward_grouped_tensor / _backward_grouped_tensor implementing a cuBLASLt grouped GEMM path that avoids per-GEMM Python dispatch loops, with a _is_grouped_tensor_path_supported gate that filters out debug, cpu-offloading, save_original_input, and non-MXFP8 FP8 recipes.
  • Fixes a bias-gradient crash in backward_dw when single_grouped_bias=True + delay_wgrad_compute=True via a new has_grad_biases guard, and moves grouped-linear tests into a standalone file with dedicated graph-capture and path-comparison tests.

Confidence Score: 5/5

Safe to merge. The new grouped-tensor GEMM path is off by default, well-gated behind NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM and the SM100+ capability check, and the existing non-grouped path is unchanged.

The bias-gradient crash in backward_dw (previously flagged) is correctly fixed by the has_grad_biases guard, and a dedicated regression test is added. The _forward_grouped_tensor / _backward_grouped_tensor logic correctly handles dgrad, wgrad, deferred wgrad, and bias across fp8 and non-fp8 paths. The only gaps are a redundant ctx.m_splits reference and missing fuse_wgrad_accumulation=True coverage in the new grouped-tensor comparison tests.

transformer_engine/pytorch/module/grouped_linear.py and tests/pytorch/test_grouped_linear.py — specifically the fuse_wgrad_accumulation=True branch in _backward_grouped_tensor and the unused ctx.m_splits assignment.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Core change: adds ~300 lines of new grouped-tensor forward/backward path. Correct for the common cases tested; the fuse_wgrad_accumulation=True branch in _backward_grouped_tensor lacks end-to-end test coverage in the new tests. ctx.m_splits is set but never read in the grouped-tensor backward.
tests/pytorch/test_grouped_linear.py New standalone test file with graph-capture, path-comparison, and crash-regression tests. The grouped-tensor path comparison test is not parameterized over fuse_wgrad_accumulation, leaving that backward branch uncovered.
benchmarks/linear/benchmark_grouped_linear.py Converts m_splits to a CUDA tensor when NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 is set, matching the module's own default. Minimal, correct change.
tests/pytorch/test_numerics.py Grouped-linear tests moved to the standalone file. Net deletions only; no new logic added.
qa/L0_pytorch_unittest/test.sh Adds the new test file to the test runner. Single-line, correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["GroupedLinear.forward(inp, m_splits)"] --> B{"_is_grouped_tensor_path_supported?"}
    B -- "Yes (SM100+, MXFP8/BF16, env=1)" --> C["_forward_grouped_tensor"]
    B -- "No (default)" --> D["existing loop-based forward"]
    C --> C1["torch.as_tensor(m_splits) → split_sizes"]
    C1 --> C2["tex.splits_to_offsets → base_split_offsets"]
    C2 --> C3{"fp8?"}
    C3 -- Yes --> C4["tex.group_quantize(x) → grouped_x"]
    C3 -- No --> C5["_make_grouped_tensor(x) → grouped_x"]
    C4 & C5 --> C6["_prepare_weights_for_grouped_tensor_gemm"]
    C6 --> C7["general_grouped_gemm_for_grouped_tensor (TN)"]
    C7 --> C8["save_for_backward: grouped_x, weights, split_sizes, offsets"]
    D --> D1["existing save_for_backward path"]
    C8 & D1 --> E{"ctx.use_grouped_tensor_path?"}
    E -- True --> F["_backward_grouped_tensor"]
    E -- False --> G["existing grouped backward"]
    F --> F1["cast grad_output → dy_2d"]
    F1 --> F2{"fp8?"}
    F2 -- Yes --> F3["tex.bgrad_group_quantize → grouped_dy + dbias"]
    F2 -- No --> F4["_make_grouped_tensor + compute_grouped_dbias"]
    F3 & F4 --> F5{"requires_dgrad?"}
    F5 -- Yes --> F6["general_grouped_gemm_for_grouped_tensor (NN, dgrad)"]
    F5 -- No --> F7{"weights_requires_grad?"}
    F6 --> F7
    F7 -- Yes --> F8{"delay_wgrad_compute?"}
    F8 -- No --> F9["general_grouped_gemm_for_grouped_tensor (NT, immediate)"]
    F8 -- Yes --> F10["wgrad_store.put → deferred to backward_dw()"]
    F7 -- No --> F11["wgrad_list = None x N"]
    F9 & F10 & F11 --> F12["return dgrad, wgrad_list, grad_biases"]
Loading

Reviews (5): Last reviewed commit: "fix tests" | Re-trigger Greptile

Comment thread tests/pytorch/test_grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py
Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread benchmarks/linear/benchmark_grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/enable-grouped-quantize-cublaslt branch from d176247 to 698383e Compare May 25, 2026 03:56
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 25, 2026

/te-ci pytorch

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 26, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant