-
Notifications
You must be signed in to change notification settings - Fork 728
[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API) #2443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
177b2ec
7d46b0b
6d4a141
35d0f19
dd8eaf3
d79bf21
ee517d3
51b64fb
9be771c
4cec043
898cf30
422a654
6e42235
626dd1d
d44cfc4
e341a8b
6942d20
bef5c7e
81d6383
6c6cc4d
9ed2adf
ca913b9
c55626d
f863ba8
5a8c7ae
5b9df92
775df95
441472a
3df11fc
58f1e68
f05f849
f95f229
f84e8f9
c67c183
e9c79a3
1b8fb1e
caa741e
9cca8a9
ff4187c
218257f
c2af15b
f75d98e
c208d83
5bd8ff9
509c12e
a51bd3b
b0bbe6d
04c52ca
4ea7334
6d6c7b2
f4740ea
ee80f69
85292f3
0cdbd6a
80b0a71
cc25997
0b4ecba
f753353
cf54c14
deb0890
f959f34
89f5d8d
67521f7
6d5ca20
8bcdaff
e90498d
a77c914
8d254af
a5c9117
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,14 @@ | |
| """Shared functions for the collective GEMM tests""" | ||
|
|
||
| import argparse | ||
| import glob | ||
| import os | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| from jax.experimental import mesh_utils | ||
| from jax.experimental.multihost_utils import sync_global_devices | ||
|
|
||
| from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap | ||
|
|
||
|
|
@@ -56,9 +59,9 @@ def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs) | |
| tols["atol"] = atol | ||
|
|
||
| if not isinstance(actual, float): | ||
| actual = actual.astype(jnp.float32) | ||
| actual = np.asarray(actual, dtype=np.float32) | ||
| if not isinstance(desired, float): | ||
| desired = desired.astype(jnp.float32) | ||
| desired = np.asarray(desired, dtype=np.float32) | ||
|
|
||
| np.testing.assert_allclose(actual, desired, **tols, **kwargs) | ||
|
|
||
|
|
@@ -96,6 +99,20 @@ def _initialize_distributed(args): | |
|
|
||
| assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" | ||
|
|
||
| # cuBLASMp issues NCCL collectives on its own communication stream | ||
| # inside the GEMM custom call. Add COLLECTIVES so XLA captures those | ||
| # ops alongside the custom call instead of invalidating the capture. | ||
| # Lower the min-graph-size to 1 so single-matmul modules also get | ||
| # captured -- otherwise small test cases skip the captured path. | ||
| # Userbuffers does not need either flag. | ||
| if args.use_cublasmp: | ||
| xla_flags = os.environ.get("XLA_FLAGS", "") | ||
| os.environ["XLA_FLAGS"] = ( | ||
| xla_flags | ||
| + " --xla_gpu_enable_command_buffer=+COLLECTIVES" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we need to make sure that CUSTOM CALL is captured by default.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CUSTOM_CALL is already part of the defaults for command buffers, so I don't believe we need to add it ourselves here. The issue is that cuBLASMp has NCCL calls in it for device-initiated communication, and XLA exempts NCCL collectives from graph capture, so when NCCL collectives appear inside a CUSTOM_CALL chunk, we end up with a CUDA_ERROR_STREAM_CAPTURE_INVALIDATED error. Adding COLLECTIVES to the command buffers is the only way I'm aware that gets around this issue. I don't know of any XLA API or flag that selectively enables command buffers for NCCL collectives only inside a CUSTOM_CALL.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But then don't you need this mark this custom op with "collective"?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commit message for this says: This doesn't sound like what we want for CollectiveGemm. This is not a custom collective. It's a compute op that has some internal communication (strictly P2P in TE/JAX, not collective) with a specific order-of-operations dependencies with the compute. We want XLA to treat it opaquely, and we want it to be invoked with the low priority compute stream that's used with the GEMM chunks while the internal high priority collective streams in Userbuffers or cuBLASMp are used for the overlapped communication (XLA does not know about these internal streams and doesn't need to). |
||
| + " --xla_gpu_graph_min_graph_size=1" | ||
| ) | ||
|
|
||
| print( | ||
| f"Initializing JAX distributed with coordinator={args.coordinator_address}, " | ||
| f"num_processes={args.num_processes}, process_id={args.process_id}" | ||
|
|
@@ -118,6 +135,20 @@ def _initialize_distributed(args): | |
| devices_per_process = 1 | ||
| num_total_devices = args.num_processes | ||
|
|
||
| # Remove stale NCCL unique ID files from previous (possibly crashed) runs. | ||
| # These files are used for one-time coordination during bootstrap; stale files | ||
| # cause non-leader processes to read an old unique ID, breaking NCCL init. | ||
| # Only process 0 performs the cleanup; a global barrier ensures all processes | ||
| # wait for the cleanup to complete before any TP leader writes a fresh file. | ||
| nccl_base_path = os.environ.get("NVTE_JAX_NCCL_FILE_PATH", "/tmp") | ||
| if args.process_id == 0: | ||
| for f in glob.glob(os.path.join(nccl_base_path, "nccl_*_unique_id_*.bin")): | ||
| try: | ||
| os.remove(f) | ||
| except OSError: | ||
| pass | ||
| sync_global_devices("nccl_id_cleanup") | ||
|
|
||
| print( | ||
| f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," | ||
| f" devices_per_process={devices_per_process}, process_id={args.process_id}" | ||
|
|
@@ -128,6 +159,7 @@ def _initialize_distributed(args): | |
| num_devices_per_process=devices_per_process, | ||
| process_id=args.process_id, | ||
| tensor_parallel_size=args.tensor_parallel_size, | ||
| use_cublasmp=args.use_cublasmp, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -199,6 +231,16 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para | |
| parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") | ||
| parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") | ||
| parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") | ||
| parser.add_argument( | ||
| "--std", | ||
| type=float, | ||
| default=0.023, | ||
| help=( | ||
| "Standard deviation for input/weight/bias tensors. Matches TE/PyTorch's" | ||
| " run_gemm_with_overlap.py default so both frameworks evaluate FP8 noise" | ||
| " on equal footing." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--collective-type", | ||
| type=str, | ||
|
|
@@ -224,5 +266,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para | |
| parser.add_argument( | ||
| "--enable-result-check", action="store_true", default=True, help="Enable result checking" | ||
| ) | ||
| parser.add_argument( | ||
| "--use-cublasmp", | ||
| action="store_true", | ||
| default=False, | ||
| help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation", | ||
| ) | ||
|
|
||
| return parser | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,19 +5,31 @@ | |
| """config for collective_gemm tests""" | ||
| import pytest | ||
|
|
||
| import transformer_engine.jax # noqa: F401 - must load libtransformer_engine.so before transformer_engine_jax | ||
| from transformer_engine_jax import nvte_built_with_cublasmp | ||
|
|
||
|
|
||
| def pytest_addoption(parser): | ||
| """Pytest hook for collective_gemm tests""" | ||
| parser.addoption("--coordinator-address", action="store", default="localhost:12345") | ||
| parser.addoption("--num-processes", action="store", default=1) | ||
| parser.addoption("--process-id", action="store", default=0) | ||
| parser.addoption("--local-device-ids", action="store", default=None) | ||
| parser.addoption("--use-cublasmp", action="store_true", default=False) | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def distributed_args(request): | ||
| """Fixture for querying distributed initialization arguments""" | ||
| if request.cls: | ||
| use_cublasmp = request.config.getoption("--use-cublasmp") | ||
| if use_cublasmp and not nvte_built_with_cublasmp(): | ||
| pytest.skip( | ||
| "Collective GEMM cuBLASMp backend tests require Transformer Engine to be built " | ||
| "with NVTE_WITH_CUBLASMP=1." | ||
| ) | ||
|
Comment on lines
+25
to
+30
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we going to build CUBLASMP in our CI images by default?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to but only starting with 26.05. We need NCCL 2.30+ for cuBLASMp to be graph-safe. I'm told JAX containers are going to satisfy that requirement starting with 26.05 so we can safely install cuBLASMp without breaking anything else. |
||
| if use_cublasmp and "mxfp8" in request.node.name.lower(): | ||
| pytest.skip("MXFP8 is not supported by the cuBLASMp backend wrappers in TE/common.") | ||
| request.cls.coordinator_address = request.config.getoption("--coordinator-address") | ||
| request.cls.num_processes = int(request.config.getoption("--num-processes")) | ||
| request.cls.process_id = int(request.config.getoption("--process-id")) | ||
|
|
@@ -27,3 +39,4 @@ def distributed_args(request): | |
| if request.cls.local_device_ids is None | ||
| else len(request.cls.local_device_ids.split(",")) | ||
| ) | ||
| request.cls.use_cublasmp = use_cublasmp | ||
Uh oh!
There was an error while loading. Please reload this page.