Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
ccf7321
Batch CP attention tests via a persistent NCCL pool
sudhakarsingh27 May 12, 2026
59609ac
Reset FP8 state and barrier between pool cases
sudhakarsingh27 May 12, 2026
73e8cef
Deep-copy ModelConfig in run_dpa_with_cp
sudhakarsingh27 May 12, 2026
311137c
Skip deterministic configs incompatible with FusedAttention
sudhakarsingh27 May 14, 2026
49878d6
Reseed RNG between pool cases; reset before, not after
sudhakarsingh27 May 14, 2026
385e966
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
86b334b
Robustify pool: capture worker stderr, tighten timeout, add timing knob
sudhakarsingh27 May 14, 2026
ae5298c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
e162a9e
Address PR review: NCCL leak, stdout protocol, Windows note
sudhakarsingh27 May 14, 2026
169be82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2026
557bd80
[PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward
sudhakarsingh27 May 15, 2026
4815883
Address PR review (R2): drop dead code in pool worker and PoolWorker
sudhakarsingh27 May 15, 2026
d15bfce
Address PR review (items 2+3): reuse CP groups across pool cases
sudhakarsingh27 May 15, 2026
dd1d802
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into sudha…
sudhakarsingh27 May 15, 2026
87c67ac
Flatten try/finally wrap in run_dpa_with_cp
sudhakarsingh27 May 15, 2026
adb84af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 79 additions & 23 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.

import copy
import os
import sys
import logging
Expand Down Expand Up @@ -29,6 +30,15 @@
)
from utils import ModelConfig, compare_and_assert

# Pool mode (NVTE_CP_POOL_PG=1) only: shared CP collective groups, created once
# per pool by run_attention_with_cp_pool.main() and reused across every case in
# that pool. world_size and the rank set don't change per case, so re-creating
# these per call would be wasted NCCL setup (~50-100 ms each). Single-shot
# subprocess mode leaves these None / [] and run_dpa_with_cp creates/destroys
# its own groups inline.
_pool_cp_comm_group = None
_pool_cp_comm_sub_groups: list = []

dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}


