[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443denera wants to merge 57 commits into
Conversation
908bbc2 to
69cf235
Compare
| @@ -17,6 +18,12 @@ | |||
|
|
|||
| #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 | |||
|
|
|||
| /* \brief Check if TE is built with cuBlasMp. | |||
| @@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder { | |||
| ExtComm comm); | |||
|
|
|||
| void ub_barrier(ExtComm comm); | |||
|
|
|||
| int64_t get_nccl_comm_ptr(std::string comm_name) { | |||
| NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); | |||
There was a problem hiding this comment.
This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."
4596411 to
b4ad546
Compare
Greptile SummaryThis PR extends the Comm+GEMM overlap API in TransformerEngine to support a cuBLASMp backend alongside the existing Userbuffers backend, covering both PyTorch and JAX framework paths. The cuBLASMp path uses its own NCCL communicator (initialized in
Confidence Score: 3/5Not safe to merge: the num_splits default change silently alters Userbuffers pipeline behaviour for existing callers, and several carry-over issues remain unresolved. The CommOverlap Python binding now exposes a different default for num_splits (3 to 4) without any changelog entry, silently breaking backward-compatible usage. Carry-over defects — absent intra NCCL comm for single-node cuBLASMp, unconditional nccl.h in public headers, exception-unsafe warmup allocations, and the NCCL file-path isolation regression in the JAX helper — represent a cluster of correctness and portability risks across the PR. transformer_engine/pytorch/csrc/extensions/pybind.cpp (num_splits default change), transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (warmup workspace exception safety), transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h and transformer_engine/pytorch/csrc/common.h (unconditional nccl.h in public headers). Important Files Changed
Reviews (27): Last reviewed commit: "handling ncclComm_t via shared pointers ..." | Re-trigger Greptile |
There was a problem hiding this comment.
Additional Comments (8)
-
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)logic: Variable shadowing bug:
kis assignedk * _tp_sizewherekappears on both sides. Should bek = k_local * _tp_size. -
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)logic: Invalid reinterpret_cast: cannot cast an
int*(pointer) toint(value). Should bereinterpret_cast<void**>(&handler._device_barrier). -
transformer_engine/pytorch/csrc/extensions.h, line 517 (link)syntax: Stray character
athat will cause compilation failure. -
transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)logic: Constructor parameter mismatch:
CommOverlapBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_rank, ...). Order oftp_rankandtp_sizeis swapped. -
transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)logic: Constructor parameter mismatch:
CommOverlapP2PBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_rank, ...). Order oftp_rankandtp_sizeis swapped. -
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)logic: Constructor parameter mismatch:
CommOverlapP2PBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_domain_id, ...). Should usetp_rankinstead oftp_domain_id.Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?
-
tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)style: Unconditional initialization:
local_kernel2_t_shapeis only used whenub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)style: Incorrect initialization:
ref2_gshould only be assigned whenub_obj2 is not None, but it's unconditionally assigned to(0,)which doesn't match the expected tensor type.
18 files reviewed, 8 comments
147036f to
c5471f8
Compare
…rk extensions Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
c5471f8 to
d79bf21
Compare
There was a problem hiding this comment.
Additional Comments (3)
-
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)logic: return value from
nvte_comm_gemm_ctx_create()is not assigned to_cublasmp_ctx, causing null pointer when destructor callsnvte_comm_gemm_ctx_destroy(_cublasmp_ctx)on line 173 -
transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)style: parameter order changed -
collective_opmoved aftertranspose_batch_sequence,sequence_dim,is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified? -
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)logic:
_tp_sizefield is not initialized in the cuBlasMp constructor, but it's used incublasmp_ag_gemm()andcublasmp_gemm_rs()methods (lines 321, 332). Add initialization
18 files reviewed, 3 comments
364b416 to
ee517d3
Compare
There was a problem hiding this comment.
Additional Comments (6)
-
transformer_engine/pytorch/module/base.py, line 415-417 (link)logic: Parameter order is incorrect - the C++ constructor signature is
(helper, tp_rank, tp_size, ...)but Python is passing(helper, tp_size, local_rank, ...). This swapstp_rankandtp_size, causing incorrect initialization. -
transformer_engine/pytorch/module/base.py, line 387-389 (link)logic: Parameter order is incorrect - the C++ constructor signature is
(helper, tp_rank, tp_size, ...)but Python is passing(helper, tp_size, local_rank, ...). This swapstp_rankandtp_size, causing incorrect initialization. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters.
19 files reviewed, 6 comments
Signed-off-by: Alp Dener <adener@nvidia.com>
5cb8204 to
51b64fb
Compare
for more information, see https://pre-commit.ci
| for (int64_t col = 0; col < cols; ++col) { | ||
| const auto* send = reinterpret_cast<const uint8_t*>(local.data() + col * local_rows); | ||
| auto* recv = reinterpret_cast<uint8_t*>(gathered.data() + col * global_rows); | ||
| CHECK_MPI(MPI_Allgatherv(send, static_cast<int>(local_rows * sizeof(T)), MPI_BYTE, recv, |
There was a problem hiding this comment.
Do we actually support different number of rows on each GPU to use this scheme both here and in the column allgather?
There was a problem hiding this comment.
I don't think we ever have a use-case in TE (on the framework side) where the distributed matrix dim is not divisible by TP size but I believe cuBLASMp supports arbitrary global matrix dims without assuming that it is neatly divisible. The original DistributeTensors fixture in the CPP test suite sets a row/cols start and row/cols num for each operand on each calling rank, and my modifications to the CPP tests respect that existing convention when making sure we do AG/RS/AR comparisons between overlapped vs. non-overlapped distributed GEMM instead of doing overlapped distributed GEMM vs. global GEMM (which was causing numerical discrepancies before because the reference output was missing the reduction).
| // - AllReduce: BIAS applied per-rank before AR, output = sum_r(A_r@B_r + bias) = sum_r(A_r@B_r) + nranks*bias | ||
| // -> no bias correction needed in ref |
There was a problem hiding this comment.
It's not wrong at least for now. This is a known bug in cuBLASMp and the fix is rolling out with v0.9.
I can gate the cuBLASMp versions in the CPP test and apply the bias correction to AR for v0.9+ so the tests don't break when our CI container bases update.
There was a problem hiding this comment.
Wait, but isn't this behavior wrong? So we should guard against using this with bias and instead do bias manually for cuBLASMp < 0.9, right?
There was a problem hiding this comment.
Yeah, the behavior is wrong, but the bug is only in the all-reduce overlap, which we don't actually use or expose in the framework bindings.
Reduce-scatter overlap in cuBLASMp correctly applies the bias only once after the reduction so there's no need to guard against it.
| ctx.requires_wgrad | ||
| and ctx.ub_overlap_ag | ||
| and ctx.ub_obj_gradout is not None | ||
| and ctx.ub_obj_gradout.with_cublasmp() |
There was a problem hiding this comment.
ctx is not used here anymore - it should reference bwd_args.
There was a problem hiding this comment.
Yeah, I just saw this on my end too and fixed it. Catching branch up with main again and running tests. Will push up the fixes when everything's coming back green.
| #ifdef NVTE_WITH_CUBLASMP | ||
| NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", | ||
| "with valid process groups!"); | ||
| NVTE_CHECK(backend_is_nccl, |
There was a problem hiding this comment.
Codex raised an issue (which I think is correct) that in the if statements in base.py:238 we default to mpi or gloo if those are available and not NCCL if you do not set the backend explicitly. This check though requires backend to be NCCL - we should reconcile that.
There was a problem hiding this comment.
Yeah, this is a valid catch, but the fix is removing the NCCL requirement entirely. The check was leftover from when the bootstrapping was attempting to get a NCCL communicator directly out of the PyTorch PG, but we don't do that anymore. We simply create a new NCCL comm that spans the same devices/ranks as the PyTorch PG instead, and the PyTorch PG itself can have any backend in that case.
I'll make the fix on my end so it pushes up with the rest. Thanks!
There was a problem hiding this comment.
Ok, so should this check then be relaxed? It is still here.
| } | ||
| TensorWrapper pre_gelu_out_(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); | ||
| // Match GemmV2FFI's operand swap: rhs becomes A, lhs becomes B. | ||
| cudaStream_t prepare_stream = cudaStreamPerThread; |
There was a problem hiding this comment.
You should synchronize that stream after calling that warmup iteration.
…sed and tests modified to account for them Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
…ling in cuBLASMp bindings, added atol and rtol args to TE/PyTorch comm+GEMM runner scripts Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
| xla_flags = os.environ.get("XLA_FLAGS", "") | ||
| os.environ["XLA_FLAGS"] = ( | ||
| xla_flags | ||
| + " --xla_gpu_enable_command_buffer=+COLLECTIVES" |
There was a problem hiding this comment.
Here, we need to make sure that CUSTOM CALL is captured by default.
Would be better to add check for that?
There was a problem hiding this comment.
CUSTOM_CALL is already part of the defaults for command buffers, so I don't believe we need to add it ourselves here.
The issue is that cuBLASMp has NCCL calls in it for device-initiated communication, and XLA exempts NCCL collectives from graph capture, so when NCCL collectives appear inside a CUSTOM_CALL chunk, we end up with a CUDA_ERROR_STREAM_CAPTURE_INVALIDATED error.
Adding COLLECTIVES to the command buffers is the only way I'm aware that gets around this issue. I don't know of any XLA API or flag that selectively enables command buffers for NCCL collectives only inside a CUSTOM_CALL.
There was a problem hiding this comment.
But then don't you need this mark this custom op with "collective"?
Something like this tensorflow/tensorflow@104721c
There was a problem hiding this comment.
Commit message for this says:
Currently all FFIs are treated as opaque kernels which are further treated as compute kernels in LHS.
...
It introduces a annotation to specific compute_on="gpu_stream:collective", then it can be scheduled in the same way as native collectives and stream assignment will give it the collective stream which is of high priority and wont overlap with other collectives.
This doesn't sound like what we want for CollectiveGemm. This is not a custom collective. It's a compute op that has some internal communication (strictly P2P in TE/JAX, not collective) with a specific order-of-operations dependencies with the compute. We want XLA to treat it opaquely, and we want it to be invoked with the low priority compute stream that's used with the GEMM chunks while the internal high priority collective streams in Userbuffers or cuBLASMp are used for the overlapped communication (XLA does not know about these internal streams and doesn't need to).
| use_cublasmp = request.config.getoption("--use-cublasmp") | ||
| if use_cublasmp and not nvte_built_with_cublasmp(): | ||
| pytest.skip( | ||
| "Collective GEMM cuBLASMp backend tests require Transformer Engine to be built " | ||
| "with NVTE_WITH_CUBLASMP=1." | ||
| ) |
There was a problem hiding this comment.
Are we going to build CUBLASMP in our CI images by default?
There was a problem hiding this comment.
I would like to but only starting with 26.05.
We need NCCL 2.30+ for cuBLASMp to be graph-safe. I'm told JAX containers are going to satisfy that requirement starting with 26.05 so we can safely install cuBLASMp without breaking anything else.
| auto *executor = CollectiveGemmPlanRegistry::getInstance().get_executor( | ||
| buffer_shape, buffer_dtype, config.collective_op); | ||
|
|
||
| // Run a dummy cuBLASMp matmul in the prepare stage so its lazy NCCL |
There was a problem hiding this comment.
Could we make sure not to do any real GEMM computation here, rather than init NCCL mems/windows?
There was a problem hiding this comment.
Ideally yes, but currently no. There's a known bug in ncclPutSignal and ncclWaitSignal that does lazy initialization of RMA CE signal buffers, which means they're not CUDA graph-safe on first invocation. That forces us to do a dummy cuBlasMp GEMM call in the prepare phase just to trigger that lazy init so that invocations in the execute phase are graph-safe.
NCCL is aware of the bug and they're working on it. We can remove this dummy matmul here as soon as that bug is fixed.
…same mean/std as TE/PyTorch when generating random operands Signed-off-by: Alp Dener <adener@nvidia.com>
…after dummy cuBLASMp call, version-guarded C++ all-reduce tests against CUBLASMP_VERSION >= 900 Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
| initialized = false; | ||
| #ifdef NVTE_WITH_CUBLASMP | ||
| for (auto &comm : nccl_comms) { | ||
| NVTE_CHECK_NCCL(ncclCommDestroy(comm.second)); |
There was a problem hiding this comment.
Are we sure that those communicators will not be used after the initialize_ub call finishes? Since the helper is going to be destroyed then and so the communicator also is going to be destroyed.
There was a problem hiding this comment.
You're right, these communicators are absolutely used by cuBLASMp after initialize_ub() and keeping them in the helper risks losing them.
I haven't seen this actually happen in practice, but I suspect that's because PyBind increments the reference count when the CommOverlapHelper is passed into the initialization for CommOverlap or CommOverlapP2P. Garbage collection does not wipe the helper out as a result even when it goes out of scope at the end of initialize_ub().
It's still not a good design though. It means that we actually leak this helper object — we can no longer access it in Python once it goes out of scope in initialize_ub() but it also doesn't get garbage collected because the reference count does not reach zero.
I'll see about doing the NCCL communicators in the CommOverlap/CommOverlapP2P init so that they get destroyed only when the overlap object is destroyed.
…up in CommOverlap destructor to avoid leaking on exception during warmup Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
…rematurely destroyed when CommOverlapHelper goes out of scope Signed-off-by: Alp Dener <adener@nvidia.com>
|
|
||
| class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { | ||
| private: | ||
| void *_warmup_workspace{nullptr}; |
There was a problem hiding this comment.
Why do we keep this pointer throughout the execution if it is only needed during warmup?
| NVTE_CHECK_CUDA(cudaMalloc(&a_ptr, a_bytes)); | ||
| NVTE_CHECK_CUDA(cudaMalloc(&b_ptr, b_bytes)); | ||
| NVTE_CHECK_CUDA(cudaMalloc(&d_ptr, d_bytes)); | ||
| NVTE_CHECK_CUDA(cudaMalloc(&warmup_workspace, a_bytes + b_bytes + d_bytes)); |
There was a problem hiding this comment.
Ok, this is not right (well, it is actually resulting in a ~ok execution, but the intent is not correctly coded).
This warmup_workspace here is a local variable, so the cudaMalloc is modifying it rather than the _warmup_workspace from the class -> that variable stays as nullptr. If you really need to keep this pointer around then it should be passed as a pointer to pointer in order to be writable.
Now, if you just did that then that would expose another problem - at the end of this function you do cudaFree without setting the pointer back to nullptr, and so in the commoverlap destructor you would have a double free issue.
I don't see why you need to keep the workspace pointer, so probably just getting rid of that and keeping the pointers here completely local is the right choice, but please let me know why you wanted to keep that pointer in the first place.
There was a problem hiding this comment.
I don't see why you need to keep the workspace pointer, so probably just getting rid of that and keeping the pointers here completely local is the right choice, but please let me know why you wanted to keep that pointer in the first place.
This is what I had originally but Greptile complained about it leaking if the warmup GEMM call throws an exception.
I agree with you that I'd rather just have it be purely local, allocated before warmup GEMM and freed after. If you're okay with ignoring Greptile's warning about it, I can revert back to what it was before.
| // the last reference (held by the helper and/or any CommOverlap consumers) | ||
| // is released, so the communicator outlives whichever owner is destroyed | ||
| // first. | ||
| using NcclCommSharedPtr = std::shared_ptr<std::remove_pointer<ncclComm_t>::type>; |
There was a problem hiding this comment.
This file does not include nccl.h directly.
| # Bulk overlaps require the Userbuffers backend; the cuBLASMp backend | ||
| # falls back to async NCCL ops via torch.distributed. | ||
| bulk_available = not is_userbuffer_cublasmp_backend() | ||
| bulk_available = not using_cublasmp_backend() |
There was a problem hiding this comment.
This will fail since it requires the UB communicators to be initialized, but this is called in every case, including single gpu. It should be guarded via the UB usage options.
| cudaEventDestroy(_comm_launch_event); | ||
| } | ||
| if (_with_cublasmp) { | ||
| nvte_comm_gemm_ctx_destroy(_cublasmp_ctx); |
There was a problem hiding this comment.
Does this touch the nccl communicator? If so, then the shared pointer is not good enough since its destructor is going to be called before the destructor of the base class and so the communicator is going to be destroyed before this function gets called.
There was a problem hiding this comment.
This does not touch the communicator. See here:
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
delete ctx;
}
The structure is deleted but it never invokes ncclCommDestroy.
cuBlasMpFree also does not touch the communicator. cuBLASMp assumes the communicator is an external resource. So it should be okay to keep the shared pointed at the TE/PyTorch level and clean them up where they're created in the first place.
|
|
||
| // Build a TensorWrapper for the prepare stage. Operand contents are | ||
| // uninitialized at this point and no kernels are launched. | ||
| static TensorWrapper prepare_operand_tensor(Buffer_Type buf, Buffer_Type scale_inv, |
There was a problem hiding this comment.
From Codex: this might not work for Hopper FP8 since it only creates the rowwise tensor, whereas Hopper can pass columnwise tensors to cuBLASMp. Not sure if this is a valid concern - if we only need to run the warmup on anything in order to get everything ready and there is no input dependency then passing only rowwise is fine.
There was a problem hiding this comment.
I've been testing on Hopper and Blackwell, and it's been working on both, so I don't think there's a dependency there on input dtype. The missing allocation we're using the warmup GEMM for is a CopyEngine RMA signal buffer, and the signal dtype is independent of the operands.
| raise AssertionError(result.stderr.decode()) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("use_cublasmp", (False, True)) |
There was a problem hiding this comment.
You should skip cublasmp if support for it was not built in.
Description
This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.
Type of change
Checklist: