Skip to content

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443

Open
denera wants to merge 57 commits into
NVIDIA:mainfrom
denera:common/tp-overlap-cublasmp
Open

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
denera wants to merge 57 commits into
NVIDIA:mainfrom
denera:common/tp-overlap-cublasmp

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Dec 2, 2025

Description

This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Dec 2, 2025
@denera denera force-pushed the common/tp-overlap-cublasmp branch 2 times, most recently from 908bbc2 to 69cf235 Compare December 2, 2025 20:12
Comment thread transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
@@ -17,6 +18,12 @@

#define NVTE_COMM_OVERLAP_MAX_STREAMS 3

/* \brief Check if TE is built with cuBlasMp.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cuBLASMp

@@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder {
ExtComm comm);

void ub_barrier(ExtComm comm);

int64_t get_nccl_comm_ptr(std::string comm_name) {
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 4596411 to b4ad546 Compare December 16, 2025 19:04
@denera denera marked this pull request as ready for review December 16, 2025 22:58
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Dec 16, 2025

Greptile Summary

This PR extends the Comm+GEMM overlap API in TransformerEngine to support a cuBLASMp backend alongside the existing Userbuffers backend, covering both PyTorch and JAX framework paths. The cuBLASMp path uses its own NCCL communicator (initialized in CommOverlapHelper) to drive all-gather and reduce-scatter fused GEMMs, bypassing the Userbuffers transport layer entirely.

  • New cuBLASMp constructors added to CommOverlap, CommOverlapP2P, and their base classes; a construction-time cublasmp_capture_warmup pre-registers lazy NCCL windows and workspaces before any CUDA-graph capture.
  • Python API changes: initialize_ub gains with_cublasmp; CommOverlapHelper now builds dedicated NCCL communicators for the cuBLASMp path; the CommOverlap default for num_splits changed silently from 3 to 4, altering Userbuffers pipeline depth for callers that omit the argument.
  • JAX support: InitializeCgemmCommunicator gains use_cublasmp and CollectiveGemmPlanRegistry builds a bare CommOverlapP2PBase with the cuBLASMp constructor when enabled.

Confidence Score: 3/5

Not safe to merge: the num_splits default change silently alters Userbuffers pipeline behaviour for existing callers, and several carry-over issues remain unresolved.

The CommOverlap Python binding now exposes a different default for num_splits (3 to 4) without any changelog entry, silently breaking backward-compatible usage. Carry-over defects — absent intra NCCL comm for single-node cuBLASMp, unconditional nccl.h in public headers, exception-unsafe warmup allocations, and the NCCL file-path isolation regression in the JAX helper — represent a cluster of correctness and portability risks across the PR.

transformer_engine/pytorch/csrc/extensions/pybind.cpp (num_splits default change), transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (warmup workspace exception safety), transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h and transformer_engine/pytorch/csrc/common.h (unconditional nccl.h in public headers).

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds use_cublasmp dispatch lambdas for CommOverlap and CommOverlapP2P. The num_splits default for CommOverlap silently changed from 3 to 4, altering pipeline depth on the Userbuffers path for any caller that omits the argument.
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp Adds cuBLASMp NCCL communicator setup, cublasmp_capture_warmup, and new cuBLASMp constructors. Warmup workspace is allocated/freed inside the helper via a by-value pointer so _warmup_workspace is never set and the destructor guard is dead code; exceptions between malloc and free leak the allocation.
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h Adds cuBLASMp constructors and member variables without #ifdef NVTE_WITH_CUBLASMP guards, making the public header unconditionally depend on nccl.h and comm_gemm.h.
transformer_engine/pytorch/csrc/common.h Adds unconditional #include nccl.h pulled into every PyTorch extension source file regardless of NVTE_WITH_CUBLASMP, breaking non-NCCL builds.
transformer_engine/pytorch/module/base.py Adds with_cublasmp flag to initialize_ub and using_cublasmp_backend() helper. The helper uses assert for its pre-condition (silently disabled under Python -O) and is called at init time requiring initialize_ub to be called first.
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp Adds cuBLASMp executor branch in CollectiveGemmPlanRegistry; correctly includes use_cublasmp in the plan_id hash. Removes pgid from NCCL file path, removing per-process-group isolation for concurrent runs.
transformer_engine/pytorch/module/linear.py Adds cuBLASMp-aware output routing, backward all-gather for wgrad, and bulk_available gating. Logic is consistent with layernorm_mlp.
tests/pytorch/distributed/run_gemm_with_overlap.py Rewrites the reference computation to per-rank local comparisons and adds cuBLASMp test support. Fixes torch.transpose and general_gemm unpack bugs from the earlier version.
examples/jax/collective_gemm/run_test_cgemm.sh Adds per-backend test loop with wait between backends, correct -c spacing, scoped log deletion, and dynamic cuBLASMp detection. Fixes most concurrency and log-path bugs from prior review.

Reviews (27): Last reviewed commit: "handling ncclComm_t via shared pointers ..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (8)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)

    logic: Variable shadowing bug: k is assigned k * _tp_size where k appears on both sides. Should be k = k_local * _tp_size.

  2. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)

    logic: Invalid reinterpret_cast: cannot cast an int* (pointer) to int (value). Should be reinterpret_cast<void**>(&handler._device_barrier).

  3. transformer_engine/pytorch/csrc/extensions.h, line 517 (link)

    syntax: Stray character a that will cause compilation failure.

  4. transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)

    logic: Constructor parameter mismatch: CommOverlapBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  5. transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  6. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_domain_id, ...). Should use tp_rank instead of tp_domain_id.

    Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?

  7. tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)

    style: Unconditional initialization: local_kernel2_t_shape is only used when ub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  8. tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)

    style: Incorrect initialization: ref2_g should only be assigned when ub_obj2 is not None, but it's unconditionally assigned to (0,) which doesn't match the expected tensor type.

18 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 147036f to c5471f8 Compare December 17, 2025 02:15
denera and others added 6 commits December 17, 2025 02:16
…rk extensions

Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from c5471f8 to d79bf21 Compare December 17, 2025 02:16
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)

    logic: return value from nvte_comm_gemm_ctx_create() is not assigned to _cublasmp_ctx, causing null pointer when destructor calls nvte_comm_gemm_ctx_destroy(_cublasmp_ctx) on line 173

  2. transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)

    style: parameter order changed - collective_op moved after transpose_batch_sequence, sequence_dim, is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified?

  3. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)

    logic: _tp_size field is not initialized in the cuBlasMp constructor, but it's used in cublasmp_ag_gemm() and cublasmp_gemm_rs() methods (lines 321, 332). Add initialization

18 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 364b416 to ee517d3 Compare December 17, 2025 02:50
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/module/base.py, line 415-417 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  2. transformer_engine/pytorch/module/base.py, line 387-389 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  3. tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  4. tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  5. tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  6. tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

19 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 5cb8204 to 51b64fb Compare December 17, 2025 03:36
Comment thread transformer_engine/common/CMakeLists.txt Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
for (int64_t col = 0; col < cols; ++col) {
const auto* send = reinterpret_cast<const uint8_t*>(local.data() + col * local_rows);
auto* recv = reinterpret_cast<uint8_t*>(gathered.data() + col * global_rows);
CHECK_MPI(MPI_Allgatherv(send, static_cast<int>(local_rows * sizeof(T)), MPI_BYTE, recv,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually support different number of rows on each GPU to use this scheme both here and in the column allgather?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we ever have a use-case in TE (on the framework side) where the distributed matrix dim is not divisible by TP size but I believe cuBLASMp supports arbitrary global matrix dims without assuming that it is neatly divisible. The original DistributeTensors fixture in the CPP test suite sets a row/cols start and row/cols num for each operand on each calling rank, and my modifications to the CPP tests respect that existing convention when making sure we do AG/RS/AR comparisons between overlapped vs. non-overlapped distributed GEMM instead of doing overlapped distributed GEMM vs. global GEMM (which was causing numerical discrepancies before because the reference output was missing the reduction).

Comment thread tests/cpp_distributed/test_comm_gemm.cu Outdated
Comment on lines +365 to +366
// - AllReduce: BIAS applied per-rank before AR, output = sum_r(A_r@B_r + bias) = sum_r(A_r@B_r) + nranks*bias
// -> no bias correction needed in ref
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh? Isn't this wrong?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not wrong at least for now. This is a known bug in cuBLASMp and the fix is rolling out with v0.9.

I can gate the cuBLASMp versions in the CPP test and apply the bias correction to AR for v0.9+ so the tests don't break when our CI container bases update.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, but isn't this behavior wrong? So we should guard against using this with bias and instead do bias manually for cuBLASMp < 0.9, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the behavior is wrong, but the bug is only in the all-reduce overlap, which we don't actually use or expose in the framework bindings.

Reduce-scatter overlap in cuBLASMp correctly applies the bias only once after the reduction so there's no need to guard against it.

Comment on lines +1042 to +1045
ctx.requires_wgrad
and ctx.ub_overlap_ag
and ctx.ub_obj_gradout is not None
and ctx.ub_obj_gradout.with_cublasmp()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx is not used here anymore - it should reference bwd_args.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I just saw this on my end too and fixed it. Catching branch up with main again and running tests. Will push up the fixes when everything's coming back green.

#ifdef NVTE_WITH_CUBLASMP
NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ",
"with valid process groups!");
NVTE_CHECK(backend_is_nccl,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codex raised an issue (which I think is correct) that in the if statements in base.py:238 we default to mpi or gloo if those are available and not NCCL if you do not set the backend explicitly. This check though requires backend to be NCCL - we should reconcile that.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a valid catch, but the fix is removing the NCCL requirement entirely. The check was leftover from when the bootstrapping was attempting to get a NCCL communicator directly out of the PyTorch PG, but we don't do that anymore. We simply create a new NCCL comm that spans the same devices/ranks as the PyTorch PG instead, and the PyTorch PG itself can have any backend in that case.

I'll make the fix on my end so it pushes up with the rest. Thanks!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so should this check then be relaxed? It is still here.

}
TensorWrapper pre_gelu_out_(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING));
// Match GemmV2FFI's operand swap: rhs becomes A, lhs becomes B.
cudaStream_t prepare_stream = cudaStreamPerThread;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should synchronize that stream after calling that warmup iteration.

…sed and tests modified to account for them

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera requested a review from ksivaman as a code owner May 15, 2026 17:27
pre-commit-ci Bot and others added 5 commits May 15, 2026 17:28
…ling in cuBLASMp bindings, added atol and rtol args to TE/PyTorch comm+GEMM runner scripts

Signed-off-by: Alp Dener <adener@nvidia.com>
xla_flags = os.environ.get("XLA_FLAGS", "")
os.environ["XLA_FLAGS"] = (
xla_flags
+ " --xla_gpu_enable_command_buffer=+COLLECTIVES"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we need to make sure that CUSTOM CALL is captured by default.
Would be better to add check for that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUSTOM_CALL is already part of the defaults for command buffers, so I don't believe we need to add it ourselves here.

The issue is that cuBLASMp has NCCL calls in it for device-initiated communication, and XLA exempts NCCL collectives from graph capture, so when NCCL collectives appear inside a CUSTOM_CALL chunk, we end up with a CUDA_ERROR_STREAM_CAPTURE_INVALIDATED error.

Adding COLLECTIVES to the command buffers is the only way I'm aware that gets around this issue. I don't know of any XLA API or flag that selectively enables command buffers for NCCL collectives only inside a CUSTOM_CALL.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then don't you need this mark this custom op with "collective"?
Something like this tensorflow/tensorflow@104721c

Copy link
Copy Markdown
Collaborator Author

@denera denera May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit message for this says:

Currently all FFIs are treated as opaque kernels which are further treated as compute kernels in LHS.

...

It introduces a annotation to specific compute_on="gpu_stream:collective", then it can be scheduled in the same way as native collectives and stream assignment will give it the collective stream which is of high priority and wont overlap with other collectives.

This doesn't sound like what we want for CollectiveGemm. This is not a custom collective. It's a compute op that has some internal communication (strictly P2P in TE/JAX, not collective) with a specific order-of-operations dependencies with the compute. We want XLA to treat it opaquely, and we want it to be invoked with the low priority compute stream that's used with the GEMM chunks while the internal high priority collective streams in Userbuffers or cuBLASMp are used for the overlapped communication (XLA does not know about these internal streams and doesn't need to).

Comment on lines +25 to +30
use_cublasmp = request.config.getoption("--use-cublasmp")
if use_cublasmp and not nvte_built_with_cublasmp():
pytest.skip(
"Collective GEMM cuBLASMp backend tests require Transformer Engine to be built "
"with NVTE_WITH_CUBLASMP=1."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to build CUBLASMP in our CI images by default?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to but only starting with 26.05.

We need NCCL 2.30+ for cuBLASMp to be graph-safe. I'm told JAX containers are going to satisfy that requirement starting with 26.05 so we can safely install cuBLASMp without breaking anything else.

auto *executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, config.collective_op);

// Run a dummy cuBLASMp matmul in the prepare stage so its lazy NCCL
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make sure not to do any real GEMM computation here, rather than init NCCL mems/windows?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally yes, but currently no. There's a known bug in ncclPutSignal and ncclWaitSignal that does lazy initialization of RMA CE signal buffers, which means they're not CUDA graph-safe on first invocation. That forces us to do a dummy cuBlasMp GEMM call in the prepare phase just to trigger that lazy init so that invocations in the execute phase are graph-safe.

NCCL is aware of the bug and they're working on it. We can remove this dummy matmul here as soon as that bug is fixed.

denera and others added 4 commits May 18, 2026 21:30
…same mean/std as TE/PyTorch when generating random operands

Signed-off-by: Alp Dener <adener@nvidia.com>
…after dummy cuBLASMp call, version-guarded C++ all-reduce tests against CUBLASMP_VERSION >= 900

Signed-off-by: Alp Dener <adener@nvidia.com>
initialized = false;
#ifdef NVTE_WITH_CUBLASMP
for (auto &comm : nccl_comms) {
NVTE_CHECK_NCCL(ncclCommDestroy(comm.second));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that those communicators will not be used after the initialize_ub call finishes? Since the helper is going to be destroyed then and so the communicator also is going to be destroyed.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, these communicators are absolutely used by cuBLASMp after initialize_ub() and keeping them in the helper risks losing them.

I haven't seen this actually happen in practice, but I suspect that's because PyBind increments the reference count when the CommOverlapHelper is passed into the initialization for CommOverlap or CommOverlapP2P. Garbage collection does not wipe the helper out as a result even when it goes out of scope at the end of initialize_ub().

It's still not a good design though. It means that we actually leak this helper object — we can no longer access it in Python once it goes out of scope in initialize_ub() but it also doesn't get garbage collected because the reference count does not reach zero.

I'll see about doing the NCCL communicators in the CommOverlap/CommOverlapP2P init so that they get destroyed only when the overlap object is destroyed.

denera and others added 5 commits May 19, 2026 19:09
…up in CommOverlap destructor to avoid leaking on exception during warmup

Signed-off-by: Alp Dener <adener@nvidia.com>
…rematurely destroyed when CommOverlapHelper goes out of scope

Signed-off-by: Alp Dener <adener@nvidia.com>

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
private:
void *_warmup_workspace{nullptr};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we keep this pointer throughout the execution if it is only needed during warmup?

NVTE_CHECK_CUDA(cudaMalloc(&a_ptr, a_bytes));
NVTE_CHECK_CUDA(cudaMalloc(&b_ptr, b_bytes));
NVTE_CHECK_CUDA(cudaMalloc(&d_ptr, d_bytes));
NVTE_CHECK_CUDA(cudaMalloc(&warmup_workspace, a_bytes + b_bytes + d_bytes));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this is not right (well, it is actually resulting in a ~ok execution, but the intent is not correctly coded).
This warmup_workspace here is a local variable, so the cudaMalloc is modifying it rather than the _warmup_workspace from the class -> that variable stays as nullptr. If you really need to keep this pointer around then it should be passed as a pointer to pointer in order to be writable.
Now, if you just did that then that would expose another problem - at the end of this function you do cudaFree without setting the pointer back to nullptr, and so in the commoverlap destructor you would have a double free issue.
I don't see why you need to keep the workspace pointer, so probably just getting rid of that and keeping the pointers here completely local is the right choice, but please let me know why you wanted to keep that pointer in the first place.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why you need to keep the workspace pointer, so probably just getting rid of that and keeping the pointers here completely local is the right choice, but please let me know why you wanted to keep that pointer in the first place.

This is what I had originally but Greptile complained about it leaking if the warmup GEMM call throws an exception.

I agree with you that I'd rather just have it be purely local, allocated before warmup GEMM and freed after. If you're okay with ignoring Greptile's warning about it, I can revert back to what it was before.

// the last reference (held by the helper and/or any CommOverlap consumers)
// is released, so the communicator outlives whichever owner is destroyed
// first.
using NcclCommSharedPtr = std::shared_ptr<std::remove_pointer<ncclComm_t>::type>;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file does not include nccl.h directly.

# Bulk overlaps require the Userbuffers backend; the cuBLASMp backend
# falls back to async NCCL ops via torch.distributed.
bulk_available = not is_userbuffer_cublasmp_backend()
bulk_available = not using_cublasmp_backend()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail since it requires the UB communicators to be initialized, but this is called in every case, including single gpu. It should be guarded via the UB usage options.

cudaEventDestroy(_comm_launch_event);
}
if (_with_cublasmp) {
nvte_comm_gemm_ctx_destroy(_cublasmp_ctx);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this touch the nccl communicator? If so, then the shared pointer is not good enough since its destructor is going to be called before the destructor of the base class and so the communicator is going to be destroyed before this function gets called.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not touch the communicator. See here:

void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
  NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
  if (ctx->workspace) {
    NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
    NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
  }
  delete ctx;
}

The structure is deleted but it never invokes ncclCommDestroy.

cuBlasMpFree also does not touch the communicator. cuBLASMp assumes the communicator is an external resource. So it should be okay to keep the shared pointed at the TE/PyTorch level and clean them up where they're created in the first place.


// Build a TensorWrapper for the prepare stage. Operand contents are
// uninitialized at this point and no kernels are launched.
static TensorWrapper prepare_operand_tensor(Buffer_Type buf, Buffer_Type scale_inv,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Codex: this might not work for Hopper FP8 since it only creates the rowwise tensor, whereas Hopper can pass columnwise tensors to cuBLASMp. Not sure if this is a valid concern - if we only need to run the warmup on anything in order to get everything ready and there is no input dependency then passing only rowwise is fine.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been testing on Hopper and Blackwell, and it's been working on both, so I don't think there's a dependency there on input dtype. The missing allocation we're using the warmup GEMM for is a CopyEngine RMA signal buffer, and the signal dtype is independent of the operands.

raise AssertionError(result.stderr.decode())


@pytest.mark.parametrize("use_cublasmp", (False, True))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should skip cublasmp if support for it was not built in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants