-
Notifications
You must be signed in to change notification settings - Fork 724
CP Tests batching using subprocess worker pool #2993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sudhakarsingh27
wants to merge
16
commits into
NVIDIA:main
Choose a base branch
from
sudhakarsingh27:sudhakars/cp_batching_pool
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+513
−71
Open
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
ccf7321
Batch CP attention tests via a persistent NCCL pool
sudhakarsingh27 59609ac
Reset FP8 state and barrier between pool cases
sudhakarsingh27 73e8cef
Deep-copy ModelConfig in run_dpa_with_cp
sudhakarsingh27 311137c
Skip deterministic configs incompatible with FusedAttention
sudhakarsingh27 49878d6
Reseed RNG between pool cases; reset before, not after
sudhakarsingh27 385e966
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 86b334b
Robustify pool: capture worker stderr, tighten timeout, add timing knob
sudhakarsingh27 ae5298c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e162a9e
Address PR review: NCCL leak, stdout protocol, Windows note
sudhakarsingh27 169be82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 557bd80
[PyTorch] Fix stream race on max_logit_per_step in all-gather CP forward
sudhakarsingh27 4815883
Address PR review (R2): drop dead code in pool worker and PoolWorker
sudhakarsingh27 d15bfce
Address PR review (items 2+3): reuse CP groups across pool cases
sudhakarsingh27 dd1d802
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into sudha…
sudhakarsingh27 87c67ac
Flatten try/finally wrap in run_dpa_with_cp
sudhakarsingh27 adb84af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,198 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """ | ||
| Persistent worker for batched CP attention tests. | ||
|
|
||
| Launched ONCE per (pytest session, world_size) by torchrun. All ranks init | ||
| NCCL, then enter a dispatch loop: | ||
|
|
||
| rank 0: | ||
| read one JSON request line from stdin | ||
| broadcast it to all ranks | ||
| all ranks: | ||
| call run_dpa_with_cp(**kwargs) — the same work function the | ||
| per-case subprocess design uses, with NVTE_CP_POOL_PG=1 so the | ||
| function reuses our PG instead of re-initing it | ||
| torch.cuda.empty_cache() per case | ||
| all ranks gather (ok, error_msg) to rank 0 | ||
| rank 0: | ||
| write one JSON response line to stdout | ||
|
|
||
| Protocol (line-delimited JSON over rank-0 stdio): | ||
| request : {"op": "run", "kwargs": {...}} | ||
| {"op": "shutdown"} | ||
| response: {"ok": true} | ||
| {"ok": false, "error": "first failing rank's traceback"} | ||
| """ | ||
| import json | ||
| import os | ||
| import sys | ||
| import time | ||
| import traceback | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| # Make sibling modules importable when launched directly. | ||
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|
|
||
| from run_attention_with_cp import run_dpa_with_cp | ||
| from transformer_engine.pytorch.quantization import FP8GlobalStateManager | ||
|
|
||
|
|
||
| def _recv_request(rank: int) -> dict: | ||
| box = [None] | ||
| if rank == 0: | ||
| line = sys.stdin.readline() | ||
| box[0] = {"op": "shutdown"} if not line else json.loads(line) | ||
| dist.broadcast_object_list(box, src=0) | ||
| return box[0] | ||
|
|
||
|
|
||
| # Sentinel prefix on every response line so the parent reader can skip any | ||
| # stdout chatter that gets interleaved (torchrun status, library prints, even | ||
| # non-rank-0 stray output — torchrun ranks share rank 0's stdout fd). | ||
| _RESP_PREFIX = "[CP_POOL_RESP] " | ||
|
|
||
|
|
||
| def _send_response(rank: int, payload: dict) -> None: | ||
| if rank == 0: | ||
| sys.stdout.write(_RESP_PREFIX + json.dumps(payload) + "\n") | ||
| sys.stdout.flush() | ||
|
|
||
|
|
||
| def _reset_between_cases() -> None: | ||
| """Drop state that would otherwise cascade across cases. | ||
|
|
||
| Matches the per-case startup of the single-shot worker | ||
| (``_run_single_config`` on the per-case-subprocess branch): identical RNG | ||
| seed at the start of every case, FP8 state cleared, allocator clean. | ||
| ``run_dpa_with_cp`` re-sets ``NVTE_FUSED_ATTN``/``NVTE_FLASH_ATTN`` | ||
| unconditionally and pops the other transient env vars itself, so no | ||
| explicit pop is needed here. | ||
| """ | ||
| torch.manual_seed(1234) | ||
| torch.cuda.manual_seed(1234) | ||
| FP8GlobalStateManager.reset() | ||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| _case_counter = 0 | ||
|
|
||
|
|
||
| def _run_one(req: dict, rank: int) -> tuple[bool, str]: | ||
| global _case_counter | ||
| op = req["op"] | ||
| if op != "run": | ||
| return False, f"unknown op: {op}" | ||
| # Reset BEFORE the case so the first case also starts from a known RNG seed | ||
| # and clean FP8 state — same as the single-shot worker's per-process startup. | ||
| _reset_between_cases() | ||
| t0 = time.monotonic() | ||
| ok = True | ||
| err = "" | ||
| try: | ||
| run_dpa_with_cp(**req.get("kwargs", {})) | ||
| except Exception: | ||
| ok = False | ||
| err = f"[Rank {rank}] {traceback.format_exc()}" | ||
| wall = time.monotonic() - t0 | ||
| # Per-case wall time on rank 0, opt-in via NVTE_CP_POOL_TIMING=1. | ||
| # Used to tune POOL_SUBMIT_TIMEOUT_SEC against the observed distribution. | ||
| if rank == 0 and int(os.environ.get("NVTE_CP_POOL_TIMING", "0")): | ||
| _case_counter += 1 | ||
| sys.stderr.write( | ||
| f"[POOL-TIMING] case_idx={_case_counter} " | ||
| f"world_size={int(os.environ.get('WORLD_SIZE', 0))} " | ||
| f"wall_s={wall:.3f} ok={ok}\n" | ||
| ) | ||
| sys.stderr.flush() | ||
| return ok, err | ||
|
|
||
|
|
||
| def _create_cp_comm_groups(rank: int, world_size: int) -> tuple: | ||
| """Pre-create the CP collective groups for this pool. | ||
|
|
||
| world_size and the rank set are constant for the lifetime of one pool, so | ||
| the world group and the a2a+p2p sub-groups are deterministic. Creating | ||
| them once here and reusing them across every case eliminates ~50-100 ms | ||
| of NCCL setup per case (cyanguwa's review feedback on PR #2993). | ||
|
|
||
| Returns ``(world_group, a2a_p2p_sub_groups)``. ``a2a_p2p_sub_groups`` is | ||
| empty when world_size is too small to support a2a+p2p (needs an even | ||
| world_size ≥ 4); cases with cp_comm_type='a2a+p2p' wouldn't be routed to | ||
| such a pool anyway. | ||
| """ | ||
| world_group = dist.new_group(range(world_size), backend="nccl") | ||
| sub_groups: list = [] | ||
| if world_size >= 4 and world_size % 2 == 0: | ||
| # Mirror the layout in run_attention_with_cp.py: cp_size/2 pairs along | ||
| # axis 0, plus 2 stride-2 groups along axis 1. | ||
| cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] | ||
| cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] | ||
| for sub_ranks in cp_comm_sub_ranks: | ||
| sub_group = dist.new_group(sub_ranks, backend="nccl") | ||
| if rank in sub_ranks: | ||
| sub_groups.append(sub_group) | ||
| return world_group, sub_groups | ||
|
|
||
|
|
||
| def main() -> None: | ||
| rank = int(os.environ["RANK"]) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| torch.cuda.set_device(rank % torch.cuda.device_count()) | ||
| dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) | ||
| os.environ["NVTE_CP_POOL_PG"] = "1" | ||
|
|
||
| # Stash pool-shared CP groups on the run_attention_with_cp module so | ||
| # run_dpa_with_cp can read them per case. Imported here (after the env var | ||
| # is set) to keep import-time side effects minimal. | ||
| import run_attention_with_cp as _rac | ||
|
|
||
| _rac._pool_cp_comm_group, _rac._pool_cp_comm_sub_groups = _create_cp_comm_groups( | ||
| rank, world_size | ||
| ) | ||
|
|
||
| try: | ||
| while True: | ||
| req = _recv_request(rank) | ||
| if req.get("op") == "shutdown": | ||
| break | ||
|
|
||
| ok, msg = _run_one(req, rank) | ||
|
|
||
| gathered: list[tuple[bool, str]] = [None] * world_size # type: ignore[list-item] | ||
| # gather_object is itself a collective synchronization point — if | ||
| # every rank reached it, none is ahead. No extra barrier needed. | ||
| dist.gather_object((ok, msg), gathered if rank == 0 else None, dst=0) | ||
|
|
||
| if rank == 0: | ||
| all_ok = all(o for o, _ in gathered) | ||
| if all_ok: | ||
| _send_response(rank, {"ok": True}) | ||
| else: | ||
| first_err = next(m for o, m in gathered if not o) | ||
| _send_response(rank, {"ok": False, "error": first_err}) | ||
| finally: | ||
| # Tear down pool-shared CP groups before the main PG (NCCL requires | ||
| # sub-groups to be destroyed first). Each destroy is independently | ||
| # guarded so a wedged communicator on one group doesn't leak the rest. | ||
| if _rac._pool_cp_comm_group is not None: | ||
| try: | ||
| dist.destroy_process_group(_rac._pool_cp_comm_group) | ||
| except Exception: | ||
| pass | ||
| for g in _rac._pool_cp_comm_sub_groups: | ||
| try: | ||
| dist.destroy_process_group(g) | ||
| except Exception: | ||
| pass | ||
| _rac._pool_cp_comm_group = None | ||
| _rac._pool_cp_comm_sub_groups = [] | ||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchrun(and the worker processes it spawns for ranks 1–N) all inherit the same stdout file descriptor as rank 0. If torchrun writes any status line to stdout, or if any non-rank-0 worker accidentally prints (e.g. via aprintcall in a library, NCCL debug output, or a Python warning), those bytes are interleaved with rank 0's JSON responses. The parent'sreadline()inPoolWorker.submitwould then receive a non-JSON line and raise ajson.JSONDecodeError, killing the pool and failing the test with a misleading error message.Consider redirecting torchrun's own output or adding a sentinel prefix to every response line so the reader can skip unrecognised lines.