diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8dfea644a5..9f6b4944e6 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import copy import os import sys import logging @@ -29,6 +30,15 @@ ) from utils import ModelConfig, compare_and_assert +# Pool mode (NVTE_CP_POOL_PG=1) only: shared CP collective groups, created once +# per pool by run_attention_with_cp_pool.main() and reused across every case in +# that pool. world_size and the rank set don't change per case, so re-creating +# these per call would be wasted NCCL setup (~50-100 ms each). Single-shot +# subprocess mode leaves these None / [] and run_dpa_with_cp creates/destroys +# its own groups inline. +_pool_cp_comm_group = None +_pool_cp_comm_sub_groups: list = [] + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -209,10 +219,13 @@ def run_dpa_with_cp( os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] + # Deep-copy: the module-level dict is shared across pool cases; the + # THD branch below rewrites attn_mask_type in place, which would + # otherwise leak into subsequent cases reusing the same model key. + config = copy.deepcopy(model_configs_flash_attn[model]) if kernel_backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] + config = copy.deepcopy(model_configs_fused_attn[model]) assert config.attn_mask_type in [ "causal", "no_mask", @@ -226,6 +239,9 @@ def run_dpa_with_cp( # set up distributed group rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) + # When NVTE_CP_POOL_PG=1, the pool runner owns the lifecycle of the main + # process group across many cases; here we only reuse it. + _pool_managed_pg = os.getenv("NVTE_CP_POOL_PG", "0") == "1" if dist.is_initialized(): world_size = dist.get_world_size() rank = dist.get_rank() @@ -234,25 +250,35 @@ def run_dpa_with_cp( device = rank % device_count torch.cuda.set_device(device) logging.info(f"[Rank {rank}] Setup: world_size {world_size}") - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - - # set up communication group for CP + if not _pool_managed_pg: + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + # Set up communication group for CP. In pool mode, the pool worker has + # already pre-created world-scoped and a2a+p2p sub-groups once and stashed + # them in module-level pointers; we reuse those and the pool destroys them + # at shutdown. In single-shot mode we create them per call and destroy in + # the finally below. cp_comm_ranks = range(world_size) assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert world_size % 2 == 0, ( - "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" - " = 2." - ) - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - + _reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None + cp_comm_group = None + cp_comm_sub_groups: list = [] + if _reusing_pool_groups: + cp_comm_group = _pool_cp_comm_group + cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else [] + else: + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has" + " cp_size = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) if dtype == "fp8": if scaling_mode == "delayed": fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) @@ -564,7 +590,10 @@ def run_dpa_with_cp( seq_kv_size = dbias.shape[-1] # Reshape to split seq_q dimension dbias = dbias.view( - *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size + *shape_before_seq, + 2 * world_size, + seq_q_size // (2 * world_size), + seq_kv_size, ) # Index select on the newly created dimension (now at position seq_q_dim) dbias = dbias.index_select(seq_q_dim, seq_idx) @@ -754,7 +783,14 @@ def run_dpa_with_cp( ) elif qkv_format == "thd": compare_and_assert( - t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + t, + tensors_cp[i], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, ) else: compare_and_assert( @@ -762,8 +798,28 @@ def run_dpa_with_cp( ) logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") - # destroy distribution group - dist.destroy_process_group() + # Teardown on the success path. Pool mode: cp_comm_group / cp_comm_sub_groups + # point at pool-shared groups owned by the pool runner (which destroys them + # at pool shutdown), and the main PG is also pool-owned — both branches + # below are no-ops. Single-shot mode: destroy what we created here. If the + # body above raises, we skip this — the subprocess dies at function return + # and NCCL releases the communicators with the process. + if not _reusing_pool_groups: + if cp_comm_group is not None: + try: + dist.destroy_process_group(cp_comm_group) + except Exception: + pass + for g in cp_comm_sub_groups: + try: + dist.destroy_process_group(g) + except Exception: + pass + if not _pool_managed_pg: + try: + dist.destroy_process_group() + except Exception: + pass def main(**kwargs): diff --git a/tests/pytorch/attention/run_attention_with_cp_pool.py b/tests/pytorch/attention/run_attention_with_cp_pool.py new file mode 100644 index 0000000000..767437cb04 --- /dev/null +++ b/tests/pytorch/attention/run_attention_with_cp_pool.py @@ -0,0 +1,198 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Persistent worker for batched CP attention tests. + +Launched ONCE per (pytest session, world_size) by torchrun. All ranks init +NCCL, then enter a dispatch loop: + + rank 0: + read one JSON request line from stdin + broadcast it to all ranks + all ranks: + call run_dpa_with_cp(**kwargs) — the same work function the + per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the + function reuses our PG instead of re-initing it + torch.cuda.empty_cache() per case + all ranks gather (ok, error_msg) to rank 0 + rank 0: + write one JSON response line to stdout + +Protocol (line-delimited JSON over rank-0 stdio): + request : {"op": "run", "kwargs": {...}} + {"op": "shutdown"} + response: {"ok": true} + {"ok": false, "error": "first failing rank's traceback"} +""" +import json +import os +import sys +import time +import traceback + +import torch +import torch.distributed as dist + +# Make sibling modules importable when launched directly. +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from run_attention_with_cp import run_dpa_with_cp +from transformer_engine.pytorch.quantization import FP8GlobalStateManager + + +def _recv_request(rank: int) -> dict: + box = [None] + if rank == 0: + line = sys.stdin.readline() + box[0] = {"op": "shutdown"} if not line else json.loads(line) + dist.broadcast_object_list(box, src=0) + return box[0] + + +# Sentinel prefix on every response line so the parent reader can skip any +# stdout chatter that gets interleaved (torchrun status, library prints, even +# non-rank-0 stray output — torchrun ranks share rank 0's stdout fd). +_RESP_PREFIX = "[CP_POOL_RESP] " + + +def _send_response(rank: int, payload: dict) -> None: + if rank == 0: + sys.stdout.write(_RESP_PREFIX + json.dumps(payload) + "\n") + sys.stdout.flush() + + +def _reset_between_cases() -> None: + """Drop state that would otherwise cascade across cases. + + Matches the per-case startup of the single-shot worker + (``_run_single_config`` on the per-case-subprocess branch): identical RNG + seed at the start of every case, FP8 state cleared, allocator clean. + ``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN`` + unconditionally and pops the other transient env vars itself, so no + explicit pop is needed here. + """ + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + FP8GlobalStateManager.reset() + torch.cuda.empty_cache() + + +_case_counter = 0 + + +def _run_one(req: dict, rank: int) -> tuple[bool, str]: + global _case_counter + op = req["op"] + if op != "run": + return False, f"unknown op: {op}" + # Reset BEFORE the case so the first case also starts from a known RNG seed + # and clean FP8 state — same as the single-shot worker's per-process startup. + _reset_between_cases() + t0 = time.monotonic() + ok = True + err = "" + try: + run_dpa_with_cp(**req.get("kwargs", {})) + except Exception: + ok = False + err = f"[Rank {rank}] {traceback.format_exc()}" + wall = time.monotonic() - t0 + # Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1. + # Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution. + if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")): + _case_counter += 1 + sys.stderr.write( + f"[POOL-TIMING] case_idx={_case_counter} " + f"world_size={int(os.environ.get('WORLD_SIZE', 0))} " + f"wall_s={wall:.3f} ok={ok}\n" + ) + sys.stderr.flush() + return ok, err + + +def _create_cp_comm_groups(rank: int, world_size: int) -> tuple: + """Pre-create the CP collective groups for this pool. + + world_size and the rank set are constant for the lifetime of one pool, so + the world group and the a2a+p2p sub-groups are deterministic. Creating + them once here and reusing them across every case eliminates ~50-100 ms + of NCCL setup per case (cyanguwa's review feedback on PR #2993). + + Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is + empty when world_size is too small to support a2a+p2p (needs an even + world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to + such a pool anyway. + """ + world_group = dist.new_group(range(world_size), backend="nccl") + sub_groups: list = [] + if world_size >= 4 and world_size % 2 == 0: + # Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along + # axis 0, plus 2 stride-2 groups along axis 1. + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + sub_groups.append(sub_group) + return world_group, sub_groups + + +def main() -> None: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(rank % torch.cuda.device_count()) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["NVTE_CP_POOL_PG"] = "1" + + # Stash pool-shared CP groups on the run_attention_with_cp module so + # run_dpa_with_cp can read them per case. Imported here (after the env var + # is set) to keep import-time side effects minimal. + import run_attention_with_cp as _rac + + _rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups( + rank, world_size + ) + + try: + while True: + req = _recv_request(rank) + if req.get("op") == "shutdown": + break + + ok, msg = _run_one(req, rank) + + gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] + # gather_object is itself a collective synchronization point — if + # every rank reached it, none is ahead. No extra barrier needed. + dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) + + if rank == 0: + all_ok = all(o for o, _ in gathered) + if all_ok: + _send_response(rank, {"ok": True}) + else: + first_err = next(m for o, m in gathered if not o) + _send_response(rank, {"ok": False, "error": first_err}) + finally: + # Tear down pool-shared CP groups before the main PG (NCCL requires + # sub-groups to be destroyed first). Each destroy is independently + # guarded so a wedged communicator on one group doesn't leak the rest. + if _rac._pool_cp_comm_group is not None: + try: + dist.destroy_process_group(_rac._pool_cp_comm_group) + except Exception: + pass + for g in _rac._pool_cp_comm_sub_groups: + try: + dist.destroy_process_group(g) + except Exception: + pass + _rac._pool_cp_comm_group = None + _rac._pool_cp_comm_sub_groups = [] + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 23d1bfdd85..9356df20cf 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -2,12 +2,17 @@ # # See LICENSE for license information. +import json import os +import select import subprocess import sys +import threading +import time import pathlib import logging import copy +from collections import deque import pytest import torch from transformer_engine.pytorch import ( @@ -24,7 +29,7 @@ _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import ModelConfig, get_available_attention_backends, run_distributed +from utils import ModelConfig, get_available_attention_backends pytest_logging_level = logging.getLevelName(logging.root.level) @@ -39,7 +44,7 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) -test_essential = True +test_essential = False model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) @@ -60,19 +65,188 @@ } -def get_bash_arguments(num_gpus_per_node, **kwargs): - args = [ - "python3", - "-m", - "torch.distributed.launch", - "--nproc-per-node=" + str(num_gpus_per_node), - ] - te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") - args.append(script_path) - for k, v in kwargs.items(): - args.append(f"{k}={v}") - return args +# --- Persistent pool runner ----------------------------------------------- +# +# Each (world_size) is served by one long-lived torchrun running +# run_attention_with_cp_pool.py. We submit one work item per pytest case over +# rank-0 stdin and read one JSON response from rank-0 stdout. Replaces +# the per-case torchrun launch path; init/destroy NCCL once per pool, not +# once per case. +# +# Why two pool sizes: cp_comm_type="a2a+p2p" needs world_size=4; everything +# else uses world_size=2. We can't resize an active PG, so we keep one pool +# per world_size and route each case to the right one. Pools are spawned +# lazily on first use so a session that only exercises 2-GPU cases never +# pays the 4-GPU init cost. + +# Per-case wall is ~5 s p50 / ~15 s max on H100 (test_essential=True). +# 90 s gives ~6× headroom over the slowest observed case while still detecting +# a genuine hang within ~1.5 min instead of ~10 min. Override with the env var +# if a slower machine or expanded test matrix needs more room. +POOL_SUBMIT_TIMEOUT_SEC = float(os.getenv("NVTE_CP_POOL_TIMEOUT_SEC", "90")) + + +class PoolWorker: + # Crash-path AssertionErrors include the tail of the worker's stderr so CI + # JUnit XML shows the actual failure cause (NCCL/CUDA messages, Python + # traceback) inline with the failing test, not just "pool worker died". + # Equivalent in spirit to PR #2965's run_distributed() stderr capture. + _STDERR_BUFFER_LINES = 200 # ring cap (~40 KB ceiling) + _STDERR_TAIL_CHARS = 4000 # how much to attach to the AssertionError + # Worker prefixes every JSON response with this sentinel; the reader skips + # any other stdout content (torchrun status, library prints, rank>0 stray + # output that lands on the shared stdout fd) so the JSON protocol can't + # be corrupted by interleaved chatter. + _RESP_PREFIX = "[CP_POOL_RESP] " + + def __init__(self, world_size: int): + self.world_size = world_size + self.proc: subprocess.Popen | None = None + self._stderr_buf: deque[str] = deque(maxlen=self._STDERR_BUFFER_LINES) + + def _spawn(self) -> None: + te_path = os.getenv("TE_PATH", "/opt/transformerengine") + worker = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp_pool.py") + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc-per-node={self.world_size}", + "--standalone", # picks a free rendezvous port + worker, + ] + # stderr=PIPE so we can capture the tail for crash-path AssertionErrors; + # a daemon drainer thread also echoes each line to sys.stderr so pytest's + # per-test stderr capture still works. The thread is daemon, so it + # self-terminates when the pipe closes — no tracking needed. + self.proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + env={**os.environ, "PYTHONUNBUFFERED": "1"}, + ) + self._stderr_buf.clear() + threading.Thread(target=self._drain_stderr, daemon=True).start() + + def _drain_stderr(self) -> None: + proc = self.proc + if proc is None or proc.stderr is None: + return + for line in iter(proc.stderr.readline, ""): + self._stderr_buf.append(line) + sys.stderr.write(line) + sys.stderr.flush() + + def _diag(self, msg: str) -> str: + tail = "".join(self._stderr_buf)[-self._STDERR_TAIL_CHARS :] + if not tail.strip(): + return msg + return f"{msg}\n\n--- pool worker stderr (tail) ---\n{tail}" + + def _ensure_alive(self) -> None: + if self.proc is None or self.proc.poll() is not None: + self._spawn() + + def _kill(self) -> None: + if self.proc and self.proc.poll() is None: + self.proc.terminate() + try: + self.proc.wait(timeout=5) + except subprocess.TimeoutExpired: + self.proc.kill() + self.proc.wait() + self.proc = None + + def submit(self, kwargs: dict, timeout: float = POOL_SUBMIT_TIMEOUT_SEC) -> None: + self._ensure_alive() + req = json.dumps({"op": "run", "kwargs": kwargs}) + "\n" + try: + self.proc.stdin.write(req) + self.proc.stdin.flush() + except BrokenPipeError: + msg = self._diag("pool worker died before request could be sent") + self._kill() + raise AssertionError(msg) + + # Read lines until we see our sentinel-prefixed JSON response. Anything + # else (torchrun status output, library prints, rank>0 stray writes that + # land on the shared stdout fd) is echoed to stderr and skipped — it + # would corrupt json.loads if we tried to decode it. The select() + # deadline alone bounds the loop; no separate line counter needed. + deadline = time.monotonic() + timeout + while True: + remaining = max(0.0, deadline - time.monotonic()) + # select() on a pipe fd is Linux/macOS only — on Windows the select + # module only accepts sockets. CP attention tests run on Linux GPU + # hosts so this is fine; flag if portability is ever needed. + ready, _, _ = select.select([self.proc.stdout], [], [], remaining) + if not ready: + msg = self._diag( + f"pool worker (world_size={self.world_size}) timed out after " + f"{timeout}s; pool killed and will be respawned for the next case" + ) + self._kill() + raise AssertionError(msg) + + line = self.proc.stdout.readline() + if not line: + msg = self._diag("pool worker died mid-request") + self._kill() + raise AssertionError(msg) + + if line.startswith(self._RESP_PREFIX): + resp_line = line[len(self._RESP_PREFIX) :] + break + # Non-protocol stdout — echo to stderr for CI visibility, keep looking. + sys.stderr.write(line) + sys.stderr.flush() + + resp = json.loads(resp_line) + if not resp["ok"]: + # gather_object already carries the full per-rank traceback in + # resp["error"]; no need to also attach stderr (would duplicate). + raise AssertionError(resp["error"]) + + def shutdown(self) -> None: + if self.proc and self.proc.poll() is None: + try: + self.proc.stdin.write(json.dumps({"op": "shutdown"}) + "\n") + self.proc.stdin.flush() + self.proc.stdin.close() + except BrokenPipeError: + pass + try: + self.proc.wait(timeout=30) + except subprocess.TimeoutExpired: + self._kill() + self.proc = None + + +@pytest.fixture(scope="session") +def cp_pool(): + """Returns a callable: cp_pool(world_size) -> PoolWorker.""" + pools: dict[int, PoolWorker] = {} + + def _get(world_size: int) -> PoolWorker: + if world_size > torch.cuda.device_count(): + pytest.skip(f"Test requires {world_size} GPUs, but found {torch.cuda.device_count()}") + if world_size not in pools: + pools[world_size] = PoolWorker(world_size) + return pools[world_size] + + yield _get + for p in pools.values(): + p.shutdown() + + +def _submit(pool: PoolWorker, **kwargs) -> None: + # run_dpa_with_cp expects all kwargs as strings (it does e.g. + # `fp8_bwd == "True"`), matching the old argv-based path. Serialize + # everything as strings so we don't accidentally change semantics. + pool.submit({k: str(v) for k, v in kwargs.items()}) dtypes = ["bf16", "fp16"] @@ -91,10 +265,9 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): +def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 - if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + pool = cp_pool(num_gpus) config = model_configs_flash_attn[model] config.context_parallel = True @@ -140,16 +313,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if not flash_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FlashAttention", - cp_comm_type=cp_comm_type, - log_level=pytest_logging_level, - ), + _submit( + pool, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FlashAttention", + cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ) @@ -274,15 +445,23 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("scaling_mode", [None, "delayed", "current", "mxfp8"]) @pytest.mark.parametrize("f16_O", [True, False]) def test_cp_with_fused_attention( - dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O + cp_pool, + dtype, + model, + qkv_format, + cp_comm_type, + fp8_bwd, + fp8_mha, + fp8_dpa, + scaling_mode, + f16_O, ): config = model_configs_fused_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 - if num_gpus > torch.cuda.device_count(): - pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()} GPUs.") + pool = cp_pool(num_gpus) if get_device_compute_capability() < (9, 0) and qkv_format == "thd": pytest.skip("Only sm90+ architectures support THD format!") @@ -404,21 +583,24 @@ def test_cp_with_fused_attention( if not fused_attn_supported: pytest.skip("No attention backend available.") - run_distributed( - get_bash_arguments( - num_gpus_per_node=num_gpus, - dtype=dtype, - model=model, - qkv_format=qkv_format, - kernel_backend="FusedAttention", - cp_comm_type=cp_comm_type, - fp8_bwd=fp8_bwd, - fp8_dpa=fp8_dpa, - fp8_mha=fp8_mha, - scaling_mode=scaling_mode, - f16_O=f16_O, - is_training=is_training, - deterministic=_deterministic, - log_level=pytest_logging_level, - ), + if _deterministic and config.softmax_type != "vanilla": + pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") + if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: + pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + + _submit( + pool, + dtype=dtype, + model=model, + qkv_format=qkv_format, + kernel_backend="FusedAttention", + cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, + fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + is_training=is_training, + deterministic=_deterministic, + log_level=pytest_logging_level, ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 3db0417bdb..35684625a5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3277,6 +3277,12 @@ def forward( elif o_format == "sbhd": out_f16[i - 1].copy_(out_per_step[i - 1]) if return_max_logit: + # max_logit_per_step[i-1] was written on flash_attn_streams[i-1] + # (cp_stream for i-1=1). The torch.maximum below runs on the + # default stream, so without this wait the read can race with + # the write. The post-loop wait_stream(cp_stream) is too late. + # No-op when flash_attn_streams[i-1] is current_stream(). + torch.cuda.current_stream().wait_stream(flash_attn_streams[i - 1]) max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream)