diff --git a/build_tools/jax.py b/build_tools/jax.py index a7b200f915..5d9276b5e6 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -103,6 +103,9 @@ def setup_jax_extension( setup_mpi_flags(include_dirs, cxx_flags) + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): + cxx_flags.append("-DNVTE_WITH_CUBLASMP") + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 533addaf53..e2e6d09c29 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -6,6 +6,7 @@ import os from pathlib import Path +from importlib import metadata import setuptools @@ -88,6 +89,9 @@ def setup_pytorch_extension( libraries.append("nvshmem_host") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))): + cxx_flags.append("-DNVTE_WITH_CUBLASMP") + # Construct PyTorch CUDA extension sources = [str(path) for path in sources] include_dirs = [str(path) for path in include_dirs] diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py index 6815932395..479f56dd9a 100644 --- a/examples/jax/collective_gemm/common.py +++ b/examples/jax/collective_gemm/common.py @@ -4,11 +4,14 @@ """Shared functions for the collective GEMM tests""" import argparse +import glob +import os import jax import jax.numpy as jnp import numpy as np from jax.experimental import mesh_utils +from jax.experimental.multihost_utils import sync_global_devices from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap @@ -56,9 +59,9 @@ def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs) tols["atol"] = atol if not isinstance(actual, float): - actual = actual.astype(jnp.float32) + actual = np.asarray(actual, dtype=np.float32) if not isinstance(desired, float): - desired = desired.astype(jnp.float32) + desired = np.asarray(desired, dtype=np.float32) np.testing.assert_allclose(actual, desired, **tols, **kwargs) @@ -96,6 +99,20 @@ def _initialize_distributed(args): assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" + # cuBLASMp issues NCCL collectives on its own communication stream + # inside the GEMM custom call. Add COLLECTIVES so XLA captures those + # ops alongside the custom call instead of invalidating the capture. + # Lower the min-graph-size to 1 so single-matmul modules also get + # captured -- otherwise small test cases skip the captured path. + # Userbuffers does not need either flag. + if args.use_cublasmp: + xla_flags = os.environ.get("XLA_FLAGS", "") + os.environ["XLA_FLAGS"] = ( + xla_flags + + " --xla_gpu_enable_command_buffer=+COLLECTIVES" + + " --xla_gpu_graph_min_graph_size=1" + ) + print( f"Initializing JAX distributed with coordinator={args.coordinator_address}, " f"num_processes={args.num_processes}, process_id={args.process_id}" @@ -118,6 +135,20 @@ def _initialize_distributed(args): devices_per_process = 1 num_total_devices = args.num_processes + # Remove stale NCCL unique ID files from previous (possibly crashed) runs. + # These files are used for one-time coordination during bootstrap; stale files + # cause non-leader processes to read an old unique ID, breaking NCCL init. + # Only process 0 performs the cleanup; a global barrier ensures all processes + # wait for the cleanup to complete before any TP leader writes a fresh file. + nccl_base_path = os.environ.get("NVTE_JAX_NCCL_FILE_PATH", "/tmp") + if args.process_id == 0: + for f in glob.glob(os.path.join(nccl_base_path, "nccl_*_unique_id_*.bin")): + try: + os.remove(f) + except OSError: + pass + sync_global_devices("nccl_id_cleanup") + print( f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," f" devices_per_process={devices_per_process}, process_id={args.process_id}" @@ -128,6 +159,7 @@ def _initialize_distributed(args): num_devices_per_process=devices_per_process, process_id=args.process_id, tensor_parallel_size=args.tensor_parallel_size, + use_cublasmp=args.use_cublasmp, ) @@ -199,6 +231,16 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") + parser.add_argument( + "--std", + type=float, + default=0.023, + help=( + "Standard deviation for input/weight/bias tensors. Matches TE/PyTorch's" + " run_gemm_with_overlap.py default so both frameworks evaluate FP8 noise" + " on equal footing." + ), + ) parser.add_argument( "--collective-type", type=str, @@ -224,5 +266,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para parser.add_argument( "--enable-result-check", action="store_true", default=True, help="Enable result checking" ) + parser.add_argument( + "--use-cublasmp", + action="store_true", + default=False, + help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation", + ) return parser diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py index 5be5709ba7..d8ffbc2853 100644 --- a/examples/jax/collective_gemm/conftest.py +++ b/examples/jax/collective_gemm/conftest.py @@ -5,6 +5,9 @@ """config for collective_gemm tests""" import pytest +import transformer_engine.jax # noqa: F401 - must load libtransformer_engine.so before transformer_engine_jax +from transformer_engine_jax import nvte_built_with_cublasmp + def pytest_addoption(parser): """Pytest hook for collective_gemm tests""" @@ -12,12 +15,21 @@ def pytest_addoption(parser): parser.addoption("--num-processes", action="store", default=1) parser.addoption("--process-id", action="store", default=0) parser.addoption("--local-device-ids", action="store", default=None) + parser.addoption("--use-cublasmp", action="store_true", default=False) @pytest.fixture(autouse=True) def distributed_args(request): """Fixture for querying distributed initialization arguments""" if request.cls: + use_cublasmp = request.config.getoption("--use-cublasmp") + if use_cublasmp and not nvte_built_with_cublasmp(): + pytest.skip( + "Collective GEMM cuBLASMp backend tests require Transformer Engine to be built " + "with NVTE_WITH_CUBLASMP=1." + ) + if use_cublasmp and "mxfp8" in request.node.name.lower(): + pytest.skip("MXFP8 is not supported by the cuBLASMp backend wrappers in TE/common.") request.cls.coordinator_address = request.config.getoption("--coordinator-address") request.cls.num_processes = int(request.config.getoption("--num-processes")) request.cls.process_id = int(request.config.getoption("--process-id")) @@ -27,3 +39,4 @@ def distributed_args(request): if request.cls.local_device_ids is None else len(request.cls.local_device_ids.split(",")) ) + request.cls.use_cublasmp = use_cublasmp diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index c0a095d6b5..6c2f8e9b8a 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -23,6 +23,30 @@ else echo "NVLINK support detected" fi +echo "*** Checking cuBLASMp support in TE build ***" +CUBLASMP_SUPPORT=$(python3 - <<'PY' +try: + import transformer_engine.jax + from transformer_engine_jax import nvte_built_with_cublasmp +except Exception as exc: + print(f"error:{exc}") + raise SystemExit(0) + +print("1" if nvte_built_with_cublasmp() else "0") +PY +) + +if [[ "$CUBLASMP_SUPPORT" == "1" ]]; then + echo "cuBLASMp backend support detected" + BACKENDS=("cublasmp" "userbuffers") +elif [[ "$CUBLASMP_SUPPORT" == "0" ]]; then + echo "cuBLASMp backend support not detected; skipping cuBLASMp backend tests" + BACKENDS=("userbuffers") +else + echo "Failed to query cuBLASMp support from transformer_engine_jax: $CUBLASMP_SUPPORT" + exit 1 +fi + # Define individual test cases to run (file::class::method) # DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all # the time. @@ -93,50 +117,62 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Clear PIDs array for this test case PIDS=() - for i in $(seq 0 $(($NUM_GPUS - 1))); do - # Define output file for logs - LOG_FILE="${TEST_NAME}_gpu_${i}.log" - - if [ $i -eq 0 ]; then - # For process 0: show live output AND save to log file using tee - echo "=== Live output from process 0 ===" - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \ - "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ - --num-processes=$NUM_GPUS \ - --process-id=$i 2>&1 | tee "$LOG_FILE" & - PID=$! - PIDS+=($PID) + for BACKEND in "${BACKENDS[@]}"; do + echo "Setting backend to $BACKEND for test $TEST_NAME" + + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_NAME}_gpu_${i}_${BACKEND}.log" + + test_args=( + "--num-processes=$NUM_GPUS" + "--process-id=$i" + ) + if [ "$BACKEND" == "cublasmp" ]; then + test_args+=("--use-cublasmp") + fi + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 ===" + pytest -s -c "${TE_PATH}/tests/jax/pytest.ini" -vs \ + "--junitxml=${XML_LOG_DIR}/${TEST_NAME}_gpu_${i}_${BACKEND}.xml" \ + "${TE_PATH}/examples/jax/collective_gemm/${TEST_CASE}" \ + "${test_args[@]}" 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest -s -c "${TE_PATH}/tests/jax/pytest.ini" -vs \ + "${TE_PATH}/examples/jax/collective_gemm/${TEST_CASE}" \ + "${test_args[@]}" > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done + + # Wait for all processes to finish + wait + + # Check and print the log content from process 0 + if grep -q "SKIPPED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then + echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then + echo "... $TEST_CASE FAILED" + HAS_FAILURE=1 + elif grep -q "PASSED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then + echo "... $TEST_CASE PASSED" else - # For other processes: redirect to log files only - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \ - --num-processes=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - PID=$! - PIDS+=($PID) + echo "... $TEST_CASE INVALID" + HAS_FAILURE=1 fi - done - # Wait for all processes to finish - wait - - # Check and print the log content from process 0 - if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then - echo "... $TEST_CASE SKIPPED" - elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then - echo "... $TEST_CASE FAILED" - HAS_FAILURE=1 - elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then - echo "... $TEST_CASE PASSED" - else - echo "... $TEST_CASE INVALID" - HAS_FAILURE=1 - fi - - # Remove the log files after processing them - wait - rm ${TEST_NAME}_gpu_*.log + + # Remove the log files after processing them + wait + rm ${TEST_NAME}_gpu_*_${BACKEND}.log + + done done wait diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index 1d300f8e90..66f2870d49 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -95,11 +95,14 @@ def run_dense_grad_tests(args, mesh=None): # Create test data rng = jax.random.PRNGKey(0) rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) - x = jax.random.normal( + std = jnp.asarray(args.std, dtype=jnp.bfloat16) + x = std * jax.random.normal( x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 ) - weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) - bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + weight = std * jax.random.normal( + weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16 + ) + bias = std * jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) collective_op = ( CollectiveOp.ALL_GATHER @@ -183,6 +186,7 @@ def setUp(self): self.args.process_id = self.process_id self.args.local_device_ids = self.local_device_ids self.args.num_devices_per_process = self.num_devices_per_process + self.args.use_cublasmp = self.use_cublasmp self.args.enable_data_parallel = True self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] _initialize_distributed(self.args) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 8221d7bbfd..a82ee1c042 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -16,6 +16,7 @@ import os from functools import partial +import numpy as np import jax import jax.numpy as jnp from jax.sharding import PartitionSpec, NamedSharding @@ -86,11 +87,14 @@ def run_gemm_tests(args, mesh=None): # Create test data rng = jax.random.PRNGKey(0) rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) - x = jax.random.normal( + std = jnp.asarray(args.std, dtype=jnp.bfloat16) + x = std * jax.random.normal( x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 ) - weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) - bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + weight = std * jax.random.normal( + weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16 + ) + bias = std * jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) collective_op = ( CollectiveOp.ALL_GATHER if args.collective_type == "all_gather" @@ -151,20 +155,27 @@ def run_gemm_tests(args, mesh=None): jax.block_until_ready(gathered_output) if args.enable_result_check and args.process_id == 0: - # CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32). - # With catastrophic cancellation the output is near zero while the absolute diff can - # reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer - # activations at O(8) scale), which exceeds the previous atol=1e-5. The 2x - # margin (0.125) covers this worst-case 1-ULP absolute difference. - is_cgemm_rs_bf16 = collective_op == CollectiveOp.REDUCE_SCATTER and not use_quantization - rtol = 1e-2 if is_cgemm_rs_bf16 else None - atol = 0.125 if is_cgemm_rs_bf16 else None - assert_allclose( - gathered_ref_output, - gathered_output, - dtype=get_tolerance_dtype(quantizer_set), - rtol=rtol, - atol=atol, + if use_quantization: + # FP8 quantization noise on near-zero outputs can exceed the rtol + # gate; allow a small absolute tolerance. + rtol, atol = 0.125, 0.625 + else: + rtol, atol = 0.02, 0.002 + # Use NumPy (not JAX) for the result check to avoid triggering new XLA compilations + # on process 0 only, which would deadlock in multi-process JAX because XLA compilation + # of distributed arrays requires collective synchronization across all processes. + actual = np.asarray(gathered_output, dtype=np.float32) + desired = np.asarray(gathered_ref_output, dtype=np.float32) + diff = np.abs(actual - desired) + abs_desired = np.abs(desired) + failures = (diff > atol) & (diff > rtol * abs_desired) + num_failures = int(np.sum(failures)) + assert num_failures == 0, ( + f"NUMERICAL CHECK FAILED: {num_failures}/{diff.size} elements " + f"({100 * num_failures / diff.size:.4f}%) exceed tolerances " + f"(rtol={rtol}, atol={atol}). " + f"Max abs error: {float(np.max(diff)):.6f}, " + f"max rel error: {float(np.max(diff / np.maximum(abs_desired, 1e-5))):.6f}" ) @@ -180,6 +191,7 @@ def setUp(self): self.args.process_id = self.process_id self.args.local_device_ids = self.local_device_ids self.args.num_devices_per_process = self.num_devices_per_process + self.args.use_cublasmp = self.use_cublasmp self.args.enable_data_parallel = True self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] _initialize_distributed(self.args) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index be94c68d37..a5ba370fd2 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -139,19 +139,26 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split( rng, 7 ) - x = jax.random.normal( + std = jnp.asarray(args.std, dtype=jnp.bfloat16) + x = std * jax.random.normal( x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 ) - weight_1 = jax.random.normal( - weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16 - ) / jnp.sqrt(args.hidden_in) - bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16) - weight_2 = jax.random.normal( - weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16 - ) / jnp.sqrt(args.hidden_out) - bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16) - gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt( - args.hidden_in + weight_1 = ( + std + * jax.random.normal(weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16) + / jnp.sqrt(args.hidden_in) + ) + bias_1 = std * jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16) + weight_2 = ( + std + * jax.random.normal(weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16) + / jnp.sqrt(args.hidden_out) + ) + bias_2 = std * jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16) + gamma = ( + std + * jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) + / jnp.sqrt(args.hidden_in) ) collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER) collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER) @@ -249,6 +256,7 @@ def setUp(self): self.args.process_id = self.process_id self.args.local_device_ids = self.local_device_ids self.args.num_devices_per_process = self.num_devices_per_process + self.args.use_cublasmp = self.use_cublasmp self.args.enable_data_parallel = True self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] _initialize_distributed(self.args) diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt index 0d7258a81d..44ad7c7384 100644 --- a/tests/cpp_distributed/CMakeLists.txt +++ b/tests/cpp_distributed/CMakeLists.txt @@ -21,6 +21,8 @@ project(transformer_engine_distributed_tests LANGUAGES CUDA CXX) add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) +enable_testing() + include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) @@ -30,28 +32,45 @@ if(NOT DEFINED TE_LIB_PATH) get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) endif() -find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) - +find_library(TE_LIB + NAMES transformer_engine + PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} + ENV TE_LIB_PATH + REQUIRED) message(STATUS "Found transformer_engine library: ${TE_LIB}") -include_directories(../../transformer_engine/common/include) -include_directories(../../transformer_engine/common) -include_directories(../../transformer_engine) -include_directories(${CMAKE_SOURCE_DIR}) - -find_package(CUDAToolkit REQUIRED) add_executable(test_comm_gemm test_comm_gemm.cu ../cpp/test_common.cu) +list(APPEND test_comm_gemm_INCLUDES + ${CMAKE_SOURCE_DIR}/../../transformer_engine/common/include + ${CMAKE_SOURCE_DIR}/../../transformer_engine/common + ${CMAKE_SOURCE_DIR}/../../transformer_engine + ${CMAKE_SOURCE_DIR} + ${MPI_CXX_INCLUDE_PATH} + $ENV{CUBLASMP_HOME}/include) +target_include_directories(test_comm_gemm PRIVATE ${test_comm_gemm_INCLUDES}) + +find_package(CUDAToolkit REQUIRED) find_package(OpenMP REQUIRED) find_package(MPI REQUIRED) find_library(NCCL_LIB NAMES nccl libnccl PATH_SUFFIXES lib REQUIRED) -target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) -target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) +list(APPEND test_comm_gemm_LINKER_LIBS + CUDA::cuda_driver + CUDA::cudart + GTest::gtest_main + ${TE_LIB} + CUDA::nvrtc + ${NCCL_LIB} + OpenMP::OpenMP_CXX + MPI::MPI_CXX) +target_link_libraries(test_comm_gemm PUBLIC ${test_comm_gemm_LINKER_LIBS}) + +target_compile_options(test_comm_gemm PRIVATE -O2 -fopenmp) include(GoogleTest) gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu index 45f6664567..b1c4a9395a 100644 --- a/tests/cpp_distributed/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -104,6 +105,96 @@ std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nsta return ret; } +template +std::vector CastToFloat(const std::vector& in) { + std::vector out(in.size()); + for (size_t i = 0; i < in.size(); ++i) { + out[i] = static_cast(in[i]); + } + return out; +} + +template +ncclDataType_t NcclDataType(); + +template <> +ncclDataType_t NcclDataType() { + return ncclFloat16; +} + +template <> +ncclDataType_t NcclDataType() { + return ncclBfloat16; +} + +template <> +ncclDataType_t NcclDataType() { + return ncclFloat; +} + +template <> +ncclDataType_t NcclDataType() { + return ncclFloat8e4m3; +} + +template <> +ncclDataType_t NcclDataType() { + return ncclFloat8e5m2; +} + + + +template +std::vector AllGatherColsSharded(const std::vector& local, int64_t rows, int64_t local_cols, + int64_t global_cols) { + std::vector cols_per_rank{}; + int nranks{}; + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks)); + cols_per_rank.resize(nranks); + CHECK_MPI(MPI_Allgather(&local_cols, 1, MPI_INT64_T, cols_per_rank.data(), 1, MPI_INT64_T, + MPI_COMM_WORLD)); + + std::vector counts(nranks); + std::vector displs(nranks, 0); + for (int r = 0; r < nranks; ++r) { + counts[r] = static_cast(cols_per_rank[r] * rows * sizeof(T)); + if (r > 0) displs[r] = displs[r - 1] + counts[r - 1]; + } + + std::vector gathered(rows * global_cols); + CHECK_MPI(MPI_Allgatherv(local.data(), static_cast(local.size() * sizeof(T)), MPI_BYTE, + gathered.data(), counts.data(), displs.data(), MPI_BYTE, + MPI_COMM_WORLD)); + return gathered; +} + +template +std::vector AllGatherRowsSharded(const std::vector& local, int64_t local_rows, int64_t cols, + int64_t global_rows) { + std::vector rows_per_rank{}; + int nranks{}; + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks)); + rows_per_rank.resize(nranks); + CHECK_MPI(MPI_Allgather(&local_rows, 1, MPI_INT64_T, rows_per_rank.data(), 1, MPI_INT64_T, + MPI_COMM_WORLD)); + + std::vector counts(nranks); + std::vector displs(nranks, 0); + for (int r = 0; r < nranks; ++r) { + counts[r] = static_cast(rows_per_rank[r] * sizeof(T)); + if (r > 0) displs[r] = displs[r - 1] + counts[r - 1]; + } + + std::vector gathered(global_rows * cols); + for (int64_t col = 0; col < cols; ++col) { + const auto* send = reinterpret_cast(local.data() + col * local_rows); + auto* recv = reinterpret_cast(gathered.data() + col * global_rows); + CHECK_MPI(MPI_Allgatherv(send, static_cast(local_rows * sizeof(T)), MPI_BYTE, recv, + counts.data(), displs.data(), MPI_BYTE, MPI_COMM_WORLD)); + } + return gathered; +} + template test::Tensor Make(size_t m, size_t n, float scale) { test::Tensor ret("", std::vector{n, m}, TypeInfo::dtype); @@ -148,6 +239,12 @@ struct Params { class CommGemmFixure : public ::testing::TestWithParam { protected: + enum class OverlapType { + kAllGather, + kReduceScatter, + kAllReduce, + }; + CommGemmFixure() { CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_)); CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); @@ -180,6 +277,8 @@ class CommGemmFixure : public ::testing::TestWithParam { virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0; + virtual OverlapType overlap_type() const = 0; + virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, @@ -213,14 +312,6 @@ class CommGemmFixure : public ::testing::TestWithParam { return static_cast(dist(rng) * bias_scale); }); - auto ga = transa ? MakeFromData(adata, 0, 0, k, m, k, a_scale) - : MakeFromData(adata, 0, 0, m, k, m, a_scale); - auto gb = transb ? MakeFromData(bdata, 0, 0, n, k, n, b_scale) - : MakeFromData(bdata, 0, 0, k, n, k, b_scale); - auto gbias = MakeFromData(biasdata, 0, 0, m, 1, m, bias_scale); - auto gd = Make(m, n, d_scale); - auto gaux = Make(m, n, d_scale); - auto dims = DistributeTensors(m, n, k); auto a = transa ? MakeFromData(adata, dims.a_rows_start, dims.a_cols_start, dims.a_rows_num, dims.a_cols_num, k, a_scale) @@ -240,24 +331,88 @@ class CommGemmFixure : public ::testing::TestWithParam { CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad, accumulate, 0 /*comm_sm_count*/, stream); auto workspace = Make(1, 32 << 20, 1.0); - nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb, - grad, workspace.data(), accumulate, true /* use_split_accumulator */, - 0 /* math_sm_count */, stream); + + std::vector out_golden{}; + if (overlap_type() == OverlapType::kAllGather) { + // Build AG reference input by explicitly all-gathering the sharded input operand. + std::vector b_global_data{}; + if (transb) { + auto b_local = CopyMatrix(bdata, dims.b_cols_start, dims.b_rows_start, dims.b_cols_num, + dims.b_rows_num, n); + b_global_data = AllGatherRowsSharded(b_local, dims.b_cols_num, k, n); + } else { + auto b_local = CopyMatrix(bdata, dims.b_rows_start, dims.b_cols_start, dims.b_rows_num, + dims.b_cols_num, k); + b_global_data = AllGatherColsSharded(b_local, k, dims.b_cols_num, n); + } + + auto b_ref = + transb ? MakeFromData(b_global_data, 0, 0, n, k, n, b_scale) + : MakeFromData(b_global_data, 0, 0, k, n, k, b_scale); + auto d_ref = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + auto aux_ref = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + nvte_cublas_gemm(a.data(), b_ref.data(), d_ref.data(), bias.data(), aux_ref.data(), transa, + transb, grad, workspace.data(), accumulate, + true /* use_split_accumulator */, 0 /* math_sm_count */, stream); + + std::vector out_ref(dims.d_rows_num * dims.d_cols_num); + NVTE_CHECK_CUDA(cudaMemcpy(out_ref.data(), d_ref.rowwise_dptr(), + out_ref.size() * sizeof(out_ref[0]), cudaMemcpyDefault)); + out_golden = CastToFloat(out_ref); + } else { + transformer_engine::TensorWrapper empty_bias(nullptr, std::vector{0}, + TypeInfo::dtype); + auto d_partial = Make(m, n, d_scale); + auto aux_partial = Make(m, n, d_scale); + nvte_cublas_gemm(a.data(), b.data(), d_partial.data(), empty_bias.data(), + aux_partial.data(), transa, transb, grad, workspace.data(), accumulate, + true /* use_split_accumulator */, 0 /* math_sm_count */, stream); + + std::vector partial_host(m * n); + NVTE_CHECK_CUDA(cudaMemcpy(partial_host.data(), d_partial.rowwise_dptr(), + partial_host.size() * sizeof(partial_host[0]), + cudaMemcpyDefault)); + std::vector partial_float = CastToFloat(partial_host); + + auto d_partial_float = Make(m, n, 1.0f); + NVTE_CHECK_CUDA(cudaMemcpy(d_partial_float.rowwise_dptr(), partial_float.data(), + partial_float.size() * sizeof(float), cudaMemcpyDefault)); + + std::vector reduced; + if (overlap_type() == OverlapType::kReduceScatter) { + auto d_reduced_float = Make(dims.d_rows_num, dims.d_cols_num, 1.0f); + CHECK_NCCL(ncclReduceScatter(d_partial_float.rowwise_dptr(), d_reduced_float.rowwise_dptr(), + dims.d_rows_num * dims.d_cols_num, ncclFloat, ncclSum, + comm_, stream)); + + reduced.resize(dims.d_rows_num * dims.d_cols_num); + NVTE_CHECK_CUDA(cudaMemcpy(reduced.data(), d_reduced_float.rowwise_dptr(), + reduced.size() * sizeof(float), cudaMemcpyDefault)); + } else { + CHECK_NCCL(ncclAllReduce(d_partial_float.rowwise_dptr(), d_partial_float.rowwise_dptr(), + m * n, ncclFloat, ncclSum, comm_, stream)); + + reduced.resize(m * n); + NVTE_CHECK_CUDA(cudaMemcpy(reduced.data(), d_partial_float.rowwise_dptr(), + reduced.size() * sizeof(float), cudaMemcpyDefault)); + } + + for (size_t col = 0; col < dims.d_cols_num; ++col) { + for (size_t row = 0; row < m; ++row) { + reduced[col * m + row] += static_cast(biasdata[row]); + } + } + out_golden = std::move(reduced); + } + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); std::vector out(dims.d_rows_num * dims.d_cols_num); NVTE_CHECK_CUDA( cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault)); - std::vector out_golden_global(m * n); - NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(), - out_golden_global.size() * sizeof out_golden_global[0], - cudaMemcpyDefault)); - - auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start, - dims.d_rows_num, dims.d_cols_num, m); NVTE_CHECK(out.size() == out_golden.size()); for (size_t i = 0; i < out.size(); ++i) { - EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol); + EXPECT_NEAR(static_cast(out[i]), out_golden[i], tol); } } @@ -268,6 +423,8 @@ class CommGemmFixure : public ::testing::TestWithParam { }; struct AgGemm : public CommGemmFixure { + OverlapType overlap_type() const override { return OverlapType::kAllGather; } + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m); auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n); @@ -303,6 +460,8 @@ struct AgGemm : public CommGemmFixure { }; struct GemmRs : public CommGemmFixure { + OverlapType overlap_type() const override { return OverlapType::kReduceScatter; } + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { auto rows_num = nvte_comm_gemm_numroc(ctx_, k); auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n); @@ -337,7 +496,10 @@ struct GemmRs : public CommGemmFixure { } }; +#if CUBLASMP_VERSION >= 900 struct GemmAr : public CommGemmFixure { + OverlapType overlap_type() const override { return OverlapType::kAllReduce; } + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { auto rows_num = nvte_comm_gemm_numroc(ctx_, k); @@ -368,6 +530,7 @@ struct GemmAr : public CommGemmFixure { accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); } }; +#endif // CUBLASMP_VERSION >= 900 TEST_P(AgGemm, Gemm) { auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); @@ -389,6 +552,7 @@ TEST_P(GemmRs, Gemm) { d_type, DType, Run(transa, transb, m, n, k, tol);))); } +#if CUBLASMP_VERSION >= 900 TEST_P(GemmAr, Gemm) { auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( @@ -398,6 +562,7 @@ TEST_P(GemmAr, Gemm) { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( d_type, DType, Run(transa, transb, m, n, k, tol);))); } +#endif // CUBLASMP_VERSION >= 900 std::string ParamSuffix(const testing::TestParamInfo& info) { const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param; @@ -450,12 +615,13 @@ INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs, DType::kFloat16, true, false, 64, 128, 256, 7e-2}), &ParamSuffix); +#if CUBLASMP_VERSION >= 900 INSTANTIATE_TEST_SUITE_P( GemmAr, GemmAr, testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64, 64 * 4, 64 * 4, 7e-2}, Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64, - 64 * 4, 64 * 4, 1e-3}, + 64 * 4, 64 * 4, 6e-1}, Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false, 128, 128 * 4, 128 * 4, 1.5e-1}, Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false, @@ -463,3 +629,4 @@ INSTANTIATE_TEST_SUITE_P( Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false, 128, 128 * 4, 128 * 4, 1.5e-1}), &ParamSuffix); +#endif // CUBLASMP_VERSION >= 900 diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 96a7e43231..d1701406f2 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -151,6 +151,27 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs." ) + parser.add_argument( + "--use-cublasmp", action="store_true", default=False, help="Use cuBLASMp backend." + ) + parser.add_argument( + "--rtol", + type=float, + default=None, + help=( + "Override the relative-error tolerance used in the numerical check. " + "When unset, defaults to 0.125 for FP8/MXFP8 and 0.02 otherwise." + ), + ) + parser.add_argument( + "--atol", + type=float, + default=None, + help=( + "Override the absolute-error tolerance used in the numerical check. " + "When unset, defaults to 0.0625 for FP8/MXFP8 and 0.002 otherwise." + ), + ) parser.add_argument( "-v", "--verbose", action="store_true", default=False, help="Verbose info messages." ) @@ -203,6 +224,7 @@ def _main(opts): capture_output=True, text=True, shell=True, + check=False, ) if result.stdout == "0": # Extra checks for non-MNNVL platforms @@ -306,7 +328,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None helper = ( tex.CommOverlapHelper() if tex.ubuf_built_with_mpi() - else tex.CommOverlapHelper(bootstrap_pg) + else tex.CommOverlapHelper(bootstrap_pg, tp_group) ) # Initialize userbuffers with (M, N) buffer @@ -322,50 +344,59 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None and opts.comm_type == tex.CommOverlapType.AG ): buffer_dtype = torch.uint8 - ub_obj = ( - tex.CommOverlapP2P( + if opts.p2p: + ub_obj = tex.CommOverlapP2P( (outer_size, hidden_size), buffer_dtype, helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) opts.comm_type, + use_cublasmp=opts.use_cublasmp, + num_comm_sm=3 if opts.use_cublasmp else 1, set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, atomic_gemm=opts.atomic, aggregate=opts.aggregate, use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))), ) - if opts.p2p - else tex.CommOverlap( + else: + ub_obj = tex.CommOverlap( (outer_size, hidden_size), buffer_dtype, helper, tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + use_cublasmp=opts.use_cublasmp, + comm_type=opts.comm_type, + num_comm_sm=16, atomic_gemm=opts.atomic, ) - ) # Numerical check on AG + atomic GEMM requires testing an AG+RS pair ub_obj2 = None if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics: - ub_obj2 = ( - tex.CommOverlapP2P( + ub2_buffer_dtype = torch.uint8 if opts.fp8_output else torch.bfloat16 + if opts.atomic_rs_p2p: + ub_obj2 = tex.CommOverlapP2P( (outer_size, hidden_size), - torch.uint8 if opts.fp8_output else torch.bfloat16, + ub2_buffer_dtype, helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + tp_size, tex.CommOverlapType.RS, + use_cublasmp=opts.use_cublasmp, + num_comm_sm=16 if opts.use_cublasmp else 1, set_sm_margin=True, atomic_gemm=True, ) - if opts.atomic_rs_p2p - else tex.CommOverlap( + else: + ub_obj2 = tex.CommOverlap( (outer_size, hidden_size), - torch.uint8 if opts.fp8_output else torch.bfloat16, + ub2_buffer_dtype, helper, - tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) + tp_size, + use_cublasmp=opts.use_cublasmp, + comm_type=tex.CommOverlapType.RS, + num_comm_sm=3 if opts.use_cublasmp else 16, atomic_gemm=True, ) - ) # Figure out problem sizing: # M = sequence * batch @@ -408,7 +439,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None mean=0.0, std=opts.std, ) - if ub_obj2 is not None: + if opts.comm_type == tex.CommOverlapType.AG and ub_obj2 is not None: kernel2_t = torch.nn.init.normal_( torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"), mean=0.0, @@ -426,25 +457,19 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None ) else: if opts.comm_type == tex.CommOverlapType.AG: - # AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K) - ker_g = torch.transpose( - te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1 - ).to(dtype=torch.float32) - # AG Input: (M/P, N) -> gather -> (M, N) - inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0].to(dtype=torch.float32) + # AG Kernel: Keep local (K/P, N) for per-rank reference comparison + ker_g = kernel_t + # AG Input: (M/P, N) -> gather -> (M, N); full input needed for local GEMM chunk + inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0] if ub_obj2 is not None: - ker2_g = te.distributed.gather_along_first_dim( - torch.transpose(kernel2_t, 0, 1), tp_group - )[0].to(dtype=torch.float32) + # Keep kernel2 local (N, K/P) for per-rank reference comparison + ker2_g = kernel2_t else: - # RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N) - ker_g = te.distributed.gather_along_first_dim( - torch.transpose(kernel_t, 0, 1), tp_group - )[0].to(dtype=torch.float32) - # RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) - inp_g = torch.transpose( - te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1 - ).to(dtype=torch.float32) + # RS: Compute local GEMM on each rank, will apply reduce-scatter after + # RS Kernel: (N, K/P) - keep local + ker_g = kernel_t + # RS Input: (M, K/P) - keep local + inp_g = inp if opts.bulk_overlap: if opts.comm_type == tex.CommOverlapType.AG: @@ -456,10 +481,46 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Sum the list together for final global result ref_g = torch.stack(bulk_inp_list).sum(dim=0) else: - ref_g = torch.matmul(inp_g, ker_g) + # cuBLASMp always uses split-accumulator internally and does not expose a + # control to disable it, so force the reference to match when testing the + # cuBLASMp backend; otherwise honor the framework default. + ref_use_split_accumulator = True if opts.use_cublasmp else te.module.base._2X_ACC_FPROP + # For AG: ker_g=kernel_t=(K/P,N), inp_g=(M,N) -> ref_g=(M,K/P) local chunk + # For RS: ker_g=kernel_t=(N,K/P), inp_g=(M,K/P) -> ref_g=(M,N) partial + ref_g, *_ = tex.general_gemm( + ker_g, + inp_g, + out_dtype=torch.bfloat16, + use_split_accumulator=ref_use_split_accumulator, + ) + if opts.comm_type == tex.CommOverlapType.RS: + # Apply non-overlapped reduce-scatter to local reference GEMM output + # ref_g is currently (M, N) on each rank (partial result from local GEMM) + # All-gather to collect all (M, N) from all ranks + ref_rs_list = [torch.zeros_like(ref_g) for _ in range(tp_size)] + dist.all_gather(ref_rs_list, ref_g, group=tp_group) + # Stack and sum across ranks to get global result + ref_global = torch.stack(ref_rs_list, dim=0).sum(dim=0) # (M, N) + # Scatter: each rank keeps its portion (M/P, N) + start_idx = tp_rank * (outer_size // tp_size) + end_idx = (tp_rank + 1) * (outer_size // tp_size) + ref_g = ref_global[start_idx:end_idx, :] # (M/P, N) if ub_obj2 is not None: inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable - ref2_g = torch.matmul(inp2_g, ker2_g) + # ker2_g=kernel2_t=(N,K/P), inp2_g=(M,K/P) -> partial (M,N) per rank + ref2_partial, *_ = tex.general_gemm( + ker2_g, + inp2_g, + out_dtype=torch.bfloat16, + use_split_accumulator=ref_use_split_accumulator, + ) + # Apply non-overlapped reduce-scatter to partial results + ref2_rs_list = [torch.zeros_like(ref2_partial) for _ in range(tp_size)] + dist.all_gather(ref2_rs_list, ref2_partial, group=tp_group) + ref2_global = torch.stack(ref2_rs_list, dim=0).sum(dim=0) # (M, N) fully reduced + start_idx = tp_rank * (outer_size // tp_size) + end_idx = (tp_rank + 1) * (outer_size // tp_size) + ref2_g = ref2_global[start_idx:end_idx, :] # (M/P, N) # Initialize quantizers with_quantized_compute = opts.quantization != "none" @@ -580,7 +641,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None tp_group, ) gemm_inp = inp - else: + elif not opts.use_cublasmp: ag_out, _ = fill_userbuffers_buffer_for_all_gather( ub_obj, inp_fp8 if with_quantized_compute else inp, @@ -588,6 +649,8 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None tp_group, ) gemm_inp = ag_out + else: + gemm_inp = inp_fp8 if with_quantized_compute else inp if ub_obj2 is not None: rs_out2 = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" @@ -658,13 +721,13 @@ def _gemm(): # Trace the CUDA graph first g = torch.cuda.CUDAGraph() if with_quantized_compute: - if ub_obj is None: + if ub_obj2 is None: with torch.cuda.graph(g): all_outputs = _fp8_gemm() else: with torch.cuda.graph(g): all_outputs = _fp8_gemm() - _ = _fp8_gemm2(all_outputs[0]) + all_outputs2 = _fp8_gemm2(all_outputs[0]) else: with torch.cuda.graph(g): all_outputs = _gemm() @@ -682,7 +745,7 @@ def _gemm(): all_outputs = _fp8_gemm() end_events[i].record() if ub_obj2 is not None: - _fp8_gemm2(all_outputs[0]) + all_outputs2 = _fp8_gemm2(all_outputs[0]) else: start_events[i].record() all_outputs = _gemm() @@ -758,23 +821,26 @@ def _gemm(): else: if opts.comm_type == tex.CommOverlapType.AG: if ub_obj2 is not None: - # AG+RS Output: (M/P, N) -> gather -> (M, N) - output = rs_out2.to(dtype=torch.float32) - test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] + # AG+RS Output: Keep local (M/P, N) for comparison with local reference + output = ( + rs_out2.to(dtype=torch.float32) + if not opts.use_cublasmp + else (all_outputs2[0].dequantize() if opts.fp8_output else all_outputs2[0]) + ) + test_out = output else: - # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) + # AG Output: Keep local (M, K/P) for comparison with local reference output = all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0] - test_out = torch.transpose( - te.distributed.gather_along_first_dim( - torch.transpose(output, 0, 1), tp_group - )[0], - 0, - 1, - ) + test_out = output else: - # RS Output: (M/P, N) -> gather -> (M, N) - output = rs_out.to(dtype=torch.float32) - test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] + # RS Output: Keep local (M/P, N) for comparison with local reference + output = ( + rs_out.to(dtype=torch.float32) + if not opts.use_cublasmp + else (all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0]) + ) + # Don't gather - keep local for comparison with local reduce-scattered reference + test_out = output ref_out = ref2_g if ub_obj2 is not None else ref_g test_nonzeros = torch.count_nonzero(test_out) @@ -782,7 +848,9 @@ def _gemm(): nonzero_info = ( f"output nonzeros = {test_nonzeros} " + f"| reference count = {ref_nonzeros}" ) - dist_print(nonzero_info, src=0, section=True, group=tp_group) + + # Both AG and RS now compare local outputs across all ranks + dist_print(nonzero_info, section=True, group=tp_group) sizing_info = ( f"input: {list(inp.shape)} " + f"| GEMM1 weights: {list(kernel_t.shape)[::-1]} " @@ -792,15 +860,17 @@ def _gemm(): sizing_info += f"| output: {list(output.shape)}\n" dist_print(sizing_info, section=True, group=tp_group) + # Both AG and RS now compare local outputs; print per-rank sizing_info_g = ( - f"input: {list(inp_g.shape)} " + f"| GEMM1 weights: {list(ker_g.shape)} " + f"input: {list(inp.shape)} | GEMM1 weights: {list(kernel_t.shape)[::-1]} " ) if ub_obj2 is not None: - sizing_info_g += f"| GEMM2 weights: {list(ker2_g.shape)} " + sizing_info_g += f"| GEMM2 weights: {list(kernel2_t.shape)[::-1]} " sizing_info_g += ( - f"| output: {list(test_out.shape)} " + f"| reference: {list(ref_out.shape)}\n" + f"| output (local): {list(test_out.shape)} | reference (local):" + f" {list(ref_out.shape)}\n" ) - dist_print(sizing_info_g, src=0, group=tp_group) + dist_print(sizing_info_g, group=tp_group) torch.cuda.synchronize() dist.barrier(tp_group) @@ -808,8 +878,14 @@ def _gemm(): m = torch.argmax(diff) abs_err = diff[m].item() rel_err = abs_err / max(abs(ref_out.flatten()[m].item()), 1e-5) - rtol = 0.02 if opts.quantization == "none" else 0.125 - atol = 0.001 if opts.quantization == "none" else 0.0625 + rtol = ( + opts.rtol if opts.rtol is not None else (0.02 if opts.quantization == "none" else 0.125) + ) + atol = ( + opts.atol + if opts.atol is not None + else (0.002 if opts.quantization == "none" else 0.0625) + ) if rel_err > rtol and abs_err > atol: numerics_failed = True numerics_info = ( @@ -828,9 +904,7 @@ def _gemm(): if abs_err <= atol: numerics_info += f"abs. error = {abs_err} (tol = {atol})" - dist_print( - numerics_info, src=0, section=True, info=True, error=numerics_failed, group=tp_group - ) + dist_print(numerics_info, section=True, info=True, error=numerics_failed, group=tp_group) dist.barrier(tp_group) if LOCAL_RANK == 0: diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 53c7a5e7cc..46795415e5 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -23,6 +23,7 @@ DelayedScaling, Float8CurrentScaling, Format, + MMParams, MXFP8BlockScaling, ) @@ -258,6 +259,30 @@ def _parse_args(argv=None, namespace=None): default=0, help="Number of layers at the end to run in bf16.", ) + parser.add_argument( + "--use-cublasmp", + action="store_true", + default=False, + help="Use cuBLASMp backend.", + ) + parser.add_argument( + "--rtol", + type=float, + default=None, + help=( + "Override the relative-error tolerance used in the numerical check. " + "When unset, defaults to 0.125 for FP8 and 0.025 otherwise." + ), + ) + parser.add_argument( + "--atol", + type=float, + default=None, + help=( + "Override the absolute-error tolerance used in the numerical check. " + "When unset, defaults to 0.0625 for FP8 and 0.00125 otherwise." + ), + ) args = parser.parse_args(argv, namespace) if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: @@ -436,6 +461,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, + with_cublasmp=opts.use_cublasmp, ) with te.quantized_model_init(enabled=opts.fp8_init): @@ -471,6 +497,11 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): elif opts.quantization == "mxfp8": fp8_recipe = MXFP8BlockScaling() + if opts.fp8: + fp8_recipe.fp8_gemm_fprop = MMParams(use_split_accumulator=True) + fp8_recipe.fp8_gemm_dgrad = MMParams(use_split_accumulator=True) + fp8_recipe.fp8_gemm_wgrad = MMParams(use_split_accumulator=True) + layer_contexts = [ ( partial( @@ -552,8 +583,8 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 + rtol = opts.rtol if opts.rtol is not None else (0.125 if opts.fp8 else 0.025) + atol = opts.atol if opts.atol is not None else (0.0625 if opts.fp8 else 0.00125) grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 7a81f93bd6..7ec1b91dfc 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -50,7 +50,9 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization): +def _run_gemm_with_overlap( + comm_type, bulk, p2p, atomic, aggregate, quantization, use_cublasmp=False +): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -79,6 +81,18 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization test_cmd.append("--atomic") if aggregate: test_cmd.append("--aggregate") + if use_cublasmp: + if not tex.nvte_built_with_cublasmp(): + pytest.skip("Transformer Engine not built with cuBLASMp (NVTE_WITH_CUBLASMP=0).") + if quantization == "mxfp8": + pytest.skip( + "cuBLASMp comm+GEMM overlap does not yet support MXFP8 (block scaling)." + ) + if comm_type == "RS" and not p2p and not tex.device_supports_multicast(): + pytest.skip( + "cuBLASMp non-P2P reduce-scatter requires NVSwitch (multicast support)." + ) + test_cmd.append("--use-cublasmp") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) if ( @@ -90,7 +104,13 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, aggregate, quantization def _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + fp8, + quantization, + num_layers=1, + use_cublasmp=False, ): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ @@ -117,6 +137,13 @@ def _run_layer_with_overlap( test_cmd.append("--fp8") test_cmd.append(f"--quantization={quantization}") + if use_cublasmp: + if not tex.nvte_built_with_cublasmp(): + pytest.skip("Transformer Engine not built with cuBLASMp (NVTE_WITH_CUBLASMP=0).") + if quantization == "mxfp8": + pytest.skip("cuBLASMp comm+GEMM overlap does not yet support MXFP8 (block scaling).") + test_cmd.append("--use-cublasmp") + os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" @@ -141,24 +168,26 @@ def _run_layer_with_overlap( raise AssertionError(result.stderr.decode()) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("aggregate", (False, True)) -def test_split_all_gather_overlaps(quantization, aggregate): +def test_split_all_gather_overlaps(quantization, aggregate, use_cublasmp): """ Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization) + _run_gemm_with_overlap("AG", False, True, False, aggregate, quantization, use_cublasmp) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize("quantization", ("none", "fp8", "mxfp8")) @pytest.mark.parametrize("p2p", (False, True)) -def test_split_reduce_scatter_overlaps(quantization, p2p): +def test_split_reduce_scatter_overlaps(quantization, p2p, use_cublasmp): """ Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, False, quantization) + _run_gemm_with_overlap("RS", False, p2p, False, False, quantization, use_cublasmp) @pytest.mark.parametrize( @@ -197,6 +226,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): _run_gemm_with_overlap(comm_type, True, False, False, False, quantization) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", (False,), @@ -235,13 +265,18 @@ def test_bulk_overlaps(comm_type, quantization, connections): ) ], ) -def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): +def test_layers_with_overlap_bf16( + layer_type, linear_parallel_mode, overlap_rs_dgrad, use_cublasmp, fp8 +): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None) + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, use_cublasmp=use_cublasmp + ) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"], @@ -282,13 +317,22 @@ def test_layers_with_overlap_fp8( linear_parallel_mode, overlap_rs_dgrad, quantization, + use_cublasmp, ): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization) + _run_layer_with_overlap( + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + True, + quantization, + use_cublasmp=use_cublasmp, + ) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "fp8", (False,), @@ -321,16 +365,23 @@ def test_layers_with_overlap_fp8( ], ) def test_multi_layer_with_overlap_bf16( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers, use_cublasmp ): """ Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + fp8, + None, + num_layers, + use_cublasmp=use_cublasmp, ) +@pytest.mark.parametrize("use_cublasmp", (False, True)) @pytest.mark.parametrize( "quantization", ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"], @@ -357,11 +408,17 @@ def test_multi_layer_with_overlap_bf16( ], ) def test_multi_layer_with_overlap_fp8( - layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers + layer_type, linear_parallel_mode, overlap_rs_dgrad, quantization, num_layers, use_cublasmp ): """ Test Transformer Engine layers with comm+GEMM overlap. """ _run_layer_with_overlap( - layer_type, linear_parallel_mode, overlap_rs_dgrad, True, quantization, num_layers + layer_type, + linear_parallel_mode, + overlap_rs_dgrad, + True, + quantization, + num_layers, + use_cublasmp=use_cublasmp, ) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 3dcefd46fd..3caca3b62c 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -477,6 +477,7 @@ def main() -> None: parser.add_argument("--head-dim", type=int, default=256) parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--use-cublasmp", type=bool, action="store_true") args = parser.parse_args() # Run parallel tests if needed @@ -517,6 +518,7 @@ def main() -> None: dtype=model_config.dtype, bootstrap_backend=bootstrap_backend, ub_cfgs=userbuffer_configs, + use_cublasmp=args.use_cublasmp, ) # Run tests diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 06d85b6d84..66c052dfcf 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -139,6 +139,40 @@ function(find_nccl_version OUT_VERSION OUT_INCLUDE_DIR) set(${OUT_INCLUDE_DIR} "${_nvte_nccl_include_dir}" PARENT_SCOPE) endfunction() +function(find_cublasmp_version OUT_VERSION OUT_INCLUDE_DIR SEARCH_DIR) + find_path(_nvte_cublasmp_include_dir + NAMES cublasmp.h + HINTS "${SEARCH_DIR}/include" + PATH_SUFFIXES include + REQUIRED) + + file(STRINGS "${_nvte_cublasmp_include_dir}/cublasmp.h" _nvte_cublasmp_major_line + REGEX "^#define CUBLASMP_VER_MAJOR[ \t]+[0-9]+$") + file(STRINGS "${_nvte_cublasmp_include_dir}/cublasmp.h" _nvte_cublasmp_minor_line + REGEX "^#define CUBLASMP_VER_MINOR[ \t]+[0-9]+$") + file(STRINGS "${_nvte_cublasmp_include_dir}/cublasmp.h" _nvte_cublasmp_patch_line + REGEX "^#define CUBLASMP_VER_PATCH[ \t]+[0-9]+$") + + string(REGEX REPLACE "^#define CUBLASMP_VER_MAJOR[ \t]+([0-9]+)$" "\\1" + _nvte_cublasmp_major "${_nvte_cublasmp_major_line}") + string(REGEX REPLACE "^#define CUBLASMP_VER_MINOR[ \t]+([0-9]+)$" "\\1" + _nvte_cublasmp_minor "${_nvte_cublasmp_minor_line}") + string(REGEX REPLACE "^#define CUBLASMP_VER_PATCH[ \t]+([0-9]+)$" "\\1" + _nvte_cublasmp_patch "${_nvte_cublasmp_patch_line}") + + if ("${_nvte_cublasmp_major}" STREQUAL "" + OR "${_nvte_cublasmp_minor}" STREQUAL "" + OR "${_nvte_cublasmp_patch}" STREQUAL "") + message(FATAL_ERROR + "Failed to parse cuBLASMp version from ${_nvte_cublasmp_include_dir}/cublasmp.h") + endif() + + set(${OUT_VERSION} + "${_nvte_cublasmp_major}.${_nvte_cublasmp_minor}.${_nvte_cublasmp_patch}" + PARENT_SCOPE) + set(${OUT_INCLUDE_DIR} "${_nvte_cublasmp_include_dir}" PARENT_SCOPE) +endfunction() + # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -352,25 +386,36 @@ endif() option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) if (NVTE_WITH_CUBLASMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) find_nccl_version(NCCL_VERSION NCCL_INCLUDE_DIR) + find_cublasmp_version(CUBLASMP_VERSION CUBLASMP_INCLUDE_DIR ${CUBLASMP_DIR}) find_library(CUBLASMP_LIB - NAMES cublasmp libcublasmp + NAMES cublasmp libcublasmp.so libcublasmp.so.0 PATHS ${CUBLASMP_DIR} - PATH_SUFFIXES lib + PATH_SUFFIXES lib lib64 lib/aarch64-linux-gnu lib/sbsa-linux-gnu lib/x86_64-linux-gnu REQUIRED) find_library(NCCL_LIB NAMES nccl libnccl PATH_SUFFIXES lib REQUIRED) - if (NCCL_VERSION VERSION_LESS 2.29.0) + # cuBLASMp 0.8 is the first release with CUDA-graph-safe overlap algos, + # and NCCL 2.30 is the first release with graph-safe one-sided RMA + # primitives (ncclPutSignal/ncclWaitSignal) that those algos use. + if (CUBLASMP_VERSION VERSION_LESS 0.8.0) + message(FATAL_ERROR + "NVTE_WITH_CUBLASMP requires cuBLASMp >= 0.8.0, but found cuBLASMp " + "${CUBLASMP_VERSION} in ${CUBLASMP_INCLUDE_DIR}/cublasmp.h") + endif() + if (NCCL_VERSION VERSION_LESS 2.30.0) message(FATAL_ERROR - "NVTE_WITH_CUBLASMP requires NCCL >= 2.29.0, but found NCCL ${NCCL_VERSION} " - "in ${NCCL_INCLUDE_DIR}/nccl.h") + "NVTE_WITH_CUBLASMP requires NCCL >= 2.30.0 (for graph-capture-safe " + "one-sided RMA primitives used by cuBLASMp's overlap algorithms), but " + "found NCCL ${NCCL_VERSION} in ${NCCL_INCLUDE_DIR}/nccl.h") endif() target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) - message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") + message(STATUS "Using cuBLASMp ${CUBLASMP_VERSION} at: ${CUBLASMP_DIR}") message(STATUS "Using NCCL ${NCCL_VERSION} at: ${NCCL_LIB}") endif() diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index ce389c2006..92bd28af41 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -187,6 +187,7 @@ void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n get_cuda_dtype(a->dtype()), ctx->grid_row_major.get(), ctx->a_desc.get())); } + // B is (K/P, N) local -- K is distributed across ranks, N is fully replicated. if (transb) { NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), block_size(ctx, k), @@ -273,10 +274,47 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo "Unsupported scaling mode: " + std::to_string(t->scaling_mode)); } + // Mirror cublaslt_gemm.cu's CanonicalizeGemmInput for FP8 tensor scaling: depending on the + // architecture and the quantizer's usage modes, the appropriate data + scale_inv + // may live on the rowwise or columnwise side of the tensor. + // * Hopper (!nvte_is_non_tn_fp8_gemm_supported): only TN FP8 GEMMs are supported, so an + // FP8 input not already in TN orientation must be swapped to its columnwise (transposed) + // view and the transpose flag flipped. + // * Blackwell+ (nvte_is_non_tn_fp8_gemm_supported): any FP8 GEMM layout is supported, but + // the quantizer usage may have only been set to columnwise. In that case, fall back to + // the columnwise view and flip the transpose flag so the GEMM sees the matching data and + // scale_inv pair. + // The original tensor is never modified; a new Tensor view aliases the columnwise pointers. + const bool fp8_needs_tn = !nvte_is_non_tn_fp8_gemm_supported(); + auto canonicalize_fp8_input = [fp8_needs_tn](const Tensor* t, bool current_trans, bool want_trans, + const char* side) -> std::pair { + if (!is_fp8_dtype(t->dtype())) { + return {*t, current_trans}; + } + const bool hopper_tn_swap = fp8_needs_tn && current_trans != want_trans; + const bool blackwell_missing_rowwise = !fp8_needs_tn && !t->has_data(); + if (!hopper_tn_swap && !blackwell_missing_rowwise) { + return {*t, current_trans}; + } + NVTE_CHECK(t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype), + "cuBLASMp FP8 GEMM input ", side, " is missing column-wise usage"); + Tensor view; + view.scaling_mode = t->scaling_mode; + view.data = t->columnwise_data; + view.scale_inv = t->columnwise_scale_inv; + // Columnwise data is the transposed view of the original — flip the transpose flag. + return {view, !current_trans}; + }; + + auto [a_used, transa_eff] = canonicalize_fp8_input(a, transa, /*want_trans=*/true, "A"); + auto [b_used, transb_eff] = canonicalize_fp8_input(b, transb, /*want_trans=*/false, "B"); + transa = transa_eff; + transb = transb_eff; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorInit(ctx->matmul_desc.get(), CUBLAS_COMPUTE_32F)); int64_t ldd{}; - init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb); + init_matrices_fn(ctx, &ldd, m, n, k, &a_used, &b_used, d, transa, transb); const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -292,23 +330,23 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo sizeof algo_attr)); const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; - if (is_fp8_dtype(a->dtype())) { - NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + if (is_fp8_dtype(a_used.dtype())) { + NVTE_CHECK(a_used.scale_inv.dptr, "Scaling must be set for FP8 dtype"); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, sizeof scale_mode)); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, - &a->scale_inv.dptr, sizeof(void*))); + &a_used.scale_inv.dptr, sizeof(void*))); } - if (is_fp8_dtype(b->dtype())) { - NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + if (is_fp8_dtype(b_used.dtype())) { + NVTE_CHECK(b_used.scale_inv.dptr, "Scaling must be set for FP8 dtype"); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, sizeof scale_mode)); NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, - &b->scale_inv.dptr, sizeof(void*))); + &b_used.scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(d->dtype())) { NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); @@ -407,11 +445,11 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo n, k, &alpha, - a->data.dptr, + a_used.data.dptr, 1, 1, ctx->a_desc.get(), - b->data.dptr, + b_used.data.dptr, 1, 1, ctx->b_desc.get(), @@ -445,9 +483,6 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::apply(cublasMpMatmul, std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size, workspace_host.data(), workspace_host.size()}))); - - NVTE_CHECK_CUDA(cudaEventRecord(ctx->event.get(), main_stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->event.get(), 0)); } } // namespace @@ -481,6 +516,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank .b_desc = std::move(b_desc), .d_desc = std::move(d_desc), .matmul_desc = std::move(matmul_desc), + .workspace = nullptr, + .workspace_size = 0, }; } diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 28218e2b43..30e69b37f7 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -23,6 +24,14 @@ using namespace std::placeholders; +bool nvte_built_with_cublasmp() { +#ifdef NVTE_WITH_CUBLASMP + return true; +#else + return false; +#endif +} + namespace transformer_engine { namespace { @@ -33,10 +42,6 @@ std::vector shape_to_vector(const NVTEShape &shape) { } // namespace -/*************************************************************************************************** - * Comm+GEMM Overlap Common Core - **************************************************************************************************/ - bool ubuf_built_with_mpi() { #ifdef NVTE_UB_WITH_MPI return true; @@ -45,6 +50,11 @@ bool ubuf_built_with_mpi() { #endif } +/*************************************************************************************************** + * Comm+GEMM Overlap Common Core + **************************************************************************************************/ + +// Constructor for Userbuffers backend CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, @@ -69,6 +79,35 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl num_comm_sm, set_sm_margin, use_ce, atomic_gemm); } +// Constructor for cuBLASMp backend +CommOverlapCore::CommOverlapCore(ncclComm_t nccl_comm_ptr, int tp_rank, int tp_size, + int num_comm_sm, bool is_p2p, bool atomic_gemm) { + NVTE_CHECK( + nvte_built_with_cublasmp(), + "Comm+GEMM overlap with cuBLASMp backend requires TE to be built with NVTE_WITH_CUBLASMP=1."); + _with_cublasmp = true; + _cublasmp_ctx = nvte_comm_gemm_ctx_create(nccl_comm_ptr, tp_size, tp_rank); + + _tp_id = tp_rank; + _tp_size = tp_size; + _num_comm_sm = num_comm_sm; + _is_p2p = is_p2p; + _atomic_gemm = atomic_gemm; + if (_is_p2p) { + if (_atomic_gemm) { + _algo_type = kNVTECommGemmAlgoAtomicP2P; + } else { + _algo_type = kNVTECommGemmAlgoSplitP2P; + } + } else { + if (_atomic_gemm) { + _algo_type = kNVTECommGemmAlgoAtomicMulticast; + } else { + _algo_type = kNVTECommGemmAlgoSplitMulticast; + } + } +} + void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, @@ -134,39 +173,43 @@ void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_stream } CommOverlapCore::~CommOverlapCore() { - cudaEventDestroy(_stop_comm); - cudaEventDestroy(_start_comm); - cudaEventDestroy(_stop_compute); - cudaEventDestroy(_start_compute); - if (_comm_launch_event) { - cudaEventDestroy(_comm_launch_event); - } + if (_with_cublasmp) { + nvte_comm_gemm_ctx_destroy(_cublasmp_ctx); + } else { + cudaEventDestroy(_stop_comm); + cudaEventDestroy(_start_comm); + cudaEventDestroy(_stop_compute); + cudaEventDestroy(_start_compute); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } - if (_atomic_gemm) { - cudaFree(_counter.dptr()); - } + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } - for (size_t i = 0; i < _stream_compute.size(); i++) { - cudaStreamSynchronize(_stream_compute[i]); - cudaStreamDestroy(_stream_compute[i]); - } + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } - auto error = cudaGetLastError(); - if (error != cudaSuccess) { - NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); - } + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } - if (_comm_created) { - try { + if (_comm_created) { + try { #ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); + destroy_communicator_mpi(_ub_comm); #else - destroy_communicator(_ub_comm); + destroy_communicator(_ub_comm); #endif - } catch (const std::exception &e) { - NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } + _comm_created = false; } - _comm_created = false; } } @@ -290,10 +333,99 @@ TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source return chunk; } +namespace { + +struct CublasMpDims { + int64_t m, n, k; +}; + +// Resolve the global m/n/k for the three cuBLASMp communication patterns +// from flattened operand shapes. +CublasMpDims compute_ag_dims(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, int tp_size) { + auto A_tensor = convertNVTETensorCheck(A.data()); + auto B_tensor = convertNVTETensorCheck(B.data()); + int64_t A0 = A_tensor->flat_first_dim(); + int64_t A1 = A_tensor->flat_last_dim(); + int64_t B0 = B_tensor->flat_first_dim(); + int64_t B1 = B_tensor->flat_last_dim(); + + // col-major A: (M/P, K) -- tensor-parallel in the non-contracting dimension + int64_t m = (transa ? A0 : A1) * tp_size; + // col-major B: (K, N/P) -- sequence-parallel in the non-contracting dimension + int64_t n = (transb ? B1 : B0) * tp_size; + // contracting dimension not distributed + int64_t k = transa ? A1 : A0; + return {m, n, k}; +} + +CublasMpDims compute_rs_dims(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, int tp_size) { + auto A_tensor = convertNVTETensorCheck(A.data()); + auto B_tensor = convertNVTETensorCheck(B.data()); + int64_t A0 = A_tensor->flat_first_dim(); + int64_t A1 = A_tensor->flat_last_dim(); + int64_t B0 = B_tensor->flat_first_dim(); + int64_t B1 = B_tensor->flat_last_dim(); + + // col-major A: (M, K/P) -- tensor-parallel in the contracting dimension + int64_t m = transa ? A0 : A1; + // col-major B: (K/P, N) -- tensor-parallel in the contracting dimension + int64_t n = transb ? B1 : B0; + // contracting dimension is distributed + int64_t k = (transa ? A1 : A0) * tp_size; + return {m, n, k}; +} + +CublasMpDims compute_ar_dims(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, int tp_size) { + // AR shares the same m/n/k semantics as RS at descriptor level. + return compute_rs_dims(A, transa, B, transb, tp_size); +} + +} // namespace + +void CommOverlapCore::cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + auto [m, n, k] = compute_ag_dims(A, transa, B, transb, _tp_size); + // col-major GEMM compute overlapped with all-gather on input B + // (M/P, K) x [(K, N/P) -(AG)-> (K, N)] = (M/P, N) + nvte_all_gather_gemm(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + +void CommOverlapCore::cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + auto [m, n, k] = compute_rs_dims(A, transa, B, transb, _tp_size); + // col-major GEMM compute overlapped with reduce-scatter on the output + // (M, K/P) x (K/P, N) = (M, N) -(RS)-> (M, N/P) + nvte_gemm_reduce_scatter(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + +void CommOverlapCore::cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, bool grad, bool accumulate, + cudaStream_t stream_main) { + auto [m, n, k] = compute_ar_dims(A, transa, B, transb, _tp_size); + // col-major GEMM compute overlapped with all-reduce on the output + // (M, K/P) x (K/P, N) = (M, N) -(AR)-> (M, N) + nvte_gemm_all_reduce(_cublasmp_ctx, m, n, k, A.data(), B.data(), D.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, accumulate, _num_comm_sm, + stream_main, _algo_type); +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ +// Constructor for Userbuffers backend CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, @@ -331,9 +463,11 @@ void CommOverlapBase::initialize(const std::vector &buffer_shape, DType } CommOverlapBase::~CommOverlapBase() { - cudaEventDestroy(_start_d2dcopy); - cudaStreamSynchronize(_stream_comm); - cudaStreamDestroy(_stream_comm); + if (!_with_cublasmp) { + cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); + cudaStreamDestroy(_stream_comm); + } } /* @@ -346,6 +480,8 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_CHECK(!_with_cublasmp, "Bulk overlap is not supported with cuBlasMp"); + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -403,10 +539,16 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions size_t m = transa ? A.size(0) : A.size(1); size_t k = transa ? A.size(1) : A.size(0); @@ -499,6 +641,11 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -662,6 +809,7 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ +// Constructor for Userbuffers backend CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, @@ -742,16 +890,20 @@ void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DTy } CommOverlapP2PBase::~CommOverlapP2PBase() { - cudaEventDestroy(_stop_recv); - cudaEventDestroy(_stop_send); - cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) { - cudaStreamDestroy(_stream_send[i]); + if (!_with_cublasmp) { + cudaEventDestroy(_stop_recv); + cudaEventDestroy(_stop_send); + cudaStreamDestroy(_stream_recv); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } } } void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, bool rowwise) { + if (_with_cublasmp) return; // cuBlasMp executes its own copy-into-buffer op + // Check element size const size_t element_size = source.element_size(); NVTE_CHECK(_ubuf.element_size() == element_size, @@ -806,6 +958,11 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag( const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -908,6 +1065,11 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_ag_gemm(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1075,6 +1237,11 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs( TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1139,6 +1306,11 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + if (_with_cublasmp) { + return cublasmp_gemm_rs(A, transa, B, transb, D, bias, pre_gelu_out, grad, accumulate, + stream_main); + } + int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 6307eab14c..f700b540da 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -9,6 +9,8 @@ #include #include +#include +#include #include #include @@ -17,6 +19,12 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +/* \brief Check if TE is built with cuBLASMp. + * + * \return True if TE is built with cuBLASMp. + */ +bool nvte_built_with_cublasmp(); + namespace transformer_engine { /* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. @@ -59,6 +67,10 @@ class CommOverlapCore { bool _atomic_gemm{false}; bool _is_p2p{false}; + bool _with_cublasmp{false}; + NVTECommGemmCtx *_cublasmp_ctx{nullptr}; + NVTECommGemmAlgoType _algo_type = kNVTECommGemmAlgoDefault; + TensorWrapper _ubuf; TensorWrapper _counter; float *_ubuf_scale_inv; @@ -75,12 +87,17 @@ class CommOverlapCore { public: CommOverlapCore() {} // dummy constructor for exposing type to Python + // Constructor for Userbuffers backend CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm); + // Constructor for cuBLASMp backend + CommOverlapCore(ncclComm_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm, bool is_p2p, + bool atomic_gemm); + virtual ~CommOverlapCore(); void *get_ubuf_dptr() { return _ubuf.dptr(); } @@ -109,6 +126,20 @@ class CommOverlapCore { bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + bool with_cublasmp() { return _with_cublasmp; } + + void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + + void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + + void cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + bool grad, bool accumulate, cudaStream_t stream_main); + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, @@ -169,6 +200,7 @@ class CommOverlapBase : public CommOverlapCore { public: CommOverlapBase() {} // dummy constructor for exposing type to Python + // Constructor for Userbuffers backend CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, @@ -177,6 +209,11 @@ class CommOverlapBase : public CommOverlapCore { bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + // Constructor for cuBLASMp backend + CommOverlapBase(ncclComm_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 16, + bool atomic_gemm = false) + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, false, atomic_gemm) {} + virtual ~CommOverlapBase(); /* @@ -249,6 +286,7 @@ class CommOverlapP2PBase : public CommOverlapCore { public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + // Constructor for Userbuffers backend CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, @@ -257,6 +295,11 @@ class CommOverlapP2PBase : public CommOverlapCore { int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + // Constructor for cuBLASMp backend + CommOverlapP2PBase(ncclComm_t nccl_comm_ptr, int tp_rank, int tp_size, int num_comm_sm = 1, + bool atomic_gemm = false) + : CommOverlapCore(nccl_comm_ptr, tp_rank, tp_size, num_comm_sm, true, atomic_gemm) {} + virtual ~CommOverlapP2PBase(); void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index ef7687e3e9..5d95f91b69 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -110,11 +110,15 @@ pybind11::module_local()) \ .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ py::call_guard()) \ + .def("get_tp_size", &transformer_engine::CommOverlapCore::get_tp_size, \ + py::call_guard()) \ .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ py::call_guard()) \ .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ py::call_guard()) \ .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()) \ + .def("with_cublasmp", &transformer_engine::CommOverlapCore::with_cublasmp, \ py::call_guard()); \ py::class_, \ @@ -138,6 +142,8 @@ }, \ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + py::call_guard()); \ + m.def("nvte_built_with_cublasmp", &nvte_built_with_cublasmp, \ py::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ff6d07986..212c8083ec 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -23,7 +23,9 @@ get_num_compute_streams, JAXX_Collective_Op, get_device_compute_capability, + nvte_built_with_cublasmp, initialize_cgemm_communicator, + is_collective_gemm_with_cublasmp, get_cgemm_num_max_streams, get_grouped_gemm_setup_workspace_size, ) @@ -222,6 +224,7 @@ def collective_gemm_bootstrap( num_sm_for_communication=2, use_ce=True, aggregate_all_gather=False, + use_cublasmp=False, ): """Initialize NCCL communicators for Collective GEMM operations. @@ -260,6 +263,9 @@ def collective_gemm_bootstrap( Can improve performance by offloading memory operations. Default: True. aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations into larger ones for better efficiency. Default: False. + use_cublasmp (bool, optional): Use cuBLASMp backend for Collective GEMM overlap. + Requires Transformer Engine to be compiled with NVTE_WITH_CUBLASMP=1. + Default: False. Raises: AssertionError: If num_total_devices is not divisible by num_devices_per_process, @@ -295,6 +301,24 @@ def collective_gemm_bootstrap( This function must be called after JAX distributed initialization and before any collective GEMM operations. Each process should call this function with its own unique process_id. + + With the cuBLASMp backend, XLA command buffer capture must include + ``COLLECTIVES`` so that the NCCL calls inside cuBLASMp end up in the + same captured buffer as the CollectiveGemm custom call. Otherwise the + capture aborts with ``CUDA_ERROR_STREAM_CAPTURE_INVALIDATED``. Set the + flag before ``jax.distributed.initialize()``: + + import os + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + + " --xla_gpu_enable_command_buffer=+COLLECTIVES" + ) + + This is not required for non-overlapped collective GEMM (when + ``collective_op`` is ``CollectiveOp.NONE`` and JAX/XLA handles the + collective via its own graph-level optimization), nor for the + Userbuffers backend, which uses CUDA multicast APIs and async + memcpy on symmetric memory pointers that XLA already captures. """ if not (num_devices_per_process == 1 and jax.local_device_count() == 1): @@ -306,6 +330,12 @@ def collective_gemm_bootstrap( ) if not 0 <= process_id < num_total_devices: raise ValueError(f"Invalid process_id={process_id}") + if use_cublasmp and not nvte_built_with_cublasmp(): + raise RuntimeError( + "Collective GEMM with cuBLASMp backend was requested, but Transformer Engine " + "was not built with cuBLASMp support. Rebuild with NVTE_WITH_CUBLASMP=1 or " + "disable use_cublasmp." + ) initialize_cgemm_communicator( num_total_devices, num_devices_per_process, @@ -317,6 +347,7 @@ def collective_gemm_bootstrap( num_sm_for_communication, use_ce, aggregate_all_gather, + use_cublasmp, ) @@ -606,7 +637,11 @@ def _dims_are_consecutive(dims): if scaling_mode.is_nvfp4_scaling: workspace_size += lhs_scale_inv.size + rhs_scale_inv.size if not collective_op.is_none: - workspace_size *= get_cgemm_num_max_streams() + if is_collective_gemm_with_cublasmp(): + # cuBlasMp manages its own cuBlasLt workspaces per stream + workspace_size = 0 + else: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size += 256 @@ -812,10 +847,10 @@ def batcher( contracting_dims, scaling_mode, use_split_accumulator, - collective_op, transpose_batch_sequence, sequence_dim, is_outer, + collective_op, ): del transpose_batch_sequence, sequence_dim, is_outer if GemmPrimitive.outer_primitive is None: @@ -998,9 +1033,9 @@ def _parse_operand_output_specs( lhs_scale_specs = rhs_scale_specs = (None,) if scaling_mode.is_1d_block_scaling(): rhs_scale_specs = rhs_specs - # Set the seq spec to None to trigger AG the scales as TE/Common CGEMM does not handle - # scale collecting yet - if collective_op.is_all_gather: + # Set the seq spec to None to trigger AG the scales as TE/Common CGEMM w/ Userbuffers + # backend does not handle scale collecting yet (cuBLASMp backend does) + if collective_op.is_all_gather and not is_collective_gemm_with_cublasmp(): lhs_scale_specs = tuple( None if i == sequence_dim else s for i, s in enumerate(lhs_specs) ) @@ -1957,7 +1992,27 @@ def gemm( transpose_batch_sequence: bool, default = False Transpose the batch and sequence dimensions of the input tensor. collective_op: CollectiveOp, default = CollectiveOp.NONE - Collective operation type for collective GEMM. + Collective operation type for collective GEMM. When set to + ``CollectiveOp.ALL_GATHER`` or ``CollectiveOp.REDUCE_SCATTER``, the GEMM + is executed with communication overlap via the Userbuffers or cuBLASMp + backend (see :func:`collective_gemm_bootstrap`). + + .. note:: + Collective GEMM with communication overlap is captured into XLA + command buffers as a custom call. When executing with the cuBLASMp + backend, this captured graph spans NCCL collectives that XLA does not + include in command buffers by default, so add ``COLLECTIVES`` to the + enabled kinds before JAX initialization:: + + os.environ["XLA_FLAGS"] = ( + os.environ.get("XLA_FLAGS", "") + + " --xla_gpu_enable_command_buffer=+COLLECTIVES" + ) + + Without this, capture aborts with + ``CUDA_ERROR_STREAM_CAPTURE_INVALIDATED``. Not required when + ``collective_op`` is ``CollectiveOp.NONE`` or when using the Userbuffers + backend instead of cuBLASMp. Returns ------- diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 36a4a068a4..da25e3676a 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -18,11 +18,9 @@ ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &i int tp_domain_id = get_tp_domain_id(); bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0); - pid_t pgid = getpgid(0); - std::string base_path = getenv("NVTE_JAX_NCCL_FILE_PATH", "/tmp"); - std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) + - "_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) + + std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_" + + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) + "_domain_" + std::to_string(tp_domain_id) + ".bin"; if (is_tp_leader) { @@ -136,7 +134,7 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces handler._initialize = true; - // Bootstrap UB via creating a dummy CommOverlapP2PBase object + // Bootstrap UB/cuBlasMp via creating a dummy CommOverlapP2PBase object std::vector buffer_shape{1, 1}; [[maybe_unused]] auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( buffer_shape, DType::kFloat32, JAXX_Collective_Op::ALL_GATHER); @@ -144,14 +142,20 @@ void CommunicatorHandler::init(int num_total_devices, int num_devices_per_proces void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, - int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag) { + int comm_priority, int num_comm_sm, bool use_ce, bool aggregate_ag, + bool use_cublasmp) { auto &config = CgemmConfig::get(false); - config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); + config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag, + use_cublasmp); auto &handler = CommunicatorHandler::get(false); handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); } +bool IsCollectiveGemmWithCublasmp() { + auto &config = CgemmConfig::get(); + return config.use_cublasmp; +} + int GetCgemmNumMaxStreams() { auto &config = CgemmConfig::get(); return config.num_max_streams; @@ -168,7 +172,8 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector bu hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast(dtype), static_cast(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams, cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, - cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx); + cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx, + cgemm_config.use_cublasmp); auto it = plan_map.find(plan_id); if (it != plan_map.end()) { @@ -192,14 +197,22 @@ CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector bu } std::unique_ptr executor; - executor = std::make_unique( - buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, - comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, - comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, - comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), - cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, - cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, - cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + if (cgemm_config.use_cublasmp) { + executor = std::make_unique( + comm_handler.get_comm_for_current_device(), + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + cgemm_config.num_comm_sm, false /*atomic_gemm*/); + } else { + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, + get_nvte_collective_op(collective_op), cgemm_config.num_max_streams, 1 /*comm_cga_size*/, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + true /*set_sm_margin*/, cgemm_config.use_ce, false /*atomic_gemm*/, + cgemm_config.aggregate_ag); + } CommOverlapCore *executor_ptr = executor.get(); plan_map[plan_id] = std::move(executor); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index 2b980e7ee4..9bc8c9cf8d 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -35,9 +35,10 @@ class CgemmConfig { int num_comm_sm; bool use_ce; bool aggregate_ag; + bool use_cublasmp; static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm, - bool _use_ce, bool _aggregate_ag) { + bool _use_ce, bool _aggregate_ag, bool _use_cublasmp = false) { auto &config = get(false); config._initialized = true; config.num_max_streams = _num_max_streams; @@ -46,6 +47,7 @@ class CgemmConfig { config.num_comm_sm = _num_comm_sm; config.use_ce = _use_ce; config.aggregate_ag = _aggregate_ag; + config.use_cublasmp = _use_cublasmp; } static CgemmConfig &get(bool is_initialized = true) { @@ -178,8 +180,10 @@ class CollectiveGemmPlanRegistry { // Function declarations void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, int tp_size, int num_max_streams, int gemm_priority, - int comm_priority, int num_comm_sm, bool use_ce, - bool aggregate_ag); + int comm_priority, int num_comm_sm, bool use_ce, bool aggregate_ag, + bool use_cublasmp = false); + +bool IsCollectiveGemmWithCublasmp(); int GetCgemmNumMaxStreams(); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6ca907032c..d65877e9a0 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -99,6 +99,21 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +// Build a TensorWrapper for the prepare stage. Operand contents are +// uninitialized at this point and no kernels are launched. +static TensorWrapper prepare_operand_tensor(Buffer_Type buf, Buffer_Type scale_inv, + JAXX_Scaling_Mode scaling_mode, + const std::vector &flat_shape) { + auto dtype = convert_ffi_datatype_to_te_dtype(buf.element_type()); + TensorWrapper t(get_nvte_scaling_mode(scaling_mode)); + t.set_rowwise_data(buf.untyped_data(), dtype, flat_shape); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING && scale_inv.element_count() > 0) { + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + t.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, std::vector{1}); + } + return t; +} + Error_Type GemmInitV2FFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type workspace, @@ -128,8 +143,46 @@ Error_Type GemmInitV2FFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type buffer_shape[0] = out_shape[0]; buffer_shape[1] = out_shape[1]; } - [[maybe_unused]] auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor( + auto *executor = CollectiveGemmPlanRegistry::getInstance().get_executor( buffer_shape, buffer_dtype, config.collective_op); + + // Run a dummy cuBLASMp matmul in the prepare stage so its lazy NCCL + // window registration and workspace allocation happen outside any + // CUDA-graph capture. The throwaway D output gets overwritten on the + // next execute-stage call. + if (IsCollectiveGemmWithCublasmp() && executor != nullptr) { + auto lhs_ = prepare_operand_tensor(lhs, lhs_scale_inv, config.scaling_mode, lhs_shape); + auto rhs_ = prepare_operand_tensor(rhs, rhs_scale_inv, config.scaling_mode, rhs_shape); + // Match GemmV2FFI's local D shape: AG gathers along axis 0, RS scatters along axis 0. + std::vector d_shape = out_shape; + if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) { + d_shape[0] *= comm_handler.tp_size; + } else if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + d_shape[0] /= comm_handler.tp_size; + } + TensorWrapper d_(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); + d_.set_rowwise_data(output->untyped_data(), + convert_ffi_datatype_to_te_dtype(output->element_type()), d_shape); + TensorWrapper bias_(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); + if (bias.element_count() > 0) { + bias_.set_rowwise_data(bias.untyped_data(), + convert_ffi_datatype_to_te_dtype(bias.element_type()), + std::vector{static_cast(bias.element_count())}); + } + TensorWrapper pre_gelu_out_(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING)); + // Match GemmV2FFI's operand swap: rhs becomes A, lhs becomes B. + cudaStream_t prepare_stream = cudaStreamPerThread; + if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) { + executor->cublasmp_ag_gemm(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, d_, + bias_, pre_gelu_out_, false /*grad*/, false /*accumulate*/, + prepare_stream); + } else if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + executor->cublasmp_gemm_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, d_, + bias_, pre_gelu_out_, false /*grad*/, false /*accumulate*/, + prepare_stream); + } + NVTE_CHECK_CUDA(cudaStreamSynchronize(prepare_stream)); + } } return ffi_with_cuda_error_check(); } @@ -191,8 +244,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator") - .Attr("collective_op"), - FFI_CudaGraph_Traits); + .Attr("collective_op") +#ifndef NVTE_WITH_CUBLASMP + // enable CUDA graphs only when cuBLASMp is NOT enabled + , + FFI_CudaGraph_Traits +#endif +); Error_Type GemmV2FFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, @@ -300,18 +358,27 @@ Error_Type GemmV2FFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale buffer_shape, buffer_dtype, config.collective_op); auto pre_gelu_ = TensorWrapper(nullptr, std::vector{0}, DType::kByte); if (config.collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { - auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); - // Prepare the auxiliary buffer for the reduce-scattered GEMM output auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); NVTE_CHECK(out_.numel() == output->element_count(), "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(output->dimensions())); - // Launch GEMM+RS - executor->split_overlap_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, - ubuf_out_, bias_, pre_gelu_, workspace_, false /*grad*/, - false /*accumulate*/, config.use_split_accumulator, out_, stream); + if (IsCollectiveGemmWithCublasmp()) { + // cuBLASMp writes the reduce-scattered result directly into D + auto rs_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); + executor->split_overlap_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, out_, + bias_, pre_gelu_, workspace_, false /*grad*/, + false /*accumulate*/, config.use_split_accumulator, rs_out_, + stream); + } else { + // Userbuffers writes the full GEMM result into ubuf, then reduce-scatters into rs_output + auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); + executor->split_overlap_rs(rhs_, config.rhs_transposed, lhs_, config.lhs_transposed, + ubuf_out_, bias_, pre_gelu_, workspace_, false /*grad*/, + false /*accumulate*/, config.use_split_accumulator, out_, + stream); + } } else if (config.collective_op == JAXX_Collective_Op::ALL_GATHER) { auto aux_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Empty diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 70d0403b3e..bdd487e140 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -124,7 +124,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_topk_workspace_sizes", &GetTopkWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); + m.def("nvte_built_with_cublasmp", &::nvte_built_with_cublasmp); m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); + m.def("is_collective_gemm_with_cublasmp", &IsCollectiveGemmWithCublasmp); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); m.def("get_grouped_gemm_setup_workspace_size", &nvte_get_grouped_gemm_setup_workspace_size); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8082ff07ed..22a9803dd8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,10 +7,14 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#include + #include +#include #include #include #include +#include #include #include @@ -645,10 +649,18 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t **************************************************************************************************/ class CommOverlapHelper : torch::CustomClassHolder { + public: + // Shared ownership of an ncclComm_t. The deleter calls ncclCommDestroy when + // the last reference (held by the helper and/or any CommOverlap consumers) + // is released, so the communicator outlives whichever owner is destroyed + // first. + using NcclCommSharedPtr = std::shared_ptr::type>; + private: bool initialized{false}; bool backend_is_nccl{false}; - std::map pgs; + std::map torch_pgs; + std::map nccl_comms; public: int myrank = -1; @@ -669,17 +681,33 @@ class CommOverlapHelper : torch::CustomClassHolder { ExtComm comm); void ub_barrier(ExtComm comm); + + NcclCommSharedPtr get_nccl_comm(std::string comm_name); }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { + private: + // Keeps the cuBLASMp NCCL communicator alive for the lifetime of this + // instance, independent of the CommOverlapHelper that created it. + CommOverlapHelper::NcclCommSharedPtr _nccl_comm; + public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, - CommOverlapHelper *helper, int tp_size, int num_splits = 3, + CommOverlapHelper *helper, int tp_size, int num_splits = 4, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, bool rs_overlap_first_gemm = false); + // cuBLASMp variant. `comm_type`, `buffer_shape`, and `buffer_dtype` size + // the construction-time warmup matmul that primes cuBLASMp's lazy NCCL + // window registrations and workspace allocation so subsequent matmuls + // (including those captured in CUDA graphs) avoid the unsafe lazy paths. + CommOverlap(CommOverlapHelper *helper, int tp_rank, int tp_size, + transformer_engine::CommOverlapType comm_type, + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + int num_comm_sm = 16, bool atomic_gemm = false); + ~CommOverlap() {} using transformer_engine::CommOverlapCore::copy_into_buffer; @@ -693,15 +721,26 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { + private: + // Keeps the cuBLASMp NCCL communicator alive for the lifetime of this + // instance, independent of the CommOverlapHelper that created it. + CommOverlapHelper::NcclCommSharedPtr _nccl_comm; + public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, - int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, - bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 1, + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 1, + bool set_sm_margin = false, bool atomic_gemm = false, bool use_ce = true, bool aggregate = false); + // cuBLASMp variant. See CommOverlap for the `comm_type`/buffer args. + CommOverlapP2P(CommOverlapHelper *helper, int tp_rank, int tp_size, + transformer_engine::CommOverlapType comm_type, + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + int num_comm_sm = 1, bool atomic_gemm = false); + ~CommOverlapP2P() {} using transformer_engine::CommOverlapP2PBase::copy_into_buffer; diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index a126ab0d60..57753c38e7 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -3,6 +3,9 @@ * * See LICENSE for license information. ************************************************************************/ +#ifdef NVTE_WITH_CUBLASMP +#include +#endif #include "../extensions.h" #include "transformer_engine/transformer_engine.h" @@ -28,20 +31,20 @@ CommOverlapHelper::CommOverlapHelper() { CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, std::optional intra_domain_group) { #ifndef NVTE_UB_WITH_MPI - pgs.insert({"world", world_group}); - myrank = pgs["world"]->getRank(); - numranks = pgs["world"]->getSize(); - c10d::ProcessGroup::BackendType backend = pgs["world"]->getBackendType(); + torch_pgs.insert({"world", world_group}); + myrank = torch_pgs["world"]->getRank(); + numranks = torch_pgs["world"]->getSize(); + c10d::ProcessGroup::BackendType backend = torch_pgs["world"]->getBackendType(); backend_is_nccl = (backend == c10d::ProcessGroup::BackendType::NCCL); if (intra_domain_group.has_value()) { // Get local rank on node and number of local ranks NVTE_CHECK(intra_domain_group.value()->getBackendType() == backend, "Internal TE error: Intra-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"intra", intra_domain_group.value()}); - mylocal = pgs["intra"]->getRank(); - numlocal = pgs["intra"]->getSize(); + "group!", torch_pgs["world"]->getBackendName()); + torch_pgs.insert({"intra", intra_domain_group.value()}); + mylocal = torch_pgs["intra"]->getRank(); + numlocal = torch_pgs["intra"]->getSize(); if (numlocal == numranks) { // Intra-node group is same as the world group so there can only be 1 node @@ -60,13 +63,68 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, // Intra-node group is not set so we assume there is only 1 node mylocal = myrank; numlocal = numranks; - pgs.insert({"intra", world_group}); + torch_pgs.insert({"intra", world_group}); mynode = 0; numnodes = 1; } initialized = true; + +#ifdef NVTE_WITH_CUBLASMP + // Initialize world NCCL communicator via ncclCommInitRank (one GPU per process under torchrun) + ncclUniqueId nccl_world_id; + if (myrank == 0) { + NVTE_CHECK_NCCL(ncclGetUniqueId(&nccl_world_id)); + } + auto nccl_world_id_tensor = + torch::from_blob(reinterpret_cast(&nccl_world_id), {sizeof(ncclUniqueId)}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + nccl_world_id_tensor = (backend_is_nccl) ? nccl_world_id_tensor.cuda() : nccl_world_id_tensor; + { + c10d::BroadcastOptions bcast_opts; + bcast_opts.rootRank = 0; + std::vector bcast_tensors = {nccl_world_id_tensor}; + auto work = torch_pgs["world"]->broadcast(bcast_tensors, bcast_opts); + work->wait(); + } + nccl_world_id_tensor = (backend_is_nccl) ? nccl_world_id_tensor.cpu() : nccl_world_id_tensor; + nccl_world_id = *reinterpret_cast(nccl_world_id_tensor.data_ptr()); + + ncclComm_t nccl_world; + NVTE_CHECK_NCCL(ncclCommInitRank(&nccl_world, numranks, nccl_world_id, myrank)); + nccl_comms.insert({"world", NcclCommSharedPtr(nccl_world, ncclCommDestroy)}); + + if (intra_domain_group.has_value()) { + // Generate a separate unique ID for the intra-node communicator + ncclUniqueId nccl_intra_id; + if (mylocal == 0) { + NVTE_CHECK_NCCL(ncclGetUniqueId(&nccl_intra_id)); + } + + // Broadcast the intra-node unique ID from the local root to all local ranks + auto nccl_intra_id_tensor = + torch::from_blob(reinterpret_cast(&nccl_intra_id), {sizeof(ncclUniqueId)}, + at::device(torch::kCPU).dtype(torch::kUInt8)); + nccl_intra_id_tensor = (backend_is_nccl) ? nccl_intra_id_tensor.cuda() : nccl_intra_id_tensor; + { + c10d::BroadcastOptions bcast_opts; + bcast_opts.rootRank = 0; + std::vector bcast_tensors = {nccl_intra_id_tensor}; + auto work = torch_pgs["intra"]->broadcast(bcast_tensors, bcast_opts); + work->wait(); + } + nccl_intra_id_tensor = (backend_is_nccl) ? nccl_intra_id_tensor.cpu() : nccl_intra_id_tensor; + nccl_intra_id = *reinterpret_cast(nccl_intra_id_tensor.data_ptr()); + + // Initialize intra-node communicator + ncclComm_t nccl_intra; + NVTE_CHECK_NCCL(ncclCommInitRank(&nccl_intra, numlocal, nccl_intra_id, mylocal)); + nccl_comms.insert({"intra", NcclCommSharedPtr(nccl_intra, ncclCommDestroy)}); + } else { + nccl_comms.insert({"intra", nccl_comms["world"]}); + } +#endif #else NVTE_ERROR("Internal TE error: CommOverlapHelper cannot be initialized with valid PyTorch ", "distributed process groups when TE is compiled with NVTE_UB_WITH_MPI=1!"); @@ -75,9 +133,18 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, CommOverlapHelper::~CommOverlapHelper() { #ifndef NVTE_UB_WITH_MPI - for (auto &pg : pgs) pg.second = nullptr; + for (auto &pg : torch_pgs) { + pg.second = nullptr; + } + torch_pgs.clear(); backend_is_nccl = false; initialized = false; +#ifdef NVTE_WITH_CUBLASMP + // Releasing the helper's references is enough: each shared_ptr's deleter + // calls ncclCommDestroy once the last owner (helper or any consuming + // CommOverlap/CommOverlapP2P) drops it. + nccl_comms.clear(); +#endif #endif } @@ -96,9 +163,10 @@ void CommOverlapHelper::ub_allgather(void *globaldata, size_t globalbytes, void at::device(torch::kCPU).dtype(torch::kUInt8)); auto globaltmp = (backend_is_nccl) ? globaltensor.cuda() : globaltensor; - std::vector> globalchunks = {globaltmp.chunk(pgs[group]->getSize())}; + std::vector> globalchunks = { + globaltmp.chunk(torch_pgs[group]->getSize())}; std::vector localchunk = {localtmp}; - auto work = pgs[group]->allgather(globalchunks, localchunk); + auto work = torch_pgs[group]->allgather(globalchunks, localchunk); work->wait(); if (backend_is_nccl) { @@ -116,7 +184,7 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { #ifndef NVTE_UB_WITH_MPI NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", "with valid process groups!"); - auto work = pgs[group]->barrier(); + auto work = torch_pgs[group]->barrier(); work->wait(); #else NVTE_ERROR("Internal TE error: CommOverlapHelper::ub_barrier is a no-op when TE is compiled ", @@ -124,6 +192,26 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { #endif } +CommOverlapHelper::NcclCommSharedPtr CommOverlapHelper::get_nccl_comm(std::string comm_name) { +#ifdef NVTE_WITH_CUBLASMP + NVTE_CHECK(initialized, "Internal TE error: tex.CommOverlapHelper() is not initialized ", + "with valid process groups!"); + NVTE_CHECK(backend_is_nccl, + "Internal TE error: tex.CommOverlapHelper() was not initialized with an NCCL backend, " + "so no NCCL communicators are available!"); + auto it = nccl_comms.find(comm_name); + if (it != nccl_comms.end()) { + return it->second; + } else { + NVTE_ERROR("Internal TE error: No NCCL communicator found with name ", comm_name, "!"); + } +#else + NVTE_ERROR( + "Internal TE error: CommOverlapHelper::get_nccl_comm() is an internal API that requires TE " + "to be built with NVTE_WITH_CUBLASMP=1!"); +#endif +} + /*************************************************************************************************** * CommOverlap **************************************************************************************************/ @@ -141,6 +229,87 @@ CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} +namespace { + +// Run a dummy cuBLASMp matmul during construction so its lazy NCCL window +// registration and workspace allocation happen outside any CUDA-graph +// capture. The warmup is sized from the comm buffer so the cached +// workspace covers any matmul the caller will later run with the same +// descriptor. BF16 is used unconditionally; its workspace is at least as +// large as the FP8 workspace for the same m/n/k. +void cublasmp_capture_warmup(te::CommOverlapCore *core, int tp_size, te::CommOverlapType comm_type, + const std::vector &buffer_shape) { + NVTE_CHECK(buffer_shape.size() == 2, "cuBLASMp warmup expects a 2-D buffer shape, got rank ", + buffer_shape.size()); + // Treat the matmul as square in the weight dim so workspace is sized + // for the wider of the two cases. + const int64_t N_global = static_cast(buffer_shape[0]); + const int64_t hidden = static_cast(buffer_shape[1]); + auto ceil_div = [](int64_t a, int64_t b) { return (a + b - 1) / b; }; + const int64_t M_local = ceil_div(hidden, tp_size); + const int64_t N_local = ceil_div(N_global, tp_size); + const int64_t K_local = ceil_div(hidden, tp_size); + const int64_t bf16_bytes = 2; + + std::vector a_shape, b_shape, d_shape; + if (comm_type == te::CommOverlapType::AG) { + // A = (M_local, K), B = (N_local, K), D = (N_global, M_local) + a_shape = {static_cast(M_local), static_cast(hidden)}; + b_shape = {static_cast(N_local), static_cast(hidden)}; + d_shape = {static_cast(N_global), static_cast(M_local)}; + } else { // RS (or AR -- same descriptor-level dims) + // A = (M_global, K_local), B = (N_global, K_local), D = (N_local, M_global) + a_shape = {static_cast(hidden), static_cast(K_local)}; + b_shape = {static_cast(N_global), static_cast(K_local)}; + d_shape = {static_cast(N_local), static_cast(hidden)}; + } + + const size_t a_bytes = a_shape[0] * a_shape[1] * bf16_bytes; + const size_t b_bytes = b_shape[0] * b_shape[1] * bf16_bytes; + const size_t d_bytes = d_shape[0] * d_shape[1] * bf16_bytes; + + void *a_ptr = nullptr; + void *b_ptr = nullptr; + void *d_ptr = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&a_ptr, a_bytes)); + NVTE_CHECK_CUDA(cudaMalloc(&b_ptr, b_bytes)); + NVTE_CHECK_CUDA(cudaMalloc(&d_ptr, d_bytes)); + NVTE_CHECK_CUDA(cudaMemset(a_ptr, 0, a_bytes)); + NVTE_CHECK_CUDA(cudaMemset(b_ptr, 0, b_bytes)); + + te::TensorWrapper A_tw, B_tw, D_tw, bias_tw, pre_gelu_tw; + A_tw.set_rowwise_data(a_ptr, te::DType::kBFloat16, a_shape); + B_tw.set_rowwise_data(b_ptr, te::DType::kBFloat16, b_shape); + D_tw.set_rowwise_data(d_ptr, te::DType::kBFloat16, d_shape); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (comm_type == te::CommOverlapType::AG) { + core->cublasmp_ag_gemm(A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw, + pre_gelu_tw, /*grad=*/false, /*accumulate=*/false, stream); + } else { + core->cublasmp_gemm_rs(A_tw, /*transa=*/true, B_tw, /*transb=*/false, D_tw, bias_tw, + pre_gelu_tw, /*grad=*/false, /*accumulate=*/false, stream); + } + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + cudaFree(a_ptr); + cudaFree(b_ptr); + cudaFree(d_ptr); +} + +} // namespace + +CommOverlap::CommOverlap(CommOverlapHelper *helper, int tp_rank, int tp_size, + te::CommOverlapType comm_type, const std::vector &buffer_shape, + at::ScalarType buffer_dtype, int num_comm_sm, bool atomic_gemm) + : te::CommOverlapBase(helper->get_nccl_comm("intra").get(), tp_rank, tp_size, num_comm_sm, + atomic_gemm), + _nccl_comm(helper->get_nccl_comm("intra")) { + // buffer_dtype is unused on this path (the warmup runs in BF16); kept in + // the signature for API symmetry with the non-cuBLASMp ctor. + (void)buffer_dtype; + cublasmp_capture_warmup(this, tp_size, comm_type, buffer_shape); +} + /* ** Helper function to copy input to _ubuf */ @@ -240,6 +409,18 @@ CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::Scal comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {} +CommOverlapP2P::CommOverlapP2P(CommOverlapHelper *helper, int tp_rank, int tp_size, + te::CommOverlapType comm_type, + const std::vector &buffer_shape, at::ScalarType buffer_dtype, + int num_comm_sm, bool atomic_gemm) + : te::CommOverlapP2PBase(helper->get_nccl_comm("intra").get(), tp_rank, tp_size, num_comm_sm, + atomic_gemm), + _nccl_comm(helper->get_nccl_comm("intra")) { + // See CommOverlap constructor for the buffer_dtype rationale. + (void)buffer_dtype; + cublasmp_capture_warmup(this, tp_size, comm_type, buffer_shape); +} + /* ** Copy input to _ubufs[0] */ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 9cb1fb7f54..cce3acc7df 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -39,7 +39,8 @@ bool is_low_precision(const DType type) { } std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, - const NVTEShape& B_shape, const bool transb) { + const NVTEShape& B_shape, const bool transb, + size_t tp_size = 1, size_t tp_dim = 0) { // Flatten outer dims to get 2D matrices const auto [A0, A1] = get_2d_dims(A_shape); const auto [B0, B1] = get_2d_dims(B_shape); @@ -52,21 +53,36 @@ std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool tran std::vector ret; if (transb) { ret.emplace_back(B1); - } else { + } else if (tp_size == 1) { // Unflatten B0 for (size_t i = 0; i < B_shape.ndim - 1; ++i) { ret.emplace_back(B_shape.data[i]); } + } else { + // Keep output tensor in 2D for comm+GEMM overlap + ret.emplace_back(B0); } if (transa) { ret.emplace_back(A0); } else { ret.emplace_back(A1); } + + // Correct output dims for comm+GEMM overlap if needed + if (tp_size > 1) { + if (tp_dim == 0) { + // Outer dim is sharded, comm+GEMM overlap would need to do all-gather + ret[0] *= tp_size; + } else { + // Inner dim is sharded, comm+GEMM overlap would need to do reduce-scatter + ret[0] /= tp_size; + } + } return ret; } -bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { +bool checkGemmShape(const std::vector& expected, const NVTEShape& actual, + size_t tp_size = 1, size_t tp_dim = 0) { if (expected.size() != actual.ndim) return false; for (size_t i = 0; i < expected.size(); ++i) { if (expected[i] != actual.data[i]) return false; @@ -149,11 +165,18 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; - + // Get TP info for comm+GEMM overlap + size_t tp_size = 1; + size_t tp_dim = 0; + if (comm_overlap && !bulk_overlap && comm_overlap->with_cublasmp()) { + tp_size = comm_overlap->get_tp_size(); + tp_dim = (comm_type.value() == CommOverlapType::AG) ? 0 : 1; + } // Check tensor dimensions const auto& A_shape = A_tensor.shape(); const auto& B_shape = B_tensor.shape(); - const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + const auto& D_shape = + detail::getGemmOutputShape(A_shape, transa, B_shape, transb, tp_size, tp_dim); NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index a4571c64e2..b73a555174 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -641,11 +641,27 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_, transformer_engine::CommOverlapBase, transformer_engine::CommOverlapCore>(m, "CommOverlap") - .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, int, int, bool, bool, bool>(), + .def(py::init([](const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, bool use_cublasmp, + transformer_engine::CommOverlapType comm_type, int num_splits, + int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) { + if (use_cublasmp) { + return std::make_shared(helper, helper->mylocal, tp_size, comm_type, + buffer_shape, buffer_dtype, num_comm_sm, + atomic_gemm); + } + return std::make_shared( + buffer_shape, buffer_dtype, helper, tp_size, num_splits, num_max_streams, + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, + atomic_gemm, rs_overlap_first_gemm); + }), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), - py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, + py::arg("use_cublasmp") = false, + py::arg("comm_type") = transformer_engine::CommOverlapType::RS, + py::arg("num_splits") = 4, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) @@ -660,15 +676,28 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_, transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( m, "CommOverlapP2P") - .def(py::init &, at::ScalarType, CommOverlapHelper *, int, - transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, - bool>(), + .def(py::init([](const std::vector &buffer_shape, at::ScalarType buffer_dtype, + CommOverlapHelper *helper, int tp_size, + transformer_engine::CommOverlapType comm_type, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, + bool set_sm_margin, bool atomic_gemm, bool use_ce, bool aggregate, + bool use_cublasmp) { + if (use_cublasmp) { + return std::make_shared(helper, helper->mylocal, tp_size, comm_type, + buffer_shape, buffer_dtype, num_comm_sm, + atomic_gemm); + } + return std::make_shared(buffer_shape, buffer_dtype, helper, tp_size, + comm_type, num_max_streams, comm_cga_size, + gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, use_ce, aggregate); + }), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false) + py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_cublasmp") = false) .def("copy_into_buffer", static_cast( &CommOverlapP2P::copy_into_buffer), @@ -676,4 +705,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -} +} // NOLINT(readability/fn_size) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 746177ec78..779e141a11 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -61,17 +61,35 @@ from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled -__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] +__all__ = [ + "initialize_ub", + "destroy_ub", + "is_ub_initialized", + "using_cublasmp_backend", + "UserBufferQuantizationMode", +] _2X_ACC_FPROP = False _2X_ACC_DGRAD = True _2X_ACC_WGRAD = True _dummy_wgrads = {} _ub_communicators = None +_ub_initialized = False +_ub_with_cublasmp = False _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] +def is_ub_initialized() -> bool: + """Whether the Userbuffers communicators have been initialized.""" + return _ub_initialized + + +def using_cublasmp_backend() -> bool: + """Whether the active comm+GEMM overlap backend is cuBLASMp.""" + return _ub_initialized and _ub_with_cublasmp + + class UserBufferQuantizationMode(Enum): """ UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. @@ -106,6 +124,7 @@ def initialize_ub( dtype: torch.dtype = torch.bfloat16, ub_cfgs: Optional[Union[dict, List[dict]]] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, + with_cublasmp: bool = False, ) -> None: r""" Initialize the Userbuffers communicator for overlapping tensor-parallel communications with @@ -158,6 +177,10 @@ def initialize_ub( not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this option and always initializes Userbuffers with direct MPI calls in C++, which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. + with_cublasmp : bool = False + Whether to use cuBlasMp for the all-gather and reduce-scatter overlaps. TE must + be compiled with `NVTE_WITH_CUBLASMP=1` for this option to work. + """ if not tex.device_supports_multicast(): if not bool(int(os.getenv("UB_SKIPMC", "0"))): @@ -198,10 +221,11 @@ def initialize_ub( f"quantization configurations ({len(quantization_modes)})" ) - global _ub_communicators + global _ub_communicators, _ub_with_cublasmp if _ub_communicators is not None: raise RuntimeError("UB communicators are already initialized.") _ub_communicators = {} + _ub_with_cublasmp = with_cublasmp if tex.ubuf_built_with_mpi(): # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force @@ -259,7 +283,7 @@ def initialize_ub( local_rank = world_rank tp_domain_ranks = list(range(world_size)) - helper = tex.CommOverlapHelper(world_group) + helper = tex.CommOverlapHelper(world_group, world_group) if world_rank == 0: print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) @@ -283,6 +307,7 @@ def initialize_ub( ] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] + # Default overlap methods for layers methods = { "ring_exchange": [ @@ -349,6 +374,11 @@ def add_ub( gemm_priority: int = 0, pipeline_rs_overlap_first_gemm: bool = False, ) -> None: + if with_cublasmp and method in ("bulk", "external"): + raise ValueError( + f"At {name}, cuBLASMp does not support `{method}` overlap method. " + "Please select a different method or set with_cublasmp=False." + ) if atomic_gemm: warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." @@ -360,7 +390,7 @@ def add_ub( ) if method in ("bulk", "external"): warnings.warn( - f"At {name}, atoimic GEMM not is supported for a bulk overlap." + f"At {name}, atomic GEMM not is supported for a bulk overlap." "Defaulting to `atomic_gemm=False`." ) atomic_gemm = 0 @@ -406,13 +436,15 @@ def add_ub( if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) else dtype ) + comm_type = tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape buffer_dtype, # Communication buffer data type helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) - tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG, + tp_size, # Tensor-parallel group size (may differ from local_size) + comm_type, + use_cublasmp=with_cublasmp, num_max_streams=_NUM_MAX_UB_STREAMS, comm_cga_size=cga_size, num_comm_sm=num_sm, @@ -428,7 +460,9 @@ def add_ub( shape, # Communication buffer shape buffer_dtype, # Communication buffer data type helper, # Helper for torch.distributed callbacks during bootstrapping - tp_size, # Tensor-parallel group size (may be different than local_size) + tp_size, # Tensor-parallel group size (may differ from local_size) + use_cublasmp=with_cublasmp, + comm_type=comm_type, num_splits=num_splits, num_max_streams=_NUM_MAX_UB_STREAMS, comm_cga_size=cga_size, @@ -463,9 +497,30 @@ def add_ub( new_method = user_ub_cfg[name]["method"] methods[new_method].append(name) - for name in ( - methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] - ): + # Adjust defaults to account for the fact that cuBLASMp does not support + # bulk or external overlaps + if with_cublasmp: + warnings.warn( + "cuBLASMp does not support bulk or external overlaps. " + "'qkv_dgrad' and 'fc1_dgrad' GEMMs will be configured with 'ring_exchange'" + "overlap unless user configuration specifies otherwise. Bulk overlaps for the " + "corresponding 'qkv_wgrad' and 'fc1_wgrad' GEMMs will be disabled." + ) + methods["bulk"] = [] + methods["external"] = [] + external_gemm_to_overlap.clear() + + for name in dgrad_reduce_scatter_overlap: + wgrad_name = name.replace("dgrad", "wgrad") + if name not in layers_reduce_scatter_overlap: + layers_reduce_scatter_overlap.append(name) + if wgrad_name in layers_reduce_scatter_overlap: + layers_reduce_scatter_overlap.remove(wgrad_name) + if name not in methods["ring_exchange"] and name not in methods["pipeline"]: + methods["ring_exchange"].append(name) + + configured_methods = ["ring_exchange", "pipeline", "bulk", "external"] + for name in (m for k in configured_methods for m in methods[k]): ub_cfg = get_default_config(name) if user_ub_cfg is not None and name in user_ub_cfg: fp8_buf = (name in layers_all_gather_overlap) or ( @@ -475,6 +530,9 @@ def add_ub( ub_cfg["fp8_buf"] = fp8_buf add_ub(name, quantization_mode, **ub_cfg) + global _ub_initialized + _ub_initialized = True + def get_ub(name: str, use_fp8: bool): """Get userbuffer communicator corresponding to give key.""" @@ -482,7 +540,7 @@ def get_ub(name: str, use_fp8: bool): # So favour simplicity until the correct design becomes clear. # This is mainly an internal API so we don't need to worry about future changes key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) - if _ub_communicators is None: + if not _ub_initialized or _ub_communicators is None: raise RuntimeError("UB manager is not initialized.") if key not in _ub_communicators: raise KeyError(f"UB for {name} with use_fp8={use_fp8} is not registered.") @@ -491,8 +549,10 @@ def get_ub(name: str, use_fp8: bool): def destroy_ub(): """Destroy all allocated userbuffer communicators.""" - global _ub_communicators + global _ub_communicators, _ub_with_cublasmp, _ub_initialized _ub_communicators = None + _ub_with_cublasmp = False + _ub_initialized = False global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -515,6 +575,9 @@ def fill_userbuffers_buffer_for_all_gather( tensor's metadata, e.g. scaling factors. """ + # cuBlasMp already handles its own buffer filling and quantization factors + if comm.with_cublasmp(): + return local_tensor, local_tensor # Tensor dimensions local_shape = local_tensor.size() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8c88f3ee82..7fc96d4779 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -21,6 +21,8 @@ from .base import ( fill_userbuffers_buffer_for_all_gather, get_ub, + is_ub_initialized, + using_cublasmp_backend, quantize_weight, TransformerEngineBaseModule, get_dummy_wgrad, @@ -397,7 +399,10 @@ def forward( clear_tensor_data(ln_out, ln_out_total) ln_out = ln_out_total = None elif with_input_all_gather and not return_layernorm_output_gathered: - clear_tensor_data(ln_out_total) + # ln_out_total aliases ln_out for the cuBLASMp backend; skip the + # deallocation to avoid corrupting the backward-saved tensor. + if ln_out_total is not ln_out: + clear_tensor_data(ln_out_total) ln_out_total = None # ------------------------------------------------------ @@ -406,7 +411,9 @@ def forward( # ------------------------------------------------------ out = None if ub_overlap_rs_fprop: - out = reduce_scatter_out + # cuBLASMp writes the reduce-scattered output directly into the + # GEMM output tensor; Userbuffers writes it into the extra-output buffer. + out = gemm_out if ub_obj is not None and ub_obj.with_cublasmp() else reduce_scatter_out elif parallel_mode == "row" and tp_size > 1: nvtx_range_push(f"{nvtx_label}.row_parallel_comm") out = gemm_out @@ -838,7 +845,13 @@ def backward( dgrad = None dgrad_work = None if ctx.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out + # cuBLASMp writes the reduce-scattered dgrad directly into the + # GEMM output tensor; Userbuffers uses the extra-output buffer. + dgrad = ( + gemm_out + if ub_obj_dgrad is not None and ub_obj_dgrad.with_cublasmp() + else reduce_scatter_out + ) elif ctx.ub_bulk_wgrad: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) elif ctx.parallel_mode == "column" and ctx.tp_size > 1: @@ -860,6 +873,28 @@ def backward( # Grad input tensor has been computed... # -------------------------------------------------- + # cuBLASMp's AG+GEMM consumes the gathered grad_output inline and + # does not preserve it for wgrad. Userbuffers leaves the gathered + # tensor in its persistent buffer; cuBLASMp does not, so we gather + # here. Route through the same FP8-aware all-gather as the + # non-overlap path in + # ``TransformerEngineBaseModule.grad_output_preprocess`` by passing + # the grad_output quantizer. Columnwise data needed for wgrad is + # produced by ``update_usage(columnwise_usage=True)`` further below. + if ( + ctx.requires_wgrad + and ctx.ub_overlap_ag + and ctx.ub_obj_gradout is not None + and ctx.ub_obj_gradout.with_cublasmp() + ): + if ctx.grad_output_quantizer is not None: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=ctx.grad_output_quantizer, + ) + # -------------------------------------------------- # Compute grad weight # -------------------------------------------------- @@ -1303,6 +1338,8 @@ def __init__( self.ub_overlap_rs_dgrad = ( ub_overlap_rs_dgrad and self.sequence_parallel and self.parallel_mode == "column" ) + # Bulk overlaps require the Userbuffers backend; the cuBLASMp backend + # falls back to async NCCL ops via torch.distributed. self.ub_bulk_wgrad = ( ub_bulk_wgrad and self.sequence_parallel @@ -1333,9 +1370,23 @@ def __init__( self.ub_overlap_ag_dgrad, ] ): + assert is_ub_initialized(), "initialize_ub() must be called before layer construction." assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name + if using_cublasmp_backend(): + if self.ub_bulk_dgrad: + warnings.warn( + f"cuBLASMp backend does not support bulk overlaps for '{self.ub_name}_dgrad' " + f"and '{self.ub_name}_wgrad' GEMMs. Falling back on DGRAD+RS overlap for " + f"'{self.ub_name}_dgrad' GEMM with no bulk overlap for '{self.ub_name}_wgrad' " + "GEMM. In order to enable bulk overlaps for these GEMMs, set " + "`with_cublasmp=False` when calling `initialize_ub()`." + ) + self.ub_overlap_rs_dgrad = self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + if self.symmetric_ar_type is not None: assert torch_version() >= ( 2, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 46918ff0f1..2c0149717f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -23,6 +23,8 @@ fill_userbuffers_buffer_for_all_gather, _ub_communicators, get_ub, + is_ub_initialized, + using_cublasmp_backend, quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -595,8 +597,12 @@ def _forward( # first part of if statement means that we only clear ln_out_total if # 1) checkpointing and not recomputing (in the forward stage, not bwd recompute stage) # 2) not checkpointing and grad disabled - if ((checkpoint and not is_recomputation) or not is_grad_enabled) and ( - ln_out_total is not ln_out_return + # The `is not ln_out` guard avoids clearing the bwd-saved tensor when + # ln_out_total aliases ln_out (cuBLASMp AG-fprop path). + if ( + ((checkpoint and not is_recomputation) or not is_grad_enabled) + and ln_out_total is not ln_out_return + and ln_out_total is not ln_out ): clear_tensor_data(ln_out_total) @@ -695,7 +701,13 @@ def _forward( # Note: Perform tensor-parallel communication if needed fc2_out = None if ub_overlap_rs: - fc2_out = reduce_scatter_out + # cuBLASMp writes the reduce-scattered output directly into the + # GEMM output tensor; Userbuffers writes it into the extra-output buffer. + fc2_out = ( + gemm_out + if ub_obj_fc2out is not None and ub_obj_fc2out.with_cublasmp() + else reduce_scatter_out + ) elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group) elif set_parallel_mode and tensor_parallel: @@ -1255,6 +1267,29 @@ def backward( # Finished FC2 DGRAD... # -------------------------------------------------- + # cuBLASMp's AG+GEMM consumes the gathered grad_output inline and + # does not preserve it for fc2_wgrad. Userbuffers leaves the + # gathered tensor in its persistent buffer; cuBLASMp does not, so + # we gather here. Route through the same FP8-aware all-gather as + # the non-overlap path in + # ``TransformerEngineBaseModule.grad_output_preprocess`` by passing + # the grad_output quantizer. Columnwise data needed for fc2_wgrad + # is produced by ``update_usage(columnwise_usage=True)`` further + # below. + if ( + ctx.fc2_weight_requires_grad + and ctx.ub_overlap_ag + and ctx.ub_obj_gradout is not None + and ctx.ub_obj_gradout.with_cublasmp() + ): + if ctx.fc2_grad_output_quantizer is not None: + ctx.fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=ctx.fc2_grad_output_quantizer, + ) + # -------------------------------------------------- # FC2 WGRAD # -------------------------------------------------- @@ -1524,7 +1559,13 @@ def fc2_wgrad_gemm( fc1_dgrad = None fc1_dgrad_work = None if ctx.ub_overlap_rs_dgrad: - fc1_dgrad = reduce_scatter_out + # cuBLASMp writes the reduce-scattered dgrad directly into the + # GEMM output tensor; Userbuffers uses the extra-output buffer. + fc1_dgrad = ( + gemm_out + if ub_obj_fc1_dgrad is not None and ub_obj_fc1_dgrad.with_cublasmp() + else reduce_scatter_out + ) elif ctx.ub_bulk_wgrad: fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True) elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: @@ -1985,6 +2026,29 @@ def __init__( ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad ) + if any( + [ + self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert is_ub_initialized(), "initialize_ub() must be called before layer construction." + + if using_cublasmp_backend(): + if self.ub_bulk_dgrad: + warnings.warn( + "cuBLASMp backend does not support bulk overlaps for 'fc1_dgrad' and " + "'fc1_wgrad' GEMMs. Falling back on DGRAD+RS overlap for 'fc1_dgrad' GEMM with " + "no bulk overlap for 'fc1_wgrad' GEMM. In order to enable bulk overlaps for " + "these GEMMs, set `with_cublasmp=False` when calling `initialize_ub()`." + ) + self.ub_overlap_rs_dgrad = self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + if self.symmetric_ar_type is not None: assert torch_version() >= ( 2, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..6c2d98d160 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -21,6 +21,8 @@ fill_userbuffers_buffer_for_all_gather, get_dummy_wgrad, get_ub, + is_ub_initialized, + using_cublasmp_backend, quantize_weight, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -504,7 +506,9 @@ def _linear_forward_impl( # ------------------------------------------------------ out = None if ub_overlap_rs_fprop: - out = reduce_scatter_out + # cuBLASMp writes the reduce-scattered output directly into the GEMM + # output tensor; Userbuffers writes it into the extra-output buffer. + out = gemm_out if ub_obj is not None and ub_obj.with_cublasmp() else reduce_scatter_out elif parallel_mode == "row" and args.tp_size > 1: nvtx_range_push(f"{nvtx_label}.row_parallel_comm") out = gemm_out @@ -518,6 +522,13 @@ def _linear_forward_impl( nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") else: out = gemm_out + + # Restore the input's logical rank (e.g., (seq, batch, hidden)) on the output. + # This is mainly to correct for cuBLASMp comm+GEMM operators that unconditionally + # return a 2D output buffer that ends up incompatible with downstream consumers + # (e.g. ``bias_dropout_add`` residual connections inside ``TransformerLayer``). + out = out.view(-1, *inp.shape[1:-1], out_features) + # ------------------------------------------------------ # Output tensor is ready to return... # ------------------------------------------------------ @@ -1010,7 +1021,13 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # Prepare grad input tensor # Note: Perform tensor-parallel communication if bwd_args.ub_overlap_rs_dgrad: - dgrad = reduce_scatter_out + # cuBLASMp writes the reduce-scattered dgrad directly into the + # GEMM output tensor; Userbuffers uses the extra-output buffer. + dgrad = ( + gemm_out + if ub_obj_dgrad is not None and ub_obj_dgrad.with_cublasmp() + else reduce_scatter_out + ) elif bwd_args.ub_bulk_wgrad: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) elif bwd_args.parallel_mode == "column" and bwd_args.tp_size > 1: @@ -1032,6 +1049,27 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. # Grad input tensor has been computed... # -------------------------------------------------- + # cuBLASMp's AG+GEMM consumes the gathered grad_output inline and does + # not preserve it for wgrad. Userbuffers leaves the gathered tensor in + # its persistent buffer; cuBLASMp does not, so we gather here. Route + # through the same FP8-aware all-gather as the non-overlap path in + # ``TransformerEngineBaseModule.grad_output_preprocess`` by passing the + # grad_output quantizer. The columnwise data needed for wgrad is then + # produced by ``update_usage(columnwise_usage=True)`` further below. + if ( + bwd_args.requires_wgrad + and bwd_args.ub_overlap_ag + and bwd_args.ub_obj_gradout is not None + and bwd_args.ub_obj_gradout.with_cublasmp() + ): + if grad_output_quantizer is not None: + grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output, _ = gather_along_first_dim( + grad_output, + bwd_args.tp_group, + quantizer=grad_output_quantizer, + ) + # -------------------------------------------------- # Compute grad weight # -------------------------------------------------- @@ -1494,6 +1532,8 @@ def __init__( self.ub_overlap_rs_dgrad = ( self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad ) + # Bulk overlaps require the Userbuffers backend; the cuBLASMp backend falls back on + # DGRAD+RS overlap with no bulk overlap for WGRAD self.ub_bulk_dgrad = ( self.parallel_mode == "column" and self.sequence_parallel @@ -1525,9 +1565,23 @@ def __init__( self.ub_bulk_wgrad, ] ): + assert is_ub_initialized(), "initialize_ub() must be called before layer construction." assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." self.ub_name = ub_name + if using_cublasmp_backend(): + if self.ub_bulk_dgrad: + warnings.warn( + f"cuBLASMp backend does not support bulk overlaps for '{self.ub_name}_dgrad' " + f"and '{self.ub_name}_wgrad' GEMMs. Falling back on DGRAD+RS overlap for " + f"'{self.ub_name}_dgrad' GEMM with no bulk overlap for '{self.ub_name}_wgrad' " + "GEMM. In order to enable bulk overlaps for these GEMMs, set " + "`with_cublasmp=False` when calling `initialize_ub()`." + ) + self.ub_overlap_rs_dgrad = self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad + self.ub_bulk_dgrad = False + self.ub_bulk_wgrad = False + if self.symmetric_ar_type is not None: assert torch_version() >= ( 2,