Expand Down Expand Up @@ -209,10 +219,13 @@ def run_dpa_with_cp(
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
# Deep-copy: the module-level dict is shared across pool cases; the
# THD branch below rewrites attn_mask_type in place, which would
# otherwise leak into subsequent cases reusing the same model key.
config = copy.deepcopy(model_configs_flash_attn[model])
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
config = copy.deepcopy(model_configs_fused_attn[model])
assert config.attn_mask_type in [
"causal",
"no_mask",
Expand All @@ -226,6 +239,9 @@ def run_dpa_with_cp(
# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
# When NVTE_CP_POOL_PG=1, the pool runner owns the lifecycle of the main
# process group across many cases; here we only reuse it.
_pool_managed_pg = os.getenv("NVTE_CP_POOL_PG", "0") == "1"
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
Expand All @@ -234,25 +250,35 @@ def run_dpa_with_cp(
device = rank % device_count
torch.cuda.set_device(device)
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

# set up communication group for CP
if not _pool_managed_pg:
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

# Set up communication group for CP. In pool mode, the pool worker has
# already pre-created world-scoped and a2a+p2p sub-groups once and stashed
# them in module-level pointers; we reuse those and the pool destroys them
# at shutdown. In single-shot mode we create them per call and destroy in
# the finally below.
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
" = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)

_reusing_pool_groups = _pool_managed_pg and _pool_cp_comm_group is not None
cp_comm_group = None
cp_comm_sub_groups: list = []
if _reusing_pool_groups:
cp_comm_group = _pool_cp_comm_group
cp_comm_sub_groups = _pool_cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else []
else:
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert world_size % 2 == 0, (
"{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has"
" cp_size = 2."
)
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)
if dtype == "fp8":
if scaling_mode == "delayed":
fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
Expand Down Expand Up @@ -564,7 +590,10 @@ def run_dpa_with_cp(
seq_kv_size = dbias.shape[-1]
# Reshape to split seq_q dimension
dbias = dbias.view(
*shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size
*shape_before_seq,
2 * world_size,
seq_q_size // (2 * world_size),
seq_kv_size,
)
# Index select on the newly created dimension (now at position seq_q_dim)
dbias = dbias.index_select(seq_q_dim, seq_idx)
Expand Down Expand Up @@ -754,16 +783,43 @@ def run_dpa_with_cp(
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
t,
tensors_cp[i],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
else:
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")

# destroy distribution group
dist.destroy_process_group()
# Teardown on the success path. Pool mode: cp_comm_group / cp_comm_sub_groups
# point at pool-shared groups owned by the pool runner (which destroys them
# at pool shutdown), and the main PG is also pool-owned — both branches
# below are no-ops. Single-shot mode: destroy what we created here. If the
# body above raises, we skip this — the subprocess dies at function return
# and NCCL releases the communicators with the process.
if not _reusing_pool_groups:
if cp_comm_group is not None:
try:
dist.destroy_process_group(cp_comm_group)
except Exception:
pass
for g in cp_comm_sub_groups:
try:
dist.destroy_process_group(g)
except Exception:
pass
if not _pool_managed_pg:
try:
dist.destroy_process_group()
except Exception:
pass


def main(**kwargs):
Expand Down
198 changes: 198 additions & 0 deletions tests/pytorch/attention/run_attention_with_cp_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
Persistent worker for batched CP attention tests.

Launched ONCE per (pytest session, world_size) by torchrun. All ranks init
NCCL, then enter a dispatch loop:

rank 0:
read one JSON request line from stdin
broadcast it to all ranks
all ranks:
call run_dpa_with_cp(**kwargs) — the same work function the
per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the
function reuses our PG instead of re-initing it
torch.cuda.empty_cache() per case
all ranks gather (ok, error_msg) to rank 0
rank 0:
write one JSON response line to stdout

Protocol (line-delimited JSON over rank-0 stdio):
request : {"op": "run", "kwargs": {...}}
{"op": "shutdown"}
response: {"ok": true}
{"ok": false, "error": "first failing rank's traceback"}
"""
import json
import os
import sys
import time
import traceback

import torch
import torch.distributed as dist

# Make sibling modules importable when launched directly.
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from run_attention_with_cp import run_dpa_with_cp
from transformer_engine.pytorch.quantization import FP8GlobalStateManager


def _recv_request(rank: int) -> dict:
box = [None]
if rank == 0:
line = sys.stdin.readline()
box[0] = {"op": "shutdown"} if not line else json.loads(line)
dist.broadcast_object_list(box, src=0)
return box[0]


# Sentinel prefix on every response line so the parent reader can skip any
# stdout chatter that gets interleaved (torchrun status, library prints, even
# non-rank-0 stray output — torchrun ranks share rank 0's stdout fd).
_RESP_PREFIX = "[CP_POOL_RESP] "


def _send_response(rank: int, payload: dict) -> None:
if rank == 0:
sys.stdout.write(_RESP_PREFIX + json.dumps(payload) + "\n")
sys.stdout.flush()
Comment on lines +60 to +63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 stdout pollution can silently corrupt the JSON protocol

torchrun (and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via a print call in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent's readline() in PoolWorker.submit would then receive a non-JSON line and raise a json.JSONDecodeError, killing the pool and failing the test with a misleading error message.

Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.



def _reset_between_cases() -> None:
"""Drop state that would otherwise cascade across cases.

Matches the per-case startup of the single-shot worker
(``_run_single_config`` on the per-case-subprocess branch): identical RNG
seed at the start of every case, FP8 state cleared, allocator clean.
``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN``
unconditionally and pops the other transient env vars itself, so no
explicit pop is needed here.
"""
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
FP8GlobalStateManager.reset()
torch.cuda.empty_cache()


_case_counter = 0


def _run_one(req: dict, rank: int) -> tuple[bool, str]:
global _case_counter
op = req["op"]
if op != "run":
return False, f"unknown op: {op}"
# Reset BEFORE the case so the first case also starts from a known RNG seed
# and clean FP8 state — same as the single-shot worker's per-process startup.
_reset_between_cases()
t0 = time.monotonic()
ok = True
err = ""
try:
run_dpa_with_cp(**req.get("kwargs", {}))
except Exception:
ok = False
err = f"[Rank {rank}] {traceback.format_exc()}"
wall = time.monotonic() - t0
# Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1.
# Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution.
if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")):
_case_counter += 1
sys.stderr.write(
f"[POOL-TIMING] case_idx={_case_counter} "
f"world_size={int(os.environ.get('WORLD_SIZE', 0))} "
f"wall_s={wall:.3f} ok={ok}\n"
)
sys.stderr.flush()
return ok, err


def _create_cp_comm_groups(rank: int, world_size: int) -> tuple:
"""Pre-create the CP collective groups for this pool.

world_size and the rank set are constant for the lifetime of one pool, so
the world group and the a2a+p2p sub-groups are deterministic. Creating
them once here and reusing them across every case eliminates ~50-100 ms
of NCCL setup per case (cyanguwa's review feedback on PR #2993).

Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is
empty when world_size is too small to support a2a+p2p (needs an even
world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to
such a pool anyway.
"""
world_group = dist.new_group(range(world_size), backend="nccl")
sub_groups: list = []
if world_size >= 4 and world_size % 2 == 0:
# Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along
# axis 0, plus 2 stride-2 groups along axis 1.
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
sub_groups.append(sub_group)
return world_group, sub_groups


def main() -> None:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(rank % torch.cuda.device_count())
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["NVTE_CP_POOL_PG"] = "1"

# Stash pool-shared CP groups on the run_attention_with_cp module so
# run_dpa_with_cp can read them per case. Imported here (after the env var
# is set) to keep import-time side effects minimal.
import run_attention_with_cp as _rac

_rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups(
rank, world_size
)

try:
while True:
req = _recv_request(rank)
if req.get("op") == "shutdown":
break

ok, msg = _run_one(req, rank)

gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item]
# gather_object is itself a collective synchronization point — if
# every rank reached it, none is ahead. No extra barrier needed.
dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0)

if rank == 0:
all_ok = all(o for o, _ in gathered)
if all_ok:
_send_response(rank, {"ok": True})
else:
first_err = next(m for o, m in gathered if not o)
_send_response(rank, {"ok": False, "error": first_err})
finally:
# Tear down pool-shared CP groups before the main PG (NCCL requires
# sub-groups to be destroyed first). Each destroy is independently
# guarded so a wedged communicator on one group doesn't leak the rest.
if _rac._pool_cp_comm_group is not None:
try:
dist.destroy_process_group(_rac._pool_cp_comm_group)
except Exception:
pass
for g in _rac._pool_cp_comm_sub_groups:
try:
dist.destroy_process_group(g)
except Exception:
pass
_rac._pool_cp_comm_group = None
_rac._pool_cp_comm_sub_groups = []
dist.destroy_process_group()


if __name__ == "__main__":
main()
Loading
Loading