[PyTorch] Make modules.GroupedLinear graph-safe#3038
Conversation
Greptile SummaryThis PR introduces a new
Confidence Score: 5/5Safe to merge. The new grouped-tensor GEMM path is off by default, well-gated behind The bias-gradient crash in transformer_engine/pytorch/module/grouped_linear.py and tests/pytorch/test_grouped_linear.py — specifically the Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["GroupedLinear.forward(inp, m_splits)"] --> B{"_is_grouped_tensor_path_supported?"}
B -- "Yes (SM100+, MXFP8/BF16, env=1)" --> C["_forward_grouped_tensor"]
B -- "No (default)" --> D["existing loop-based forward"]
C --> C1["torch.as_tensor(m_splits) → split_sizes"]
C1 --> C2["tex.splits_to_offsets → base_split_offsets"]
C2 --> C3{"fp8?"}
C3 -- Yes --> C4["tex.group_quantize(x) → grouped_x"]
C3 -- No --> C5["_make_grouped_tensor(x) → grouped_x"]
C4 & C5 --> C6["_prepare_weights_for_grouped_tensor_gemm"]
C6 --> C7["general_grouped_gemm_for_grouped_tensor (TN)"]
C7 --> C8["save_for_backward: grouped_x, weights, split_sizes, offsets"]
D --> D1["existing save_for_backward path"]
C8 & D1 --> E{"ctx.use_grouped_tensor_path?"}
E -- True --> F["_backward_grouped_tensor"]
E -- False --> G["existing grouped backward"]
F --> F1["cast grad_output → dy_2d"]
F1 --> F2{"fp8?"}
F2 -- Yes --> F3["tex.bgrad_group_quantize → grouped_dy + dbias"]
F2 -- No --> F4["_make_grouped_tensor + compute_grouped_dbias"]
F3 & F4 --> F5{"requires_dgrad?"}
F5 -- Yes --> F6["general_grouped_gemm_for_grouped_tensor (NN, dgrad)"]
F5 -- No --> F7{"weights_requires_grad?"}
F6 --> F7
F7 -- Yes --> F8{"delay_wgrad_compute?"}
F8 -- No --> F9["general_grouped_gemm_for_grouped_tensor (NT, immediate)"]
F8 -- Yes --> F10["wgrad_store.put → deferred to backward_dw()"]
F7 -- No --> F11["wgrad_list = None x N"]
F9 & F10 & F11 --> F12["return dgrad, wgrad_list, grad_biases"]
Reviews (5): Last reviewed commit: "fix tests" | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
d176247 to
698383e
Compare
|
/te-ci pytorch |
|
/te-ci pytorch |
Description
Enable grouped quantization and cuBLASLt grouped gemm for
modules.GroupedLinearto benefit cases where cuteDSL fused grouped gemm is not available.Move grouped gemm and grouped linear related tests to a standalone file.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: