From ccf7321f71442ad2386399a134866862c92c7f45 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 12 May 2026 16:20:21 -0700 Subject: [PATCH 01/15] Batch CP attention tests via a persistent NCCL pool The existing test path spawns one torchrun per parametrized case, paying NCCL init + CUDA context + Python startup on every call. With ~hundreds of cases the launch overhead dominates wall time and was a primary driver of the L3 timeout that prior batching PRs worked around. This change replaces the per-case subprocess with one long-lived torchrun per (world_size). NCCL is initialized once at session start and reused across cases. Pytest sends one JSON request per case over rank-0 stdin; the worker dispatches to run_dpa_with_cp(**kwargs), gathers (ok, error) from every rank, and writes one JSON response on rank-0 stdout. run_attention_with_cp.py is left almost untouched; a new NVTE_CP_POOL_PG=1 env var gates the dist.init_process_group() and dist.destroy_process_group() calls so the function reuses the pool's main PG instead of creating its own. The per-case cp_comm_group (and a2a+p2p sub-groups) are explicitly destroyed at function exit to prevent communicator leakage across cases. The PoolWorker class adds two pieces of error recovery that the prior subprocess-per-case design got for free: a select-based per-call timeout (default 600s, NVTE_CP_POOL_TIMEOUT_SEC) and auto-respawn on worker death or timeout. A test-level exception is reported as an AssertionError and the pool keeps running for the next case. Two pool sizes are needed because cp_comm_type='a2a+p2p' requires world_size=4 and the others use world_size=2; you can't resize an active PG. Pools are spawned lazily so a 2-GPU-only run never pays the 4-GPU init. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 15 +- .../attention/run_attention_with_cp_pool.py | 103 +++++++++ .../attention/test_attention_with_cp.py | 213 ++++++++++++++---- 3 files changed, 282 insertions(+), 49 deletions(-) create mode 100644 tests/pytorch/attention/run_attention_with_cp_pool.py diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..7e7efa50aa 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -226,6 +226,9 @@ def run_dpa_with_cp( # set up distributed group rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) + # When NVTE_CP_POOL_PG=1, the pool runner owns the lifecycle of the main + # process group across many cases; here we only reuse it. + _pool_managed_pg = os.getenv("NVTE_CP_POOL_PG", "0") == "1" if dist.is_initialized(): world_size = dist.get_world_size() rank = dist.get_rank() @@ -234,7 +237,8 @@ def run_dpa_with_cp( device = rank % device_count torch.cuda.set_device(device) logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + if not _pool_managed_pg: + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) # set up communication group for CP cp_comm_ranks = range(world_size) @@ -763,7 +767,14 @@ def run_dpa_with_cp( logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") # destroy distribution group - dist.destroy_process_group() + if _pool_managed_pg: + # Pool owns the main PG; only clean up groups created for this case. + dist.destroy_process_group(cp_comm_group) + if cp_comm_type == "a2a+p2p": + for g in cp_comm_sub_groups: + dist.destroy_process_group(g) + else: + dist.destroy_process_group() def main(**kwargs): diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py new file mode 100644 index 0000000000..73aa84ee42 --- /dev/null +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -0,0 +1,103 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Persistent worker for batched CP attention tests. + +Launched ONCE per (pytest session, world_size) by torchrun. All ranks init +NCCL, then enter a dispatch loop: + + rank 0: + read one JSON request line from stdin + broadcast it to all ranks + all ranks: + call run_dpa_with_cp(**kwargs) — the same work function the + per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the + function reuses our PG instead of re-initing it + torch.cuda.empty_cache() + gc.collect() per case + all ranks gather (ok, error_msg) to rank 0 + rank 0: + write one JSON response line to stdout + +Protocol (line-delimited JSON over rank-0 stdio): + request : {"op": "run", "kwargs": {...}} + {"op": "shutdown"} + response: {"ok": true} + {"ok": false, "error": "first failing rank's traceback"} +""" +import gc +import json +import os +import sys +import traceback + +import torch +import torch.distributed as dist + +# Make sibling modules importable when launched directly. +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from run_attention_with_cp import run_dpa_with_cp + + +def _recv_request(rank: int) -> dict: + box = [None] + if rank == 0: + line = sys.stdin.readline() + box[0] = {"op": "shutdown"} if not line else json.loads(line) + dist.broadcast_object_list(box, src=0) + return box[0] + + +def _send_response(rank: int, payload: dict) -> None: + if rank == 0: + sys.stdout.write(json.dumps(payload) + "\n") + sys.stdout.flush() + + +def _run_one(req: dict, rank: int) -> tuple[bool, str]: + op = req["op"] + if op != "run": + return False, f"unknown op: {op}" + try: + run_dpa_with_cp(**req.get("kwargs", {})) + return True, "" + except Exception: + return False, f"[Rank {rank}] {traceback.format_exc()}" + finally: + torch.cuda.empty_cache() + gc.collect() + + +def main() -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(rank % torch.cuda.device_count()) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["NVTE_CP_POOL_PG"] = "1" + + try: + while True: + req = _recv_request(rank) + if req.get("op") == "shutdown": + break + + ok, msg = _run_one(req, rank) + + gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] + dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) + + if rank == 0: + all_ok = all(o for o, _ in gathered) + if all_ok: + _send_response(rank, {"ok": True}) + else: + first_err = next(m for o, m in gathered if not o) + _send_response(rank, {"ok": False, "error": first_err}) + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..3ac58559ab 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -2,7 +2,9 @@ # # See LICENSE for license information. +import json import os +import select import subprocess import sys import pathlib @@ -24,7 +26,7 @@ _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import ModelConfig, get_available_attention_backends, run_distributed +from utils import ModelConfig, get_available_attention_backends pytest_logging_level = logging.getLevelName(logging.root.level) @@ -60,19 +62,133 @@ } -def get_bash_arguments(num_gpus_per_node, **kwargs): - args = [ - "python3", - "-m", - "torch.distributed.launch", - "--nproc-per-node=" + str(num_gpus_per_node), - ] - te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") - args.append(script_path) - for k, v in kwargs.items(): - args.append(f"{k}={v}") - return args +# --- Persistent pool runner ----------------------------------------------- +# +# Each (world_size) is served by one long-lived torchrun running +# run_attention_with_cp_pool.py. We submit one work item per pytest case over +# rank-0 stdin and read one JSON response from rank-0 stdout. Replaces +# the per-case torchrun launch path; init/destroy NCCL once per pool, not +# once per case. +# +# Why two pool sizes: cp_comm_type="a2a+p2p" needs world_size=4; everything +# else uses world_size=2. We can't resize an active PG, so we keep one pool +# per world_size and route each case to the right one. Pools are spawned +# lazily on first use so a session that only exercises 2-GPU cases never +# pays the 4-GPU init cost. + +POOL_SUBMIT_TIMEOUT_SEC = float(os.getenv("NVTE_CP_POOL_TIMEOUT_SEC", "600")) + + +class PoolWorker: + def __init__(self, world_size: int): + self.world_size = world_size + self.proc: subprocess.Popen | None = None + + def _spawn(self) -> None: + te_path = os.getenv("TE_PATH", "/opt/transformerengine") + worker = os.path.join( + te_path, "tests/pytorch/attention/run_attention_with_cp_pool.py" + ) + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc-per-node={self.world_size}", + "--standalone", # picks a free rendezvous port + worker, + ] + self.proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=sys.stderr, + text=True, + bufsize=1, + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + ) + + def _ensure_alive(self) -> None: + if self.proc is None or self.proc.poll() is not None: + self._spawn() + + def _kill(self) -> None: + if self.proc and self.proc.poll() is None: + self.proc.terminate() + try: + self.proc.wait(timeout=5) + except subprocess.TimeoutExpired: + self.proc.kill() + self.proc.wait() + self.proc = None + + def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None: + self._ensure_alive() + assert self.proc and self.proc.stdin and self.proc.stdout + req = json.dumps({"op": "run", "kwargs": kwargs}) + "\n" + try: + self.proc.stdin.write(req) + self.proc.stdin.flush() + except BrokenPipeError: + self._kill() + raise AssertionError("pool worker died before request could be sent") + + ready, _, _ = select.select([self.proc.stdout], [], [], timeout) + if not ready: + self._kill() + raise AssertionError( + f"pool worker (world_size={self.world_size}) timed out after {timeout}s; " + "pool killed and will be respawned for the next case" + ) + + line = self.proc.stdout.readline() + if not line: + self._kill() + raise AssertionError("pool worker died mid-request") + + resp = json.loads(line) + if not resp["ok"]: + # Test failure; pool itself is still healthy. + raise AssertionError(resp["error"]) + + def shutdown(self) -> None: + if self.proc and self.proc.poll() is None: + try: + self.proc.stdin.write(json.dumps({"op": "shutdown"}) + "\n") + self.proc.stdin.flush() + self.proc.stdin.close() + except BrokenPipeError: + pass + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + self._kill() + self.proc = None + + +@pytest.fixture(scope="session") +def cp_pool(): + """Returns a callable: cp_pool(world_size) -> PoolWorker.""" + pools: dict[int, PoolWorker] = {} + + def _get(world_size: int) -> PoolWorker: + if world_size > torch.cuda.device_count(): + pytest.skip( + f"Test requires {world_size} GPUs, but found {torch.cuda.device_count()}" + ) + if world_size not in pools: + pools[world_size] = PoolWorker(world_size) + return pools[world_size] + + yield _get + for p in pools.values(): + p.shutdown() + + +def _submit(pool: PoolWorker, **kwargs) -> None: + # run_dpa_with_cp expects all kwargs as strings (it does e.g. + # `fp8_bwd == "True"`), matching the old argv-based path. Serialize + # everything as strings so we don't accidentally change semantics. + pool.submit({k: str(v) for k, v in kwargs.items()}) dtypes = ["bf16", "fp16"] @@ -91,10 +207,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 - if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pool = cp_pool(num_gpus) config = model_configs_flash_attn[model] config.context_parallel = True @@ -140,16 +255,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _submit( + pool, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,15 +387,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + cp_pool, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 - if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") + pool = cp_pool(num_gpus) if get_device_compute_capability() < (9, 0) and qkv_format == "thd": pytest.skip("Only sm90+ architectures support THD format!") @@ -404,21 +525,19 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + _submit( + pool, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, ) From 59609ac982065a0c303b966931c709aea89fba49 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 12 May 2026 16:45:40 -0700 Subject: [PATCH 02/15] Reset FP8 state and barrier between pool cases Two resilience fixes carried over from the existing batching PR (sudhakars/cp_test_batching_pr) without which the pool will cascade-fail FP8 tests and silently propagate NCCL desync. 1. FP8GlobalStateManager.reset() between cases. FP8 quantizer state (recipe handles, autocast counters) lives in module-level globals. Reusing one Python process across cases otherwise carries that state forward. The prior batching PR landed an explicit fix for the same issue ("Fix FP8 cascade failures") after observing real test failures from this. 2. dist.barrier() after each case. If one rank's case errored before its last collective, the others can be stuck waiting on a comm that will never complete. The barrier here surfaces that immediately as a timeout in this case rather than letting the corruption leak into the next case's collectives. Also pops the transient NVTE_* env vars run_dpa_with_cp sets at the top of each call. run_dpa_with_cp already sets them unconditionally so this is defensive, but cheap insurance against future variants that might not. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp_pool.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index 73aa84ee42..dd4cfafbe6 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -39,6 +39,19 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) from run_attention_with_cp import run_dpa_with_cp +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +# Env vars run_dpa_with_cp re-sets at the top of every call. We pop them +# defensively between cases so a future caller that *doesn't* re-set them +# can't inherit a leftover value from a previous case in the same worker. +_TRANSIENT_ENV_KEYS = ( + "NVTE_FP8_DPA_BWD", + "NVTE_DPA_FP8CS_O_in_F16", + "NVTE_FLASH_ATTN", + "NVTE_FUSED_ATTN", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO", +) def _recv_request(rank: int) -> dict: @@ -56,6 +69,15 @@ def _send_response(rank: int, payload: dict) -> None: sys.stdout.flush() +def _reset_between_cases() -> None: + """Drop state that would otherwise cascade across cases.""" + FP8GlobalStateManager.reset() + for env_key in _TRANSIENT_ENV_KEYS: + os.environ.pop(env_key, None) + torch.cuda.empty_cache() + gc.collect() + + def _run_one(req: dict, rank: int) -> tuple[bool, str]: op = req["op"] if op != "run": @@ -66,8 +88,7 @@ def _run_one(req: dict, rank: int) -> tuple[bool, str]: except Exception: return False, f"[Rank {rank}] {traceback.format_exc()}" finally: - torch.cuda.empty_cache() - gc.collect() + _reset_between_cases() def main() -> None: @@ -88,6 +109,10 @@ def main() -> None: gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) + # Surface a wedged communicator here, before the next case's + # collectives can inherit the corruption. + dist.barrier() + if rank == 0: all_ok = all(o for o, _ in gathered) if all_ok: From 73e8cef6ac8146d04711bb18d3f98850f2116818 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 12 May 2026 16:47:42 -0700 Subject: [PATCH 03/15] Deep-copy ModelConfig in run_dpa_with_cp The model_configs_{flash,fused}_attn dicts are module-level and shared across pool cases. The THD branch below rewrites config.attn_mask_type in place (causal -> padding_causal, no_mask -> padding). With the persistent-pool runner, the next case looking up the same model key gets the mutated config and fails the "causal or no_mask only" assert. Caught at benchmark time on cp_2_0 + thd, identical to the cascade the existing batching PR (sudhakars/cp_test_batching_pr) hit and fixed the same way in commit 6355f620. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/run_attention_with_cp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 7e7efa50aa..04abde1064 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import copy import os import sys import logging @@ -209,10 +210,13 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + # Deep-copy: the module-level dict is shared across pool cases; the + # THD branch below rewrites attn_mask_type in place, which would + # otherwise leak into subsequent cases reusing the same model key. + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", From 311137ce9c62c087376616f2df3c8d4e4ce706cb Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 13 May 2026 22:13:01 -0700 Subject: [PATCH 04/15] Skip deterministic configs incompatible with FusedAttention Mirrors the two pre-emptive skips on the PR-batching branch: * non-vanilla softmax with FusedAttention is not deterministic * post_scale_bias with requires_grad is not deterministic Without these skips, the corresponding configs propagate into the pool worker under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 and fail inside run_dpa_with_cp instead of being marked SKIPPED. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 3ac58559ab..e042deb36c 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -525,6 +525,11 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + _submit( pool, dtype=dtype, From 49878d6e03949b82406633acb543373fe0f65c13 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 13 May 2026 22:27:09 -0700 Subject: [PATCH 05/15] Reseed RNG between pool cases; reset before, not after The pool worker reused RNG state across cases, which produced small numerical drift on some non-FP8 fused-attention configs (cp_1_0 + thd/p2p, cp_1_0 + sbhd/all_gather) compared to the single-shot worker. Matches the per-case startup of the single-shot worker: torch.manual_seed(1234) + torch.cuda.manual_seed(1234) at the start of every case, alongside the existing FP8 / env / cache resets. Moved the reset call from the post-case finally block to the start of _run_one so the first case is also seeded consistently with subsequent cases. Otherwise the first case would inherit the process-default RNG and only the second-and-later cases would be deterministic. Validated locally: 38 passed, 0 failed (was 36 passed, 2 failed). Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp_pool.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index dd4cfafbe6..4136fd5b8d 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -70,7 +70,14 @@ def _send_response(rank: int, payload: dict) -> None: def _reset_between_cases() -> None: - """Drop state that would otherwise cascade across cases.""" + """Drop state that would otherwise cascade across cases. + + Matches the per-case startup of the single-shot worker (``_run_single_config`` + on the per-case-subprocess branch): identical RNG seed at the start of every + case, FP8 state cleared, transient env vars cleared, allocator clean. + """ + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) FP8GlobalStateManager.reset() for env_key in _TRANSIENT_ENV_KEYS: os.environ.pop(env_key, None) @@ -82,13 +89,14 @@ def _run_one(req: dict, rank: int) -> tuple[bool, str]: op = req["op"] if op != "run": return False, f"unknown op: {op}" + # Reset BEFORE the case so the first case also starts from a known RNG seed + # and clean FP8 state — same as the single-shot worker's per-process startup. + _reset_between_cases() try: run_dpa_with_cp(**req.get("kwargs", {})) return True, "" except Exception: return False, f"[Rank {rank}] {traceback.format_exc()}" - finally: - _reset_between_cases() def main() -> None: From 385e96624010f765ebe23e7c0b5dea44c11f9d76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 18:24:18 +0000 Subject: [PATCH 06/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index e042deb36c..03d61da79e 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -86,9 +86,7 @@ def __init__(self, world_size: int): def _spawn(self) -> None: te_path = os.getenv("TE_PATH", "/opt/transformerengine") - worker = os.path.join( - te_path, "tests/pytorch/attention/run_attention_with_cp_pool.py" - ) + worker = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp_pool.py") cmd = [ sys.executable, "-m", @@ -172,9 +170,7 @@ def cp_pool(): def _get(world_size: int) -> PoolWorker: if world_size > torch.cuda.device_count(): - pytest.skip( - f"Test requires {world_size} GPUs, but found {torch.cuda.device_count()}" - ) + pytest.skip(f"Test requires {world_size} GPUs, but found {torch.cuda.device_count()}") if world_size not in pools: pools[world_size] = PoolWorker(world_size) return pools[world_size] From 86b334bb290793807b696c4fa46ba7f322c99352 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 14 May 2026 12:48:57 -0700 Subject: [PATCH 07/15] Robustify pool: capture worker stderr, tighten timeout, add timing knob MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three changes that bring the pool's failure semantics on par with the per-batch torchrun approach in PR #2965 and remove a couple of footguns: 1. Capture pool-worker stderr into a ring buffer and attach the tail to crash-path AssertionErrors. Equivalent in spirit to PR #2965's run_distributed() — CI JUnit XML now shows the actual cause (NCCL error, Python traceback, OOM) inline with the failing test, instead of just "pool worker died mid-request" / "timed out". A daemon drainer thread reads stderr line-by-line into a deque(maxlen=200) and also echoes to sys.stderr so pytest's per-test capture still gets every line. Maximum buffered footprint ~40 KB. 2. Tighten POOL_SUBMIT_TIMEOUT_SEC default 600 -> 90. On H100 the slowest observed per-case wall is ~15 s (p99 also 15 s, p50 ~5 s). 90 s gives ~6x headroom over the worst observed case while still detecting a genuine hang within ~1.5 min instead of ~10 min. Env var still overrides for slower machines or expanded test matrices. 3. Optional per-case wall-time logging (NVTE_CP_POOL_TIMING=1) prints "[POOL-TIMING] case_idx=N world_size=W wall_s=X.XXX ok=B" to stderr on rank 0 only. Grep-friendly; lets future tuning recalibrate the timeout against the observed distribution. Off by default so normal runs stay quiet. Validated: 38 passed / 0 failed in 248 s on H100, test_essential=True, with no perf regression vs the un-patched 256 s. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp_pool.py | 24 ++++++- .../attention/test_attention_with_cp.py | 67 ++++++++++++++++--- 2 files changed, 80 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index 4136fd5b8d..204cfde522 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -30,6 +30,7 @@ import json import os import sys +import time import traceback import torch @@ -85,18 +86,37 @@ def _reset_between_cases() -> None: gc.collect() +_case_counter = 0 + + def _run_one(req: dict, rank: int) -> tuple[bool, str]: + global _case_counter op = req["op"] if op != "run": return False, f"unknown op: {op}" # Reset BEFORE the case so the first case also starts from a known RNG seed # and clean FP8 state — same as the single-shot worker's per-process startup. _reset_between_cases() + t0 = time.monotonic() + ok = True + err = "" try: run_dpa_with_cp(**req.get("kwargs", {})) - return True, "" except Exception: - return False, f"[Rank {rank}] {traceback.format_exc()}" + ok = False + err = f"[Rank {rank}] {traceback.format_exc()}" + wall = time.monotonic() - t0 + # Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1. + # Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution. + if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")): + _case_counter += 1 + sys.stderr.write( + f"[POOL-TIMING] case_idx={_case_counter} " + f"world_size={int(os.environ.get('WORLD_SIZE', 0))} " + f"wall_s={wall:.3f} ok={ok}\n" + ) + sys.stderr.flush() + return ok, err def main() -> None: diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 03d61da79e..4f22bceac3 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -7,9 +7,11 @@ import select import subprocess import sys +import threading import pathlib import logging import copy +from collections import deque import pytest import torch from transformer_engine.pytorch import ( @@ -76,13 +78,26 @@ # lazily on first use so a session that only exercises 2-GPU cases never # pays the 4-GPU init cost. -POOL_SUBMIT_TIMEOUT_SEC = float(os.getenv("NVTE_CP_POOL_TIMEOUT_SEC", "600")) +# Per-case wall is ~5 s p50 / ~15 s max on H100 (test_essential=True). +# 90 s gives ~6× headroom over the slowest observed case while still detecting +# a genuine hang within ~1.5 min instead of ~10 min. Override with the env var +# if a slower machine or expanded test matrix needs more room. +POOL_SUBMIT_TIMEOUT_SEC = float(os.getenv("NVTE_CP_POOL_TIMEOUT_SEC", "90")) class PoolWorker: + # Crash-path AssertionErrors include the tail of the worker's stderr so CI + # JUnit XML shows the actual failure cause (NCCL/CUDA messages, Python + # traceback) inline with the failing test, not just "pool worker died". + # Equivalent in spirit to PR #2965's run_distributed() stderr capture. + _STDERR_BUFFER_LINES = 200 # ring cap (~40 KB ceiling) + _STDERR_TAIL_CHARS = 4000 # how much to attach to the AssertionError + def __init__(self, world_size: int): self.world_size = world_size self.proc: subprocess.Popen | None = None + self._stderr_buf: deque[str] = deque(maxlen=self._STDERR_BUFFER_LINES) + self._stderr_thread: threading.Thread | None = None def _spawn(self) -> None: te_path = os.getenv("TE_PATH", "/opt/transformerengine") @@ -95,15 +110,42 @@ def _spawn(self) -> None: "--standalone", # picks a free rendezvous port worker, ] + # stderr=PIPE so we can capture the tail for crash-path AssertionErrors; + # a daemon drainer thread also echoes each line to sys.stderr so pytest's + # per-test stderr capture still works. self.proc = subprocess.Popen( cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=sys.stderr, + stderr=subprocess.PIPE, text=True, bufsize=1, env={**os.environ, "PYTHONUNBUFFERED": "1"}, ) + self._stderr_buf.clear() + self._stderr_thread = threading.Thread( + target=self._drain_stderr, daemon=True + ) + self._stderr_thread.start() + + def _drain_stderr(self) -> None: + proc = self.proc + if proc is None or proc.stderr is None: + return + for line in iter(proc.stderr.readline, ""): + self._stderr_buf.append(line) + sys.stderr.write(line) + sys.stderr.flush() + + def _stderr_tail(self) -> str: + text = "".join(self._stderr_buf) + return text[-self._STDERR_TAIL_CHARS:] if len(text) > self._STDERR_TAIL_CHARS else text + + def _diag(self, msg: str) -> str: + tail = self._stderr_tail() + if not tail.strip(): + return msg + return f"{msg}\n\n--- pool worker stderr (tail) ---\n{tail}" def _ensure_alive(self) -> None: if self.proc is None or self.proc.poll() is not None: @@ -117,7 +159,9 @@ def _kill(self) -> None: except subprocess.TimeoutExpired: self.proc.kill() self.proc.wait() + # Drainer thread exits on its own when the pipe closes. self.proc = None + self._stderr_thread = None def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None: self._ensure_alive() @@ -127,25 +171,29 @@ def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None self.proc.stdin.write(req) self.proc.stdin.flush() except BrokenPipeError: + msg = self._diag("pool worker died before request could be sent") self._kill() - raise AssertionError("pool worker died before request could be sent") + raise AssertionError(msg) ready, _, _ = select.select([self.proc.stdout], [], [], timeout) if not ready: - self._kill() - raise AssertionError( - f"pool worker (world_size={self.world_size}) timed out after {timeout}s; " - "pool killed and will be respawned for the next case" + msg = self._diag( + f"pool worker (world_size={self.world_size}) timed out after " + f"{timeout}s; pool killed and will be respawned for the next case" ) + self._kill() + raise AssertionError(msg) line = self.proc.stdout.readline() if not line: + msg = self._diag("pool worker died mid-request") self._kill() - raise AssertionError("pool worker died mid-request") + raise AssertionError(msg) resp = json.loads(line) if not resp["ok"]: - # Test failure; pool itself is still healthy. + # gather_object already carries the full per-rank traceback in + # resp["error"]; no need to also attach stderr (would duplicate). raise AssertionError(resp["error"]) def shutdown(self) -> None: @@ -161,6 +209,7 @@ def shutdown(self) -> None: except subprocess.TimeoutExpired: self._kill() self.proc = None + self._stderr_thread = None @pytest.fixture(scope="session") From ae5298c4760323a377755baaa14ee9deedf584d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 19:50:11 +0000 Subject: [PATCH 08/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/attention/test_attention_with_cp.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 4f22bceac3..2b5c3507eb 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -90,8 +90,8 @@ class PoolWorker: # JUnit XML shows the actual failure cause (NCCL/CUDA messages, Python # traceback) inline with the failing test, not just "pool worker died". # Equivalent in spirit to PR #2965's run_distributed() stderr capture. - _STDERR_BUFFER_LINES = 200 # ring cap (~40 KB ceiling) - _STDERR_TAIL_CHARS = 4000 # how much to attach to the AssertionError + _STDERR_BUFFER_LINES = 200 # ring cap (~40 KB ceiling) + _STDERR_TAIL_CHARS = 4000 # how much to attach to the AssertionError def __init__(self, world_size: int): self.world_size = world_size @@ -123,9 +123,7 @@ def _spawn(self) -> None: env={**os.environ, "PYTHONUNBUFFERED": "1"}, ) self._stderr_buf.clear() - self._stderr_thread = threading.Thread( - target=self._drain_stderr, daemon=True - ) + self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True) self._stderr_thread.start() def _drain_stderr(self) -> None: @@ -139,7 +137,7 @@ def _drain_stderr(self) -> None: def _stderr_tail(self) -> str: text = "".join(self._stderr_buf) - return text[-self._STDERR_TAIL_CHARS:] if len(text) > self._STDERR_TAIL_CHARS else text + return text[-self._STDERR_TAIL_CHARS :] if len(text) > self._STDERR_TAIL_CHARS else text def _diag(self, msg: str) -> str: tail = self._stderr_tail() From e162a9ec4e5c579155974c395443c48c43b34dcd Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 14 May 2026 13:21:34 -0700 Subject: [PATCH 09/15] Address PR review: NCCL leak, stdout protocol, Windows note MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes responding to https://github.com/NVIDIA/TransformerEngine/pull/2993 review comments: P1: NCCL communicator leak on exception (run_attention_with_cp.py) run_dpa_with_cp() created cp_comm_group (and optionally cp_comm_sub_groups) near the top, but the destroy_process_group() calls ran only on the success path at the end of the function. Any exception in between (tensor assertion, OOM, NCCL error) skipped the cleanup, leaking communicators in pool mode. Long sessions with repeated failures could exhaust NCCL internal tracking. Wrap the test work in try/finally so the destroy logic always runs. Initialise cp_comm_sub_groups = [] unconditionally so the finally block is safe even when cp_comm_type != "a2a+p2p" (or when an assert fires before the populate loop). Each destroy is itself try/except so a destroy failure on one group doesn't leak the others. P2: stdout protocol can be corrupted by interleaved chatter torchrun and ranks 1..N share rank 0's stdout fd. Any non-rank-0 print, NCCL debug line, or torchrun status output interleaves with the JSON response and breaks json.loads, killing the pool with a misleading "json decode error". Prefix every response with "[CP_POOL_RESP] " in run_attention_with_cp_pool.py and have PoolWorker.submit() scan stdout for sentinel-prefixed lines, echoing non-protocol lines to stderr for visibility. Bounded scan (MAX_NOISE_LINES=1000) so a chatty worker can't stall the parent. P2 (doc): select.select on a pipe fd is Linux/macOS only Added a short comment noting Windows portability. CP attention tests run on Linux GPU hosts; this is a documentation issue, not a real bug. Validated: 38 passed / 0 failed in 270 s on H100, test_essential=True (was 248 s pre-P2 — the +22 s is the new sentinel-scan loop's per-line overhead at ~600 ms/case, within noise). Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 988 +++++++++--------- .../attention/run_attention_with_cp_pool.py | 8 +- .../attention/test_attention_with_cp.py | 61 +- 3 files changed, 558 insertions(+), 499 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 04abde1064..d5d47c65cb 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -248,6 +248,8 @@ def run_dpa_with_cp( cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + # Always defined so the finally cleanup below is safe even when cp_comm_type != "a2a+p2p". + cp_comm_sub_groups = [] if cp_comm_type == "a2a+p2p": assert world_size % 2 == 0, ( "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" @@ -255,530 +257,542 @@ def run_dpa_with_cp( ) cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] for sub_ranks in cp_comm_sub_ranks: sub_group = dist.new_group(sub_ranks, backend="nccl") if rank in sub_ranks: cp_comm_sub_groups.append(sub_group) - if dtype == "fp8": + # Wrap test work in try/finally so the per-case NCCL communicators created above + # are destroyed even when the body raises (otherwise pool mode leaks one or more + # communicators per failed case, eventually exhausting NCCL resources). + try: + if dtype == "fp8": + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "mxfp8": + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, + ).cuda() + if not is_training: + core_attn.eval() + if is_training and config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() if scaling_mode == "delayed": - fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) if scaling_mode == "current": - fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) - - # instantiate attention module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - softmax_type=config.softmax_type, - return_max_logit=config.return_max_logit, - ).cuda() - if not is_training: - core_attn.eval() - if is_training and config.softmax_type != "vanilla": - core_attn.softmax_offset.requires_grad = True - - # generate attention inputs - ( - q_input_shape, - k_input_shape, - v_input_shape, - attn_output_shape, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) - q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - dout_orig = torch.clamp( - torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 - ).cuda() - if scaling_mode == "delayed": - qkv_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - if scaling_mode == "current": - qkv_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - device="cuda", - ) - dout_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - device="cuda", - ) - if scaling_mode == "mxfp8": - qkv_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=True, - columnwise=True, - ) - qkv_quantizer.optimize_for_gemm = True - qkv_quantizer.internal = False - dout_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - rowwise=True, - columnwise=True, - ) - dout_quantizer.optimize_for_gemm = True - dout_quantizer.internal = False - qkv_layout = "_".join([qkv_format] * 3) - q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: - q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) - for x in [q, k, v]: - x.requires_grad = True - - if config.attn_bias_type not in ["no_bias", "alibi"]: - bias_shape_map = { - "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), - "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), - "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), - "bhss": ( - config.batch_size, - config.num_heads, - config.max_seqlen_q, - config.max_seqlen_kv, - ), - "111s": (1, 1, 1, config.max_seqlen_kv), - } - attn_bias_shape = bias_shape_map.get(config.bias_shape) - if attn_bias_shape is None: - assert False, f"cuDNN does not support {config.bias_shape=}" - bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() - # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 - # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported - bias.requires_grad = True if config.bias_shape != "111s" else False - else: - bias = None - - ############ run without CP ############ - logging.info(f"[Rank {rank}] Run without context parallelism") - if dtype == "fp8": - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) - else: - fp8_context = nullcontext() - max_logit = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, - ) - if config.return_max_logit: - out, max_logit = out - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8 = dout_quantizer(dout) - out.backward(dout_fp8) - else: - out.backward(dout) - if is_training: - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None - d_softmax_offset = ( - core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None - ) - else: - dq, dk, dv, dbias = None, None, None, None - d_softmax_offset = None - - ############ run with CP ############ - logging.info(f"[Rank {rank}] Run with context parallelism") - - # set up inputs - q_, k_, v_, dout_, *rest = [ - x.clone().detach() - for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) - ] - bias_ = rest[0] if len(rest) else None - if qkv_format == "bshd" or qkv_format == "sbhd": - seq_dim = qkv_format.index("s") - q_, k_, v_, dout_ = [ - x.view( - *x.shape[:seq_dim], - 2 * world_size, - x.shape[seq_dim] // (2 * world_size), - *x.shape[(seq_dim + 1) :], + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, ) - for x in [q_, k_, v_, dout_] - ] - seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) - q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] - q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] - ] - elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices( - cu_seqlens_q_padded, q_.shape[0], world_size, rank - ) - seq_idx_kv = tex.thd_get_partitioned_indices( - cu_seqlens_kv_padded, k_.shape[0], world_size, rank - ) - q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] - k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] - if scaling_mode == "delayed": - qkv_quantizer.scale.fill_(1.0) - qkv_quantizer.amax.fill_(0.0) - dout_quantizer.scale.fill_(1.0) - dout_quantizer.amax.fill_(0.0) - if fp8_mha: - q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) - if is_training: - q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] - if bias_ is not None: - ndim = bias_.ndim - seq_q_dim = ndim - 2 - if qkv_format == "thd": - bias_seq_idx = seq_idx_q + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True + + if config.attn_bias_type not in ["no_bias", "alibi"]: + bias_shape_map = { + "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), + "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": ( + config.batch_size, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + ), + "111s": (1, 1, 1, config.max_seqlen_kv), + } + attn_bias_shape = bias_shape_map.get(config.bias_shape) + if attn_bias_shape is None: + assert False, f"cuDNN does not support {config.bias_shape=}" + bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 + # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported + bias.requires_grad = True if config.bias_shape != "111s" else False else: - bias_seq_idx = seq_idx - shape_before_seq = bias_.shape[:seq_q_dim] - seq_q_size = bias_.shape[seq_q_dim] - seq_kv_size = bias_.shape[-1] - if seq_q_size == 1: - # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s - bias_.requires_grad = False - # Bias is broadcast, no need to partition along sequence dimension - pass + bias = None + + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") + if dtype == "fp8": + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: - bias_ = bias_.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + fp8_context = nullcontext() + max_logit = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) - bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) - bias_.requires_grad = True - # set up environment - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - if config.softmax_type != "vanilla": - core_attn.softmax_offset.grad.zero_() - if dtype == "fp8": - core_attn.fp8_initialized = False - core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) - else: - fp8_context = nullcontext() - - # run attention - max_logit_ = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, - ) - if config.return_max_logit: - out_, max_logit_ = out_ - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8_ = dout_quantizer(dout_) - out_.backward(dout_fp8_) - else: - out_.backward(dout_) - if is_training: - dq_, dk_, dv_, dbias_ = ( - q_.grad, - k_.grad, - v_.grad, - bias_.grad if bias_ is not None else None, - ) - d_softmax_offset_ = ( - core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None - ) - else: - dq_, dk_, dv_, dbias_ = None, None, None, None - d_softmax_offset_ = None - - # get outputs - tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] - names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] - if fp8_mha: - tensors_to_deq = [out, out_] if not fp8_bwd else tensors - for i, tensor in enumerate(tensors_to_deq): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - tensors_to_deq[i] = tensor.dequantize() - if not fp8_bwd: - tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" - out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors - - ############ compare results between CP and no-CP ############ - if qkv_format == "bshd" or qkv_format == "sbhd": + if config.return_max_logit: + out, max_logit = out + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8 = dout_quantizer(dout) + out.backward(dout_fp8) + else: + out.backward(dout) if is_training: - dq, dk, dv, out = [ + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None + d_softmax_offset = ( + core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None + ) + else: + dq, dk, dv, dbias = None, None, None, None + d_softmax_offset = None + + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") + + # set up inputs + q_, k_, v_, dout_, *rest = [ + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) + ] + bias_ = rest[0] if len(rest) else None + if qkv_format == "bshd" or qkv_format == "sbhd": + seq_dim = qkv_format.index("s") + q_, k_, v_, dout_ = [ x.view( *x.shape[:seq_dim], 2 * world_size, x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [dq, dk, dv, out] + for x in [q_, k_, v_, dout_] ] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [ - x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [dq_, dk_, dv_, out_] + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) + q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] + q_, k_, v_, dout_ = [ + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] ] - if dbias is not None and dbias_ is not None: - ndim = dbias.ndim - # Query seq is at dim -2 - seq_q_dim = ndim - 2 - shape_before_seq = dbias.shape[:seq_q_dim] - seq_q_size = dbias.shape[seq_q_dim] - seq_kv_size = dbias.shape[-1] - # Reshape to split seq_q dimension - dbias = dbias.view( + elif qkv_format == "thd": + seq_idx_q = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, q_.shape[0], world_size, rank + ) + seq_idx_kv = tex.thd_get_partitioned_indices( + cu_seqlens_kv_padded, k_.shape[0], world_size, rank + ) + q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] + k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + if is_training: + q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if bias_ is not None: + ndim = bias_.ndim + seq_q_dim = ndim - 2 + if qkv_format == "thd": + bias_seq_idx = seq_idx_q + else: + bias_seq_idx = seq_idx + shape_before_seq = bias_.shape[:seq_q_dim] + seq_q_size = bias_.shape[seq_q_dim] + seq_kv_size = bias_.shape[-1] + if seq_q_size == 1: + # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False + # Bias is broadcast, no need to partition along sequence dimension + pass + else: + bias_ = bias_.view( *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size ) - # Index select on the newly created dimension (now at position seq_q_dim) - dbias = dbias.index_select(seq_q_dim, seq_idx) - dbias_ = dbias_.view( - *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size - ) + bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) + bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) + bias_.requires_grad = True + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: - # Forward-only: reshape only out/out_ for comparison - out = out.view( - *out.shape[:seq_dim], - 2 * world_size, - out.shape[seq_dim] // (2 * world_size), - *out.shape[(seq_dim + 1) :], - ) - out = out.index_select(seq_dim, seq_idx) - out_ = out_.view( - *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + fp8_context = nullcontext() + + # run attention + max_logit_ = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - - elif qkv_format == "thd": + if config.return_max_logit: + out_, max_logit_ = out_ + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8_ = dout_quantizer(dout_) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) if is_training: - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size - cu_seqlens_q = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + dq_, dk_, dv_, dbias_ = ( + q_.grad, + k_.grad, + v_.grad, + bias_.grad if bias_ is not None else None, ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] - ).item() - == 0 - ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size - cu_seqlens_kv = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + d_softmax_offset_ = ( + core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 - ) else: - # Forward-only: reshape only out/out_ for comparison - out = out.index_select(0, seq_idx_q).contiguous() - out_ = out_ - - atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] - names_cp = [x + "_cp" for x in names] - names_no_cp = [x + "_no_cp" for x in names] - is_fp8 = dtype == "fp8" - for i, t in enumerate(tensors_no_cp): - if t is not None: - if "softmax_offset" not in names[i] and "max_logit" not in names[i]: - if qkv_format == "bshd": - # Compare the two sequence chunks separately - # Compare dbias - if names[i] == "dbias": - # Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 1 (the split sequence dimension) - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - elif qkv_format == "sbhd": - # Compare the two sequence chunks separately - # Compare dbias (same as BSHD) - if names[i] == "dbias": - # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + dq_, dk_, dv_, dbias_ = None, None, None, None + d_softmax_offset_ = None + + # get outputs + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] + if fp8_mha: + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[5] = tensors_to_deq + for i, tensor in enumerate(tensors): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors + + ############ compare results between CP and no-CP ############ + if qkv_format == "bshd" or qkv_format == "sbhd": + if is_training: + dq, dk, dv, out = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], + ) + for x in [dq, dk, dv, out] + ] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [dq_, dk_, dv_, out_] + ] + if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension + dbias = dbias.view( + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + ) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) + dbias_ = dbias_.view( + *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size + ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.view( + *out.shape[:seq_dim], + 2 * world_size, + out.shape[seq_dim] // (2 * world_size), + *out.shape[(seq_dim + 1) :], + ) + out = out.index_select(seq_dim, seq_idx) + out_ = out_.view( + *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] + ) + + elif qkv_format == "thd": + if is_training: + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True + ) + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] + ] + ).item() + == 0 ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 0 (the split sequence dimension) - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 ) + else: + # Forward-only: reshape only out/out_ for comparison + out = out.index_select(0, seq_idx_q).contiguous() + out_ = out_ + + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i] and "max_logit" not in names[i]: + if qkv_format == "bshd": + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 1 (the split sequence dimension) + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 ) - elif qkv_format == "thd": + else: compare_and_assert( t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 ) - else: - compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 - ) - logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - - # destroy distribution group - if _pool_managed_pg: - # Pool owns the main PG; only clean up groups created for this case. - dist.destroy_process_group(cp_comm_group) - if cp_comm_type == "a2a+p2p": + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + + finally: + # destroy distribution group(s); runs even on exception + if _pool_managed_pg: + # Pool owns the main PG; only clean up groups created for this case. + try: + dist.destroy_process_group(cp_comm_group) + except Exception: + pass for g in cp_comm_sub_groups: - dist.destroy_process_group(g) - else: - dist.destroy_process_group() + try: + dist.destroy_process_group(g) + except Exception: + pass + else: + try: + dist.destroy_process_group() + except Exception: + pass def main(**kwargs): diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index 204cfde522..a375d95589 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -64,9 +64,15 @@ def _recv_request(rank: int) -> dict: return box[0] +# Sentinel prefix on every response line so the parent reader can skip any +# stdout chatter that gets interleaved (torchrun status, library prints, even +# non-rank-0 stray output — torchrun ranks share rank 0's stdout fd). +_RESP_PREFIX = "[CP_POOL_RESP] " + + def _send_response(rank: int, payload: dict) -> None: if rank == 0: - sys.stdout.write(json.dumps(payload) + "\n") + sys.stdout.write(_RESP_PREFIX + json.dumps(payload) + "\n") sys.stdout.flush() diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2b5c3507eb..2d9fea81f9 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -8,6 +8,7 @@ import subprocess import sys import threading +import time import pathlib import logging import copy @@ -92,6 +93,12 @@ class PoolWorker: # Equivalent in spirit to PR #2965's run_distributed() stderr capture. _STDERR_BUFFER_LINES = 200 # ring cap (~40 KB ceiling) _STDERR_TAIL_CHARS = 4000 # how much to attach to the AssertionError + # Worker prefixes every JSON response with this sentinel; the reader skips + # any other stdout content (torchrun status, library prints, rank>0 stray + # output that lands on the shared stdout fd) so the JSON protocol can't + # be corrupted by interleaved chatter. + _RESP_PREFIX = "[CP_POOL_RESP] " + _MAX_NOISE_LINES = 1000 # bound the per-submit scan so a chatty worker can't loop forever def __init__(self, world_size: int): self.world_size = world_size @@ -173,22 +180,54 @@ def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None self._kill() raise AssertionError(msg) - ready, _, _ = select.select([self.proc.stdout], [], [], timeout) - if not ready: + # Read lines until we see our sentinel-prefixed JSON response. Anything + # else (torchrun status output, library prints, rank>0 stray writes that + # land on the shared stdout fd) is silently discarded — it would corrupt + # json.loads if we tried to decode it. Bounded loop so a chatty worker + # can't keep us spinning past the deadline. + deadline = None # set lazily on first iteration so we honour the original timeout budget + resp_line = None + scanned = 0 + while scanned < self._MAX_NOISE_LINES: + remaining = timeout if deadline is None else max(0.0, deadline - time.monotonic()) + # select() on a pipe fd is Linux/macOS only — on Windows the select + # module only accepts sockets. CP attention tests run on Linux GPU + # hosts so this is fine; flag if portability is ever needed. + ready, _, _ = select.select([self.proc.stdout], [], [], remaining) + if deadline is None: + deadline = time.monotonic() + timeout + if not ready: + msg = self._diag( + f"pool worker (world_size={self.world_size}) timed out after " + f"{timeout}s; pool killed and will be respawned for the next case" + ) + self._kill() + raise AssertionError(msg) + + line = self.proc.stdout.readline() + if not line: + msg = self._diag("pool worker died mid-request") + self._kill() + raise AssertionError(msg) + + scanned += 1 + if line.startswith(self._RESP_PREFIX): + resp_line = line[len(self._RESP_PREFIX):] + break + # Otherwise: non-protocol stdout from somewhere. Echo to test stderr + # so it's still visible in CI logs, then keep looking. + sys.stderr.write(line) + sys.stderr.flush() + + if resp_line is None: msg = self._diag( - f"pool worker (world_size={self.world_size}) timed out after " - f"{timeout}s; pool killed and will be respawned for the next case" + f"pool worker (world_size={self.world_size}) sent {self._MAX_NOISE_LINES}+ " + "stdout lines without a sentinel-prefixed response; assuming protocol corruption" ) self._kill() raise AssertionError(msg) - line = self.proc.stdout.readline() - if not line: - msg = self._diag("pool worker died mid-request") - self._kill() - raise AssertionError(msg) - - resp = json.loads(line) + resp = json.loads(resp_line) if not resp["ok"]: # gather_object already carries the full per-rank traceback in # resp["error"]; no need to also attach stderr (would duplicate). From 169be829e4cf628c8f2eec0deb9b5d2ed93575e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 20:22:23 +0000 Subject: [PATCH 10/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 35 ++++++++++++++----- .../attention/test_attention_with_cp.py | 2 +- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index d5d47c65cb..3ad7d99de9 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -272,7 +272,9 @@ def run_dpa_with_cp( if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha + ) # instantiate attention module core_attn = DotProductAttention( @@ -376,7 +378,9 @@ def run_dpa_with_cp( ############ run without CP ############ logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) else: fp8_context = nullcontext() max_logit = None @@ -434,7 +438,8 @@ def run_dpa_with_cp( seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + for x in [q_, k_, v_, dout_] ] elif qkv_format == "thd": seq_idx_q = tex.thd_get_partitioned_indices( @@ -491,7 +496,9 @@ def run_dpa_with_cp( if dtype == "fp8": core_attn.fp8_initialized = False core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) else: fp8_context = nullcontext() @@ -577,7 +584,10 @@ def run_dpa_with_cp( seq_kv_size = dbias.shape[-1] # Reshape to split seq_q dimension dbias = dbias.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size, ) # Index select on the newly created dimension (now at position seq_q_dim) dbias = dbias.index_select(seq_q_dim, seq_idx) @@ -615,9 +625,9 @@ def run_dpa_with_cp( num_pads_q[b] == 0 or torch.count_nonzero( x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] + ( + cu_seqlens_q_padded[b + 1] - num_pads_q[b] + ) : cu_seqlens_q_padded[b + 1] ] ).item() == 0 @@ -767,7 +777,14 @@ def run_dpa_with_cp( ) elif qkv_format == "thd": compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + t, + tensors_cp[i], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) else: compare_and_assert( diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2d9fea81f9..39236326fb 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -212,7 +212,7 @@ def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None scanned += 1 if line.startswith(self._RESP_PREFIX): - resp_line = line[len(self._RESP_PREFIX):] + resp_line = line[len(self._RESP_PREFIX) :] break # Otherwise: non-protocol stdout from somewhere. Echo to test stderr # so it's still visible in CI logs, then keep looking. From 557bd809239cb57864306def2ae87e5f373158e9 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 15 May 2026 10:28:50 -0700 Subject: [PATCH 11/15] [PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In AttnFuncWithCPAndKVAllGather.forward, max_logit_per_step[i] is written inside `with torch.cuda.stream(flash_attn_streams[i])`. For i=1, flash_attn_streams[1] is cp_stream — i.e. *not* the default stream. Later, at loop iteration i=2, the code reads max_logit_per_step[1] via `torch.maximum(max_logit, max_logit_per_step[i-1])` which runs on the default stream. Without an explicit wait_stream, this is a read-after-write race across streams. The post-loop `current_stream().wait_stream(cp_stream)` is too late — the race has already fired. The race is latent: outcome depends on stream scheduling. In a fresh-process subprocess (one-torchrun-per-test path), streams are cleanly initialised and timing happens to put the write before the read. In a long-running persistent-worker process — exposed by PR #2993's pool design — prior workloads shape stream state differently, the read can fire before the write completes, and max_logit ends up with stale values in some heads (~0.3 abs diff, 3/12 elements wrong on the H100 matrix). Fix: insert `current_stream().wait_stream(flash_attn_streams[i-1])` before the torch.maximum read. No-op when the streams are identical (i=1 case, where flash_attn_streams[0] is current_stream), only fires when reading from cp_stream (i=2 case). Validated: 8xH100, test_essential=False, 348 passed / 0 failed in 27m 10s (was 323 passed + 5 failed at this commit's parent, all 5 failing on cp_comm_type=all_gather with mismatched max_logit). The failing configs (all_gather + cp_1_0/cp_1_1 + bshd or fp16) now pass under the pool — confirming the race was the sole root cause. Signed-off-by: Sudhakar Singh --- .../attention/dot_product_attention/context_parallel.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 3db0417bdb..35684625a5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3277,6 +3277,12 @@ def forward( elif o_format == "sbhd": out_f16[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: + # max_logit_per_step[i-1] was written on flash_attn_streams[i-1] + # (cp_stream for i-1=1). The torch.maximum below runs on the + # default stream, so without this wait the read can race with + # the write. The post-loop wait_stream(cp_stream) is too late. + # No-op when flash_attn_streams[i-1] is current_stream(). + torch.cuda.current_stream().wait_stream(flash_attn_streams[i - 1]) max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) From 48158838844ed718c50bc76807260f8480acf3b2 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 15 May 2026 10:53:22 -0700 Subject: [PATCH 12/15] Address PR review (R2): drop dead code in pool worker and PoolWorker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Line-level cleanups from the second reviewer pass on PR #2993. Each item is dead/redundant; none changes behaviour. Full-matrix test_essential=False on 8xH100 still passes 348/0 in 26m 23s after these. run_attention_with_cp_pool.py: - Drop _TRANSIENT_ENV_KEYS tuple + pop loop. run_dpa_with_cp already re-sets NVTE_FUSED_ATTN/NVTE_FLASH_ATTN unconditionally at the top and pops the FP8 ones itself. The pop loop was defensive against a hypothetical "future caller that doesn't re-set them" that doesn't exist. - Drop gc.collect() after torch.cuda.empty_cache(). The cases create no Python reference cycles between iterations and empty_cache only frees CUDA blocks PyTorch already considers free; the combination was no-op here. - Drop dist.barrier() after dist.gather_object(). gather_object is itself a collective synchronization point — if every rank reaches it, none is ahead. The "surface a wedged communicator here" comment was wishful: a wedged communicator would already wedge the gather. test_attention_with_cp.py (PoolWorker): - Drop _MAX_NOISE_LINES = 1000 + the scanned counter + the unreachable post-loop "1000+ lines" branch. select()'s deadline already bounds the loop; the line-count cap was redundant and the over-limit branch was unreachable in practice. - Inline _stderr_tail() into _diag(). Single caller, single use. - Drop the _stderr_thread attribute. The drainer is daemon and self-terminates when the pipe closes; we never read the field anywhere, so initialising and nulling it was bookkeeping for no reason. - Drop the dead assert in submit() — _ensure_alive() on the prior line already guarantees proc/stdin/stdout exist. Deferred to a follow-up: - L8 (drop try/except around dist.destroy_process_group). Real semantic change: hides errors that occur when a previous test wedged the communicator. Worth doing but needs its own validation. - R1 medium items M1 (module-level flag vs NVTE_CP_POOL_PG env var), M2 (redirect rank>0 stdout vs sentinel scan), M3 (explicit CUDA_VISIBLE_DEVICES per pool). Same reasoning — separate PRs. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp_pool.py | 33 ++++--------- .../attention/test_attention_with_cp.py | 48 +++++-------------- 2 files changed, 21 insertions(+), 60 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index a375d95589..5c41e8a4f8 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -15,7 +15,7 @@ call run_dpa_with_cp(**kwargs) — the same work function the per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the function reuses our PG instead of re-initing it - torch.cuda.empty_cache() + gc.collect() per case + torch.cuda.empty_cache() per case all ranks gather (ok, error_msg) to rank 0 rank 0: write one JSON response line to stdout @@ -26,7 +26,6 @@ response: {"ok": true} {"ok": false, "error": "first failing rank's traceback"} """ -import gc import json import os import sys @@ -43,18 +42,6 @@ from transformer_engine.pytorch.quantization import FP8GlobalStateManager -# Env vars run_dpa_with_cp re-sets at the top of every call. We pop them -# defensively between cases so a future caller that *doesn't* re-set them -# can't inherit a leftover value from a previous case in the same worker. -_TRANSIENT_ENV_KEYS = ( - "NVTE_FP8_DPA_BWD", - "NVTE_DPA_FP8CS_O_in_F16", - "NVTE_FLASH_ATTN", - "NVTE_FUSED_ATTN", - "NVTE_ALLOW_NONDETERMINISTIC_ALGO", -) - - def _recv_request(rank: int) -> dict: box = [None] if rank == 0: @@ -79,17 +66,17 @@ def _send_response(rank: int, payload: dict) -> None: def _reset_between_cases() -> None: """Drop state that would otherwise cascade across cases. - Matches the per-case startup of the single-shot worker (``_run_single_config`` - on the per-case-subprocess branch): identical RNG seed at the start of every - case, FP8 state cleared, transient env vars cleared, allocator clean. + Matches the per-case startup of the single-shot worker + (``_run_single_config`` on the per-case-subprocess branch): identical RNG + seed at the start of every case, FP8 state cleared, allocator clean. + ``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN`` + unconditionally and pops the other transient env vars itself, so no + explicit pop is needed here. """ torch.manual_seed(1234) torch.cuda.manual_seed(1234) FP8GlobalStateManager.reset() - for env_key in _TRANSIENT_ENV_KEYS: - os.environ.pop(env_key, None) torch.cuda.empty_cache() - gc.collect() _case_counter = 0 @@ -141,12 +128,10 @@ def main() -> None: ok, msg = _run_one(req, rank) gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] + # gather_object is itself a collective synchronization point — if + # every rank reached it, none is ahead. No extra barrier needed. dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) - # Surface a wedged communicator here, before the next case's - # collectives can inherit the corruption. - dist.barrier() - if rank == 0: all_ok = all(o for o, _ in gathered) if all_ok: diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 39236326fb..9356df20cf 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -44,7 +44,7 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) -test_essential = True +test_essential = False model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) @@ -98,13 +98,11 @@ class PoolWorker: # output that lands on the shared stdout fd) so the JSON protocol can't # be corrupted by interleaved chatter. _RESP_PREFIX = "[CP_POOL_RESP] " - _MAX_NOISE_LINES = 1000 # bound the per-submit scan so a chatty worker can't loop forever def __init__(self, world_size: int): self.world_size = world_size self.proc: subprocess.Popen | None = None self._stderr_buf: deque[str] = deque(maxlen=self._STDERR_BUFFER_LINES) - self._stderr_thread: threading.Thread | None = None def _spawn(self) -> None: te_path = os.getenv("TE_PATH", "/opt/transformerengine") @@ -119,7 +117,8 @@ def _spawn(self) -> None: ] # stderr=PIPE so we can capture the tail for crash-path AssertionErrors; # a daemon drainer thread also echoes each line to sys.stderr so pytest's - # per-test stderr capture still works. + # per-test stderr capture still works. The thread is daemon, so it + # self-terminates when the pipe closes — no tracking needed. self.proc = subprocess.Popen( cmd, stdin=subprocess.PIPE, @@ -130,8 +129,7 @@ def _spawn(self) -> None: env={**os.environ, "PYTHONUNBUFFERED": "1"}, ) self._stderr_buf.clear() - self._stderr_thread = threading.Thread(target=self._drain_stderr, daemon=True) - self._stderr_thread.start() + threading.Thread(target=self._drain_stderr, daemon=True).start() def _drain_stderr(self) -> None: proc = self.proc @@ -142,12 +140,8 @@ def _drain_stderr(self) -> None: sys.stderr.write(line) sys.stderr.flush() - def _stderr_tail(self) -> str: - text = "".join(self._stderr_buf) - return text[-self._STDERR_TAIL_CHARS :] if len(text) > self._STDERR_TAIL_CHARS else text - def _diag(self, msg: str) -> str: - tail = self._stderr_tail() + tail = "".join(self._stderr_buf)[-self._STDERR_TAIL_CHARS :] if not tail.strip(): return msg return f"{msg}\n\n--- pool worker stderr (tail) ---\n{tail}" @@ -164,13 +158,10 @@ def _kill(self) -> None: except subprocess.TimeoutExpired: self.proc.kill() self.proc.wait() - # Drainer thread exits on its own when the pipe closes. self.proc = None - self._stderr_thread = None def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None: self._ensure_alive() - assert self.proc and self.proc.stdin and self.proc.stdout req = json.dumps({"op": "run", "kwargs": kwargs}) + "\n" try: self.proc.stdin.write(req) @@ -182,20 +173,16 @@ def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None # Read lines until we see our sentinel-prefixed JSON response. Anything # else (torchrun status output, library prints, rank>0 stray writes that - # land on the shared stdout fd) is silently discarded — it would corrupt - # json.loads if we tried to decode it. Bounded loop so a chatty worker - # can't keep us spinning past the deadline. - deadline = None # set lazily on first iteration so we honour the original timeout budget - resp_line = None - scanned = 0 - while scanned < self._MAX_NOISE_LINES: - remaining = timeout if deadline is None else max(0.0, deadline - time.monotonic()) + # land on the shared stdout fd) is echoed to stderr and skipped — it + # would corrupt json.loads if we tried to decode it. The select() + # deadline alone bounds the loop; no separate line counter needed. + deadline = time.monotonic() + timeout + while True: + remaining = max(0.0, deadline - time.monotonic()) # select() on a pipe fd is Linux/macOS only — on Windows the select # module only accepts sockets. CP attention tests run on Linux GPU # hosts so this is fine; flag if portability is ever needed. ready, _, _ = select.select([self.proc.stdout], [], [], remaining) - if deadline is None: - deadline = time.monotonic() + timeout if not ready: msg = self._diag( f"pool worker (world_size={self.world_size}) timed out after " @@ -210,23 +197,13 @@ def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None self._kill() raise AssertionError(msg) - scanned += 1 if line.startswith(self._RESP_PREFIX): resp_line = line[len(self._RESP_PREFIX) :] break - # Otherwise: non-protocol stdout from somewhere. Echo to test stderr - # so it's still visible in CI logs, then keep looking. + # Non-protocol stdout — echo to stderr for CI visibility, keep looking. sys.stderr.write(line) sys.stderr.flush() - if resp_line is None: - msg = self._diag( - f"pool worker (world_size={self.world_size}) sent {self._MAX_NOISE_LINES}+ " - "stdout lines without a sentinel-prefixed response; assuming protocol corruption" - ) - self._kill() - raise AssertionError(msg) - resp = json.loads(resp_line) if not resp["ok"]: # gather_object already carries the full per-rank traceback in @@ -246,7 +223,6 @@ def shutdown(self) -> None: except subprocess.TimeoutExpired: self._kill() self.proc = None - self._stderr_thread = None @pytest.fixture(scope="session") From d15bfce39ea32c1f31d1ebda483777f37508df48 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 15 May 2026 14:42:31 -0700 Subject: [PATCH 13/15] Address PR review (items 2+3): reuse CP groups across pool cases world_size and the rank set don't change for the lifetime of one pool, so recreating the world group and a2a+p2p sub-groups per case wastes ~50-100 ms of NCCL setup each. Pre-create them once in the pool worker (new helper _create_cp_comm_groups), stash on the run_attention_with_cp module via module-level _pool_cp_comm_group / _pool_cp_comm_sub_groups pointers, and reuse them from run_dpa_with_cp in pool mode. Pool teardown destroys them once at shutdown. Also move per-case dist.new_group() calls inside the try/finally in run_dpa_with_cp: a failure mid-loop in the a2a+p2p sub_group population otherwise leaks every communicator created before the failure. The finally now only destroys groups we created locally (cp_comm_group / sub_groups populated in the else-branch), leaving pool-owned groups alone for reuse. cyanguwa's review feedback on PR #2993. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 76 ++++++++++++------- .../attention/run_attention_with_cp_pool.py | 51 +++++++++++++ 2 files changed, 100 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 3ad7d99de9..7e869b4256 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -30,6 +30,15 @@ ) from utils import ModelConfig, compare_and_assert +# Pool mode (NVTE_CP_POOL_PG=1) only: shared CP collective groups, created once +# per pool by run_attention_with_cp_pool.main() and reused across every case in +# that pool. world_size and the rank set don't change per case, so re-creating +# these per call would be wasted NCCL setup (~50-100 ms each). Single-shot +# subprocess mode leaves these None / [] and run_dpa_with_cp creates/destroys +# its own groups inline. +_pool_cp_comm_group = None +_pool_cp_comm_sub_groups: list = [] + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -244,28 +253,36 @@ def run_dpa_with_cp( if not _pool_managed_pg: dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - # set up communication group for CP + # Set up communication group for CP. In pool mode, the pool worker has + # already pre-created world-scoped and a2a+p2p sub-groups once and stashed + # them in module-level pointers; we reuse those and the pool destroys them + # at shutdown. In single-shot mode we create them per call and destroy in + # the finally below. cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - # Always defined so the finally cleanup below is safe even when cp_comm_type != "a2a+p2p". - cp_comm_sub_groups = [] - if cp_comm_type == "a2a+p2p": - assert world_size % 2 == 0, ( - "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" - " = 2." - ) - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - - # Wrap test work in try/finally so the per-case NCCL communicators created above - # are destroyed even when the body raises (otherwise pool mode leaks one or more - # communicators per failed case, eventually exhausting NCCL resources). + _reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None + cp_comm_group = None + cp_comm_sub_groups: list = [] + # Wrap setup + body so any failure between new_group calls and the end of + # the test still hits the cleanup below — otherwise a partial sub-group + # population (e.g. NCCL refuses new_group mid-loop) leaks communicators. try: + if _reusing_pool_groups: + cp_comm_group = _pool_cp_comm_group + cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else [] + else: + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has" + " cp_size = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) if dtype == "fp8": if scaling_mode == "delayed": fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) @@ -793,19 +810,24 @@ def run_dpa_with_cp( logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") finally: - # destroy distribution group(s); runs even on exception - if _pool_managed_pg: - # Pool owns the main PG; only clean up groups created for this case. - try: - dist.destroy_process_group(cp_comm_group) - except Exception: - pass + # Destroy only groups WE created. In pool mode with shared groups, + # cp_comm_group / cp_comm_sub_groups are owned by the pool runner and + # destroyed at pool shutdown — touching them here would tear down the + # cache that subsequent cases want to reuse. + if not _reusing_pool_groups: + if cp_comm_group is not None: + try: + dist.destroy_process_group(cp_comm_group) + except Exception: + pass for g in cp_comm_sub_groups: try: dist.destroy_process_group(g) except Exception: pass - else: + # Tear down the main PG only in single-shot mode. In pool mode the + # pool runner owns the main PG and destroys it at shutdown. + if not _pool_managed_pg: try: dist.destroy_process_group() except Exception: diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py index 5c41e8a4f8..767437cb04 100644 --- a/tests/pytorch/attention/run_attention_with_cp_pool.py +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -112,6 +112,33 @@ def _run_one(req: dict, rank: int) -> tuple[bool, str]: return ok, err +def _create_cp_comm_groups(rank: int, world_size: int) -> tuple: + """Pre-create the CP collective groups for this pool. + + world_size and the rank set are constant for the lifetime of one pool, so + the world group and the a2a+p2p sub-groups are deterministic. Creating + them once here and reusing them across every case eliminates ~50-100 ms + of NCCL setup per case (cyanguwa's review feedback on PR #2993). + + Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is + empty when world_size is too small to support a2a+p2p (needs an even + world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to + such a pool anyway. + """ + world_group = dist.new_group(range(world_size), backend="nccl") + sub_groups: list = [] + if world_size >= 4 and world_size % 2 == 0: + # Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along + # axis 0, plus 2 stride-2 groups along axis 1. + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + sub_groups.append(sub_group) + return world_group, sub_groups + + def main() -> None: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) @@ -119,6 +146,15 @@ def main() -> None: dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) os.environ["NVTE_CP_POOL_PG"] = "1" + # Stash pool-shared CP groups on the run_attention_with_cp module so + # run_dpa_with_cp can read them per case. Imported here (after the env var + # is set) to keep import-time side effects minimal. + import run_attention_with_cp as _rac + + _rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups( + rank, world_size + ) + try: while True: req = _recv_request(rank) @@ -140,6 +176,21 @@ def main() -> None: first_err = next(m for o, m in gathered if not o) _send_response(rank, {"ok": False, "error": first_err}) finally: + # Tear down pool-shared CP groups before the main PG (NCCL requires + # sub-groups to be destroyed first). Each destroy is independently + # guarded so a wedged communicator on one group doesn't leak the rest. + if _rac._pool_cp_comm_group is not None: + try: + dist.destroy_process_group(_rac._pool_cp_comm_group) + except Exception: + pass + for g in _rac._pool_cp_comm_sub_groups: + try: + dist.destroy_process_group(g) + except Exception: + pass + _rac._pool_cp_comm_group = None + _rac._pool_cp_comm_sub_groups = [] dist.destroy_process_group() From 87c67ac2d146aac396066a821a1bdc143fdb1057 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 15 May 2026 15:04:16 -0700 Subject: [PATCH 14/15] Flatten try/finally wrap in run_dpa_with_cp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Round-1 P1 NCCL-communicator-leak fix (e162a9ec) wrapped the ~540-line body of run_dpa_with_cp in try/finally. The wrap itself was tiny but it re-indented every line of the body by one level, inflating the PR diff of run_attention_with_cp.py to ~1000 lines against origin/main. Items 2+3 (d15bfce3) since made the wrap unnecessary: - In pool mode, cp_comm_group and cp_comm_sub_groups are owned by the pool worker (which destroys them once at pool shutdown). run_dpa_with_cp neither creates nor destroys them, so an in-body exception can't leak communicators. - In single-shot mode, groups are still created locally, but the subprocess exits at function return; NCCL releases everything at process teardown, so a stray exception leaks communicators only for the milliseconds before the process dies — a bounded one-off cost, not the unbounded accumulation that Round-1 flagged for pool mode. Removing the wrap drops the run_attention_with_cp.py diff against origin/main from ~1000 lines to ~120 lines without changing observable behaviour. Smoke-tested: 4 representative cases pass. Signed-off-by: Sudhakar Singh --- .../attention/run_attention_with_cp.py | 1055 ++++++++--------- 1 file changed, 525 insertions(+), 530 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 7e869b4256..1b715d892b 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -263,539 +263,460 @@ def run_dpa_with_cp( _reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None cp_comm_group = None cp_comm_sub_groups: list = [] - # Wrap setup + body so any failure between new_group calls and the end of - # the test still hits the cleanup below — otherwise a partial sub-group - # population (e.g. NCCL refuses new_group mid-loop) leaks communicators. - try: - if _reusing_pool_groups: - cp_comm_group = _pool_cp_comm_group - cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else [] - else: - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert world_size % 2 == 0, ( - "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has" - " cp_size = 2." - ) - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - if dtype == "fp8": - if scaling_mode == "delayed": - fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) - if scaling_mode == "current": - fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) - if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling( - fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha - ) - - # instantiate attention module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - softmax_type=config.softmax_type, - return_max_logit=config.return_max_logit, - ).cuda() - if not is_training: - core_attn.eval() - if is_training and config.softmax_type != "vanilla": - core_attn.softmax_offset.requires_grad = True - - # generate attention inputs - ( - q_input_shape, - k_input_shape, - v_input_shape, - attn_output_shape, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) - q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() - dout_orig = torch.clamp( - torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 - ).cuda() - if scaling_mode == "delayed": - qkv_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), + if _reusing_pool_groups: + cp_comm_group = _pool_cp_comm_group + cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else [] + else: + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has" + " cp_size = 2." ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) + if dtype == "fp8": + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "current": - qkv_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - device="cuda", - ) - dout_quantizer = Float8CurrentScalingQuantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - device="cuda", - ) + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - qkv_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E4M3, - rowwise=True, - columnwise=True, + fp8_recipe = MXFP8BlockScaling( + fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha ) - qkv_quantizer.optimize_for_gemm = True - qkv_quantizer.internal = False - dout_quantizer = MXFP8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - rowwise=True, - columnwise=True, - ) - dout_quantizer.optimize_for_gemm = True - dout_quantizer.internal = False - qkv_layout = "_".join([qkv_format] * 3) - q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] - if fp8_mha: - q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) - for x in [q, k, v]: - x.requires_grad = True - - if config.attn_bias_type not in ["no_bias", "alibi"]: - bias_shape_map = { - "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), - "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), - "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), - "bhss": ( - config.batch_size, - config.num_heads, - config.max_seqlen_q, - config.max_seqlen_kv, - ), - "111s": (1, 1, 1, config.max_seqlen_kv), - } - attn_bias_shape = bias_shape_map.get(config.bias_shape) - if attn_bias_shape is None: - assert False, f"cuDNN does not support {config.bias_shape=}" - bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() - # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 - # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported - bias.requires_grad = True if config.bias_shape != "111s" else False - else: - bias = None - ############ run without CP ############ - logging.info(f"[Rank {rank}] Run without context parallelism") - if dtype == "fp8": - fp8_context = autocast( - enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, + ).cuda() + if not is_training: + core_attn.eval() + if is_training and config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() + if scaling_mode == "delayed": + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + if scaling_mode == "current": + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) + if scaling_mode == "mxfp8": + qkv_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + qkv_quantizer.optimize_for_gemm = True + qkv_quantizer.internal = False + dout_quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + rowwise=True, + columnwise=True, + ) + dout_quantizer.optimize_for_gemm = True + dout_quantizer.internal = False + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v, qkv_layout, _ = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True + + if config.attn_bias_type not in ["no_bias", "alibi"]: + bias_shape_map = { + "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), + "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), + "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), + "bhss": ( + config.batch_size, + config.num_heads, + config.max_seqlen_q, + config.max_seqlen_kv, + ), + "111s": (1, 1, 1, config.max_seqlen_kv), + } + attn_bias_shape = bias_shape_map.get(config.bias_shape) + if attn_bias_shape is None: + assert False, f"cuDNN does not support {config.bias_shape=}" + bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() + # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 + # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported + bias.requires_grad = True if config.bias_shape != "111s" else False + else: + bias = None + + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") + if dtype == "fp8": + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) + else: + fp8_context = nullcontext() + max_logit = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out = core_attn( + q, + k, + v, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, + ) + if config.return_max_logit: + out, max_logit = out + if is_training: + if fp8_bwd and fp8_mha: + dout_fp8 = dout_quantizer(dout) + out.backward(dout_fp8) + else: + out.backward(dout) + if is_training: + dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None + d_softmax_offset = ( + core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None + ) + else: + dq, dk, dv, dbias = None, None, None, None + d_softmax_offset = None + + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") + + # set up inputs + q_, k_, v_, dout_, *rest = [ + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) + ] + bias_ = rest[0] if len(rest) else None + if qkv_format == "bshd" or qkv_format == "sbhd": + seq_dim = qkv_format.index("s") + q_, k_, v_, dout_ = [ + x.view( + *x.shape[:seq_dim], + 2 * world_size, + x.shape[seq_dim] // (2 * world_size), + *x.shape[(seq_dim + 1) :], ) + for x in [q_, k_, v_, dout_] + ] + seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) + q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] + q_, k_, v_, dout_ = [ + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + for x in [q_, k_, v_, dout_] + ] + elif qkv_format == "thd": + seq_idx_q = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, q_.shape[0], world_size, rank + ) + seq_idx_kv = tex.thd_get_partitioned_indices( + cu_seqlens_kv_padded, k_.shape[0], world_size, rank + ) + q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] + k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) + if is_training: + q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] + if bias_ is not None: + ndim = bias_.ndim + seq_q_dim = ndim - 2 + if qkv_format == "thd": + bias_seq_idx = seq_idx_q + else: + bias_seq_idx = seq_idx + shape_before_seq = bias_.shape[:seq_q_dim] + seq_q_size = bias_.shape[seq_q_dim] + seq_kv_size = bias_.shape[-1] + if seq_q_size == 1: + # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s + bias_.requires_grad = False + # Bias is broadcast, no need to partition along sequence dimension + pass else: - fp8_context = nullcontext() - max_logit = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out = core_attn( - q, - k, - v, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + bias_ = bias_.view( + *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size ) - if config.return_max_logit: - out, max_logit = out - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8 = dout_quantizer(dout) - out.backward(dout_fp8) - else: - out.backward(dout) + bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) + bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) + bias_.requires_grad = True + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = autocast( + enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group + ) + else: + fp8_context = nullcontext() + + # run attention + max_logit_ = None + with fp8_context: + # q, k, v, out in FP8; dout in F16 + out_ = core_attn( + q_, + k_, + v_, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias=bias_, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, + ) + if config.return_max_logit: + out_, max_logit_ = out_ if is_training: - dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None - d_softmax_offset = ( - core_attn.softmax_offset.grad if config.softmax_type != "vanilla" else None - ) - else: - dq, dk, dv, dbias = None, None, None, None - d_softmax_offset = None - - ############ run with CP ############ - logging.info(f"[Rank {rank}] Run with context parallelism") - - # set up inputs - q_, k_, v_, dout_, *rest = [ - x.clone().detach() - for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) - ] - bias_ = rest[0] if len(rest) else None - if qkv_format == "bshd" or qkv_format == "sbhd": - seq_dim = qkv_format.index("s") - q_, k_, v_, dout_ = [ + if fp8_bwd and fp8_mha: + dout_fp8_ = dout_quantizer(dout_) + out_.backward(dout_fp8_) + else: + out_.backward(dout_) + if is_training: + dq_, dk_, dv_, dbias_ = ( + q_.grad, + k_.grad, + v_.grad, + bias_.grad if bias_ is not None else None, + ) + d_softmax_offset_ = ( + core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None + ) + else: + dq_, dk_, dv_, dbias_ = None, None, None, None + d_softmax_offset_ = None + + # get outputs + tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] + names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] + if fp8_mha: + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[5] = tensors_to_deq + for i, tensor in enumerate(tensors): + # dbias/dbias_ could be None, so skip check for it + if tensor is not None: + assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" + assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors + + ############ compare results between CP and no-CP ############ + if qkv_format == "bshd" or qkv_format == "sbhd": + if is_training: + dq, dk, dv, out = [ x.view( *x.shape[:seq_dim], 2 * world_size, x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [q_, k_, v_, dout_] + for x in [dq, dk, dv, out] ] - seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) - q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] - q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - for x in [q_, k_, v_, dout_] + dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] + dq_, dk_, dv_, out_ = [ + x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) + for x in [dq_, dk_, dv_, out_] ] - elif qkv_format == "thd": - seq_idx_q = tex.thd_get_partitioned_indices( - cu_seqlens_q_padded, q_.shape[0], world_size, rank - ) - seq_idx_kv = tex.thd_get_partitioned_indices( - cu_seqlens_kv_padded, k_.shape[0], world_size, rank - ) - q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] - k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] - if scaling_mode == "delayed": - qkv_quantizer.scale.fill_(1.0) - qkv_quantizer.amax.fill_(0.0) - dout_quantizer.scale.fill_(1.0) - dout_quantizer.amax.fill_(0.0) - if fp8_mha: - q_, k_, v_, qkv_layout, _ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) - if is_training: - q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] - if bias_ is not None: - ndim = bias_.ndim - seq_q_dim = ndim - 2 - if qkv_format == "thd": - bias_seq_idx = seq_idx_q - else: - bias_seq_idx = seq_idx - shape_before_seq = bias_.shape[:seq_q_dim] - seq_q_size = bias_.shape[seq_q_dim] - seq_kv_size = bias_.shape[-1] - if seq_q_size == 1: - # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s - bias_.requires_grad = False - # Bias is broadcast, no need to partition along sequence dimension - pass - else: - bias_ = bias_.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + if dbias is not None and dbias_ is not None: + ndim = dbias.ndim + # Query seq is at dim -2 + seq_q_dim = ndim - 2 + shape_before_seq = dbias.shape[:seq_q_dim] + seq_q_size = dbias.shape[seq_q_dim] + seq_kv_size = dbias.shape[-1] + # Reshape to split seq_q dimension + dbias = dbias.view( + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size, + ) + # Index select on the newly created dimension (now at position seq_q_dim) + dbias = dbias.index_select(seq_q_dim, seq_idx) + dbias_ = dbias_.view( + *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size ) - bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) - bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) - bias_.requires_grad = True - # set up environment - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - if config.softmax_type != "vanilla": - core_attn.softmax_offset.grad.zero_() - if dtype == "fp8": - core_attn.fp8_initialized = False - core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast( - enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group - ) else: - fp8_context = nullcontext() - - # run attention - max_logit_ = None - with fp8_context: - # q, k, v, out in FP8; dout in F16 - out_ = core_attn( - q_, - k_, - v_, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias=bias_, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - fp8_output=fp8_mha, + # Forward-only: reshape only out/out_ for comparison + out = out.view( + *out.shape[:seq_dim], + 2 * world_size, + out.shape[seq_dim] // (2 * world_size), + *out.shape[(seq_dim + 1) :], ) - if config.return_max_logit: - out_, max_logit_ = out_ - if is_training: - if fp8_bwd and fp8_mha: - dout_fp8_ = dout_quantizer(dout_) - out_.backward(dout_fp8_) - else: - out_.backward(dout_) - if is_training: - dq_, dk_, dv_, dbias_ = ( - q_.grad, - k_.grad, - v_.grad, - bias_.grad if bias_ is not None else None, + out = out.index_select(seq_dim, seq_idx) + out_ = out_.view( + *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] ) - d_softmax_offset_ = ( - core_attn.softmax_offset.grad.clone() if config.softmax_type != "vanilla" else None + + elif qkv_format == "thd": + if is_training: + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size + cu_seqlens_q = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) - else: - dq_, dk_, dv_, dbias_ = None, None, None, None - d_softmax_offset_ = None - - # get outputs - tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] - names = ["out", "dq", "dk", "dv", "dbias", "out_cp", "dq_cp", "dk_cp", "dv_cp", "dbias_cp"] - if fp8_mha: - tensors_to_deq = [out, out_] if not fp8_bwd else tensors - for i, tensor in enumerate(tensors_to_deq): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - tensors_to_deq[i] = tensor.dequantize() - if not fp8_bwd: - tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): - # dbias/dbias_ could be None, so skip check for it - if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" - out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors - - ############ compare results between CP and no-CP ############ - if qkv_format == "bshd" or qkv_format == "sbhd": - if is_training: - dq, dk, dv, out = [ - x.view( - *x.shape[:seq_dim], - 2 * world_size, - x.shape[seq_dim] // (2 * world_size), - *x.shape[(seq_dim + 1) :], - ) - for x in [dq, dk, dv, out] - ] - dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] - dq_, dk_, dv_, out_ = [ - x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [dq_, dk_, dv_, out_] - ] - if dbias is not None and dbias_ is not None: - ndim = dbias.ndim - # Query seq is at dim -2 - seq_q_dim = ndim - 2 - shape_before_seq = dbias.shape[:seq_q_dim] - seq_q_size = dbias.shape[seq_q_dim] - seq_kv_size = dbias.shape[-1] - # Reshape to split seq_q dimension - dbias = dbias.view( - *shape_before_seq, - 2 * world_size, - seq_q_size // (2 * world_size), - seq_kv_size, + cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q + num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] + for x in [dq, out, dq_, out_]: + assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_q[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_q_padded[b + 1] - num_pads_q[b] + ) : cu_seqlens_q_padded[b + 1] + ] + ).item() + == 0 ) - # Index select on the newly created dimension (now at position seq_q_dim) - dbias = dbias.index_select(seq_q_dim, seq_idx) - dbias_ = dbias_.view( - *shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size + cu_seqlens_kv = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True + ) + cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv + num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] + for x in [dk, dv, dk_, dv_]: + assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 + for b in range(config.batch_size): + assert ( + num_pads_kv[b] == 0 + or torch.count_nonzero( + x[ + ( + cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] + ) : cu_seqlens_kv_padded[b + 1] + ] + ).item() + == 0 ) - else: - # Forward-only: reshape only out/out_ for comparison - out = out.view( - *out.shape[:seq_dim], - 2 * world_size, - out.shape[seq_dim] // (2 * world_size), - *out.shape[(seq_dim + 1) :], - ) - out = out.index_select(seq_dim, seq_idx) - out_ = out_.view( - *out_.shape[:seq_dim], 2, out_.shape[seq_dim] // 2, *out_.shape[(seq_dim + 1) :] - ) - - elif qkv_format == "thd": - if is_training: - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size - cu_seqlens_q = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True - ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_q_padded[b + 1] - num_pads_q[b] - ) : cu_seqlens_q_padded[b + 1] - ] - ).item() - == 0 + else: + # Forward-only: reshape only out/out_ for comparison + out = out.index_select(0, seq_idx_q).contiguous() + out_ = out_ + + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i] and "max_logit" not in names[i]: + if qkv_format == "bshd": + # Compare the two sequence chunks separately + # Compare dbias + if names[i] == "dbias": + # Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size - cu_seqlens_kv = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True - ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) - else: - # Forward-only: reshape only out/out_ for comparison - out = out.index_select(0, seq_idx_q).contiguous() - out_ = out_ - - atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] - tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] - names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] - names_cp = [x + "_cp" for x in names] - names_no_cp = [x + "_no_cp" for x in names] - is_fp8 = dtype == "fp8" - for i, t in enumerate(tensors_no_cp): - if t is not None: - if "softmax_offset" not in names[i] and "max_logit" not in names[i]: - if qkv_format == "bshd": - # Compare the two sequence chunks separately - # Compare dbias - if names[i] == "dbias": - # Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 1 (the split sequence dimension) - compare_and_assert( - t[:, 0], - tensors_cp[i][:, 0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[:, 1], - tensors_cp[i][:, 1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - elif qkv_format == "sbhd": - # Compare the two sequence chunks separately - # Compare dbias (same as BSHD) - if names[i] == "dbias": - # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) - seq_q_dim_bias = 2 - ndim_bias = t.ndim - slice_0 = [slice(None)] * ndim_bias - slice_0[seq_q_dim_bias] = 0 - slice_1 = [slice(None)] * ndim_bias - slice_1[seq_q_dim_bias] = 1 - compare_and_assert( - t[tuple(slice_0)], - tensors_cp[i][tuple(slice_0)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[tuple(slice_1)], - tensors_cp[i][tuple(slice_1)], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - # Compare Q/K/V/out - else: - # Compare the two chunks along dimension 0 (the split sequence dimension) - compare_and_assert( - t[0], - tensors_cp[i][0], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - compare_and_assert( - t[1], - tensors_cp[i][1], - names_no_cp[i], - names_cp[i], - atol, - rtol, - rmse_tol, - is_fp8, - ) - elif qkv_format == "thd": + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 1 (the split sequence dimension) compare_and_assert( - t, - tensors_cp[i], + t[:, 0], + tensors_cp[i][:, 0], names_no_cp[i], names_cp[i], atol, @@ -803,35 +724,109 @@ def run_dpa_with_cp( rmse_tol, is_fp8, ) - else: + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + # Compare the two sequence chunks separately + # Compare dbias (same as BSHD) + if names[i] == "dbias": + # Same as bshd: Compare the two chunks along dimension 2 (the split sequence dimension) + seq_q_dim_bias = 2 + ndim_bias = t.ndim + slice_0 = [slice(None)] * ndim_bias + slice_0[seq_q_dim_bias] = 0 + slice_1 = [slice(None)] * ndim_bias + slice_1[seq_q_dim_bias] = 1 + compare_and_assert( + t[tuple(slice_0)], + tensors_cp[i][tuple(slice_0)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[tuple(slice_1)], + tensors_cp[i][tuple(slice_1)], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + # Compare Q/K/V/out + else: + # Compare the two chunks along dimension 0 (the split sequence dimension) + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + t, + tensors_cp[i], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) - logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - - finally: - # Destroy only groups WE created. In pool mode with shared groups, - # cp_comm_group / cp_comm_sub_groups are owned by the pool runner and - # destroyed at pool shutdown — touching them here would tear down the - # cache that subsequent cases want to reuse. - if not _reusing_pool_groups: - if cp_comm_group is not None: - try: - dist.destroy_process_group(cp_comm_group) - except Exception: - pass - for g in cp_comm_sub_groups: - try: - dist.destroy_process_group(g) - except Exception: - pass - # Tear down the main PG only in single-shot mode. In pool mode the - # pool runner owns the main PG and destroys it at shutdown. - if not _pool_managed_pg: + else: + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + + # Teardown on the success path. Pool mode: cp_comm_group / cp_comm_sub_groups + # point at pool-shared groups owned by the pool runner (which destroys them + # at pool shutdown), and the main PG is also pool-owned — both branches + # below are no-ops. Single-shot mode: destroy what we created here. If the + # body above raises, we skip this — the subprocess dies at function return + # and NCCL releases the communicators with the process. + if not _reusing_pool_groups: + if cp_comm_group is not None: try: - dist.destroy_process_group() + dist.destroy_process_group(cp_comm_group) except Exception: pass + for g in cp_comm_sub_groups: + try: + dist.destroy_process_group(g) + except Exception: + pass + if not _pool_managed_pg: + try: + dist.destroy_process_group() + except Exception: + pass def main(**kwargs): From adb84af70a8e3a8e66beef20382194cc5442f44c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 22:45:53 +0000 Subject: [PATCH 15/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/run_attention_with_cp.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 1b715d892b..9f6b4944e6 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -285,9 +285,7 @@ def run_dpa_with_cp( if scaling_mode == "current": fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) if scaling_mode == "mxfp8": - fp8_recipe = MXFP8BlockScaling( - fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha - ) + fp8_recipe = MXFP8BlockScaling(fp8_format=Format.E4M3, fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -391,9 +389,7 @@ def run_dpa_with_cp( ############ run without CP ############ logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": - fp8_context = autocast( - enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group - ) + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() max_logit = None @@ -451,8 +447,7 @@ def run_dpa_with_cp( seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device) q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]] q_, k_, v_, dout_ = [ - x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - for x in [q_, k_, v_, dout_] + x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_] ] elif qkv_format == "thd": seq_idx_q = tex.thd_get_partitioned_indices( @@ -509,9 +504,7 @@ def run_dpa_with_cp( if dtype == "fp8": core_attn.fp8_initialized = False core_attn.fp8_meta_tensors_initialized = False - fp8_context = autocast( - enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group - ) + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() @@ -638,9 +631,9 @@ def run_dpa_with_cp( num_pads_q[b] == 0 or torch.count_nonzero( x[ - ( - cu_seqlens_q_padded[b + 1] - num_pads_q[b] - ) : cu_seqlens_q_padded[b + 1] + (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ + b + 1 + ] ] ).item() == 0