Improve TE Group MLP CPU Overhead #2991
Conversation
for more information, see https://pre-commit.ci
Signed-off-by: zhongboz <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Greptile SummaryThis PR reduces CPU overhead in the TE grouped MLP path by replacing four sequential Python/CUDA operations (type cast,
Confidence Score: 4/5Safe to merge after fixing the missing .contiguous() call in the CUDA input path of prepare_grouped_splits. The CUDA path in misc.cpp assigns the input tensor directly to split_sizes_for_kernel without calling .contiguous(). makeTransformerEngineTensor then passes tensor.data_ptr() to the kernel which reads elements at consecutive raw memory positions with no stride awareness. A non-contiguous CUDA tensor would silently produce wrong split sizes and corrupt the entire grouped GEMM dispatch. The sibling splits_to_offsets guards against this with an explicit .contiguous() call. transformer_engine/pytorch/csrc/extensions/misc.cpp - the CUDA input branch of prepare_grouped_splits needs a .contiguous() call before passing the tensor to makeTransformerEngineTensor. Important Files Changed
Reviews (3): Last reviewed commit: "fix" | Re-trigger Greptile |
| m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits, | ||
| "Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"), | ||
| py::arg("num_groups"), py::arg("logical_last_dim")); |
There was a problem hiding this comment.
prepare_grouped_splits is registered without py::call_guard<py::gil_scoped_release>(), unlike the analogous splits_to_offsets binding above it. The GIL is released only for the CUDA kernel launch inside NVTE_SCOPED_GIL_RELEASE, but the input validation and bulk_allocate call still hold the GIL. Given that this PR's purpose is to reduce CPU overhead, holding the GIL over the bulk_allocate + makeTransformerEngineTensor setup work unnecessarily serialises Python threads. NVTE_SCOPED_GIL_RELEASE handles the "already released" case via PyGILState_Check(), so adding the call guard here is safe and consistent.
| m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits, | |
| "Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"), | |
| py::arg("num_groups"), py::arg("logical_last_dim")); | |
| m.def("prepare_grouped_splits", &transformer_engine::pytorch::prepare_grouped_splits, | |
| "Prepare grouped split metadata from int32 or int64 split sizes", py::arg("split_sizes"), | |
| py::arg("num_groups"), py::arg("logical_last_dim"), | |
| py::call_guard<py::gil_scoped_release>()); |
| base_offsets[idx + 1] = prefix; | ||
| // cuDNN grouped GEMM expects padded split end offsets as int32. TE | ||
| // GroupedTensor metadata keeps the full int64 base_offsets/tensor_offsets. | ||
| split_points[idx] = static_cast<int32_t>(prefix); |
There was a problem hiding this comment.
Silent int32 truncation for large token counts —
prefix is int64_t, and the cast to int32_t silently wraps when the cumulative row count exceeds INT32_MAX (~2.1 B tokens). cuDNN itself consumes int32 padded offsets, so the limit is real, but it would be better to surface it as an explicit error than to pass corrupted offsets to cuDNN and get a hard-to-diagnose runtime failure. A guard of the form NVTE_CHECK(prefix <= INT32_MAX, "split_points overflow int32: total tokens exceed INT32_MAX") placed before the cast would make the failure mode visible.
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
| if (split_sizes.is_cuda()) { | ||
| NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ", | ||
| device.index(), ", but got CUDA device ", split_sizes.device().index(), "."); | ||
| split_sizes_for_kernel = split_sizes; |
There was a problem hiding this comment.
Missing
.contiguous() call for CUDA input in the CUDA path. makeTransformerEngineTensor(at::Tensor) uses tensor.data_ptr() directly without respecting strides — confirmed at common.cpp:140. A non-contiguous 1D CUDA tensor (e.g., a strided view or slice) will pass data_ptr() to the kernel, which reads elements at positions [0, 1, 2, ...] in raw memory, silently producing wrong split_sizes_i64, base_offsets, and all downstream outputs. The sibling splits_to_offsets avoids this with an explicit .contiguous() call.
| if (split_sizes.is_cuda()) { | |
| NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ", | |
| device.index(), ", but got CUDA device ", split_sizes.device().index(), "."); | |
| split_sizes_for_kernel = split_sizes; | |
| if (split_sizes.is_cuda()) { | |
| NVTE_CHECK(split_sizes.device() == device, "CUDA split_sizes must be on current CUDA device ", | |
| device.index(), ", but got CUDA device ", split_sizes.device().index(), "."); | |
| split_sizes_for_kernel = split_sizes.contiguous(); |
Description
Improve TE grouped mlp CPU overhead, suppose cuda graph is not enabled.
This is for issue: #2897
E2E Model: Qwen3.5 35B-A3B, nano scale model more prone to CPU overhead
What we measure: suppose we didn't turn on cuda graph, so that
split_sizeslive on CPU, we measure the CPU side time between the end of H2D copy ofsplit_sizes, to the launch of grouped quantize kernel.before: 355us

after: 167us

Type of change
Changes
Please list the changes introduced in this PR:
Checklist: