diff --git a/.github/workflows/build-gpu-image.yml b/.github/workflows/build-gpu-image.yml index 12dbfad96..cdfc23634 100644 --- a/.github/workflows/build-gpu-image.yml +++ b/.github/workflows/build-gpu-image.yml @@ -30,6 +30,11 @@ on: required: true default: true type: boolean + prewarm_modal: + description: "Prebuild the pushed image in Modal when auth is configured" + required: true + default: true + type: boolean prewarm_timeout: description: "Timeout for GPU node prewarm rollout" required: true @@ -155,11 +160,16 @@ jobs: PULL_IMAGE_REPO: ${{ inputs.pull_image_repo || 'docker.io/bradhiltonnw/art-gpu' }} IMAGE_TAG: ${{ inputs.tag }} NO_CACHE: ${{ inputs.no_cache }} + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} + PREWARM_MODAL_INPUT: ${{ inputs.prewarm_modal }} PREWARM_NODES: ${{ inputs.prewarm_nodes }} PREWARM_TIMEOUT: ${{ inputs.prewarm_timeout }} run: | IMAGE_TAG="${IMAGE_TAG:-latest}" NO_CACHE="${NO_CACHE:-false}" + export PREWARM_MODAL="${PREWARM_MODAL:-auto}" + PREWARM_MODAL_INPUT="${PREWARM_MODAL_INPUT:-true}" PREWARM_NODES="${PREWARM_NODES:-true}" PREWARM_TIMEOUT="${PREWARM_TIMEOUT:-30m}" @@ -175,6 +185,10 @@ jobs: args+=(--no-cache) fi + if [ "${PREWARM_MODAL_INPUT}" = "false" ]; then + args+=(--no-prewarm-modal) + fi + if [ "${PREWARM_NODES}" != "true" ]; then args+=(--no-prewarm-nodes) fi diff --git a/.gitignore b/.gitignore index d1f4ebd59..0dfae3afe 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,8 @@ data/cache.db streaming-chat-completions/ unsloth_compiled_cache/ wandb/ +!/typings/wandb/ +!/typings/wandb/** docs/node_modules/ dist/ replays/ diff --git a/dev/trainer_rank.py b/dev/trainer_rank.py new file mode 100644 index 000000000..177ece785 --- /dev/null +++ b/dev/trainer_rank.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import os + +import torch +import torch.distributed as dist +from transformers import AutoTokenizer +import typer + +from art.megatron.trainer_rank import AdamParams, ForwardInput, TrainerRank + + +def main( + model: str = "Qwen/Qwen3-0.6B", + dataset: str = "roneneldan/TinyStories", + split: str = "train", + text_column: str = "text", + samples: int = 16, + steps: int = 1, + lr: float = 5e-5, + layers: int = 2, + max_seq_length: int = 256, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + if not torch.cuda.is_available(): + raise RuntimeError("dev/trainer_rank.py requires CUDA") + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + + try: + from datasets import load_dataset + + from art.megatron import train as megatron_train + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + inputs: list[ForwardInput[torch.Tensor, None, None, None]] = [] + for row in load_dataset(dataset, split=split, streaming=True): + text = str(row.get(text_column, "")).strip() # type: ignore[union-attr] + if not text: + continue + token_ids = tokenizer( + text, + add_special_tokens=True, + truncation=True, + max_length=max_seq_length + 1, + return_tensors="pt", + )["input_ids"].reshape(-1) + if int(token_ids.numel()) <= 1: + continue + inputs.append( + ForwardInput( + input_tokens=token_ids[:-1], + target_tokens=token_ids[1:], + ) + ) + if len(inputs) >= samples: + break + if not inputs: + raise RuntimeError("dataset produced no tokenized training examples") + + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=lambda provider: setattr( + provider, + "num_layers", + layers, + ), + print_env=dist.get_rank() == 0, + ) + rank = TrainerRank(runtime) + if dist.get_rank() == 0: + print( + "TrainerRank ready: " + f"dp={megatron_train.ps.get_data_parallel_world_size()} " + f"device={rank.device}", + flush=True, + ) + + for step in range(steps): + loss_sum = torch.tensor(0.0, device=rank.device) + token_count = torch.tensor(0.0, device=rank.device) + for micro in rank.forward_micro_batches(inputs): + loss = torch.tensor(0.0, device=rank.device) + for output in micro.outputs: + assert output.target_logprobs is not None + loss = loss - output.target_logprobs.sum() + token_count += output.target_logprobs.numel() + if loss.requires_grad: + loss.backward() + loss_sum += loss.detach() + + rank.dp_reduce(loss_sum) + rank.dp_reduce(token_count) + scale = 1.0 / max(float(token_count.item()), 1.0) + metrics = rank.optim_step( + params=AdamParams(learning_rate=lr), + scale_grads=scale, + ) + metrics["loss"] = float(loss_sum.item() * scale) + metrics["tokens"] = float(token_count.item()) + if dist.get_rank() == 0: + print(f"step={step} {metrics}", flush=True) + + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_fast_check.py b/dev/trainer_rank_fast_check.py new file mode 100644 index 000000000..51372d7d8 --- /dev/null +++ b/dev/trainer_rank_fast_check.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import subprocess +import sys + +FAST_TESTS = ( + "tests/unit/test_trainer_rank_validation.py", + "tests/unit/test_trainer_rank_weird_shapes.py", + "tests/unit/test_shared_prefix_packing.py", + "tests/unit/test_shared_prefix_tree.py", + "tests/unit/test_shared_prefix_attention_builder.py", + "tests/unit/test_shared_prefix_grad_parity.py", +) + + +def main() -> None: + raise SystemExit( + subprocess.call( + [sys.executable, "-m", "pytest", "--tb=short", *FAST_TESTS, *sys.argv[1:]] + ) + ) + + +if __name__ == "__main__": + main() diff --git a/dev/trainer_rank_parity_probe.py b/dev/trainer_rank_parity_probe.py new file mode 100644 index 000000000..06cd0a959 --- /dev/null +++ b/dev/trainer_rank_parity_probe.py @@ -0,0 +1,539 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +import json +import os +import re +from typing import Any, cast + +import torch +import torch.distributed as dist +import typer + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.megatron.trainer_rank import ( + AnyForwardInput, + TrainerRank, + _batch_seq_logits, + _language_model, +) + + +@dataclass(frozen=True) +class _Capture: + values: dict[str, torch.Tensor] + positions_by_item: tuple[torch.Tensor, ...] + source_positions_by_item: tuple[torch.Tensor, ...] + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + sequences: int = 6, + sequence_length: int = 7, + compare_requests: int = 6, + request_shape: str = "varied", + oracle: str = "independent", + max_depth: int = 1, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + torch.manual_seed(1234) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=lambda provider: setattr( + provider, + "num_layers", + layers, + ), + print_env=dist.get_rank() == 0, + ) + if int(ps.get_tensor_model_parallel_world_size()) != 1: + raise RuntimeError("trainer_rank_parity_probe currently expects TP=1") + for chunk in runtime.model: + chunk.eval() + + rank = TrainerRank(runtime, shared_prefix_max_depth=max_depth) + requests = _unique_requests( + sequences=sequences, + sequence_length=sequence_length, + request_shape=request_shape, + ) + request_count = min(compare_requests, len(requests)) + + with torch.no_grad(): + packed = _run_capture(rank, requests) + records = _records_from_capture( + kind="packed", + capture=packed, + request_indices=range(len(requests)), + cp_rank=int(ps.get_context_parallel_rank()), + dp_rank=int(ps.get_data_parallel_rank()), + ) + for request_index, request in enumerate(requests): + if oracle == "independent": + oracle_capture = _run_capture(rank, [request]) + oracle_request_indices = (request_index,) + oracle_local_indices = None + elif oracle == "same-layout": + oracle_capture = _run_capture( + rank, + requests, + mutate_except=request_index, + ) + oracle_request_indices = range(len(requests)) + oracle_local_indices = (request_index,) + else: + raise ValueError("oracle must be 'independent' or 'same-layout'") + records.extend( + _records_from_capture( + kind="independent", + capture=oracle_capture, + request_indices=oracle_request_indices, + cp_rank=int(ps.get_context_parallel_rank()), + dp_rank=int(ps.get_data_parallel_rank()), + local_indices=oracle_local_indices, + ) + ) + + gathered: list[list[dict[str, object]] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, records) + if dist.get_rank() == 0: + flat_records = [ + record for rank_records in gathered for record in rank_records or [] + ] + report = _build_report( + records=flat_records, + requests=requests[:request_count], + topology={ + "world": dist.get_world_size(), + "dp": int(ps.get_data_parallel_world_size()), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + }, + oracle=oracle, + ) + print(json.dumps(report, sort_keys=True), flush=True) + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _unique_requests( + *, + sequences: int, + sequence_length: int, + request_shape: str, +) -> list[AnyForwardInput]: + from art.megatron.trainer_rank import ForwardInput + + if sequences < 1 or sequence_length < 2: + raise ValueError("sequences must be >= 1 and sequence_length must be >= 2") + if request_shape == "varied": + base_rows = ( + (11, 12, 13, 14, 15, 16, 17), + (11, 12, 13, 14, 24, 25), + (11, 12, 13, 14, 24, 26), + (11, 12, 13, 27), + (31, 32, 33, 34), + (31, 32, 33, 35), + (11, 12, 13, 14, 15, 16, 17), + (41, 42, 43), + (41, 42, 44, 45), + (51, 52, 53, 54, 55), + (61, 62, 63), + (61, 62, 64, 65), + (71, 72), + (81, 82, 83, 84), + (91, 92, 93), + (101, 102, 103, 104, 105), + ) + return [ + ForwardInput( + input_tokens=torch.tensor(row, dtype=torch.long) + 1000 * index + ) + for index, row in enumerate(base_rows[:sequences]) + ] + if request_shape == "deep": + base_rows = ( + (11, 12, 13, 14, 15, 16, 17), + (11, 12, 13, 14, 15, 16, 18), + (11, 12, 13, 14, 15, 19), + (11, 12, 13, 14, 20), + (11, 12, 21), + (31, 32, 33, 34, 35), + (31, 32, 33, 34, 36), + (31, 32, 33, 37), + (41, 42, 43), + (41, 42, 44), + (51, 52, 53, 54), + (61, 62), + (71, 72, 73, 74, 75), + (71, 72, 73, 76), + (81,), + (91, 92, 93), + ) + return [ + ForwardInput(input_tokens=torch.tensor(row, dtype=torch.long)) + for row in base_rows[:sequences] + ] + if request_shape != "equal": + raise ValueError("request_shape must be 'equal', 'varied', or 'deep'") + return [ + ForwardInput( + input_tokens=torch.arange( + 1000 * index + 11, + 1000 * index + 11 + sequence_length, + dtype=torch.long, + ) + ) + for index in range(sequences) + ] + + +def _run_capture( + rank: TrainerRank, + requests: Sequence[AnyForwardInput], + *, + mutate_except: int | None = None, +) -> _Capture: + from art.megatron.train import _placeholder_attention_mask + + model = _language_model(rank.runtime.model[0]) + items = [rank._forward_item(request) for request in requests] + batch = pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + if mutate_except is not None: + batch = _mutated_batch( + batch, keep_positions=batch.positions_by_sequence[mutate_except] + ) + prepared = rank._prepare_packed_forward(batch) + local_seq_len = int(prepared.tokens.shape[1]) + values: dict[str, torch.Tensor] = {} + handles = _register_hooks(model, values, seq_len=local_seq_len) + try: + handler = rank._handler() + forward_kwargs = handler.get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ) + extra_block_kwargs = cast( + dict[str, object] | None, + forward_kwargs.pop("extra_block_kwargs", None), + ) + preprocessed = model._preprocess( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + packed_seq_params=prepared.packed_seq_params, + ) + values["00.preprocess.decoder_input"] = _rows( + cast(torch.Tensor, preprocessed[0]).detach(), + seq_len=local_seq_len, + ) + hidden = cast( + torch.Tensor, + model.decoder( + hidden_states=preprocessed[0], + attention_mask=_placeholder_attention_mask(rank.device), + rotary_pos_emb=preprocessed[1], + rotary_pos_cos=preprocessed[2], + rotary_pos_sin=preprocessed[3], + rotary_pos_cos_sin=preprocessed[6] if len(preprocessed) == 7 else None, + packed_seq_params=prepared.packed_seq_params, + sequence_len_offset=preprocessed[4], + padding_mask=preprocessed[5], + **(extra_block_kwargs or {}), + ), + ) + gathered_hidden = rank._gather_sequence_parallel_hidden(hidden) + values["90.decoder.output"] = gathered_hidden.detach() + values["99.lm_head.logits"] = _logits(rank, gathered_hidden).detach() + return _Capture( + values=values, + positions_by_item=prepared.positions_by_item, + source_positions_by_item=prepared.source_positions_by_item, + ) + finally: + for handle in handles: + handle.remove() + + +def _mutated_batch( + batch: SharedPrefixPack, + *, + keep_positions: torch.Tensor, +) -> SharedPrefixPack: + tokens = batch.tokens.clone() + mask = torch.ones(int(tokens.shape[1]), dtype=torch.bool, device=tokens.device) + mask[keep_positions.to(device=tokens.device)] = False + replacement = ( + torch.arange(int(tokens.shape[1]), dtype=tokens.dtype, device=tokens.device) + + 50_000 + ) + tokens[0, mask] = replacement[mask] % 100_000 + return SharedPrefixPack( + tokens=tokens, + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + position_ids=batch.position_ids, + positions_by_sequence=batch.positions_by_sequence, + ) + + +def _register_hooks( + model: torch.nn.Module, + values: dict[str, torch.Tensor], + *, + seq_len: int, +) -> list[Any]: + handles: list[Any] = [] + for module_name, module in model.named_modules(): + label = _capture_label(module_name) + if label is None: + continue + + def hook( + _module: torch.nn.Module, + _inputs: tuple[object, ...], + output: object, + *, + label: str = label, + ) -> None: + tensor = _first_tensor(output) + if tensor is not None: + try: + values[label] = _rows(tensor.detach(), seq_len=seq_len) + except RuntimeError: + pass + + handles.append(module.register_forward_hook(hook)) + return handles + + +def _capture_label(module_name: str) -> str | None: + layer_prefix = r"decoder\.layers\.(\d+)(?:\._orig_mod)?" + if re.fullmatch(r"decoder\.layers\.(\d+)\._orig_mod", module_name): + return None + layer_match = re.fullmatch(r"decoder\.layers\.(\d+)", module_name) + if layer_match: + return f"30.layer.{int(layer_match.group(1)):03d}.output" + input_norm_match = re.fullmatch(rf"{layer_prefix}\.input_layernorm", module_name) + if input_norm_match: + return f"05.layer.{int(input_norm_match.group(1)):03d}.input_layernorm" + qkv_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.linear_qkv", module_name + ) + if qkv_match: + return f"08.layer.{int(qkv_match.group(1)):03d}.self_attention.linear_qkv" + core_attention_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.core_attention", + module_name, + ) + if core_attention_match: + return f"10.layer.{int(core_attention_match.group(1)):03d}.self_attention.core_attention" + attention_proj_match = re.fullmatch( + rf"{layer_prefix}\.self_attention\.linear_proj", + module_name, + ) + if attention_proj_match: + return f"12.layer.{int(attention_proj_match.group(1)):03d}.self_attention.linear_proj" + attention_match = re.fullmatch( + rf"{layer_prefix}\.self_attention", + module_name, + ) + if attention_match: + return f"15.layer.{int(attention_match.group(1)):03d}.self_attention" + pre_mlp_norm_match = re.fullmatch( + rf"{layer_prefix}\.pre_mlp_layernorm", + module_name, + ) + if pre_mlp_norm_match: + return f"18.layer.{int(pre_mlp_norm_match.group(1)):03d}.pre_mlp_layernorm" + fc1_match = re.fullmatch(rf"{layer_prefix}\.mlp\.linear_fc1", module_name) + if fc1_match: + return f"20.layer.{int(fc1_match.group(1)):03d}.mlp.linear_fc1" + fc2_match = re.fullmatch(rf"{layer_prefix}\.mlp\.linear_fc2", module_name) + if fc2_match: + return f"22.layer.{int(fc2_match.group(1)):03d}.mlp.linear_fc2" + mlp_match = re.fullmatch(rf"{layer_prefix}\.mlp", module_name) + if mlp_match: + return f"25.layer.{int(mlp_match.group(1)):03d}.mlp" + if module_name == "decoder.final_layernorm": + return "80.decoder.final_layernorm" + return None + + +def _first_tensor(value: object) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, (tuple, list)): + for item in value: + tensor = _first_tensor(item) + if tensor is not None: + return tensor + return None + + +def _rows(tensor: torch.Tensor, *, seq_len: int) -> torch.Tensor: + if tensor.ndim >= 2 and int(tensor.shape[0]) == seq_len: + rows = tensor + if rows.ndim >= 3 and int(rows.shape[1]) == 1: + return rows[:, 0].contiguous() + return rows.contiguous() + if tensor.ndim >= 2 and int(tensor.shape[1]) == seq_len: + rows = ( + tensor[:, :, 0] + if tensor.ndim == 4 and int(tensor.shape[2]) == 1 + else tensor + ) + if int(rows.shape[0]) == 1: + return rows[0].contiguous() + raise RuntimeError( + f"Cannot identify sequence axis for tensor shape={tuple(tensor.shape)} " + f"seq_len={seq_len}" + ) + + +def _logits(rank: TrainerRank, hidden_rows: torch.Tensor) -> torch.Tensor: + model = _language_model(rank.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + if int(hidden_rows.shape[0]) == 0: + return hidden_rows.new_empty((0, int(model.vocab_size))) + local_logits = rank._local_logits_from_hidden_rows( + model, + hidden_rows, + output_weight=output_weight, + ) + return _batch_seq_logits( + rank._gather_tensor_parallel_logits(local_logits.unsqueeze(1)), + seq_len=int(hidden_rows.shape[0]), + ).squeeze(0) + + +def _records_from_capture( + *, + kind: str, + capture: _Capture, + request_indices: Sequence[int], + cp_rank: int, + dp_rank: int, + local_indices: Sequence[int] | None = None, +) -> list[dict[str, object]]: + records: list[dict[str, object]] = [] + local_index_set = None if local_indices is None else frozenset(local_indices) + for local_index, request_index in enumerate(request_indices): + if local_index_set is not None and local_index not in local_index_set: + continue + positions = capture.positions_by_item[local_index] + source_positions = capture.source_positions_by_item[local_index] + if int(positions.numel()) == 0: + continue + for name, rows in capture.values.items(): + records.append( + { + "kind": kind, + "name": name, + "request_index": int(request_index), + "source_positions": source_positions.cpu(), + "value": rows.index_select(0, positions.to(rows.device)).cpu(), + "cp": int(cp_rank), + "dp": int(dp_rank), + } + ) + return records + + +def _build_report( + *, + records: list[dict[str, object]], + requests: Sequence[AnyForwardInput], + topology: dict[str, int], + oracle: str, +) -> dict[str, object]: + results = [] + names = sorted( + { + cast(str, record["name"]) + for record in records + if record.get("kind") == "packed" + } + ) + for request_index, request in enumerate(requests): + length = int(request.input_tokens.numel()) + for name in names: + packed = _assemble(records, "packed", name, request_index, length) + independent = _assemble(records, "independent", name, request_index, length) + if packed is None or independent is None: + continue + diff = (packed.float() - independent.float()).abs() + denom = independent.float().abs().max().clamp_min(1e-12) + results.append( + { + "request": request_index, + "site": name, + "shape": list(packed.shape), + "max_abs": float(diff.max().item()) if int(diff.numel()) else 0.0, + "mean_abs": float(diff.mean().item()) if int(diff.numel()) else 0.0, + "rel_max": float((diff.max() / denom).item()) + if int(diff.numel()) + else 0.0, + } + ) + return { + "topology": topology, + "oracle": oracle, + "requests": len(requests), + "results": results, + } + + +def _assemble( + records: list[dict[str, object]], + kind: str, + name: str, + request_index: int, + length: int, +) -> torch.Tensor | None: + matching = [ + record + for record in records + if record["kind"] == kind + and record["name"] == name + and record["request_index"] == request_index + ] + if not matching: + return None + first = cast(torch.Tensor, matching[0]["value"]) + output = torch.empty((length, *first.shape[1:]), dtype=first.dtype) + filled = torch.zeros(length, dtype=torch.bool) + for record in matching: + positions = cast(torch.Tensor, record["source_positions"]) + value = cast(torch.Tensor, record["value"]) + output[positions] = value + filled[positions] = True + if not bool(filled.all().item()): + raise RuntimeError( + f"Missing positions for {kind} {name} request={request_index}" + ) + return output + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_perf.py b/dev/trainer_rank_perf.py new file mode 100644 index 000000000..4d0c2305c --- /dev/null +++ b/dev/trainer_rank_perf.py @@ -0,0 +1,2906 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from contextlib import contextmanager, suppress +import json +import os +from pathlib import Path +import threading +import time +from typing import Any + +import torch +import torch.distributed as dist +import typer + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +import art.megatron.trainer_rank as trainer_rank_module +from art.megatron.trainer_rank import ( + AdamParams, + ForwardInput, + TopK, + TrainerRank, + _batch_seq_logits, + _language_model, + _unflatten, +) + + +def _pack_forward_items(items: Sequence[Any], *, max_depth: int) -> SharedPrefixPack: + return pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=max_depth, + ) + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + seq_len: int = 2048, + prefix_families: int = 0, + prefix_len: int = 5000, + mid_prefixes_per_family: int = 1, + mid_prefix_len: int = 0, + branches_per_prefix: int = 16, + completion_len: int = 100, + warmup: int = 2, + repeat: int = 5, + head_chunk_tokens: int = 512, + shared_prefix_max_depth: int = 1, + benchmark: str = "target_builtin_fwd", + target_count: int = 4, + top_k: int = 5, + top_k_values: str = "1,2,5,10,20,50", + max_unpacked_output_gb: float = 0.5, + mask_prefix_targets: bool = True, + workload: str = "regular", + tree_depth: int = 3, + tree_seed: int = 1, + tree_duplicate_factor: int = 1, + adapter_slots: int = 0, + adapter_slot_mode: str = "family", + adapter_slot_rank: int = 1, + learning_rate: float = 1e-5, + full_step_offload_reload: bool = False, + memory_safety_factor: float = 1.10, + memory_reserve_fraction: float = 0.03, + memory_sample_interval_s: float = 0.05, + compare_target_correctness: bool = False, + run_adapter_sanity: bool = False, + progress_jsonl: str = "", + output_jsonl: str = "", +) -> None: + if progress_jsonl: + os.environ["ART_TRAINER_RANK_PROGRESS_JSONL"] = progress_jsonl + + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + provider_configure = ( + (lambda provider: setattr(provider, "num_layers", layers)) + if layers > 0 + else None + ) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=provider_configure, + print_env=dist.get_rank() == 0, + ) + for chunk in runtime.model: + chunk.eval() + rank = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_tokens, + shared_prefix_max_depth=shared_prefix_max_depth, + memory_safety_factor=memory_safety_factor, + memory_reserve_fraction=memory_reserve_fraction, + ) + if adapter_slots < 0: + raise ValueError("adapter_slots must be >= 0") + if adapter_slot_rank < 1: + raise ValueError("adapter_slot_rank must be >= 1") + if adapter_slots: + loaded_sites = _load_adapter_slots( + rank, + count=adapter_slots, + slot_rank=adapter_slot_rank, + ) + else: + loaded_sites = 0 + hidden_size, vocab_size, dtype_size = _runtime_output_shape(runtime) + model_config = getattr(_language_model(runtime.model[0]), "config", None) + + benchmarks = { + name.strip().replace("-", "_") + for name in benchmark.split(",") + if name.strip() + } + if "all" in benchmarks: + benchmarks = { + "target_builtin_fwd", + "target_trainer_fwd", + "target_hidden_fwd", + "logits_builtin_fwd", + "logits_hidden_fwd", + "target_builtin_fwd_bwd", + "target_builtin_masked_fwd_bwd", + "target_trainer_fwd_bwd", + "target_hidden_fwd_bwd", + "target_builtin_train_step", + "target_trainer_train_step", + "target_trainer_fixed_train_step", + "target_trainer_adaptive_train_step", + "target_trainer_adaptive_profile_train_step", + "target_hidden_train_step", + "trainer_multi_target_fwd_bwd", + "trainer_multi_target_train_step", + "trainer_multi_target_fixed_train_step", + "trainer_multi_target_adaptive_train_step", + "trainer_target", + "trainer_multi_target", + "trainer_topk", + "trainer_topk_head", + "trainer_topk_fwd_bwd", + "trainer_topk_train_step", + "trainer_topk_fixed_train_step", + "trainer_topk_adaptive_train_step", + "trainer_topk_sweep", + "trainer_target_topk", + "trainer_hidden", + "trainer_all_no_logits", + "trainer_logits", + } + if "trainer_all" in benchmarks: + benchmarks.update( + { + "trainer_target", + "trainer_multi_target", + "trainer_multi_target_fwd_bwd", + "trainer_multi_target_train_step", + "trainer_multi_target_fixed_train_step", + "trainer_multi_target_adaptive_train_step", + "trainer_topk", + "trainer_topk_head", + "trainer_topk_fwd_bwd", + "trainer_topk_train_step", + "trainer_topk_fixed_train_step", + "trainer_topk_adaptive_train_step", + "trainer_topk_sweep", + "trainer_target_topk", + "trainer_hidden", + "trainer_all_no_logits", + "trainer_logits", + } + ) + + if target_count < 1: + raise ValueError("target_count must be >= 1") + if top_k < 1: + raise ValueError("top_k must be >= 1") + if memory_sample_interval_s < 0: + raise ValueError("memory_sample_interval_s must be >= 0") + requests, multi_target_requests, request_metadata = _requests( + seq_len=seq_len, + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + target_count=target_count, + mask_prefix_targets=mask_prefix_targets, + workload=workload, + tree_depth=tree_depth, + tree_seed=tree_seed, + tree_duplicate_factor=tree_duplicate_factor, + ) + requests = _route_adapter_slots( + requests, + adapter_slots=adapter_slots, + mode=adapter_slot_mode, + ) + multi_target_requests = _route_adapter_slots( + multi_target_requests, + adapter_slots=adapter_slots, + mode=adapter_slot_mode, + ) + stats_items = [rank._forward_item(request) for request in requests] + stats_batch = _pack_forward_items( + stats_items, + max_depth=rank.shared_prefix_max_depth, + ) + stats_prepared = rank._prepare_packed_forward(stats_batch) + request_stats = _packed_request_stats( + requests, + stats_items, + stats_batch, + request_metadata=request_metadata, + ) + planner_metadata = _gather_planner_metadata(stats_prepared) + target_items = None + target_prepared = None + if any(name.startswith("target_") for name in benchmarks): + target_items = stats_items + target_prepared = stats_prepared + logits_items = None + logits_prepared = None + if any(name.startswith("logits_") for name in benchmarks): + logits_items = [ + rank._forward_item(_with_outputs(request, logits=True)) + for request in requests + ] + logits_prepared = rank._prepare_packed_forward( + _pack_forward_items( + logits_items, + max_depth=rank.shared_prefix_max_depth, + ) + ) + results: dict[str, float] = {} + metadata: dict[str, object] = {} + rate_units: dict[str, dict[str, int]] = {} + + def register_case( + name: str, + case_requests: Sequence[ + ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, + ] + ], + case_stats: dict[str, int | str], + ) -> None: + units = _rate_units( + case_requests, + case_stats, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + rate_units[name] = units + for key, value in units.items(): + metadata[f"{name}_{key}"] = value + + for name in ( + "target_builtin_fwd", + "target_hidden_fwd", + "target_trainer_fwd", + "target_builtin_fwd_bwd", + "target_builtin_masked_fwd_bwd", + "target_trainer_fwd_bwd", + "target_hidden_fwd_bwd", + "target_builtin_train_step", + "target_trainer_train_step", + "target_trainer_fixed_train_step", + "target_trainer_adaptive_train_step", + "target_trainer_adaptive_profile_train_step", + "target_hidden_train_step", + ): + register_case(name, requests, request_stats) + + memory_tracker = _CudaMemoryTracker( + device_index=int(os.environ["LOCAL_RANK"]), + sample_interval_s=memory_sample_interval_s, + ) + memory_tracker.start() + torch.cuda.reset_peak_memory_stats() + with torch.no_grad(): + if "target_builtin_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_builtin_fwd_ms"] = _bench( + lambda: _builtin( + rank, + target_prepared, + _packed_labels(target_items, target_prepared), + ), + warmup=warmup, + repeat=repeat, + ) + if "target_hidden_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_hidden_fwd_ms"] = _bench( + lambda: rank._project_head( + target_items, + target_prepared, + rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(target_prepared) + ), + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_fwd" in benchmarks: + assert target_items is not None and target_prepared is not None + results["target_trainer_fwd_ms"] = _bench( + lambda: rank._forward_packed(target_items, target_prepared), + warmup=warmup, + repeat=repeat, + ) + if "logits_builtin_fwd" in benchmarks: + assert logits_prepared is not None + register_case( + "logits_builtin_fwd", _logits_requests(requests), request_stats + ) + results["logits_builtin_fwd_ms"] = _bench( + lambda: _full_logits(rank, logits_prepared), + warmup=warmup, + repeat=repeat, + ) + if "logits_hidden_fwd" in benchmarks: + assert logits_items is not None and logits_prepared is not None + register_case( + "logits_hidden_fwd", _logits_requests(requests), request_stats + ) + results["logits_hidden_fwd_ms"] = _bench( + lambda: rank._project_head( + logits_items, + logits_prepared, + rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(logits_prepared) + ), + ), + warmup=warmup, + repeat=repeat, + ) + trainer_cases = { + "trainer_target": requests, + "trainer_multi_target": multi_target_requests, + "trainer_topk": [ + _with_outputs(request, top_k=top_k) for request in requests + ], + "trainer_target_topk": [ + _with_outputs( + request, + target_tokens=request.target_tokens, + top_k=top_k, + ) + for request in requests + ], + "trainer_hidden": [ + _with_outputs(request, hidden_states=True) for request in requests + ], + "trainer_all_no_logits": [ + _with_outputs( + request, + target_tokens=multi_request.target_tokens, + top_k=top_k, + hidden_states=True, + ) + for request, multi_request in zip( + requests, multi_target_requests, strict=True + ) + ], + "trainer_logits": [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ], + } + if "trainer_topk_sweep" in benchmarks: + for k in _int_values(top_k_values): + trainer_cases[f"trainer_topk_{k}"] = [ + _with_outputs(request, top_k=k) for request in requests + ] + for name, case_requests in trainer_cases.items(): + if name not in benchmarks and not ( + "trainer_topk_sweep" in benchmarks + and name.startswith("trainer_topk_") + ): + continue + output_gb = _request_output_gb( + case_requests, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + metadata[f"{name}_output_gb"] = round(output_gb, 3) + if max_unpacked_output_gb > 0 and output_gb > max_unpacked_output_gb: + metadata[f"{name}_skipped"] = "unpacked_output_cap" + continue + items = [rank._forward_item(request) for request in case_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + name, + case_requests, + _packed_request_stats( + case_requests, items, batch, request_metadata={} + ), + ) + prepared = rank._prepare_packed_forward(batch) + if adapter_slots: + results[f"{name}_ms"] = _bench( + lambda case_requests=case_requests: rank.dp_rank_forward( + case_requests + ), + warmup=warmup, + repeat=repeat, + ) + else: + results[f"{name}_ms"] = _bench( + lambda items=items, prepared=prepared: rank._forward_packed( + items, + prepared, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_topk_head" in benchmarks: + case_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + output_gb = _request_output_gb( + case_requests, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + metadata["trainer_topk_head_output_gb"] = round(output_gb, 3) + items = [rank._forward_item(request) for request in case_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_head", + case_requests, + _packed_request_stats( + case_requests, items, batch, request_metadata={} + ), + ) + prepared = rank._prepare_packed_forward(batch) + hidden = rank._gather_sequence_parallel_hidden( + rank._decoder_hidden(prepared) + ) + results["trainer_topk_head_ms"] = _bench( + lambda: rank._project_head(items, prepared, hidden), + warmup=warmup, + repeat=repeat, + ) + + if "target_builtin_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_fwd_bwd_ms"] = _bench( + lambda: _target_builtin_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_builtin_masked_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_masked_fwd_bwd_ms"] = _bench( + lambda: _target_builtin_masked_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_trainer_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_trainer_fwd_bwd_ms"] = _bench( + lambda: ( + _target_requests_loss(rank, requests) + if adapter_slots + else _target_trainer_loss( + rank, + target_items, + target_prepared, + ) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "target_hidden_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_hidden_fwd_bwd_ms"] = _bench( + lambda: _target_hidden_loss( + rank, + target_items, + target_prepared, + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + train_step_params = AdamParams(learning_rate=learning_rate) + offload_manager = ( + _make_offload_manager(runtime) if full_step_offload_reload else None + ) + if "target_builtin_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_builtin_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: _target_builtin_loss(rank, target_items, target_prepared), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_trainer_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _target_requests_loss(rank, requests) + if adapter_slots + else _target_trainer_loss(rank, target_items, target_prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "target_trainer_fixed_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + fixed_stats: list[dict[str, int | bool]] = [] + results["target_trainer_fixed_train_step_ms"] = _bench( + lambda: _fixed_micro_batch_training_step( + rank, + requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=fixed_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "target_trainer_fixed_train_step", fixed_stats + ) + if "target_trainer_adaptive_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + adaptive_stats: list[dict[str, int | bool]] = [] + results["target_trainer_adaptive_train_step_ms"] = _bench( + lambda: _adaptive_micro_batch_training_step( + rank, + requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "target_trainer_adaptive_train_step", adaptive_stats + ) + if "target_trainer_adaptive_profile_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + adaptive_stats: list[dict[str, int | bool | float]] = [] + results["target_trainer_adaptive_profile_train_step_ms"] = _bench( + lambda: _profiled_adaptive_micro_batch_training_step( + rank, + requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, + "target_trainer_adaptive_profile_train_step", + adaptive_stats, + ) + _record_profile_stats( + metadata, + "target_trainer_adaptive_profile_train_step", + adaptive_stats, + ) + if "target_hidden_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + assert target_items is not None and target_prepared is not None + results["target_hidden_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: _target_hidden_loss(rank, target_items, target_prepared), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if "trainer_multi_target_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_multi_target_fwd_bwd", + multi_target_requests, + _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_multi_target_fwd_bwd_ms"] = _bench( + lambda: ( + _target_requests_loss(rank, multi_target_requests) + if adapter_slots + else _target_trainer_loss(rank, items, prepared) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "trainer_multi_target_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_multi_target_train_step", + multi_target_requests, + _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_multi_target_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _target_requests_loss(rank, multi_target_requests) + if adapter_slots + else _target_trainer_loss(rank, items, prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if ( + "trainer_multi_target_fixed_train_step" in benchmarks + or "trainer_multi_target_adaptive_train_step" in benchmarks + ): + items = [rank._forward_item(request) for request in multi_target_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + multi_target_stats = _packed_request_stats( + multi_target_requests, + items, + batch, + request_metadata={}, + ) + if "trainer_multi_target_fixed_train_step" in benchmarks: + register_case( + "trainer_multi_target_fixed_train_step", + multi_target_requests, + multi_target_stats, + ) + for chunk in runtime.model: + chunk.train() + fixed_stats = [] + results["trainer_multi_target_fixed_train_step_ms"] = _bench( + lambda: _fixed_micro_batch_training_step( + rank, + multi_target_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=fixed_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, + "trainer_multi_target_fixed_train_step", + fixed_stats, + ) + if "trainer_multi_target_adaptive_train_step" in benchmarks: + register_case( + "trainer_multi_target_adaptive_train_step", + multi_target_requests, + multi_target_stats, + ) + for chunk in runtime.model: + chunk.train() + adaptive_stats = [] + results["trainer_multi_target_adaptive_train_step_ms"] = _bench( + lambda: _adaptive_micro_batch_training_step( + rank, + multi_target_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="target", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, + "trainer_multi_target_adaptive_train_step", + adaptive_stats, + ) + if "trainer_topk_fwd_bwd" in benchmarks: + for chunk in runtime.model: + chunk.train() + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_fwd_bwd", + topk_requests, + _packed_request_stats(topk_requests, items, batch, request_metadata={}), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_topk_fwd_bwd_ms"] = _bench( + lambda: ( + _topk_requests_loss(rank, topk_requests) + if adapter_slots + else _trainer_topk_loss(rank, items, prepared) + ).backward(), + warmup=warmup, + repeat=repeat, + after=rank.zero_grad, + ) + if "trainer_topk_train_step" in benchmarks: + for chunk in runtime.model: + chunk.train() + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + register_case( + "trainer_topk_train_step", + topk_requests, + _packed_request_stats(topk_requests, items, batch, request_metadata={}), + ) + prepared = rank._prepare_packed_forward(batch) + results["trainer_topk_train_step_ms"] = _bench( + lambda: _training_step( + rank, + lambda: ( + _topk_requests_loss(rank, topk_requests) + if adapter_slots + else _trainer_topk_loss(rank, items, prepared) + ), + params=train_step_params, + offload_manager=offload_manager, + ), + warmup=warmup, + repeat=repeat, + ) + if ( + "trainer_topk_fixed_train_step" in benchmarks + or "trainer_topk_adaptive_train_step" in benchmarks + ): + topk_requests = [ + _with_outputs(request, top_k=top_k) for request in requests + ] + items = [rank._forward_item(request) for request in topk_requests] + batch = _pack_forward_items( + items, + max_depth=rank.shared_prefix_max_depth, + ) + topk_stats = _packed_request_stats( + topk_requests, + items, + batch, + request_metadata={}, + ) + if "trainer_topk_fixed_train_step" in benchmarks: + register_case( + "trainer_topk_fixed_train_step", + topk_requests, + topk_stats, + ) + for chunk in runtime.model: + chunk.train() + fixed_stats = [] + results["trainer_topk_fixed_train_step_ms"] = _bench( + lambda: _fixed_micro_batch_training_step( + rank, + topk_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="topk", + stats_sink=fixed_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "trainer_topk_fixed_train_step", fixed_stats + ) + if "trainer_topk_adaptive_train_step" in benchmarks: + register_case( + "trainer_topk_adaptive_train_step", + topk_requests, + topk_stats, + ) + for chunk in runtime.model: + chunk.train() + adaptive_stats = [] + results["trainer_topk_adaptive_train_step_ms"] = _bench( + lambda: _adaptive_micro_batch_training_step( + rank, + topk_requests, + params=train_step_params, + offload_manager=offload_manager, + loss_kind="topk", + stats_sink=adaptive_stats, + ), + warmup=warmup, + repeat=repeat, + ) + _record_micro_batch_stats( + metadata, "trainer_topk_adaptive_train_step", adaptive_stats + ) + + if compare_target_correctness and adapter_slots: + metadata["target_correctness_skipped"] = "adapter_slots" + elif compare_target_correctness: + assert target_items is not None and target_prepared is not None + metadata.update( + _target_correctness_metrics(rank, target_items, target_prepared) + ) + if run_adapter_sanity and adapter_slots > 0: + metadata.update( + _adapter_sanity_metrics( + rank, + requests, + params=train_step_params, + adapter_slots=adapter_slots, + ) + ) + + memory_tracker.stop() + memory_metadata = _distributed_memory_metadata(memory_tracker) + model_metadata = _model_metadata(runtime, model, layers=layers) + + if dist.get_rank() == 0: + token_rates = _rate_metrics(results, rate_units) + payload = { + "world": dist.get_world_size(), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "seq_len": seq_len, + "prefix_families": prefix_families, + "prefix_len": prefix_len, + "mid_prefixes_per_family": mid_prefixes_per_family, + "mid_prefix_len": mid_prefix_len, + "branches_per_prefix": branches_per_prefix, + "completion_len": completion_len, + "head_chunk_tokens": head_chunk_tokens, + "shared_prefix_max_depth": shared_prefix_max_depth, + "warmup": warmup, + "repeat": repeat, + "target_count": target_count, + "top_k": top_k, + "top_k_values": top_k_values, + "max_unpacked_output_gb": max_unpacked_output_gb, + "mask_prefix_targets": mask_prefix_targets, + "workload": workload, + "tree_depth": tree_depth, + "tree_seed": tree_seed, + "tree_duplicate_factor": tree_duplicate_factor, + "adapter_slots": adapter_slots, + "adapter_slot_mode": adapter_slot_mode, + "adapter_slot_rank": adapter_slot_rank, + "adapter_loaded_sites": loaded_sites, + "learning_rate": learning_rate, + "full_step_offload_reload": full_step_offload_reload, + "memory_safety_factor": memory_safety_factor, + "memory_reserve_fraction": memory_reserve_fraction, + "mtp_num_layers": getattr(model_config, "mtp_num_layers", None), + "cross_entropy_loss_fusion": getattr( + model_config, "cross_entropy_loss_fusion", None + ), + "cross_entropy_fusion_impl": getattr( + model_config, "cross_entropy_fusion_impl", None + ), + **model_metadata, + **request_stats, + **memory_metadata, + **results, + **token_rates, + **metadata, + **planner_metadata, + } + line = json.dumps(payload, sort_keys=True) + print(line, flush=True) + if output_jsonl: + output_path = Path(output_jsonl) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("a", encoding="utf-8") as output_file: + output_file.write(line + "\n") + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _requests( + *, + seq_len: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, + target_count: int, + mask_prefix_targets: bool, + workload: str, + tree_depth: int, + tree_seed: int, + tree_duplicate_factor: int, +) -> tuple[ + list[ForwardInput[torch.Tensor, None, None, None]], + list[ForwardInput[torch.Tensor, None, None, None]], + dict[str, int | str], +]: + if workload == "regular" and prefix_families <= 0: + tokens = torch.arange(seq_len, dtype=torch.long) % 32_000 + 100 + labels = _labels(tokens, target_count=1) + return ( + [ForwardInput(input_tokens=tokens, target_tokens=labels)], + [ + ForwardInput( + input_tokens=tokens, + target_tokens=_labels(tokens, target_count=target_count), + ) + ], + { + "request_count": 1, + "workload_shape": "single", + }, + ) + + if prefix_len < 1 or branches_per_prefix < 1 or completion_len < 1: + raise ValueError( + "prefix_len, branches_per_prefix, and completion_len must be >= 1" + ) + if mid_prefixes_per_family < 1 or mid_prefix_len < 0: + raise ValueError("mid_prefixes_per_family must be >= 1 and mid_prefix_len >= 0") + + sequences, prefix_lengths, workload_shape = _workload_sequences( + workload=workload, + seq_len=seq_len, + prefix_families=max(prefix_families, 1), + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + tree_depth=tree_depth, + tree_seed=tree_seed, + tree_duplicate_factor=tree_duplicate_factor, + ) + requests = [] + multi_requests = [] + for tokens, shared_length in zip(sequences, prefix_lengths, strict=True): + labels = _labels(tokens, target_count=1) + multi_labels = _labels(tokens, target_count=target_count) + if mask_prefix_targets and shared_length: + labels[:shared_length] = -100 + multi_labels[:shared_length] = -100 + requests.append(ForwardInput(input_tokens=tokens, target_tokens=labels)) + multi_requests.append( + ForwardInput(input_tokens=tokens, target_tokens=multi_labels) + ) + + return ( + requests, + multi_requests, + { + "request_count": len(requests), + "workload_shape": workload_shape, + }, + ) + + +def _load_adapter_slots( + rank: TrainerRank, + *, + count: int, + slot_rank: int, +) -> int: + loaded_sites = 0 + for slot_index in range(count): + loaded_sites += rank.load_checkpoint_slot( + f"S{slot_index}", + _synthetic_adapter( + rank.runtime.model, slot_rank=slot_rank, seed=slot_index + ), + ) + return loaded_sites + + +def _synthetic_adapter( + model: Sequence[torch.nn.Module], + *, + slot_rank: int, + seed: int, +) -> dict[str, torch.Tensor]: + from art.megatron.lora import LoRA + + adapter: dict[str, torch.Tensor] = {} + generator = torch.Generator(device="cuda").manual_seed(10_000 + seed) + for chunk in model: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + a_keys = module._expected_weight_keys("lora_A") + b_keys = module._expected_weight_keys("lora_B") + for a_key, b_key in zip(a_keys, b_keys, strict=True): + adapter[a_key] = ( + torch.randn( + slot_rank, + module.in_features, + dtype=module.A_T.dtype, + device=module.A_T.device, + generator=generator, + ) + * 0.01 + ) + adapter[b_key] = ( + torch.randn( + module.out_features, + slot_rank, + dtype=module.B_T.dtype, + device=module.B_T.device, + generator=generator, + ) + * 0.01 + ) + if not adapter: + raise RuntimeError("adapter slot stress requested, but model has no LoRA sites") + return adapter + + +def _route_adapter_slots( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + adapter_slots: int, + mode: str, +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if adapter_slots == 0: + return list(requests) + if mode not in {"family", "round_robin", "single", "skewed_random"}: + raise ValueError( + "adapter_slot_mode must be one of: family, round_robin, single, " + "skewed_random" + ) + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + top_k=request.top_k, + logits=request.logits, + hidden_states=request.hidden_states, + checkpoint=f"S{_adapter_slot_index(index, request, adapter_slots, mode)}", + ) + for index, request in enumerate(requests) + ] + + +def _adapter_slot_index( + index: int, + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + adapter_slots: int, + mode: str, +) -> int: + if mode == "single": + return 0 + if mode == "round_robin": + return index % adapter_slots + if mode == "skewed_random": + bucket = (index * 1103515245 + 12345) & 0x7FFFFFFF + skew = bucket % 100 + if skew < 50: + return 0 + if skew < 75: + return min(1, adapter_slots - 1) + if skew < 90: + return min(2, adapter_slots - 1) + return min(3 + (bucket % max(1, adapter_slots - 3)), adapter_slots - 1) + first_token = ( + int(request.input_tokens[0].item()) if request.input_tokens.numel() else 0 + ) + return (first_token // 10_000_019) % adapter_slots + + +def _with_outputs( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + target_tokens: torch.Tensor | None = None, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, +) -> ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None +]: + return ForwardInput( + input_tokens=request.input_tokens, + target_tokens=target_tokens, + top_k=top_k, + logits=logits, + hidden_states=hidden_states, + checkpoint=request.checkpoint, + lora=request.lora, + ) + + +def _workload_sequences( + *, + workload: str, + seq_len: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, + tree_depth: int, + tree_seed: int, + tree_duplicate_factor: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + if workload in {"austin_198k", "austin_5k_16x100"}: + return _regular_tree_sequences( + prefix_families=30, + prefix_len=5000, + mid_prefixes_per_family=1, + mid_prefix_len=0, + branches_per_prefix=16, + completion_len=100, + ) + if workload == "austin_varied": + return _austin_varied_sequences() + if workload == "regular": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "single": + tokens = torch.arange(seq_len, dtype=torch.long) % 32_000 + 100 + return (tokens,), (0,), "single" + if workload == "long_root": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=1, + mid_prefix_len=0, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "long_mid": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "many_tiny_leaves": + return _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(1, mid_prefixes_per_family), + mid_prefix_len=max(0, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=max(1, completion_len), + ) + if workload == "uneven": + return _uneven_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + if workload == "duplicates": + sequences, shared, shape = _regular_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=max(2, mid_prefixes_per_family), + mid_prefix_len=max(1, mid_prefix_len), + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + factor = max(1, tree_duplicate_factor) + return ( + tuple(sequence for sequence in sequences for _ in range(factor)), + tuple(length for length in shared for _ in range(factor)), + f"{shape}:duplicates={factor}", + ) + if workload == "random": + return _random_tree_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + branches_per_prefix=max(2, min(branches_per_prefix, 4)), + completion_len=completion_len, + tree_depth=max(1, tree_depth), + seed=tree_seed, + ) + raise ValueError( + "workload must be one of: regular, single, long_root, long_mid, " + "many_tiny_leaves, uneven, duplicates, random, austin_198k, austin_varied" + ) + + +def _regular_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + nested = mid_prefixes_per_family > 1 and mid_prefix_len > 0 + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(prefix_families): + family_base = family * 10_000_019 + root = _tokens(family_base, prefix_len) + mid_count = mid_prefixes_per_family if nested else 1 + for mid in range(mid_count): + mid_prefix = ( + _tokens(family_base + 1_000_003 + mid * 100_003, mid_prefix_len) + if nested + else torch.empty(0, dtype=torch.long) + ) + shared = torch.cat((root, mid_prefix)) + for branch in range(branches_per_prefix): + sequences.append( + torch.cat( + ( + shared, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + shared_lengths.append(int(shared.numel())) + shape = ( + f"families={prefix_families}:mid={mid_prefixes_per_family}:" + f"branches={branches_per_prefix}:nested={int(nested)}" + ) + return tuple(sequences), tuple(shared_lengths), shape + + +def _austin_varied_sequences() -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(30): + family_base = family * 10_000_019 + prefix_len = 4500 + ((family * 137) % 1001) + root = _tokens(family_base, prefix_len) + branch_count = 10 + ((family * 7) % 13) + for branch in range(branch_count): + completion_len = 32 + ((family * 19 + branch * 23) % 145) + sequences.append( + torch.cat( + ( + root, + _tokens( + family_base + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + shared_lengths.append(int(root.numel())) + return tuple(sequences), tuple(shared_lengths), "austin_varied" + + +def _uneven_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + for family in range(prefix_families): + family_base = family * 10_000_019 + root_len = max(1, prefix_len // (family + 1)) + root = _tokens(family_base, root_len) + for mid in range(mid_prefixes_per_family): + mid_len = max(1, mid_prefix_len // (mid + 1)) + mid_prefix = _tokens(family_base + 1_000_003 + mid * 100_003, mid_len) + branch_count = max(1, branches_per_prefix - mid) + for branch in range(branch_count): + leaf_len = max(1, completion_len * (branch + 1) // branch_count) + shared = torch.cat((root, mid_prefix)) + sequences.append( + torch.cat( + ( + shared, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + leaf_len, + ), + ) + ) + ) + shared_lengths.append(int(shared.numel())) + return tuple(sequences), tuple(shared_lengths), "uneven" + + +def _random_tree_sequences( + *, + prefix_families: int, + prefix_len: int, + branches_per_prefix: int, + completion_len: int, + tree_depth: int, + seed: int, +) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...], str]: + generator = torch.Generator().manual_seed(seed) + next_offset = 1 + sequences: list[torch.Tensor] = [] + shared_lengths: list[int] = [] + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def segment(length: int) -> torch.Tensor: + nonlocal next_offset + out = _tokens(next_offset, max(1, length)) + next_offset += max(1, length) + 10_000 + return out + + def length_for_depth(depth: int) -> int: + if depth == 0: + return max(1, prefix_len) + choices = (1, 8, 64, max(1, completion_len), max(1, prefix_len // 2)) + return choices[randint(0, len(choices) - 1)] + + def walk(prefix: torch.Tensor, depth: int) -> None: + shared = torch.cat((prefix, segment(length_for_depth(depth)))) + if depth + 1 >= tree_depth: + leaf_count = randint(2, branches_per_prefix) + for _ in range(leaf_count): + leaf = segment(randint(1, max(1, completion_len))) + sequences.append(torch.cat((shared, leaf))) + shared_lengths.append(int(shared.numel())) + return + for _ in range(randint(2, branches_per_prefix)): + walk(shared, depth + 1) + + for _ in range(prefix_families): + walk(torch.empty(0, dtype=torch.long), 0) + return tuple(sequences), tuple(shared_lengths), f"random:depth={tree_depth}" + + +def _packed_request_stats( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + items: Sequence[object], + batch: object, + *, + request_metadata: dict[str, int | str], +) -> dict[str, int | str]: + from art.megatron.shared_prefix_tree import parse_shared_prefix_tree + + trainable_mask = torch.zeros(int(batch.tokens.numel()), dtype=torch.bool) + trainable_tokens = 0 + for item, positions in zip(items, batch.positions_by_sequence, strict=True): + labels = getattr(item, "labels", None) + if labels is None: + continue + mask = labels != -100 + row_mask = mask.reshape(int(mask.shape[0]), -1).any(dim=1) + trainable_tokens += int(mask.sum().item()) + trainable_mask[positions.reshape(-1).cpu()] |= row_mask.cpu() + group_ids = batch.group_ids + parent_ids = batch.parent_ids + return { + **request_metadata, + "request_count": len(requests), + "packed_tokens": int(batch.tokens.numel()), + "logical_tokens": sum( + int(request.input_tokens.numel()) for request in requests + ), + "trainable_tokens": trainable_tokens, + "packed_trainable_tokens": int(trainable_mask.sum().item()), + "packed_group_count": int(group_ids.max().item()) + if int(group_ids.numel()) + else 0, + "nested_prefix_depth": max( + ( + segment.depth + for row in parse_shared_prefix_tree( + group_ids=group_ids, + parent_ids=parent_ids, + ) + for segment in row.segments + ), + default=0, + ), + } + + +def _gather_planner_metadata(prepared: object) -> dict[str, object]: + local = _local_planner_metadata(prepared) + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + if dist.get_rank() != 0: + return {} + ranks = [metrics or {} for metrics in gathered] + gdn_tokens = [int(metrics.get("gdn_tokens", 0)) for metrics in ranks] + attention_tokens = [int(metrics.get("attention_tokens", 0)) for metrics in ranks] + keys = ( + "tree_local_bucket_count", + "tree_chain_bucket_count", + "tree_local_segment_count", + "tree_chain_segment_count", + "tree_local_real_tokens", + "tree_chain_real_tokens", + "tree_state_transfer_count", + "tree_state_transfer_rows", + "tree_max_padding_ratio", + ) + merged: dict[str, object] = { + "planner_rank_gdn_tokens": gdn_tokens, + "planner_rank_attention_tokens": attention_tokens, + "planner_gdn_token_imbalance": max(gdn_tokens, default=0) + - min(gdn_tokens, default=0), + } + for key in keys: + values = [metrics[key] for metrics in ranks if key in metrics] + if not values: + continue + if key.endswith("_ratio"): + merged[f"planner_{key}_max"] = round( + max(float(value) for value in values), 3 + ) + else: + merged[f"planner_{key}_sum"] = int(sum(int(value) for value in values)) + merged[f"planner_{key}_max"] = int(max(int(value) for value in values)) + rank0 = ranks[0] if ranks else {} + if "tree_depth_count" in rank0: + merged["planner_tree_depth_count"] = rank0["tree_depth_count"] + return merged + + +def _local_planner_metadata(prepared: object) -> dict[str, object]: + plan = getattr( + getattr(prepared, "attention_state", None), "gdn_execution_plan", None + ) + if plan is None: + return {} + local_buckets = tuple( + bucket + for depth in getattr(plan, "tree_segment_buckets_by_depth", ()) + for bucket in depth + ) + chain_buckets = tuple( + bucket + for depth in getattr(plan, "tree_chain_buckets_by_depth", ()) + for bucket in depth + ) + all_buckets = (*local_buckets, *chain_buckets) + padding_ratios = [ + bucket.length * bucket.segment_count / max(1, bucket.real_token_count) + for bucket in all_buckets + ] + transfers_by_depth = getattr(plan, "tree_state_transfers_by_depth", ()) + return { + "attention_tokens": int(getattr(plan, "attention_token_count", 0)), + "gdn_tokens": int(getattr(plan, "gdn_token_count", 0)), + "tree_depth_count": len(getattr(plan, "tree_segment_buckets_by_depth", ())), + "tree_local_bucket_count": len(local_buckets), + "tree_chain_bucket_count": len(chain_buckets), + "tree_local_segment_count": sum( + bucket.segment_count for bucket in local_buckets + ), + "tree_chain_segment_count": sum( + bucket.segment_count for bucket in chain_buckets + ), + "tree_local_real_tokens": sum( + bucket.real_token_count for bucket in local_buckets + ), + "tree_chain_real_tokens": sum( + bucket.real_token_count for bucket in chain_buckets + ), + "tree_state_transfer_count": sum( + len(transfers) for transfers in transfers_by_depth + ), + "tree_state_transfer_rows": sum( + len(transfer.family_indices) + for transfers in transfers_by_depth + for transfer in transfers + ), + "tree_max_padding_ratio": max(padding_ratios, default=1.0), + } + + +def _tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _int_values(value: str) -> list[int]: + values = [int(part) for part in value.split(",") if part.strip()] + if not values or any(item < 1 for item in values): + raise ValueError("top_k_values must contain positive integers") + return values + + +def _labels(tokens: torch.Tensor, *, target_count: int) -> torch.Tensor: + labels = torch.stack( + [((tokens * 7 + 3 + index) % 32_000) for index in range(target_count)], + dim=1, + ) + if target_count > 1: + labels[::17, -1] = -100 + return labels + return labels[:, 0] + + +class _CudaMemoryTracker: + def __init__(self, *, device_index: int, sample_interval_s: float) -> None: + self.device_index = device_index + self.sample_interval_s = sample_interval_s + self.process_peak_bytes = 0 + self.allocated_peak_bytes = 0 + self.reserved_peak_bytes = 0 + self._stop = threading.Event() + self._thread: threading.Thread | None = None + + def start(self) -> None: + if not torch.cuda.is_available(): + return + torch.cuda.reset_peak_memory_stats() + self._sample() + if self.sample_interval_s <= 0: + return + self._thread = threading.Thread(target=self._poll, daemon=True) + self._thread.start() + + def stop(self) -> None: + if not torch.cuda.is_available(): + return + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=1.0) + torch.cuda.synchronize() + self._sample() + self.allocated_peak_bytes = max( + self.allocated_peak_bytes, + int(torch.cuda.max_memory_allocated()), + ) + self.reserved_peak_bytes = max( + self.reserved_peak_bytes, + int(torch.cuda.max_memory_reserved()), + ) + + def _poll(self) -> None: + while not self._stop.wait(self.sample_interval_s): + self._sample() + + def _sample(self) -> None: + self.process_peak_bytes = max( + self.process_peak_bytes, + _current_process_gpu_memory_bytes(self.device_index), + ) + self.allocated_peak_bytes = max( + self.allocated_peak_bytes, + int(torch.cuda.memory_allocated()) if torch.cuda.is_available() else 0, + ) + self.reserved_peak_bytes = max( + self.reserved_peak_bytes, + int(torch.cuda.memory_reserved()) if torch.cuda.is_available() else 0, + ) + + +def _current_process_gpu_memory_bytes(device_index: int) -> int: + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + pid = os.getpid() + processes = list(pynvml.nvmlDeviceGetComputeRunningProcesses(handle)) + with suppress(Exception): + processes.extend(pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)) + for process in processes: + if int(process.pid) == pid: + return int(process.usedGpuMemory) + except Exception: + return 0 + return 0 + + +def _distributed_memory_metadata(tracker: _CudaMemoryTracker) -> dict[str, float]: + values = torch.tensor( + [ + tracker.allocated_peak_bytes, + tracker.reserved_peak_bytes, + tracker.process_peak_bytes, + ], + device="cuda", + dtype=torch.float64, + ) + dist.all_reduce(values, op=dist.ReduceOp.MAX) + return { + "peak_memory_allocated_gb": round(float(values[0].item()) / 1024**3, 3), + "peak_memory_reserved_gb": round(float(values[1].item()) / 1024**3, 3), + "peak_memory_process_gb": round(float(values[2].item()) / 1024**3, 3), + "peak_memory_gb": round(float(values[0].item()) / 1024**3, 3), + } + + +def _mean_abs_pct(reference: torch.Tensor, candidate: torch.Tensor) -> float: + reference_fp32 = reference.detach().float() + candidate_fp32 = candidate.detach().float() + return float( + (candidate_fp32 - reference_fp32).abs().mean().item() + / (reference_fp32.abs().mean().item() + 1e-18) + ) + + +def _model_metadata(runtime: object, model_name: str, *, layers: int) -> dict[str, Any]: + from art.megatron.lora import LoRA + + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + config = getattr(model, "config", None) + total_params = sum( + int(param.numel()) for chunk in runtime.model for param in chunk.parameters() + ) + trainable_params = sum( + int(param.numel()) + for chunk in runtime.model + for param in chunk.parameters() + if param.requires_grad + ) + lora_sites = sum( + 1 + for chunk in runtime.model + for module in chunk.modules() + if isinstance(module, LoRA) + ) + local = torch.tensor( + [total_params, trainable_params, lora_sites], + device="cuda", + dtype=torch.float64, + ) + dist.all_reduce(local, op=dist.ReduceOp.MAX) + return { + "model": model_name, + "layers_arg": layers, + "provider_num_layers": getattr(provider, "num_layers", None), + "config_num_layers": getattr(config, "num_layers", None), + "rank_local_param_count": int(local[0].item()), + "rank_local_trainable_param_count": int(local[1].item()), + "rank_local_lora_site_count": int(local[2].item()), + } + + +def _bench( + fn: Callable[[], object], + *, + warmup: int, + repeat: int, + after: Callable[[], object] | None = None, +) -> float: + for _ in range(warmup): + fn() + if after is not None: + after() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + fn() + if after is not None: + after() + stop.record() + torch.cuda.synchronize() + elapsed = torch.tensor(start.elapsed_time(stop) / repeat, device="cuda") + dist.all_reduce(elapsed, op=dist.ReduceOp.MAX) + return round(float(elapsed.item()), 3) + + +def _builtin( + rank: TrainerRank, + prepared: object, + labels: torch.Tensor | None, +) -> torch.Tensor: + from art.megatron.train import _placeholder_attention_mask + + return rank.runtime.model[0]( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + attention_mask=_placeholder_attention_mask(rank.device), + labels=labels, + packed_seq_params=prepared.packed_seq_params, + **rank._handler().get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ), + ) + + +def _full_logits(rank: TrainerRank, prepared: object) -> torch.Tensor: + logits = rank._gather_tensor_parallel_logits(_builtin(rank, prepared, None)) + return _batch_seq_logits(logits, seq_len=int(prepared.tokens.shape[1])) + + +def _target_builtin_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + return _builtin(rank, prepared, _packed_labels(items, prepared)).float().sum() + + +def _target_builtin_masked_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + labels = _packed_labels(items, prepared) + per_token_loss = _builtin(rank, prepared, labels).float().reshape(-1) + valid = labels.reshape(-1) != -100 + return per_token_loss[valid].sum() + per_token_loss.sum() * 0.0 + + +def _target_hidden_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + outputs = rank._project_head(items, prepared, hidden) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _target_trainer_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + outputs = rank._forward_packed(items, prepared) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _target_requests_loss( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> torch.Tensor: + outputs = rank.dp_rank_forward(requests) + losses = [ + -output.target_logprobs.sum() + for output in outputs + if output.target_logprobs is not None + ] + if not losses: + raise RuntimeError("target logprobs were not produced") + return torch.stack(losses).sum() + + +def _trainer_topk_loss( + rank: TrainerRank, + items: object, + prepared: object, +) -> torch.Tensor: + outputs = rank._forward_packed(items, prepared) + losses = [ + -output.top_k.logprobs.sum() for output in outputs if output.top_k is not None + ] + if not losses: + raise RuntimeError("top_k logprobs were not produced") + return torch.stack(losses).sum() + + +def _topk_requests_loss( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> torch.Tensor: + outputs = rank.dp_rank_forward(requests) + losses = [ + -output.top_k.logprobs.sum() for output in outputs if output.top_k is not None + ] + if not losses: + raise RuntimeError("top_k logprobs were not produced") + return torch.stack(losses).sum() + + +def _fixed_micro_batch_training_step( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + offload_manager: object | None, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + def body() -> dict[str, float]: + return _fixed_micro_batch_training_step_body( + rank, + requests, + params=params, + loss_kind=loss_kind, + stats_sink=stats_sink, + ) + + if offload_manager is None: + return body() + with offload_manager.job(): # type: ignore[attr-defined] + return body() + + +def _fixed_micro_batch_training_step_body( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + rank.zero_grad() + dp_rank, dp_size = rank._dp_rank_and_size() + stats: list[dict[str, int | bool]] = [] + for start in range(0, len(requests), dp_size): + stop = min(start + dp_size, len(requests)) + indices = tuple(range(start + dp_rank, stop, dp_size)) + local_requests = [requests[index] for index in indices] + outputs = rank.dp_rank_forward(local_requests) + loss = _micro_batch_loss(rank, outputs, loss_kind=loss_kind) + if loss.requires_grad: + loss.backward() + stats.append( + { + "global_count": stop - start, + "local_count": len(local_requests), + "packed_tokens": _logical_input_tokens(local_requests), + "logical_tokens": _logical_input_tokens(local_requests), + "rejected_candidates": 0, + "cold_start": False, + } + ) + stats_sink[:] = stats + return rank.optim_step(params=params, scale_grads=1.0) + + +def _adaptive_micro_batch_training_step( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + offload_manager: object | None, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + def body() -> dict[str, float]: + return _adaptive_micro_batch_training_step_body( + rank, + requests, + params=params, + loss_kind=loss_kind, + stats_sink=stats_sink, + ) + + if offload_manager is None: + return body() + with offload_manager.job(): # type: ignore[attr-defined] + return body() + + +def _adaptive_micro_batch_training_step_body( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + loss_kind: str, + stats_sink: list[dict[str, int | bool]], +) -> dict[str, float]: + rank.zero_grad() + stats: list[dict[str, int | bool]] = [] + step_start = time.perf_counter() + for micro_batch in rank.forward_micro_batches(requests): + loss = _micro_batch_loss(rank, micro_batch.outputs, loss_kind=loss_kind) + if loss.requires_grad: + loss.backward() + row = { + "global_count": int(micro_batch.stats.global_count), + "local_count": int(micro_batch.stats.local_count), + "packed_tokens": int(micro_batch.stats.packed_tokens), + "logical_tokens": int(micro_batch.stats.logical_tokens), + "estimated_required_bytes": int(micro_batch.stats.estimated_required_bytes), + "available_bytes": int(micro_batch.stats.available_bytes), + "rejected_candidates": int(micro_batch.stats.rejected_candidates), + "cold_start": bool(micro_batch.stats.cold_start), + } + stats.append(row) + _emit_adaptive_progress( + "target_trainer_adaptive_train_step_window", + { + **row, + "window_index": len(stats) - 1, + "elapsed_ms": (time.perf_counter() - step_start) * 1000.0, + }, + ) + stats_sink[:] = stats + return rank.optim_step(params=params, scale_grads=1.0) + + +def _profiled_adaptive_micro_batch_training_step( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + offload_manager: object | None, + loss_kind: str, + stats_sink: list[dict[str, int | bool | float]], +) -> dict[str, float]: + def body() -> dict[str, float]: + return _profiled_adaptive_micro_batch_training_step_body( + rank, + requests, + params=params, + loss_kind=loss_kind, + stats_sink=stats_sink, + ) + + if offload_manager is None: + return body() + with offload_manager.job(): # type: ignore[attr-defined] + return body() + + +def _profiled_adaptive_micro_batch_training_step_body( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + loss_kind: str, + stats_sink: list[dict[str, int | bool | float]], +) -> dict[str, float]: + rank.zero_grad() + items = list(requests) + rank._validate_replicated_top_level_count(len(items)) + start = 0 + stats: list[dict[str, int | bool | float]] = [] + step_start = time.perf_counter() + while start < len(items): + with _profile_adaptive_selection(rank) as select_profile: + candidate, select_ms = _timed_cuda( + rank, lambda: rank._select_next_micro_batch(items, start) + ) + select_profile["select_plan_residual_ms"] = max( + 0.0, + select_profile["select_plan_ms"] + - select_profile["select_forward_item_ms"] + - select_profile["select_pack_ms"] + - select_profile["select_output_estimate_ms"] + - select_profile["select_signature_ms"], + ) + select_profile["select_memory_check_residual_ms"] = max( + 0.0, + select_profile["select_memory_check_ms"] + - select_profile["select_memory_estimate_ms"] + - select_profile["select_available_memory_ms"], + ) + select_profile["select_residual_ms"] = max( + 0.0, + select_ms + - select_profile["select_estimate_ms"] + - select_profile["select_plan_ms"] + - select_profile["select_memory_check_ms"] + - select_profile["select_profile_check_ms"], + ) + flat_outputs, execute_ms = _timed_cuda( + rank, + lambda: rank._run_flat_plan_with_memory_tracking( + candidate.plan, + context="target_trainer_adaptive_profile_train_step", + ), + ) + + def unflatten_outputs() -> list[object]: + flat_iter = iter(flat_outputs) + return [_unflatten(item, flat_iter) for item in candidate.inputs] + + outputs, unflatten_ms = _timed_cuda( + rank, + unflatten_outputs, + ) + loss, loss_ms = _timed_cuda( + rank, lambda: _micro_batch_loss(rank, outputs, loss_kind=loss_kind) + ) + if loss.requires_grad: + _, backward_ms = _timed_cuda(rank, loss.backward) + else: + backward_ms = 0.0 + row = { + "global_count": int(candidate.stats_global_count), + "local_count": int(len(candidate.inputs)), + "packed_tokens": int(candidate.plan.packed_tokens), + "logical_tokens": int(candidate.plan.logical_tokens), + "estimated_required_bytes": int(candidate.check.estimated_required_bytes), + "available_bytes": int(candidate.check.available_bytes), + "rejected_candidates": int(candidate.rejected_candidates), + "cold_start": bool(candidate.cold_start), + "select_ms": select_ms, + "execute_ms": execute_ms, + "unflatten_ms": unflatten_ms, + "loss_ms": loss_ms, + "backward_ms": backward_ms, + "optim_ms": 0.0, + **select_profile, + } + stats.append(row) + stop = start + candidate.stats_global_count + if stop < len(items): + rank._last_global_micro_batch_size = max( + rank._last_global_micro_batch_size or 0, + candidate.stats_global_count, + ) + _emit_adaptive_progress( + "target_trainer_adaptive_profile_train_step_window", + { + **row, + "window_index": len(stats) - 1, + "global_start": int(start), + "global_stop": int(stop), + "remembered_window": int(rank._last_global_micro_batch_size or 0), + "elapsed_ms": (time.perf_counter() - step_start) * 1000.0, + }, + ) + start = stop + metrics, optim_ms = _timed_cuda( + rank, lambda: rank.optim_step(params=params, scale_grads=1.0) + ) + if stats: + stats[-1]["optim_ms"] = optim_ms + stats_sink[:] = stats + return metrics + + +def _emit_adaptive_progress(event: str, row: dict[str, object]) -> None: + if dist.is_available() and dist.is_initialized() and dist.get_rank() != 0: + return + path = os.environ.get("ART_TRAINER_RANK_PROGRESS_JSONL") + if not path: + return + payload = {"event": event, **row} + line = json.dumps(payload, sort_keys=True) + print(line, flush=True) + progress_path = Path(path) + progress_path.parent.mkdir(parents=True, exist_ok=True) + with progress_path.open("a") as handle: + handle.write(line + "\n") + + +@contextmanager +def _profile_adaptive_selection(rank: TrainerRank) -> Any: + stats = { + "select_plan_ms": 0.0, + "select_plan_calls": 0, + "select_forward_item_ms": 0.0, + "select_forward_item_calls": 0, + "select_pack_ms": 0.0, + "select_pack_calls": 0, + "select_estimate_ms": 0.0, + "select_estimate_calls": 0, + "select_plan_lookup_calls": 0, + "select_plan_cache_hit_calls": 0, + "select_plan_cache_miss_calls": 0, + "select_estimate_lookup_calls": 0, + "select_estimate_cache_hit_calls": 0, + "select_estimate_cache_miss_calls": 0, + "select_output_estimate_ms": 0.0, + "select_output_estimate_calls": 0, + "select_signature_ms": 0.0, + "select_signature_calls": 0, + "select_memory_check_ms": 0.0, + "select_memory_check_calls": 0, + "select_memory_estimate_ms": 0.0, + "select_memory_estimate_calls": 0, + "select_available_memory_ms": 0.0, + "select_available_memory_calls": 0, + "select_profile_check_ms": 0.0, + "select_profile_check_calls": 0, + } + + def timed( + key: str, + calls_key: str, + fn: Callable[..., object], + *args: object, + **kwargs: object, + ) -> object: + start = time.perf_counter() + try: + return fn(*args, **kwargs) + finally: + stats[key] += (time.perf_counter() - start) * 1000.0 + stats[calls_key] += 1 + + original_plan = rank._plan_flat_forward + original_cached_plan = rank._cached_adaptive_plan + original_estimate = rank._estimate_flat_forward + original_cached_estimate = rank._cached_adaptive_estimate + original_forward_item = rank._forward_item + original_pack = trainer_rank_module.pack_shared_prefixes + original_output_estimate = rank._estimate_group_request_output_bytes + original_signature = rank._memory_signature_from_requests + original_memory_check = rank._memory_check + original_memory_estimate = rank._estimate_required_memory_bytes_from_values + original_available = rank._available_memory_bytes + original_profile_check = rank._all_ranks_have_memory_profile + + def plan_wrapper(requests: object) -> object: + return timed("select_plan_ms", "select_plan_calls", original_plan, requests) + + def cached_plan_wrapper(*args: object, **kwargs: object) -> object: + stats["select_plan_lookup_calls"] += 1 + before = stats["select_plan_calls"] + result = original_cached_plan(*args, **kwargs) + if stats["select_plan_calls"] == before: + stats["select_plan_cache_hit_calls"] += 1 + else: + stats["select_plan_cache_miss_calls"] += 1 + return result + + def estimate_wrapper(requests: object) -> object: + return timed( + "select_estimate_ms", + "select_estimate_calls", + original_estimate, + requests, + ) + + def cached_estimate_wrapper(*args: object, **kwargs: object) -> object: + stats["select_estimate_lookup_calls"] += 1 + before = stats["select_estimate_calls"] + result = original_cached_estimate(*args, **kwargs) + if stats["select_estimate_calls"] == before: + stats["select_estimate_cache_hit_calls"] += 1 + else: + stats["select_estimate_cache_miss_calls"] += 1 + return result + + def forward_item_wrapper(request: object) -> object: + return timed( + "select_forward_item_ms", + "select_forward_item_calls", + original_forward_item, + request, + ) + + def pack_wrapper(*args: object, **kwargs: object) -> object: + start = time.perf_counter() + try: + return original_pack(*args, **kwargs) + finally: + stats["select_pack_ms"] += (time.perf_counter() - start) * 1000.0 + stats["select_pack_calls"] += 1 + + def output_estimate_wrapper(items: object) -> object: + return timed( + "select_output_estimate_ms", + "select_output_estimate_calls", + original_output_estimate, + items, + ) + + def signature_wrapper(*args: object, **kwargs: object) -> object: + return timed( + "select_signature_ms", + "select_signature_calls", + original_signature, + *args, + **kwargs, + ) + + def memory_check_wrapper(plan: object) -> object: + return timed( + "select_memory_check_ms", + "select_memory_check_calls", + original_memory_check, + plan, + ) + + def memory_estimate_wrapper(*args: object, **kwargs: object) -> object: + return timed( + "select_memory_estimate_ms", + "select_memory_estimate_calls", + original_memory_estimate, + *args, + **kwargs, + ) + + def available_wrapper() -> object: + return timed( + "select_available_memory_ms", + "select_available_memory_calls", + original_available, + ) + + def profile_check_wrapper(*args: object, **kwargs: object) -> object: + return timed( + "select_profile_check_ms", + "select_profile_check_calls", + original_profile_check, + *args, + **kwargs, + ) + + rank._plan_flat_forward = plan_wrapper # type: ignore[method-assign] + rank._cached_adaptive_plan = cached_plan_wrapper # type: ignore[method-assign] + rank._estimate_flat_forward = estimate_wrapper # type: ignore[method-assign] + rank._cached_adaptive_estimate = cached_estimate_wrapper # type: ignore[method-assign] + rank._forward_item = forward_item_wrapper # type: ignore[method-assign] + trainer_rank_module.pack_shared_prefixes = pack_wrapper # type: ignore[assignment] + rank._estimate_group_request_output_bytes = output_estimate_wrapper # type: ignore[method-assign] + rank._memory_signature_from_requests = signature_wrapper # type: ignore[method-assign] + rank._memory_check = memory_check_wrapper # type: ignore[method-assign] + rank._estimate_required_memory_bytes_from_values = memory_estimate_wrapper # type: ignore[method-assign] + rank._available_memory_bytes = available_wrapper # type: ignore[method-assign] + rank._all_ranks_have_memory_profile = profile_check_wrapper # type: ignore[method-assign] + try: + yield stats + finally: + rank._plan_flat_forward = original_plan # type: ignore[method-assign] + rank._cached_adaptive_plan = original_cached_plan # type: ignore[method-assign] + rank._estimate_flat_forward = original_estimate # type: ignore[method-assign] + rank._cached_adaptive_estimate = original_cached_estimate # type: ignore[method-assign] + rank._forward_item = original_forward_item # type: ignore[method-assign] + trainer_rank_module.pack_shared_prefixes = original_pack # type: ignore[assignment] + rank._estimate_group_request_output_bytes = original_output_estimate # type: ignore[method-assign] + rank._memory_signature_from_requests = original_signature # type: ignore[method-assign] + rank._memory_check = original_memory_check # type: ignore[method-assign] + rank._estimate_required_memory_bytes_from_values = original_memory_estimate # type: ignore[method-assign] + rank._available_memory_bytes = original_available # type: ignore[method-assign] + rank._all_ranks_have_memory_profile = original_profile_check # type: ignore[method-assign] + + +def _timed_cuda( + rank: TrainerRank, + fn: Callable[[], object], +) -> tuple[object, float]: + _sync_cuda(rank) + start = time.perf_counter() + result = fn() + _sync_cuda(rank) + return result, (time.perf_counter() - start) * 1000.0 + + +def _sync_cuda(rank: TrainerRank) -> None: + if torch.cuda.is_available() and rank.device.type == "cuda": + torch.cuda.synchronize(rank.device) + + +def _micro_batch_loss( + rank: TrainerRank, + outputs: object, + *, + loss_kind: str, +) -> torch.Tensor: + losses: list[torch.Tensor] = [] + for output in _iter_outputs(outputs): + if loss_kind == "target": + target_logprobs = getattr(output, "target_logprobs", None) + if target_logprobs is not None: + losses.append(-target_logprobs.sum()) + elif loss_kind == "topk": + top_k = getattr(output, "top_k", None) + if top_k is not None: + losses.append(-top_k.logprobs.sum()) + else: + raise ValueError(f"unknown loss_kind: {loss_kind}") + if not losses: + return torch.tensor(0.0, device=rank.device) + return torch.stack(losses).sum() + + +def _iter_outputs(value: object) -> Sequence[object]: + if hasattr(value, "target_logprobs") and hasattr(value, "top_k"): + return (value,) + if isinstance(value, Sequence): + outputs: list[object] = [] + for item in value: + outputs.extend(_iter_outputs(item)) + return outputs + raise TypeError(f"unexpected TrainerRank output value: {type(value)!r}") + + +def _logical_input_tokens( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> int: + return sum( + int(request.input_tokens.numel()) + for request in requests + if request.input_tokens is not None + ) + + +def _record_micro_batch_stats( + metadata: dict[str, object], + name: str, + stats: Sequence[dict[str, int | bool | float]], +) -> None: + if not stats: + metadata[f"{name}_micro_window_count"] = 0 + return + global_counts = [int(stat["global_count"]) for stat in stats] + local_counts = [int(stat["local_count"]) for stat in stats] + packed_tokens = [int(stat["packed_tokens"]) for stat in stats] + rejected = [int(stat["rejected_candidates"]) for stat in stats] + estimated_required = [ + int(stat.get("estimated_required_bytes", 0)) for stat in stats + ] + available = [int(stat.get("available_bytes", 0)) for stat in stats] + metadata[f"{name}_micro_window_count"] = len(stats) + metadata[f"{name}_micro_global_count_first"] = global_counts[0] + metadata[f"{name}_micro_global_count_last"] = global_counts[-1] + metadata[f"{name}_micro_global_count_min"] = min(global_counts) + metadata[f"{name}_micro_global_count_max"] = max(global_counts) + metadata[f"{name}_micro_local_count_min"] = min(local_counts) + metadata[f"{name}_micro_local_count_max"] = max(local_counts) + metadata[f"{name}_micro_packed_tokens_min"] = min(packed_tokens) + metadata[f"{name}_micro_packed_tokens_max"] = max(packed_tokens) + metadata[f"{name}_micro_rejected_candidates_total"] = sum(rejected) + metadata[f"{name}_micro_estimated_required_gb_max"] = round( + max(estimated_required) / 1024**3, 3 + ) + metadata[f"{name}_micro_available_gb_min"] = round(min(available) / 1024**3, 3) + metadata[f"{name}_micro_cold_start_count"] = sum( + int(bool(stat["cold_start"])) for stat in stats + ) + metadata[f"{name}_micro_global_counts_head"] = ",".join( + str(count) for count in global_counts[:8] + ) + + +def _record_profile_stats( + metadata: dict[str, object], + name: str, + stats: Sequence[dict[str, int | bool | float]], +) -> None: + fields = sorted( + { + key + for stat in stats + for key, value in stat.items() + if key.endswith("_ms") and isinstance(value, int | float) + } + ) + for field in fields: + total = sum(float(stat.get(field, 0.0)) for stat in stats) + metadata[f"{name}_{field}_sum"] = round(total, 3) + metadata[f"{name}_{field}_max"] = round( + max((float(stat.get(field, 0.0)) for stat in stats), default=0.0), + 3, + ) + call_fields = sorted( + { + key + for stat in stats + for key, value in stat.items() + if key.endswith("_calls") and isinstance(value, int | float) + } + ) + for field in call_fields: + metadata[f"{name}_{field}_sum"] = int( + sum(int(stat.get(field, 0)) for stat in stats) + ) + metadata[f"{name}_{field}_max"] = int( + max((int(stat.get(field, 0)) for stat in stats), default=0) + ) + + +def _training_step( + rank: TrainerRank, + loss_fn: Callable[[], torch.Tensor], + *, + params: AdamParams, + offload_manager: object | None, +) -> dict[str, float]: + if offload_manager is None: + return _training_step_body(rank, loss_fn, params=params) + with offload_manager.job(): # type: ignore[attr-defined] + return _training_step_body(rank, loss_fn, params=params) + + +def _training_step_body( + rank: TrainerRank, + loss_fn: Callable[[], torch.Tensor], + *, + params: AdamParams, +) -> dict[str, float]: + rank.zero_grad() + loss = loss_fn() + loss.backward() + return rank.optim_step(params=params, scale_grads=1.0) + + +def _make_offload_manager(runtime: object) -> object: + from art.megatron.training.streaming_weight_offload import ( + StreamingWeightOffloadConfig, + ) + from art.megatron.training.weight_offload import WeightOffloadManager + + manager = WeightOffloadManager.from_config( + model=getattr(runtime, "model"), + rank=dist.get_rank(), + compile_enabled=bool(getattr(runtime, "transformer_layers_compiled", False)), + offload_between_jobs=True, + streaming_config=StreamingWeightOffloadConfig(enabled=False), + ) + manager.install() + manager.after_job() + return manager + + +def _target_correctness_metrics( + rank: TrainerRank, + items: object, + prepared: object, +) -> dict[str, float]: + for chunk in rank.runtime.model: + chunk.eval() + with torch.no_grad(): + labels = _packed_labels(items, prepared) + native_logprobs = _native_target_logprobs(rank, items, prepared, labels) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + head_outputs = rank._project_head(items, prepared, hidden) + abs_diff_sum = torch.tensor(0.0, device=rank.device) + reference_abs_sum = torch.tensor(0.0, device=rank.device) + value_count = torch.tensor(0.0, device=rank.device) + max_abs_diff = torch.tensor(0.0, device=rank.device) + for native, candidate in zip( + native_logprobs, + (output.target_logprobs for output in head_outputs), + strict=True, + ): + if candidate is None: + continue + diff = (candidate.float() - native.float()).abs() + if int(diff.numel()) == 0: + continue + abs_diff_sum += diff.sum() + reference_abs_sum += native.float().abs().sum() + value_count += float(diff.numel()) + max_abs_diff = torch.maximum(max_abs_diff, diff.max()) + sums = torch.stack((abs_diff_sum, reference_abs_sum, value_count)) + dist.all_reduce(sums, op=dist.ReduceOp.SUM) + dist.all_reduce(max_abs_diff, op=dist.ReduceOp.MAX) + mean_abs_pct = float((sums[0] / torch.clamp(sums[1], min=1e-18)).item()) + max_abs = float(max_abs_diff.item()) + return { + "target_hidden_vs_native_mean_abs_pct": mean_abs_pct, + "target_hidden_vs_native_max_abs_diff": max_abs, + "target_hidden_vs_native_value_count": float(sums[2].item()), + } + + +def _native_target_logprobs( + rank: TrainerRank, + items: object, + prepared: object, + labels: torch.Tensor, +) -> list[torch.Tensor]: + from art.megatron.train import _placeholder_attention_mask + + per_token_loss = rank.runtime.model[0]( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + attention_mask=_placeholder_attention_mask(rank.device), + labels=labels, + packed_seq_params=prepared.packed_seq_params, + **rank._handler().get_forward_kwargs( + rank.runtime.model[0], + attention_bias=prepared.attention_state, + ), + ) + flat_logprobs = -per_token_loss.reshape(-1) + outputs: list[torch.Tensor] = [] + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + raise RuntimeError("native target oracle requires labels") + item_labels = item.labels.to(device=rank.device).index_select( + 0, + source_positions.to(device=rank.device), + ) + outputs.append( + flat_logprobs.index_select(0, positions.to(device=rank.device)).masked_fill( + item_labels == -100, + 0.0, + ) + ) + return outputs + + +def _adapter_sanity_metrics( + rank: TrainerRank, + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + params: AdamParams, + adapter_slots: int, +) -> dict[str, float]: + target_request = next( + (request for request in requests if request.target_tokens is not None), + None, + ) + if target_request is None: + return {"adapter_sanity_skipped": 1.0} + base_request = ForwardInput( + input_tokens=target_request.input_tokens, + target_tokens=target_request.target_tokens, + checkpoint=None, + ) + slot_request = ForwardInput( + input_tokens=target_request.input_tokens, + target_tokens=target_request.target_tokens, + checkpoint="S0", + ) + for chunk in rank.runtime.model: + chunk.eval() + with torch.no_grad(): + base_output = rank.dp_rank_forward([base_request])[0] + slot_output = rank.dp_rank_forward([slot_request])[0] + if base_output.target_logprobs is None or slot_output.target_logprobs is None: + raise RuntimeError("adapter sanity target outputs were not produced") + output_diff = _mean_abs_pct( + base_output.target_logprobs, + slot_output.target_logprobs, + ) + output_max = float( + (slot_output.target_logprobs.float() - base_output.target_logprobs.float()) + .abs() + .max() + .item() + ) + + slot_params = list(rank._checkpoint_slot_params_by_name["S0"]) + other_params = ( + list(rank._checkpoint_slot_params_by_name["S1"]) if adapter_slots > 1 else [] + ) + before = [param.detach().clone() for param in slot_params] + other_before = [param.detach().clone() for param in other_params] + for chunk in rank.runtime.model: + chunk.train() + rank.zero_grad() + loss = _target_requests_loss(rank, [slot_request]) + loss.backward() + grad_sq = torch.tensor(0.0, device=rank.device) + for param in slot_params: + if param.grad is not None: + grad_sq = grad_sq + param.grad.detach().float().square().sum() + grad_norm = torch.sqrt(grad_sq) + rank.optim_step(params=params, checkpoints=["S0"]) + slot_delta = sum( + float((param.detach().float() - old.float()).abs().sum().item()) + for param, old in zip(slot_params, before, strict=True) + ) + other_delta = sum( + float((param.detach().float() - old.float()).abs().sum().item()) + for param, old in zip(other_params, other_before, strict=True) + ) + values = torch.tensor( + [output_diff, output_max, float(grad_norm.item()), slot_delta, other_delta], + device=rank.device, + ) + dist.all_reduce(values, op=dist.ReduceOp.MAX) + return { + "adapter_sanity_output_mean_abs_pct": float(values[0].item()), + "adapter_sanity_output_max_abs_diff": float(values[1].item()), + "adapter_sanity_grad_norm": float(values[2].item()), + "adapter_sanity_stepped_slot_delta": float(values[3].item()), + "adapter_sanity_unselected_slot_delta": float(values[4].item()), + } + + +def _runtime_output_shape(runtime: object) -> tuple[int, int, int]: + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + hidden_size = int( + getattr(provider, "hidden_size", None) + or getattr(getattr(model, "config", None), "hidden_size", 0) + ) + vocab_size = int( + getattr(getattr(model, "config", None), "padded_vocab_size", None) + or getattr(model, "vocab_size", 0) + ) + dtype_size = next(getattr(runtime, "model")[0].parameters()).element_size() + if hidden_size <= 0 or vocab_size <= 0: + raise RuntimeError( + f"could not infer output shape: hidden_size={hidden_size}, " + f"vocab_size={vocab_size}" + ) + return hidden_size, vocab_size, dtype_size + + +def _request_output_gb( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> float: + return ( + sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ) + / 1024**3 + ) + + +def _request_output_bytes( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> int: + seq_len = int(request.input_tokens.numel()) + bytes_total = 0 + if request.target_tokens is not None: + bytes_total += int(request.target_tokens.numel()) * 4 + if request.top_k is not None: + bytes_total += seq_len * int(request.top_k) * (4 + 8) + if request.logits: + bytes_total += seq_len * vocab_size * dtype_size + if request.hidden_states: + bytes_total += seq_len * hidden_size * dtype_size + return bytes_total + + +def _logits_requests( + requests: Sequence[ForwardInput[torch.Tensor, None, None, None]], +) -> list[ForwardInput[None, None, torch.Tensor, None]]: + return [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ] + + +def _rate_units( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + stats: dict[str, int | str], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> dict[str, int]: + return { + "packed_tokens": int(stats.get("packed_tokens", 0)), + "logical_tokens": int(stats.get("logical_tokens", 0)), + "target_values": _target_value_count(requests), + "output_bytes": sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ), + } + + +def _target_value_count( + requests: Sequence[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> int: + count = 0 + for request in requests: + if request.target_tokens is not None: + count += int((request.target_tokens != -100).sum().item()) + return count + + +def _rate_metrics( + results: dict[str, float], + units_by_name: dict[str, dict[str, int]], +) -> dict[str, float]: + suffixes = { + "packed_tokens": "packed_tok_s", + "logical_tokens": "logical_tok_s", + "target_values": "target_logprob_s", + "output_bytes": "output_gb_s", + } + metrics: dict[str, float] = {} + for key, ms in results.items(): + if ms <= 0: + continue + name = key.removesuffix("_ms") + units = units_by_name.get(name, {}) + for unit_key, suffix in suffixes.items(): + value = int(units.get(unit_key, 0)) + if value <= 0: + continue + scale = 1024**3 if unit_key == "output_bytes" else 1 + metrics[f"{name}_{suffix}"] = round(value * 1000.0 / ms / scale, 3) + return metrics + + +def _packed_labels(items: object, prepared: object) -> torch.Tensor: + labels = torch.full_like(prepared.tokens, -100) + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + if item.labels is None: + continue + labels.reshape(-1)[positions.to(device=labels.device)] = item.labels.to( + device=labels.device + ).index_select(0, source_positions.to(device=labels.device)) + return labels + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_review_perf.py b/dev/trainer_rank_review_perf.py new file mode 100644 index 000000000..4da5a8ad0 --- /dev/null +++ b/dev/trainer_rank_review_perf.py @@ -0,0 +1,868 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from dataclasses import dataclass +import json +from pathlib import Path +import time + +import numpy as np +import torch +from torch.nn.attention.flex_attention import AuxRequest, BlockMask +from torch.nn.attention.flex_attention import create_block_mask as torch_block_mask +import typer + +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) +from art.megatron.context_parallel.builder import build_shared_prefix_attention_spec +from art.megatron.context_parallel.executor import _build_stage_execution_spec +from art.megatron.context_parallel.runtime import ( + _RUNTIME_PLAN_CACHE, + get_or_build_runtime_plan, +) +from art.megatron.context_parallel.types import ( + ContextParallelConfig, + FlexMaskSpec, + ParallelTopology, + StageExecutionSpec, + StagePlan, +) +from art.megatron.flex_attn.compiled import ( + normalize_sparse_block_size, + sparse_compiled_flex_attention, +) +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes + + +def main( + workload: str = "austin_198k", + max_depth: int = 1, + cp_size: int = 4, + block_size: int = 128, + prefix_families: int = 4, + prefix_len: int = 1024, + mid_prefixes_per_family: int = 1, + mid_prefix_len: int = 0, + branches_per_prefix: int = 8, + completion_len: int = 128, + warmup: int = 3, + repeat: int = 10, + shape_variants: int = 4, + validate_torch: bool = True, + validate_torch_token_cap: int = 32768, + run_flex: bool = True, + flex_token_cap: int = 8192, + flex_heads: int = 2, + flex_head_dim: int = 128, + flex_mask_variants: str = "current,causal_abs_only", + max_block_mask_build_ms: float | None = None, + max_cp_planning_cold_ms: float | None = None, + output_jsonl: Path = Path(".local/trainer_rank_review/block_mask_flex.jsonl"), +) -> None: + if warmup < 0 or repeat < 1: + raise ValueError("warmup must be >= 0 and repeat must be >= 1") + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + + pack = _pack_workload( + workload=workload, + max_depth=max_depth, + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + config = ContextParallelConfig(block_size=block_size) + topology = ParallelTopology(cp=cp_size) + base = { + "workload": workload, + "max_depth": max_depth, + "cp_size": cp_size, + "block_size": block_size, + "packed_tokens": int(pack.tokens.numel()), + "logical_tokens": _logical_tokens(pack), + "warmup": warmup, + "repeat": repeat, + "validate_torch": validate_torch, + "validate_torch_token_cap": validate_torch_token_cap, + } + + plan, plan_ms = _bench_cpu( + lambda: _build_cp_plan(pack, spec, topology, config), + warmup=warmup, + repeat=repeat, + before_each=_RUNTIME_PLAN_CACHE.clear, + ) + _write( + output_jsonl, + { + **base, + "case": "cp_planning_cold", + "ms": plan_ms, + **_plan_stats(plan), + }, + ) + + cached_plan, cached_plan_ms = _bench_cpu( + lambda: _build_cp_plan(pack, spec, topology, config), + warmup=warmup, + repeat=repeat, + ) + _write( + output_jsonl, + { + **base, + "case": "cp_planning_cached", + "ms": cached_plan_ms, + **_plan_stats(cached_plan), + }, + ) + + stage_masks, mask_ms = _bench_cpu( + lambda: _build_stage_masks(pack, plan, config), + warmup=warmup, + repeat=repeat, + ) + masks = tuple(mask for mask, _ in stage_masks) + torch_validation_skipped = _torch_validation_skip_reason( + validate_torch=validate_torch, + packed_tokens=int(pack.tokens.numel()), + token_cap=validate_torch_token_cap, + ) + if torch_validation_skipped is None: + for mask, slices in stage_masks: + _assert_matches_torch_block_mask(mask, slices=slices) + _write( + output_jsonl, + { + **base, + "case": "block_mask_build", + "ms": mask_ms, + "torch_validation_skipped": torch_validation_skipped, + **_mask_stats(masks), + }, + ) + _check_threshold("block_mask_build", mask_ms, max_block_mask_build_ms) + _check_threshold("cp_planning_cold", plan_ms, max_cp_planning_cold_ms) + + if run_flex: + for record in _flex_records( + pack, + plan, + config, + warmup=warmup, + repeat=repeat, + token_cap=flex_token_cap, + heads=flex_heads, + head_dim=flex_head_dim, + variants=_csv_values(flex_mask_variants), + ): + _write(output_jsonl, {**base, **record}) + + for variant in range(shape_variants): + variant_pack = _pack_workload( + workload="regular", + max_depth=max_depth, + prefix_families=prefix_families, + prefix_len=prefix_len + variant * 17, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len + variant * 3, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len + variant * 11, + ) + variant_spec = build_shared_prefix_attention_spec( + group_ids=variant_pack.group_ids, + parent_ids=variant_pack.parent_ids, + ) + variant_plan, variant_plan_ms = _bench_cpu( + lambda pack=variant_pack, spec=variant_spec: _build_cp_plan( + pack, + spec, + topology, + config, + ), + warmup=0, + repeat=1, + before_each=_RUNTIME_PLAN_CACHE.clear, + ) + variant_stage_masks, variant_mask_ms = _bench_cpu( + lambda pack=variant_pack, plan=variant_plan: _build_stage_masks( + pack, + plan, + config, + ), + warmup=0, + repeat=1, + ) + variant_masks = tuple(mask for mask, _ in variant_stage_masks) + variant_torch_validation_skipped = _torch_validation_skip_reason( + validate_torch=validate_torch, + packed_tokens=int(variant_pack.tokens.numel()), + token_cap=validate_torch_token_cap, + ) + if variant_torch_validation_skipped is None: + for mask, slices in variant_stage_masks: + _assert_matches_torch_block_mask(mask, slices=slices) + _write( + output_jsonl, + { + **base, + "case": "shape_variant", + "variant": variant, + "variant_packed_tokens": int(variant_pack.tokens.numel()), + "variant_logical_tokens": _logical_tokens(variant_pack), + "cp_planning_ms": variant_plan_ms, + "block_mask_build_ms": variant_mask_ms, + "torch_validation_skipped": variant_torch_validation_skipped, + **_plan_stats(variant_plan), + **_mask_stats(variant_masks), + }, + ) + + print(f"wrote review perf records to {output_jsonl}", flush=True) + + +def _pack_workload( + *, + workload: str, + max_depth: int, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> SharedPrefixPack: + sequences = ( + _austin_sequences() + if workload == "austin_198k" + else _austin_varied_sequences() + if workload == "austin_varied" + else _regular_sequences( + prefix_families=prefix_families, + prefix_len=prefix_len, + mid_prefixes_per_family=mid_prefixes_per_family, + mid_prefix_len=mid_prefix_len, + branches_per_prefix=branches_per_prefix, + completion_len=completion_len, + ) + ) + return pack_shared_prefixes(sequences, max_depth=max_depth) + + +def _austin_sequences() -> tuple[torch.Tensor, ...]: + return tuple( + torch.cat( + ( + _tokens(family * 10_000_019, 5000), + _tokens(family * 10_000_019 + branch * 1009 + 17, 100), + ) + ) + for family in range(30) + for branch in range(16) + ) + + +def _austin_varied_sequences() -> tuple[torch.Tensor, ...]: + sequences: list[torch.Tensor] = [] + for family in range(30): + family_base = family * 10_000_019 + prefix_len = 4500 + ((family * 137) % 1001) + root = _tokens(family_base, prefix_len) + branch_count = 10 + ((family * 7) % 13) + for branch in range(branch_count): + completion_len = 32 + ((family * 19 + branch * 23) % 145) + sequences.append( + torch.cat( + ( + root, + _tokens( + family_base + branch * 1009 + 17, + completion_len, + ), + ) + ) + ) + return tuple(sequences) + + +def _regular_sequences( + *, + prefix_families: int, + prefix_len: int, + mid_prefixes_per_family: int, + mid_prefix_len: int, + branches_per_prefix: int, + completion_len: int, +) -> tuple[torch.Tensor, ...]: + sequences = [] + for family in range(max(1, prefix_families)): + family_base = family * 10_000_019 + root = _tokens(family_base, max(1, prefix_len)) + for mid in range(max(1, mid_prefixes_per_family)): + mid_prefix = _tokens( + family_base + 1_000_003 + mid * 100_003, + max(0, mid_prefix_len), + ) + prefix = torch.cat((root, mid_prefix)) + for branch in range(max(1, branches_per_prefix)): + sequences.append( + torch.cat( + ( + prefix, + _tokens( + family_base + mid * 100_003 + branch * 1009 + 17, + max(1, completion_len), + ), + ) + ) + ) + return tuple(sequences) + + +def _tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _build_cp_plan( + pack: SharedPrefixPack, + spec: object, + topology: ParallelTopology, + config: ContextParallelConfig, +) -> object: + return get_or_build_runtime_plan( + spec, + topology=topology, + config=config, + original_seq_len=int(pack.tokens.numel()), + ) + + +def _build_stage_masks( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, +) -> tuple[tuple[BlockMask, tuple[object, ...]], ...]: + masks = [] + context = prepare_block_mask_context( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + for rank_plan in plan: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + execution_spec = _stage_execution_spec(stage, config) + mask_metadata = execution_spec.mask_metadata or stage.mask_metadata + if mask_metadata is None: + continue + mask = build_block_mask_from_context( + FlexMaskSpec( + q_len=execution_spec.q_len, + k_len=execution_spec.k_len, + block_size=_sparse_block_size(config), + slices=stage.slices, + exact_mask=mask_metadata, + ), + context=context, + device=torch.device("cpu"), + validate=False, + ) + if mask is not None: + masks.append((mask, tuple(stage.slices))) + return tuple(masks) + + +def _flex_records( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, + *, + warmup: int, + repeat: int, + token_cap: int, + heads: int, + head_dim: int, + variants: Sequence[str], +) -> list[dict[str, object]]: + if not torch.cuda.is_available(): + return [{"case": "flex_attention_fwd_bwd", "skipped": "cuda_unavailable"}] + device = torch.device("cuda") + stage_cases = _build_stage_flex_cases( + pack, + plan, + config, + device=device, + ) + if not stage_cases: + return [{"case": "flex_attention_fwd_bwd", "skipped": "no_stage_masks"}] + largest_stage = max(max(case.q_len, case.k_len) for case in stage_cases) + if int(largest_stage) > int(token_cap): + return [ + { + "case": "flex_attention_fwd_bwd", + "skipped": "stage_tokens_exceed_flex_token_cap", + "flex_token_cap": int(token_cap), + "largest_stage_tokens": int(largest_stage), + } + ] + records: list[dict[str, object]] = [] + base_tensors = _stage_tensors( + stage_cases, + heads=heads, + head_dim=head_dim, + device=device, + ) + for variant in variants: + block_masks = [] + try: + block_masks = [ + _stage_variant_block_mask(case, variant, device=device) + for case in stage_cases + ] + except Exception as exc: + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "compile_error": type(exc).__name__, + "compile_error_message": str(exc).splitlines()[0][:500], + "flex_heads": heads, + "flex_head_dim": head_dim, + } + ) + continue + qkv = [ + ( + q.detach().clone().requires_grad_(True), + k.detach().clone().requires_grad_(True), + v.detach().clone().requires_grad_(True), + ) + for q, k, v in base_tensors + ] + + def step() -> None: + loss = torch.zeros((), device=device, dtype=torch.float32) + for (q, k, v), block_mask in zip(qkv, block_masks, strict=True): + q.grad = None + k.grad = None + v.grad = None + out, _aux = sparse_compiled_flex_attention( + q, + k, + v, + block_mask=block_mask, + scale=float(head_dim) ** -0.5, + enable_gqa=False, + return_aux=AuxRequest(lse=True), + ) + loss = loss + out.float().sum() + loss.backward() + + try: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + first_started = time.perf_counter() + step() + torch.cuda.synchronize() + first_call_ms = round((time.perf_counter() - first_started) * 1000.0, 3) + ms = _bench_cuda(step, warmup=warmup, repeat=repeat) + except Exception as exc: + torch.cuda.empty_cache() + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "compile_error": type(exc).__name__, + "compile_error_message": str(exc).splitlines()[0][:500], + "flex_heads": heads, + "flex_head_dim": head_dim, + **_stage_flex_stats(stage_cases), + } + ) + continue + records.append( + { + "case": "flex_attention_fwd_bwd", + "flex_mask_variant": variant, + "first_call_ms": first_call_ms, + "ms": ms, + "packed_tok_s": round(int(pack.tokens.numel()) * 1000.0 / ms, 3), + "flex_heads": heads, + "flex_head_dim": head_dim, + **_stage_flex_stats(stage_cases), + "peak_memory_gb": round(torch.cuda.max_memory_allocated() / 1024**3, 3), + } + ) + return records + + +@dataclass(frozen=True) +class _StageFlexCase: + rank: int + stage_index: int + q_len: int + k_len: int + logical_q_len: int + logical_k_len: int + block_mask: BlockMask + q_abs: np.ndarray + k_abs: np.ndarray + + +def _build_stage_flex_cases( + pack: SharedPrefixPack, + plan: object, + config: ContextParallelConfig, + *, + device: torch.device, +) -> tuple[_StageFlexCase, ...]: + cases: list[_StageFlexCase] = [] + context = prepare_block_mask_context( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + for rank_plan in plan: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + execution_spec = _stage_execution_spec(stage, config) + mask_metadata = execution_spec.mask_metadata or stage.mask_metadata + if mask_metadata is None: + continue + mask = build_block_mask_from_context( + FlexMaskSpec( + q_len=execution_spec.q_len, + k_len=execution_spec.k_len, + block_size=_sparse_block_size(config), + slices=stage.slices, + exact_mask=mask_metadata, + ), + context=context, + device=device, + validate=False, + ) + if mask is None: + continue + q_abs = ( + mask_metadata.q_token_indices.detach() + .to(device="cpu", dtype=torch.int64) + .reshape(-1) + .numpy() + ) + k_abs = ( + mask_metadata.k_token_indices.detach() + .to(device="cpu", dtype=torch.int64) + .reshape(-1) + .numpy() + ) + cases.append( + _StageFlexCase( + rank=int(rank_plan.rank), + stage_index=int(stage.stage_index), + q_len=int(execution_spec.q_len), + k_len=int(execution_spec.k_len), + logical_q_len=int(stage.q_len), + logical_k_len=int(stage.k_len), + block_mask=mask, + q_abs=q_abs, + k_abs=k_abs, + ) + ) + return tuple(cases) + + +def _stage_tensors( + cases: Sequence[_StageFlexCase], + *, + heads: int, + head_dim: int, + device: torch.device, +) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]: + generator = torch.Generator(device=device).manual_seed(17) + tensors = [] + for case in cases: + q_shape = (1, int(heads), int(case.q_len), int(head_dim)) + k_shape = (1, int(heads), int(case.k_len), int(head_dim)) + tensors.append( + ( + torch.randn( + q_shape, device=device, dtype=torch.bfloat16, generator=generator + ), + torch.randn( + k_shape, device=device, dtype=torch.bfloat16, generator=generator + ), + torch.randn( + k_shape, device=device, dtype=torch.bfloat16, generator=generator + ), + ) + ) + return tuple(tensors) + + +def _stage_variant_block_mask( + case: _StageFlexCase, + variant: str, + *, + device: torch.device, +) -> BlockMask: + if variant == "current": + return case.block_mask + q_abs = torch.as_tensor(case.q_abs, device=device, dtype=torch.int64) + k_abs = torch.as_tensor(case.k_abs, device=device, dtype=torch.int64) + if variant == "causal_abs_only": + + def mask_mod(batch_idx, head_idx, query_idx, kv_idx): + del batch_idx, head_idx + return q_abs[query_idx] >= k_abs[kv_idx] + + return _replace_block_mask_mod(case.block_mask, mask_mod) + raise ValueError(f"unknown flex_mask_variant {variant!r}") + + +def _stage_flex_stats(cases: Sequence[_StageFlexCase]) -> dict[str, object]: + return { + "flex_stage_count": len(cases), + "flex_stage_q_tokens": sum(case.q_len for case in cases), + "flex_stage_k_tokens": sum(case.k_len for case in cases), + "flex_stage_logical_q_tokens": sum(case.logical_q_len for case in cases), + "flex_stage_logical_k_tokens": sum(case.logical_k_len for case in cases), + "flex_stage_max_q_tokens": max(case.q_len for case in cases), + "flex_stage_max_k_tokens": max(case.k_len for case in cases), + "flex_stage_max_logical_q_tokens": max(case.logical_q_len for case in cases), + "flex_stage_max_logical_k_tokens": max(case.logical_k_len for case in cases), + } + + +def _sparse_block_size(config: ContextParallelConfig) -> tuple[int, int]: + return normalize_sparse_block_size( + config.attention_sparse_block_size or config.block_size + ) + + +def _stage_execution_spec( + stage: StagePlan, + config: ContextParallelConfig, +) -> StageExecutionSpec: + return _build_stage_execution_spec( + stage_plan=stage, + block_size=_sparse_block_size(config), + ) + + +def _replace_block_mask_mod(block_mask: BlockMask, mask_mod: object) -> BlockMask: + return BlockMask( + seq_lengths=block_mask.seq_lengths, + kv_num_blocks=block_mask.kv_num_blocks, + kv_indices=block_mask.kv_indices, + full_kv_num_blocks=block_mask.full_kv_num_blocks, + full_kv_indices=block_mask.full_kv_indices, + q_num_blocks=block_mask.q_num_blocks, + q_indices=block_mask.q_indices, + full_q_num_blocks=block_mask.full_q_num_blocks, + full_q_indices=block_mask.full_q_indices, + BLOCK_SIZE=block_mask.BLOCK_SIZE, + mask_mod=mask_mod, + ) + + +def _bench_cpu( + fn: Callable[[], object], + *, + warmup: int, + repeat: int, + before_each: Callable[[], object] | None = None, +) -> tuple[object, float]: + result = None + for _ in range(warmup): + if before_each is not None: + before_each() + result = fn() + elapsed = [] + for _ in range(repeat): + if before_each is not None: + before_each() + start = time.perf_counter() + result = fn() + elapsed.append((time.perf_counter() - start) * 1000.0) + assert result is not None + return result, round(sum(elapsed) / len(elapsed), 3) + + +def _bench_cuda(fn: Callable[[], object], *, warmup: int, repeat: int) -> float: + torch.cuda.reset_peak_memory_stats() + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + stop = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(repeat): + fn() + stop.record() + torch.cuda.synchronize() + return round(float(start.elapsed_time(stop)) / repeat, 3) + + +def _plan_stats(plan: object) -> dict[str, int]: + stage_count = 0 + remote_stage_count = 0 + mask_stage_count = 0 + for rank_plan in plan: + for stage in rank_plan.stage_plans: + stage_count += 1 + remote_stage_count += int(not stage.is_local_stage) + mask_stage_count += int(stage.mask_metadata is not None) + return { + "rank_count": len(plan), + "stage_count": stage_count, + "remote_stage_count": remote_stage_count, + "mask_stage_count": mask_stage_count, + } + + +def _mask_stats(masks: Sequence[BlockMask]) -> dict[str, int]: + return { + "mask_count": len(masks), + "partial_kv_blocks": sum(_block_count(mask, "kv_num_blocks") for mask in masks), + "full_kv_blocks": sum( + _block_count(mask, "full_kv_num_blocks") for mask in masks + ), + "partial_q_blocks": sum(_block_count(mask, "q_num_blocks") for mask in masks), + "full_q_blocks": sum(_block_count(mask, "full_q_num_blocks") for mask in masks), + } + + +def _block_count(block_mask: BlockMask, name: str) -> int: + counts = getattr(block_mask, name) + return 0 if counts is None else int(counts.sum().item()) + + +def _assert_matches_torch_block_mask( + block_mask: BlockMask, + *, + slices: Sequence[object] = (), +) -> None: + q_len, k_len = block_mask.seq_lengths + reference = torch_block_mask( + _slice_mask_mod(block_mask.mask_mod, slices), + B=int(block_mask.kv_num_blocks.shape[0]), + H=1, + Q_LEN=q_len, + KV_LEN=k_len, + device="cpu", + BLOCK_SIZE=block_mask.BLOCK_SIZE, + ) + for counts_name, indices_name in ( + ("kv_num_blocks", "kv_indices"), + ("full_kv_num_blocks", "full_kv_indices"), + ("q_num_blocks", "q_indices"), + ("full_q_num_blocks", "full_q_indices"), + ): + actual = _block_entries(block_mask, counts_name, indices_name) + expected = _block_entries(reference, counts_name, indices_name) + if actual != expected: + raise AssertionError(f"{counts_name}/{indices_name} mismatch") + + +def _slice_mask_mod(mask_mod: object, slices: Sequence[object]) -> object: + if not slices: + return mask_mod + + def sliced_mask_mod( + batch_idx: torch.Tensor, + head_idx: torch.Tensor, + query_idx: torch.Tensor, + kv_idx: torch.Tensor, + ) -> torch.Tensor: + in_slice = (query_idx < 0) & (kv_idx < 0) + for slice_ in slices: + in_slice |= ( + (query_idx >= int(slice_.q_range.start)) + & (query_idx < int(slice_.q_range.end)) + & (kv_idx >= int(slice_.k_range.start)) + & (kv_idx < int(slice_.k_range.end)) + ) + return in_slice & mask_mod(batch_idx, head_idx, query_idx, kv_idx) + + return sliced_mask_mod + + +def _block_entries( + block_mask: BlockMask, + counts_name: str, + indices_name: str, +) -> set[tuple[int, int, int, int]]: + counts = getattr(block_mask, counts_name) + indices = getattr(block_mask, indices_name) + if counts is None or indices is None: + return set() + entries = set() + for batch_index in range(int(counts.shape[0])): + for head_index in range(int(counts.shape[1])): + for block_index in range(int(counts.shape[2])): + block_count = int(counts[batch_index, head_index, block_index]) + for other_block in indices[ + batch_index, + head_index, + block_index, + :block_count, + ].tolist(): + entries.add( + ( + batch_index, + head_index, + block_index, + int(other_block), + ) + ) + return entries + + +def _logical_tokens(pack: SharedPrefixPack) -> int: + return sum(int(positions.numel()) for positions in pack.positions_by_sequence) + + +def _torch_validation_skip_reason( + *, + validate_torch: bool, + packed_tokens: int, + token_cap: int, +) -> str | None: + if not validate_torch: + return "disabled" + if token_cap > 0 and packed_tokens > token_cap: + return f"packed_tokens>{token_cap}" + return None + + +def _csv_values(value: str) -> tuple[str, ...]: + values = tuple(part.strip() for part in value.split(",") if part.strip()) + if not values: + raise ValueError("CSV option must contain at least one value") + return values + + +def _write(path: Path, payload: dict[str, object]) -> None: + line = json.dumps(payload, sort_keys=True) + with path.open("a", encoding="utf-8") as output: + output.write(line + "\n") + print(line, flush=True) + + +def _check_threshold(name: str, value_ms: float, limit_ms: float | None) -> None: + if limit_ms is not None and float(value_ms) > float(limit_ms): + raise RuntimeError( + f"{name} took {float(value_ms):.3f}ms, exceeding {float(limit_ms):.3f}ms" + ) + + +if __name__ == "__main__": + typer.run(main) diff --git a/dev/trainer_rank_topology_check.py b/dev/trainer_rank_topology_check.py new file mode 100644 index 000000000..245a8d6d8 --- /dev/null +++ b/dev/trainer_rank_topology_check.py @@ -0,0 +1,1238 @@ +from __future__ import annotations + +from dataclasses import dataclass +import json +import os +import time + +import torch +import torch.distributed as dist +import typer + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.megatron.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + _batch_seq_logits, + _language_model, + _select_positions, +) + + +@dataclass +class CheckOutput: + source_positions: torch.Tensor + target_logprobs: torch.Tensor | None + top_k: TopK | None + logits: torch.Tensor | None + hidden_states: torch.Tensor | None + + +@dataclass(frozen=True) +class DiffStats: + max_abs_diff: float = 0.0 + mean_abs_pct: float = 0.0 + + def merge(self, other: DiffStats) -> DiffStats: + return DiffStats( + max_abs_diff=max(self.max_abs_diff, other.max_abs_diff), + mean_abs_pct=max(self.mean_abs_pct, other.mean_abs_pct), + ) + + +def _gather_target_logprobs( + logprobs: torch.Tensor, + labels: torch.Tensor, +) -> torch.Tensor: + if int(labels.shape[0]) == 0: + return torch.empty(labels.shape, device=logprobs.device, dtype=logprobs.dtype) + flat_labels = labels.clamp_min(0).reshape(int(labels.shape[0]), -1) + selected = logprobs.gather(1, flat_labels).reshape(labels.shape) + return selected.masked_fill(labels == -100, 0.0) + + +def _empty_logits_like_positions( + positions: torch.Tensor, + model: object, + like: torch.Tensor, +) -> torch.Tensor: + vocab_size = getattr( + getattr(model, "config", None), + "padded_vocab_size", + None, + ) or getattr(model, "vocab_size", None) + if vocab_size is None: + raise RuntimeError("could not determine full padded vocabulary size") + return torch.empty( + (int(positions.numel()), int(vocab_size)), + device=like.device, + dtype=like.dtype, + ) + + +def main( + model: str = "Qwen/Qwen3-0.6B", + layers: int = 1, + head_chunk_a: int = 17, + head_chunk_b: int = 512, + max_prefix_depth: int = 1, + request_case: str = "shared", + stress_tokens: int = 0, + max_unpacked_output_gb: float = 0.25, + debug_output: str = "none", + compare_independent: bool = False, + compare_same_layout: bool = False, +) -> None: + os.environ.setdefault("ART_MEGATRON_TENSOR_MODEL_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_CONTEXT_PARALLEL_SIZE", "1") + os.environ.setdefault("ART_MEGATRON_PIPELINE_MODEL_PARALLEL_SIZE", "1") + + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + try: + from megatron.core import parallel_state as ps + + from art.megatron import train as megatron_train + + torch.manual_seed(1234) + provider_configure = ( + (lambda provider: setattr(provider, "num_layers", layers)) + if layers > 0 + else None + ) + runtime = megatron_train.build_training_runtime( + model_identifier=model, + provider_configure=provider_configure, + print_env=dist.get_rank() == 0, + ) + for chunk in runtime.model: + chunk.eval() + + requests = ( + _stress_requests(stress_tokens) + if stress_tokens > 0 + else _requests(request_case) + ) + requests = _debug_output_requests(requests, debug_output) + unpacked_output_gb = _estimate_unpacked_output_gb(requests, runtime) + if max_unpacked_output_gb > 0 and unpacked_output_gb > max_unpacked_output_gb: + if dist.get_rank() == 0: + print( + json.dumps( + { + "world": dist.get_world_size(), + "dp": int(ps.get_data_parallel_world_size()), + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "stress_tokens": stress_tokens, + "estimated_unpacked_output_gb": round( + unpacked_output_gb, 3 + ), + "max_unpacked_output_gb": max_unpacked_output_gb, + "skipped": "unpacked_output_cap", + }, + sort_keys=True, + ), + flush=True, + ) + dist.barrier() + return + dp_rank = int(ps.get_data_parallel_rank()) + dp_size = int(ps.get_data_parallel_world_size()) + local_pairs = [ + (index, request) + for index, request in enumerate(requests) + if index % dp_size == dp_rank + ] + local_requests = [request for _, request in local_pairs] + + rank_a = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_a, + shared_prefix_max_depth=max_prefix_depth, + ) + rank_b = TrainerRank( + runtime, + head_chunk_tokens=head_chunk_b, + shared_prefix_max_depth=max_prefix_depth, + ) + independent_outputs: list[CheckOutput] | None = None + same_layout_outputs: list[CheckOutput] | None = None + + torch.cuda.reset_peak_memory_stats() + diff_stats = DiffStats() + with torch.no_grad(): + started_at = time.perf_counter() + if request_case == "target_only": + _debug("forward-target-only") + outputs_a = list(rank_a.dp_rank_forward(local_requests)) + outputs_b = list(rank_b.dp_rank_forward(local_requests)) + oracle_outputs, actual_source_positions = _packed_oracle( + rank_a, local_requests + ) + elif stress_tokens > 0: + _debug("forward-a") + outputs_a = list(rank_a.dp_rank_forward(local_requests)) + outputs_b = outputs_a + actual_source_positions = _source_positions(rank_a, local_requests) + oracle_outputs = [ + _as_check_output(source_positions, output) + for source_positions, output in zip( + actual_source_positions, + outputs_a, + strict=True, + ) + ] + else: + _debug("forward-shared") + ( + outputs_a, + outputs_b, + oracle_outputs, + actual_source_positions, + ) = _shared_hidden_check(rank_a, rank_b, local_requests) + if compare_independent and request_case in {"shared", "unique", "deep"}: + independent_outputs = _independent_check_outputs( + rank_a, local_requests + ) + if int(ps.get_context_parallel_world_size()) <= 1: + for index, (actual, independent) in enumerate( + zip(outputs_a, independent_outputs, strict=True) + ): + diff_stats = diff_stats.merge( + _assert_close( + actual, + independent, + f"independent[{index}]", + ), + ) + if compare_same_layout and request_case in {"shared", "unique", "deep"}: + same_layout_outputs = _same_layout_check_outputs( + rank_a, + local_requests, + ) + for index, (actual, same_layout) in enumerate( + zip(outputs_a, same_layout_outputs, strict=True) + ): + diff_stats = diff_stats.merge( + _assert_close( + actual, + same_layout, + f"same_layout[{index}]", + ), + ) + _debug("compare") + elapsed_s = time.perf_counter() - started_at + + peak_memory_gb = torch.tensor( + torch.cuda.max_memory_allocated() / 1024**3, + device=rank_a.device, + ) + for index, (actual, chunked, oracle) in enumerate( + zip(outputs_a, outputs_b, oracle_outputs, strict=True) + ): + if int(oracle.source_positions.numel()) == 0: + continue + diff_stats = diff_stats.merge( + _assert_close(actual, chunked, f"chunk[{index}]"), + ) + diff_stats = diff_stats.merge( + _assert_close(actual, oracle, f"oracle[{index}]"), + ) + + diff_tensor = torch.tensor( + [diff_stats.max_abs_diff, diff_stats.mean_abs_pct], + device=rank_a.device, + ) + dist.all_reduce(diff_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(peak_memory_gb, op=dist.ReduceOp.MAX) + max_diff_value = float(diff_tensor[0].item()) + mean_abs_pct_value = float(diff_tensor[1].item()) + records = _records( + local_pairs=local_pairs, + actual_outputs=outputs_a, + actual_source_positions=actual_source_positions, + oracle_outputs=oracle_outputs, + independent_outputs=independent_outputs, + rank=int(dist.get_rank()), + dp=dp_rank, + tp=int(ps.get_tensor_model_parallel_rank()), + cp=int(ps.get_context_parallel_rank()), + ) + gathered: list[list[dict[str, object]] | None] = [None] * dist.get_world_size() + _debug("all-gather") + dist.all_gather_object(gathered, records) + _debug("reconstruct") + reconstruction_error: str | None = None + if dist.get_rank() == 0: + seen = { + record["input_index"] + for rank_records in gathered + for record in rank_records or [] + } + if seen != set(range(len(requests))): + reconstruction_error = f"DP reconstruction missed inputs: {seen}" + else: + try: + reconstructed_stats = _assert_reconstructed(gathered, requests) + max_diff_value = max( + max_diff_value, + reconstructed_stats.max_abs_diff, + ) + mean_abs_pct_value = max( + mean_abs_pct_value, + reconstructed_stats.mean_abs_pct, + ) + except AssertionError as exc: + reconstruction_error = str(exc) + if reconstruction_error is None: + print( + json.dumps( + { + "world": dist.get_world_size(), + "dp": dp_size, + "tp": int(ps.get_tensor_model_parallel_world_size()), + "cp": int(ps.get_context_parallel_world_size()), + "mean_abs_pct": mean_abs_pct_value, + "max_abs_diff": max_diff_value, + "records": sum( + len(rank_records or []) for rank_records in gathered + ), + "same_layout": compare_same_layout, + "stress_tokens": stress_tokens, + "estimated_unpacked_output_gb": round( + unpacked_output_gb, 3 + ), + "elapsed_s": round(elapsed_s, 3), + "peak_memory_gb": round(float(peak_memory_gb.item()), 3), + }, + sort_keys=True, + ), + flush=True, + ) + errors = [reconstruction_error] + dist.broadcast_object_list(errors, src=0) + if errors[0] is not None: + raise AssertionError(errors[0]) + dist.barrier() + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +def _requests( + request_case: str = "shared", +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if request_case not in {"shared", "target_only", "unique", "deep"}: + raise ValueError( + "request_case must be 'shared', 'target_only', 'unique', or 'deep'" + ) + rows = [ + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([11, 12, 13, 14, 24, 25]), + torch.tensor([11, 12, 13, 14, 24, 26]), + torch.tensor([11, 12, 13, 27]), + torch.tensor([31, 32, 33, 34]), + torch.tensor([31, 32, 33, 35]), + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([41, 42, 43]), + torch.tensor([41, 42, 44, 45]), + torch.tensor([51, 52, 53, 54, 55]), + torch.tensor([61, 62, 63]), + torch.tensor([61, 62, 64, 65]), + torch.tensor([71, 72]), + torch.tensor([81, 82, 83, 84]), + torch.tensor([91, 92, 93]), + torch.tensor([101, 102, 103, 104, 105]), + ] + if request_case == "deep": + rows = _deep_rows() + if request_case == "unique": + rows = [row + 1000 * index for index, row in enumerate(rows)] + if request_case == "target_only": + target_only_labels = [_labels(row, 0) for row in rows] + target_only_labels[0][2] = -100 + target_only_labels[3][1] = -100 + target_only_labels[10][0] = -100 + return [ + ForwardInput(input_tokens=row, target_tokens=label) + for row, label in zip(rows, target_only_labels, strict=True) + ] + + labels = [_labels(row, offset) for offset, row in enumerate(rows)] + labels[0][2] = -100 + labels[3][1] = -100 + labels[10][0] = -100 + multi_labels = torch.stack((labels[1], (labels[1] + 17) % 1000), dim=1) + multi_labels[2, 1] = -100 + requests = [] + for mask, row in enumerate(rows): + target_tokens = None + if mask & 1: + target_tokens = multi_labels if mask == 1 else labels[mask] + requests.append( + ForwardInput( + input_tokens=row, + target_tokens=target_tokens, + top_k=3 if mask & 2 else None, + logits=bool(mask & 4), + hidden_states=bool(mask & 8), + ) + ) + return requests + + +def _debug_output_requests( + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + debug_output: str, +) -> list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + if debug_output == "none": + return requests + if debug_output == "hidden": + return [ + ForwardInput(input_tokens=request.input_tokens, hidden_states=True) + for request in requests + ] + if debug_output == "logits": + return [ + ForwardInput(input_tokens=request.input_tokens, logits=True) + for request in requests + ] + if debug_output == "target": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=_labels(request.input_tokens, 0), + ) + for request in requests + ] + if debug_output == "topk": + return [ + ForwardInput(input_tokens=request.input_tokens, top_k=3) + for request in requests + ] + if debug_output == "target_topk": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=_labels(request.input_tokens, 0), + top_k=3, + ) + for request in requests + ] + if debug_output == "mixed_no_topk": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + logits=request.logits, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_no_logits": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + top_k=request.top_k, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_no_targets": + return [ + ForwardInput( + input_tokens=request.input_tokens, + top_k=request.top_k, + logits=request.logits, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_targets_only": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + ) + for request in requests + ] + if debug_output == "mixed_targets_hidden": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + hidden_states=request.hidden_states, + ) + for request in requests + ] + if debug_output == "mixed_targets_logits": + return [ + ForwardInput( + input_tokens=request.input_tokens, + target_tokens=request.target_tokens, + logits=request.logits, + ) + for request in requests + ] + raise ValueError( + "debug_output must be 'none', 'hidden', 'logits', 'target', 'topk', " + "'target_topk', 'mixed_no_topk', 'mixed_no_logits', 'mixed_no_targets', " + "'mixed_targets_only', 'mixed_targets_hidden', or 'mixed_targets_logits'" + ) + + +def _deep_rows() -> list[torch.Tensor]: + return [ + torch.tensor([11, 12, 13, 14, 15, 16, 17]), + torch.tensor([11, 12, 13, 14, 15, 16, 18]), + torch.tensor([11, 12, 13, 14, 15, 19]), + torch.tensor([11, 12, 13, 14, 20]), + torch.tensor([11, 12, 21]), + torch.tensor([31, 32, 33, 34, 35]), + torch.tensor([31, 32, 33, 34, 36]), + torch.tensor([31, 32, 33, 37]), + torch.tensor([41, 42, 43]), + torch.tensor([41, 42, 44]), + torch.tensor([51, 52, 53, 54]), + torch.tensor([61, 62]), + torch.tensor([71, 72, 73, 74, 75]), + torch.tensor([71, 72, 73, 76]), + torch.tensor([81]), + torch.tensor([91, 92, 93]), + ] + + +def _stress_requests( + token_count: int, +) -> list[ForwardInput[None, None, None, torch.Tensor]]: + if token_count < 8: + raise ValueError("stress_tokens must be >= 8") + prefix_len = token_count // 2 + tail_len = max(1, token_count // 4) + prefix = _stress_tokens(0, prefix_len) + return [ + ForwardInput( + input_tokens=torch.cat((prefix, _stress_tokens(10_000, tail_len))), + hidden_states=True, + ), + ForwardInput( + input_tokens=torch.cat((prefix, _stress_tokens(20_000, tail_len))), + hidden_states=True, + ), + ForwardInput(input_tokens=_stress_tokens(30_000, tail_len), hidden_states=True), + ] + + +def _stress_tokens(offset: int, length: int) -> torch.Tensor: + return (torch.arange(length, dtype=torch.long) + offset) % 32_000 + 100 + + +def _estimate_unpacked_output_gb( + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + runtime: object, +) -> float: + provider = getattr(runtime, "provider") + model = _language_model(getattr(runtime, "model")[0]) + hidden_size = int( + getattr(provider, "hidden_size", None) + or getattr(getattr(model, "config", None), "hidden_size", 0) + ) + vocab_size = int( + getattr(getattr(model, "config", None), "padded_vocab_size", None) + or getattr(model, "vocab_size", 0) + ) + dtype_size = next(getattr(runtime, "model")[0].parameters()).element_size() + bytes_total = sum( + _request_output_bytes( + request, + hidden_size=hidden_size, + vocab_size=vocab_size, + dtype_size=dtype_size, + ) + for request in requests + ) + return bytes_total / 1024**3 + + +def _request_output_bytes( + request: ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + *, + hidden_size: int, + vocab_size: int, + dtype_size: int, +) -> int: + seq_len = int(request.input_tokens.numel()) + bytes_total = 0 + if request.target_tokens is not None: + bytes_total += int(request.target_tokens.numel()) * 4 + if request.top_k is not None: + bytes_total += seq_len * int(request.top_k) * (4 + 8) + if request.logits: + bytes_total += seq_len * vocab_size * dtype_size + if request.hidden_states: + bytes_total += seq_len * hidden_size * dtype_size + return bytes_total + + +def _debug(label: str) -> None: + if os.environ.get("TRAINER_RANK_CHECK_DEBUG") != "1": + return + print(f"[rank{dist.get_rank()}] {label}", flush=True) + + +def _labels(tokens: torch.Tensor, offset: int) -> torch.Tensor: + return ((tokens * 7 + 3 + offset) % 1000).to(dtype=torch.long) + + +def _packed_oracle( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[list[CheckOutput], tuple[torch.Tensor, ...]]: + items = [rank._forward_item(request) for request in requests] + prepared = rank._prepare_packed_forward( + pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + ) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + return ( + _packed_oracle_from_hidden(rank, items, prepared, hidden), + prepared.source_positions_by_item, + ) + + +def _shared_hidden_check( + rank_a: TrainerRank, + rank_b: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[ + list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + list[CheckOutput], + tuple[torch.Tensor, ...], +]: + items = [rank_a._forward_item(request) for request in requests] + prepared = rank_a._prepare_packed_forward( + pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank_a.shared_prefix_max_depth, + ) + ) + hidden = rank_a._gather_sequence_parallel_hidden(rank_a._decoder_hidden(prepared)) + outputs_a = _outputs_from_hidden(rank_a, items, prepared, hidden) + outputs_b = _outputs_from_hidden(rank_b, items, prepared, hidden) + oracle = _packed_oracle_from_hidden(rank_a, items, prepared, hidden) + return ( + outputs_a, + outputs_b, + oracle, + prepared.source_positions_by_item, + ) + + +def _independent_check_outputs( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> list[CheckOutput]: + outputs: list[CheckOutput] = [] + for request in requests: + source_positions = _source_positions(rank, [request])[0] + outputs.append( + _as_check_output(source_positions, rank.dp_rank_forward([request])[0]) + ) + return outputs + + +def _same_layout_check_outputs( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> list[CheckOutput]: + items = [rank._forward_item(request) for request in requests] + batch = pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + outputs = [] + for index, positions in enumerate(batch.positions_by_sequence): + mutated = _mutated_batch(batch, keep_positions=positions) + prepared = rank._prepare_packed_forward(mutated) + hidden = rank._gather_sequence_parallel_hidden(rank._decoder_hidden(prepared)) + mutated_outputs = _outputs_from_hidden(rank, items, prepared, hidden) + outputs.append( + _as_check_output( + prepared.source_positions_by_item[index], + mutated_outputs[index], + ) + ) + return outputs + + +def _mutated_batch( + batch: SharedPrefixPack, + *, + keep_positions: torch.Tensor, +) -> SharedPrefixPack: + tokens = batch.tokens.clone() + mutate = torch.ones(int(tokens.shape[1]), dtype=torch.bool, device=tokens.device) + mutate[keep_positions.to(device=tokens.device)] = False + replacement = ( + torch.arange(int(tokens.shape[1]), dtype=tokens.dtype, device=tokens.device) + + 50_000 + ) + tokens[0, mutate] = replacement[mutate] % 100_000 + return SharedPrefixPack( + tokens=tokens, + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + position_ids=batch.position_ids, + positions_by_sequence=batch.positions_by_sequence, + ) + + +def _outputs_from_hidden( + rank: TrainerRank, + items: list[object], + prepared: object, + hidden: torch.Tensor, +) -> list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] +]: + return rank._project_head(items, prepared, hidden) + + +def _packed_oracle_from_hidden( + rank: TrainerRank, + items: list[object], + prepared: object, + hidden: torch.Tensor, +) -> list[CheckOutput]: + model = _language_model(rank.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + + outputs: list[CheckOutput] = [] + for item, positions, source_positions in zip( + items, + prepared.positions_by_item, + prepared.source_positions_by_item, + strict=True, + ): + needs_projection = ( + item.labels is not None or item.request.logits or item.request.top_k + ) + all_logits = None + if needs_projection: + if int(positions.numel()): + local_logits = rank._local_logits_from_hidden_rows( + model, + _select_positions(hidden, positions), + output_weight=output_weight, + ) + all_logits = _batch_seq_logits( + rank._gather_tensor_parallel_logits(local_logits.unsqueeze(1)), + seq_len=int(positions.numel()), + ).squeeze(0) + else: + all_logits = _empty_logits_like_positions(positions, model, hidden) + logprobs = ( + None + if all_logits is None + else torch.log_softmax(all_logits.float(), dim=-1) + ) + + target_logprobs = None + if item.labels is not None: + if logprobs is None: + raise RuntimeError("target_logprobs oracle requires logprobs") + labels = item.labels.to(device=logprobs.device).index_select( + 0, source_positions.to(device=logprobs.device) + ) + target_logprobs = _gather_target_logprobs(logprobs, labels) + + top_k = None + if item.request.top_k is not None: + if all_logits is None: + raise RuntimeError("top_k oracle requires logits") + log_z = torch.logsumexp(all_logits.float(), dim=-1) + values, tokens = torch.topk( + all_logits.float(), k=item.request.top_k, dim=-1 + ) + top_k = TopK(logprobs=values - log_z.unsqueeze(1), tokens=tokens) + + hidden_states = None + if item.request.hidden_states: + hidden_states = _select_positions(hidden, positions) + + outputs.append( + CheckOutput( + source_positions=source_positions, + target_logprobs=target_logprobs, + top_k=top_k, + logits=all_logits if item.request.logits else None, + hidden_states=hidden_states, + ) + ) + return outputs + + +def _source_positions( + rank: TrainerRank, + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> tuple[torch.Tensor, ...]: + items = [rank._forward_item(request) for request in requests] + prepared = rank._prepare_packed_forward( + pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=rank.shared_prefix_max_depth, + ) + ) + return prepared.source_positions_by_item + + +def _as_check_output( + source_positions: torch.Tensor, + output: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], +) -> CheckOutput: + return CheckOutput( + source_positions=source_positions, + target_logprobs=output.target_logprobs, + top_k=output.top_k, + logits=output.logits, + hidden_states=output.hidden_states, + ) + + +def _records( + *, + local_pairs: list[ + tuple[ + int, + ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, + ], + ] + ], + actual_outputs: list[ + ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], + actual_source_positions: tuple[torch.Tensor, ...], + oracle_outputs: list[CheckOutput], + independent_outputs: list[CheckOutput] | None, + rank: int, + dp: int, + tp: int, + cp: int, +) -> list[dict[str, object]]: + records: list[dict[str, object]] = [] + independent_records: list[CheckOutput | None] = ( + independent_outputs + if independent_outputs is not None + else [None] * len(local_pairs) + ) + for local_index, ( + (input_index, _), + actual, + actual_sources, + oracle, + independent, + ) in enumerate( + zip( + local_pairs, + actual_outputs, + actual_source_positions, + oracle_outputs, + independent_records, + strict=True, + ) + ): + records.append( + { + "input_index": input_index, + "local_index": local_index, + "rank": rank, + "dp": dp, + "tp": tp, + "cp": cp, + "actual": _cpu_record(actual_sources, actual), + "oracle": _cpu_record(oracle.source_positions, oracle), + "independent": ( + None + if independent is None + else _cpu_record(independent.source_positions, independent) + ), + } + ) + return records + + +def _cpu_record( + source_positions: torch.Tensor, + output: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + | CheckOutput, +) -> dict[str, torch.Tensor | None]: + return { + "source_positions": source_positions.cpu(), + "target_logprobs": _cpu(output.target_logprobs), + "logits": _cpu(output.logits), + "hidden_states": _cpu(output.hidden_states), + "top_k_logprobs": None if output.top_k is None else _cpu(output.top_k.logprobs), + "top_k_tokens": None if output.top_k is None else _cpu(output.top_k.tokens), + } + + +def _cpu(tensor: torch.Tensor | None) -> torch.Tensor | None: + return None if tensor is None else tensor.detach().cpu() + + +def _assert_reconstructed( + gathered: list[list[dict[str, object]] | None], + requests: list[ + ForwardInput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + ], +) -> DiffStats: + diff_stats = DiffStats() + records = [ + record + for rank_records in gathered + for record in rank_records or [] + if record["tp"] == 0 + ] + for input_index, request in enumerate(requests): + _debug(f"reconstruct-input-{input_index}") + actual = [ + record["actual"] + for record in records + if record["input_index"] == input_index + ] + oracle = [ + record["oracle"] + for record in records + if record["input_index"] == input_index + ] + independent = [ + record["independent"] + for record in records + if record["input_index"] == input_index + and record.get("independent") is not None + ] + length = int(request.input_tokens.numel()) + for key in ("target_logprobs", "logits", "hidden_states", "top_k_logprobs"): + _debug(f"reconstruct-input-{input_index}-{key}") + _debug(f"reconstruct-input-{input_index}-{key}-assemble-actual") + actual_value = _assemble(actual, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-actual-" + f"{_tensor_summary(actual_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-assemble-oracle") + oracle_value = _assemble(oracle, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-oracle-" + f"{_tensor_summary(oracle_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-oracle") + diff_stats = diff_stats.merge( + _tensor_diff_value( + actual_value, + oracle_value, + f"reconstructed[{input_index}].{key}", + ), + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-oracle-done") + if independent: + _debug(f"reconstruct-input-{input_index}-{key}-assemble-independent") + independent_value = _assemble(independent, key, length) + _debug( + f"reconstruct-input-{input_index}-{key}-independent-" + f"{_tensor_summary(independent_value)}" + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-independent") + diff_stats = diff_stats.merge( + _tensor_diff_value( + actual_value, + independent_value, + f"independent[{input_index}].{key}", + ), + ) + _debug(f"reconstruct-input-{input_index}-{key}-diff-independent-done") + _debug(f"reconstruct-input-{input_index}-{key}-done") + actual_tokens = _assemble(actual, "top_k_tokens", length) + oracle_tokens = _assemble(oracle, "top_k_tokens", length) + if actual_tokens is None or oracle_tokens is None: + if actual_tokens is not oracle_tokens: + raise AssertionError( + f"reconstructed[{input_index}].top_k None mismatch" + ) + elif not torch.equal(actual_tokens, oracle_tokens): + actual_logprobs = _assemble(actual, "top_k_logprobs", length) + oracle_logprobs = _assemble(oracle, "top_k_logprobs", length) + if ( + actual_logprobs is None + or oracle_logprobs is None + or _tensor_diff_value( + actual_logprobs, + oracle_logprobs, + f"reconstructed[{input_index}].top_k.logprobs", + ).max_abs_diff + > 5e-6 + ): + raise AssertionError( + f"reconstructed[{input_index}].top_k.tokens mismatch" + ) + if independent: + independent_tokens = _assemble(independent, "top_k_tokens", length) + if actual_tokens is None or independent_tokens is None: + if actual_tokens is not independent_tokens: + raise AssertionError( + f"independent[{input_index}].top_k None mismatch" + ) + elif not torch.equal(actual_tokens, independent_tokens): + actual_logprobs = _assemble(actual, "top_k_logprobs", length) + independent_logprobs = _assemble( + independent, + "top_k_logprobs", + length, + ) + if ( + actual_logprobs is None + or independent_logprobs is None + or _tensor_diff_value( + actual_logprobs, + independent_logprobs, + f"independent[{input_index}].top_k.logprobs", + ).max_abs_diff + > 5e-6 + ): + raise AssertionError( + f"independent[{input_index}].top_k.tokens mismatch" + ) + return diff_stats + + +def _assemble( + records: list[object], + key: str, + length: int, +) -> torch.Tensor | None: + typed_records = [record for record in records if isinstance(record, dict)] + values = [record[key] for record in typed_records if record[key] is not None] + if not values: + return None + first = values[0] + if not isinstance(first, torch.Tensor): + raise TypeError(key) + output = torch.empty((length, *first.shape[1:]), dtype=first.dtype) + filled = torch.zeros(length, dtype=torch.bool) + for record in typed_records: + value = record[key] + if value is None: + continue + if not isinstance(value, torch.Tensor): + raise TypeError(key) + positions = record["source_positions"] + if not isinstance(positions, torch.Tensor): + raise TypeError("source_positions") + output[positions] = value + filled[positions] = True + if not bool(filled.all().item()): + raise AssertionError(f"{key} reconstruction missed positions") + return output + + +def _tensor_summary(tensor: torch.Tensor | None) -> str: + if tensor is None: + return "None" + return f"shape={tuple(tensor.shape)} device={tensor.device} dtype={tensor.dtype}" + + +def _assert_close( + actual: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ], + expected: ForwardOutput[ + torch.Tensor | None, TopK | None, torch.Tensor | None, torch.Tensor | None + ] + | CheckOutput, + label: str, +) -> DiffStats: + diffs = [ + _tensor_diff( + actual.target_logprobs, expected.target_logprobs, f"{label}.target_logprobs" + ) + ] + diffs.append(_tensor_diff(actual.logits, expected.logits, f"{label}.logits")) + diffs.append( + _tensor_diff( + actual.hidden_states, expected.hidden_states, f"{label}.hidden_states" + ) + ) + if actual.top_k is None or expected.top_k is None: + if actual.top_k is not expected.top_k: + raise AssertionError(f"{label}.top_k None mismatch") + else: + try: + top_k_diff = _tensor_diff( + actual.top_k.logprobs, + expected.top_k.logprobs, + f"{label}.top_k.logprobs", + ) + except AssertionError as exc: + flat_offset = int( + (actual.top_k.logprobs.float() - expected.top_k.logprobs.float()) + .abs() + .flatten() + .argmax() + ) + row, _ = divmod(flat_offset, int(actual.top_k.logprobs.shape[1])) + raise AssertionError( + f"{exc}; actual_row={actual.top_k.logprobs[row].tolist()} " + f"expected_row={expected.top_k.logprobs[row].tolist()} " + f"actual_tokens={actual.top_k.tokens[row].tolist()} " + f"expected_tokens={expected.top_k.tokens[row].tolist()}" + ) from exc + diffs.append(top_k_diff) + if ( + not torch.equal(actual.top_k.tokens, expected.top_k.tokens) + and top_k_diff.max_abs_diff > 5e-6 + ): + mismatch = torch.nonzero( + actual.top_k.tokens != expected.top_k.tokens, + as_tuple=False, + )[0] + row = int(mismatch[0].item()) + col = int(mismatch[1].item()) + raise AssertionError( + f"{label}.top_k.tokens mismatch at ({row}, {col}): " + f"actual={int(actual.top_k.tokens[row, col].item())} " + f"expected={int(expected.top_k.tokens[row, col].item())} " + f"actual_logprob={float(actual.top_k.logprobs[row, col].item())} " + f"expected_logprob={float(expected.top_k.logprobs[row, col].item())}" + ) + return _merge_diff_stats(diffs) + + +def _tensor_diff( + actual: torch.Tensor | None, + expected: torch.Tensor | None, + label: str, +) -> DiffStats: + return _tensor_diff_value(actual, expected, label) + + +def _tensor_diff_value( + actual: torch.Tensor | None, + expected: torch.Tensor | None, + label: str, +) -> DiffStats: + if actual is None or expected is None: + if actual is not expected: + raise AssertionError(f"{label} None mismatch") + return DiffStats() + if actual.shape != expected.shape: + raise AssertionError( + f"{label} shape mismatch: {actual.shape} != {expected.shape}" + ) + actual_for_diff = actual + expected_for_diff = expected + if torch.cuda.is_available(): + actual_for_diff = actual_for_diff.to(device="cuda") + expected_for_diff = expected_for_diff.to(device="cuda") + if actual_for_diff.numel(): + abs_diff = (actual_for_diff.float() - expected_for_diff.float()).abs() + max_abs_diff = float(abs_diff.max().item()) + denominator = float(expected_for_diff.float().abs().mean().item()) + mean_abs_pct = float(abs_diff.mean().item()) / (denominator + 1e-18) + else: + max_abs_diff = 0.0 + mean_abs_pct = 0.0 + mean_abs_pct_tolerance = 5e-3 if label.startswith("independent[") else 2e-5 + max_abs_tolerance = 0.0 + _debug( + f"{label} max_abs_diff={max_abs_diff} " + f"mean_abs_pct={mean_abs_pct} tolerance={mean_abs_pct_tolerance}" + ) + if mean_abs_pct > mean_abs_pct_tolerance: + raise AssertionError( + f"{label} mean_abs_pct {mean_abs_pct} max_abs_diff {max_abs_diff}" + ) + if max_abs_diff > max_abs_tolerance and not actual_for_diff.is_floating_point(): + raise AssertionError(f"{label} max diff {max_abs_diff}") + return DiffStats(max_abs_diff=max_abs_diff, mean_abs_pct=mean_abs_pct) + + +def _merge_diff_stats(stats: list[DiffStats]) -> DiffStats: + merged = DiffStats() + for stat in stats: + merged = merged.merge(stat) + return merged + + +if __name__ == "__main__": + typer.run(main) diff --git a/pyproject.toml b/pyproject.toml index bfa06e5d1..dd900d326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,6 +199,7 @@ requires-dist = [ [tool.ty.environment] python-version = "3.12" +extra-paths = ["typings"] [tool.ty.rules] # Ignore unused-ignore-comment warnings because they vary depending on whether @@ -229,6 +230,7 @@ allowed-unresolved-imports = [ "peft.**", "pyarrow.**", "torch.**", + "torchvision.**", "torchao.**", "transformers.**", "trl.**", diff --git a/scripts/build-gpu-image.sh b/scripts/build-gpu-image.sh index 299678584..dbce31484 100755 --- a/scripts/build-gpu-image.sh +++ b/scripts/build-gpu-image.sh @@ -10,10 +10,12 @@ Options: --image-repo REPO Image repository to publish --infra INFRA Kubernetes-backed SkyPilot infra (default: k8s/cks-wb3) --no-cache Disable registry-backed BuildKit cache + --no-prewarm-modal Skip prebuilding the pushed image in Modal --no-prewarm-nodes Skip pre-pulling the pushed image on GPU nodes --pull-image-repo REPO Image repository for cluster pulls/prewarm + --prewarm-modal Require prebuilding the pushed image in Modal --prewarm-timeout DUR Timeout for the prewarm DaemonSet rollout (default: 30m) - --tag TAG Image tag to publish + --tag TAG Image tag to publish (default: latest) --help Show this help EOF } @@ -24,12 +26,13 @@ cluster_name="" infra="${SKY_INFRA:-k8s/cks-wb3}" image_repo="${ART_IMAGE_REPO:-}" pull_image_repo="${ART_PULL_IMAGE_REPO:-}" -image_tag="" +image_tag="${IMAGE_TAG:-latest}" docker_config_path="${DOCKER_CONFIG_PATH:-${HOME}/.docker/config.json}" buildkit_image="${BUILDKIT_IMAGE:-moby/buildkit:v0.29.0-rootless}" buildkit_namespace="${KUBECTL_NAMESPACE:-default}" buildkit_wait_timeout="${BUILDKIT_WAIT_TIMEOUT:-300s}" no_cache="${NO_CACHE:-false}" +prewarm_modal="${PREWARM_MODAL:-auto}" prewarm_nodes="${PREWARM_NODES:-true}" prewarm_namespace="${PREWARM_NAMESPACE:-default}" prewarm_name="${PREWARM_NAME:-art-gpu-image-prewarm}" @@ -58,6 +61,10 @@ while [[ $# -gt 0 ]]; do no_cache=true shift ;; + --no-prewarm-modal) + prewarm_modal=false + shift + ;; --no-prewarm-nodes) prewarm_nodes=false shift @@ -66,6 +73,10 @@ while [[ $# -gt 0 ]]; do pull_image_repo="$2" shift 2 ;; + --prewarm-modal) + prewarm_modal=true + shift + ;; --prewarm-timeout) prewarm_timeout="$2" shift 2 @@ -86,6 +97,14 @@ while [[ $# -gt 0 ]]; do esac done +case "${prewarm_modal}" in + auto|true|false) ;; + *) + echo "PREWARM_MODAL must be one of: auto, true, false" >&2 + exit 1 + ;; +esac + case "${infra}" in k8s/*) kube_context="${infra#k8s/}" @@ -111,10 +130,6 @@ art_sha="$(git -C "${repo_root}" rev-parse HEAD)" art_short_sha="$(git -C "${repo_root}" rev-parse --short=12 HEAD)" timestamp="$(date +%m%d-%H%M%S)" -if [[ -z "${image_tag}" ]]; then - image_tag="skypilot-${art_short_sha}" -fi - if [[ -z "${cluster_name}" ]]; then cluster_name="art-gpu-build-${timestamp}" fi @@ -409,6 +424,38 @@ if [[ -n "${prewarm_refresh_tag_image}" ]]; then prewarm_display="${prewarm_image} and refreshing ${prewarm_refresh_tag_image}" fi +modal_auth_available=false +if [[ "${prewarm_modal}" != "false" ]]; then + if uv run --with 'modal>=1.5.0' python - <<'PY' >/dev/null 2>&1; then +import modal + +modal.Workspace.from_context().hydrate() +PY + modal_auth_available=true + fi +fi + +if [[ "${prewarm_modal}" == "true" || "${modal_auth_available}" == "true" ]]; then + echo "Prewarming ${image_repo}:${image_tag} in Modal image cache" + MODAL_FORCE_BUILD=1 uv run --with 'modal>=1.5.0' python - "${image_repo}:${image_tag}" <<'PY' +import sys + +import modal + +image = ( + modal.Image.from_registry(sys.argv[1], add_python="3.12") + .apt_install("openssh-server", "sudo", "rsync", "curl", "procps", "patch", "lsof") +) +app = modal.App.lookup("skypilot-modal", create_if_missing=True) +with modal.enable_output(): + image.build(app) +PY +elif [[ "${prewarm_modal}" == "auto" ]]; then + echo "Skipping Modal image prewarm: Modal auth unavailable" +else + echo "Skipping Modal image prewarm" +fi + dump_prewarm_diagnostics() { echo "::group::Prewarm diagnostics" "${kubectl_cmd[@]}" get daemonset -n "${prewarm_namespace}" "${prewarm_name}" -o wide || true diff --git a/src/art/megatron/__init__.py b/src/art/megatron/__init__.py index 3c2e5e5b9..18345c630 100644 --- a/src/art/megatron/__init__.py +++ b/src/art/megatron/__init__.py @@ -1,6 +1,17 @@ from typing import Any -__all__ = ["MegatronBackend"] +_TRAINER_RANK_EXPORTS = ( + "AdamParams", + "ForwardInput", + "ForwardOutput", + "MicroBatch", + "MicroBatchStats", + "TopK", + "TrainerRank", + "TrainerRankMemoryError", +) + +__all__ = ["MegatronBackend", *_TRAINER_RANK_EXPORTS] def __getattr__(name: str) -> Any: @@ -8,4 +19,8 @@ def __getattr__(name: str) -> Any: from .backend import MegatronBackend return MegatronBackend + if name in _TRAINER_RANK_EXPORTS: + from . import trainer_rank + + return getattr(trainer_rank, name) raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/art/megatron/context_parallel/__init__.py b/src/art/megatron/context_parallel/__init__.py index 995b0c425..fc27c486e 100644 --- a/src/art/megatron/context_parallel/__init__.py +++ b/src/art/megatron/context_parallel/__init__.py @@ -1,20 +1,16 @@ from .builder import build_dense_reference_mask, build_shared_prefix_attention_spec from .layout_index import TokenLayoutIndex -from .runtime import build_context_parallel_token_layout_index from .types import ( ArtContextParallelState, AttnMaskKind, AttnSlice, ContextParallelConfig, - ContextParallelRuntimeKey, - ContextParallelRuntimePlan, DispatchedPackedTensors, FlexMaskSpec, PackedBatchAttentionSpec, PackedRowAttentionSpec, ParallelTopology, PreparedMegatronBatch, - SharedPrefixBuilderConfig, TokenRange, ) @@ -28,13 +24,9 @@ "PackedRowAttentionSpec", "ParallelTopology", "PreparedMegatronBatch", - "SharedPrefixBuilderConfig", "ContextParallelConfig", - "ContextParallelRuntimeKey", - "ContextParallelRuntimePlan", "TokenRange", "TokenLayoutIndex", "build_dense_reference_mask", - "build_context_parallel_token_layout_index", "build_shared_prefix_attention_spec", ] diff --git a/src/art/megatron/context_parallel/block_mask.py b/src/art/megatron/context_parallel/block_mask.py index 91fe2023b..e5ec1eaac 100644 --- a/src/art/megatron/context_parallel/block_mask.py +++ b/src/art/megatron/context_parallel/block_mask.py @@ -1,32 +1,42 @@ from __future__ import annotations +from dataclasses import dataclass + import numpy as np import torch from torch.nn.attention.flex_attention import BlockMask from art.megatron.flex_attn.compiled import normalize_sparse_block_size +from art.megatron.shared_prefix_tree import parse_shared_prefix_row from .types import AttnMaskKind, FlexMaskSpec -_INVALID_Q_GROUP = -(1 << 63) -_INVALID_Q_PARENT = _INVALID_Q_GROUP + 1 -_INVALID_K_GROUP = _INVALID_Q_GROUP + 2 +_INVALID_ABS = -(1 << 63) +_INVALID_ENTER = -1 +_INVALID_EXIT = -1 + +@dataclass(frozen=True, slots=True) +class PreparedBlockMaskContext: + source_len: int + group_enter_np: np.ndarray + group_exit_np: np.ndarray -def _build_exact_mask_mod( + +def _build_interval_mask_mod( *, q_abs: np.ndarray, k_abs: np.ndarray, - q_group: np.ndarray, - q_parent: np.ndarray, - k_group: np.ndarray, + q_enter: np.ndarray, + k_enter: np.ndarray, + k_exit: np.ndarray, device: torch.device, ): q_abs_tensor = torch.as_tensor(q_abs, device=device, dtype=torch.int64) k_abs_tensor = torch.as_tensor(k_abs, device=device, dtype=torch.int64) - q_group_tensor = torch.as_tensor(q_group, device=device, dtype=torch.int64) - q_parent_tensor = torch.as_tensor(q_parent, device=device, dtype=torch.int64) - k_group_tensor = torch.as_tensor(k_group, device=device, dtype=torch.int64) + q_enter_tensor = torch.as_tensor(q_enter, device=device, dtype=torch.int64) + k_enter_tensor = torch.as_tensor(k_enter, device=device, dtype=torch.int64) + k_exit_tensor = torch.as_tensor(k_exit, device=device, dtype=torch.int64) def mask_mod( batch_idx: torch.Tensor, @@ -37,9 +47,13 @@ def mask_mod( del batch_idx, head_idx q_abs_local = q_abs_tensor[query_idx] k_abs_local = k_abs_tensor[kv_idx] - same_group = q_group_tensor[query_idx] == k_group_tensor[kv_idx] - parent_prefix = q_parent_tensor[query_idx] == k_group_tensor[kv_idx] - return (q_abs_local >= k_abs_local) & (same_group | parent_prefix) + q_enter_local = q_enter_tensor[query_idx] + k_enter_local = k_enter_tensor[kv_idx] + k_exit_local = k_exit_tensor[kv_idx] + in_key_subtree = (k_enter_local <= q_enter_local) & ( + q_enter_local < k_exit_local + ) + return (q_abs_local >= k_abs_local) & in_key_subtree return mask_mod @@ -49,10 +63,15 @@ def _dense_blocks_to_ordered( *, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: - counts = torch.from_numpy(blocks.sum(axis=-1).astype(np.int32)) - indices = torch.from_numpy( - np.argsort(-blocks.astype(np.int32), axis=-1, kind="stable").astype(np.int32) - ) + row_indices, column_indices = np.nonzero(blocks) + counts_np = np.bincount(row_indices, minlength=blocks.shape[0]).astype(np.int32) + indices_np = np.zeros(blocks.shape, dtype=np.int32) + if int(row_indices.size) > 0: + starts = np.concatenate(([0], np.cumsum(counts_np[:-1], dtype=np.int64))) + offsets = np.arange(int(row_indices.size), dtype=np.int64) - starts[row_indices] + indices_np[row_indices, offsets] = column_indices + counts = torch.from_numpy(counts_np) + indices = torch.from_numpy(indices_np) return ( counts.view(1, 1, -1).to(device=device), indices.view(1, 1, blocks.shape[0], blocks.shape[1]).to(device=device), @@ -72,72 +91,252 @@ def _select_with_invalid_np( return selected -def _build_q_block_group_state( +def _refine_interval_blocks( *, + partial_blocks: np.ndarray, + full_blocks: np.ndarray, q_abs: np.ndarray, - q_group: np.ndarray, - q_parent: np.ndarray, - q_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], frozenset[int]]: - start = int(block_idx) * q_block - end = min((int(block_idx) + 1) * q_block, int(q_abs.size)) - q = q_abs[start:end] - q_group_block = q_group[start:end] - q_parent_block = q_parent[start:end] - q_min = int(q.min()) if int(q.size) else 0 - max_by_group: dict[int, int] = {} - all_groups: list[int] = [] - for group_value in np.unique(np.concatenate((q_group_block, q_parent_block))): - allowed = (q_group_block == group_value) | (q_parent_block == group_value) - if bool(allowed.any()): - max_by_group[int(group_value)] = int(q[allowed].max()) - if bool(allowed.all()): - all_groups.append(int(group_value)) - return q_min, max_by_group, frozenset(all_groups) - - -def _build_k_block_group_state( - *, k_abs: np.ndarray, - k_group: np.ndarray, + q_enter: np.ndarray, + k_enter: np.ndarray, + k_exit: np.ndarray, + q_block: int, k_block: int, - block_idx: int, -) -> tuple[int, dict[int, int], tuple[int, ...]]: - start = int(block_idx) * k_block - end = min((int(block_idx) + 1) * k_block, int(k_abs.size)) - k = k_abs[start:end] - k_group_block = k_group[start:end] - k_max = int(k.max()) if int(k.size) else 0 - min_by_group: dict[int, int] = {} - for group_value in np.unique(k_group_block): - min_by_group[int(group_value)] = int(k[k_group_block == group_value].min()) - return k_max, min_by_group, tuple(min_by_group) - - -def _exact_block_state( - *, - q_state: tuple[int, dict[int, int], frozenset[int]], - k_state: tuple[int, dict[int, int], tuple[int, ...]], -) -> tuple[bool, bool]: - q_min, q_allowed_max, q_all_allowed = q_state - k_max, k_min, k_groups = k_state - if not any( - q_allowed_max.get(k_group_value, _INVALID_Q_GROUP) >= min_k - for k_group_value, min_k in k_min.items() +) -> None: + if not bool((partial_blocks | full_blocks).any()): + return + + q_abs_blocks = _block_matrix( + q_abs, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_ABS, + ) + q_enter_blocks = _block_matrix( + q_enter, + block_size=q_block, + block_count=int(partial_blocks.shape[0]), + fill_value=_INVALID_ENTER, + ) + k_abs_blocks = _block_matrix( + k_abs, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_ABS, + ) + k_enter_blocks = _block_matrix( + k_enter, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_ENTER, + ) + k_exit_blocks = _block_matrix( + k_exit, + block_size=k_block, + block_count=int(partial_blocks.shape[1]), + fill_value=_INVALID_EXIT, + ) + + q_valid = (q_abs_blocks >= 0) & (q_enter_blocks >= 0) + k_valid = ( + (k_abs_blocks >= 0) & (k_enter_blocks >= 0) & (k_exit_blocks > k_enter_blocks) + ) + q_all_valid = q_valid.all(axis=1) + k_all_valid = k_valid.all(axis=1) + q_min_abs = np.where(q_valid, q_abs_blocks, np.iinfo(np.int64).max).min(axis=1) + q_min_enter = np.where( + q_valid, + q_enter_blocks, + np.iinfo(np.int64).max, + ).min(axis=1) + q_max_enter = np.where(q_valid, q_enter_blocks, _INVALID_ENTER).max(axis=1) + k_max_abs = np.where(k_valid, k_abs_blocks, _INVALID_ABS).max(axis=1) + k_max_enter = np.where(k_valid, k_enter_blocks, _INVALID_ENTER).max(axis=1) + k_min_exit = np.where(k_valid, k_exit_blocks, np.iinfo(np.int64).max).min(axis=1) + safe_full = ( + q_all_valid[:, None] + & k_all_valid[None, :] + & (q_min_abs[:, None] >= k_max_abs[None, :]) + & (k_max_enter[None, :] <= q_min_enter[:, None]) + & (q_max_enter[:, None] < k_min_exit[None, :]) + ) + candidate_blocks = partial_blocks | (full_blocks & ~safe_full) + q_indices, k_indices = np.nonzero(candidate_blocks) + if int(q_indices.size) == 0: + return + + rows = np.arange(int(k_valid.shape[0])) + first_valid_offsets = k_valid.argmax(axis=1) + first_enter = k_enter_blocks[rows, first_valid_offsets] + first_exit = k_exit_blocks[rows, first_valid_offsets] + k_single_interval = k_valid.any(axis=1) & ( + (~k_valid) + | ( + (k_enter_blocks == first_enter[:, None]) + & (k_exit_blocks == first_exit[:, None]) + ) + ).all(axis=1) + + single_pair = k_single_interval[k_indices] + if bool(single_pair.any()): + single_q = q_indices[single_pair] + single_k = k_indices[single_pair] + q_abs_selected = q_abs_blocks[single_q] + q_enter_selected = q_enter_blocks[single_q] + in_subtree = ( + q_valid[single_q] + & (q_enter_selected >= first_enter[single_k, None]) + & (q_enter_selected < first_exit[single_k, None]) + ) + max_abs_in_subtree = np.where( + in_subtree, + q_abs_selected, + _INVALID_ABS, + ).max(axis=1) + k_min_abs = np.where(k_valid, k_abs_blocks, np.iinfo(np.int64).max).min(axis=1) + has_any = max_abs_in_subtree >= k_min_abs[single_k] + + is_full = ( + has_any + & q_all_valid[single_q] + & k_all_valid[single_k] + & (q_min_abs[single_q] >= k_max_abs[single_k]) + & (first_enter[single_k] <= q_min_enter[single_q]) + & (q_max_enter[single_q] < first_exit[single_k]) + ) + partial_blocks[single_q, single_k] = has_any & ~is_full + full_blocks[single_q, single_k] = is_full + + intervals_by_k: dict[int, tuple[tuple[int, int, int], ...]] = {} + + def k_intervals(k_idx: int) -> tuple[tuple[int, int, int], ...]: + cached = intervals_by_k.get(k_idx) + if cached is not None: + return cached + min_abs_by_interval: dict[tuple[int, int], int] = {} + for abs_value, enter_value, exit_value in zip( + k_abs_blocks[k_idx, k_valid[k_idx]], + k_enter_blocks[k_idx, k_valid[k_idx]], + k_exit_blocks[k_idx, k_valid[k_idx]], + strict=True, + ): + key = (int(enter_value), int(exit_value)) + prior = min_abs_by_interval.get(key) + min_abs_by_interval[key] = ( + int(abs_value) if prior is None else min(prior, int(abs_value)) + ) + cached = tuple( + (enter, exit, min_abs) + for (enter, exit), min_abs in min_abs_by_interval.items() + ) + intervals_by_k[k_idx] = cached + return cached + + for q_idx, k_idx in zip( + q_indices[~single_pair], + k_indices[~single_pair], + strict=True, ): - return False, False - if int(q_min) < int(k_max): - return True, False - return True, all(k_group_value in q_all_allowed for k_group_value in k_groups) + q_valid_row = q_valid[q_idx] + intervals = k_intervals(int(k_idx)) + has_any = False + if bool(q_valid_row.any()) and intervals: + q_abs_row = q_abs_blocks[q_idx] + q_enter_row = q_enter_blocks[q_idx] + for enter, exit, min_abs in intervals: + in_subtree = q_valid_row & (q_enter_row >= enter) & (q_enter_row < exit) + if bool(in_subtree.any()) and int(q_abs_row[in_subtree].max()) >= int( + min_abs + ): + has_any = True + break + is_full = ( + has_any + and bool(q_all_valid[q_idx]) + and bool(k_all_valid[k_idx]) + and int(q_min_abs[q_idx]) >= int(k_max_abs[k_idx]) + and int(k_max_enter[k_idx]) <= int(q_min_enter[q_idx]) + and int(q_max_enter[q_idx]) < int(k_min_exit[k_idx]) + ) + partial_blocks[q_idx, k_idx] = has_any and not is_full + full_blocks[q_idx, k_idx] = bool(is_full) + + +def _is_strictly_increasing(values: np.ndarray) -> bool: + return int(values.size) <= 1 or bool(np.all(values[1:] > values[:-1])) + + +def _block_min_max( + values: np.ndarray, + starts: np.ndarray, + ends: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + mins = np.empty(starts.shape, dtype=values.dtype) + maxes = np.empty(starts.shape, dtype=values.dtype) + for index, (start, end) in enumerate(zip(starts, ends, strict=True)): + block = values[int(start) : int(end)] + mins[index] = block.min() + maxes[index] = block.max() + return mins, maxes + + +def _block_matrix( + values: np.ndarray, + *, + block_size: int, + block_count: int, + fill_value: int, +) -> np.ndarray: + padded = np.full(block_count * block_size, fill_value, dtype=np.int64) + padded[: int(values.size)] = values + return padded.reshape(block_count, block_size) + + +def _build_group_interval_arrays( + *, + row_tree, + length: int, +) -> tuple[np.ndarray, np.ndarray]: + enter_by_group: dict[int, int] = {} + exit_by_group: dict[int, int] = {} + segment_by_group = {segment.group_id: segment for segment in row_tree.segments} + children_by_group: dict[int, list[int]] = {} + roots: list[int] = [] + for segment in row_tree.segments: + if segment.ancestors: + children_by_group.setdefault(segment.parent_id, []).append(segment.group_id) + else: + roots.append(segment.group_id) + + next_enter = 0 + + def visit(group_id: int) -> None: + nonlocal next_enter + enter_by_group[group_id] = next_enter + next_enter += 1 + children = children_by_group.get(group_id, []) + children.sort(key=lambda child: segment_by_group[child].start) + for child_group_id in children: + visit(child_group_id) + exit_by_group[group_id] = next_enter + + roots.sort(key=lambda root: segment_by_group[root].start) + for root_group_id in roots: + visit(root_group_id) + + enter_by_token = np.full((length,), _INVALID_ENTER, dtype=np.int64) + exit_by_token = np.full((length,), _INVALID_EXIT, dtype=np.int64) + for segment in row_tree.segments: + enter_by_token[segment.start : segment.end] = enter_by_group[segment.group_id] + exit_by_token[segment.start : segment.end] = exit_by_group[segment.group_id] + return enter_by_token, exit_by_token def _build_sparse_block_mask( spec: FlexMaskSpec, *, device: torch.device, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + context: PreparedBlockMaskContext, block_size: tuple[int, int], ) -> BlockMask: q_block, k_block = block_size @@ -156,33 +355,29 @@ def _build_sparse_block_mask( ) q_abs = q_abs_tensor.numpy() k_abs = k_abs_tensor.numpy() - flat_group_ids = group_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) - flat_parent_ids = ( - parent_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) - ) - flat_group_ids_np = flat_group_ids.numpy() - flat_parent_ids_np = flat_parent_ids.numpy() - q_group = _select_with_invalid_np( - flat_group_ids_np, + q_abs_sorted = _is_strictly_increasing(q_abs[q_abs >= 0]) + k_abs_sorted = _is_strictly_increasing(k_abs[k_abs >= 0]) + q_enter = _select_with_invalid_np( + context.group_enter_np, q_abs, - invalid_value=_INVALID_Q_GROUP, + invalid_value=_INVALID_ENTER, ) - q_parent = _select_with_invalid_np( - flat_parent_ids_np, - q_abs, - invalid_value=_INVALID_Q_PARENT, + k_enter = _select_with_invalid_np( + context.group_enter_np, + k_abs, + invalid_value=_INVALID_ENTER, ) - k_group = _select_with_invalid_np( - flat_group_ids_np, + k_exit = _select_with_invalid_np( + context.group_exit_np, k_abs, - invalid_value=_INVALID_K_GROUP, + invalid_value=_INVALID_EXIT, ) - mask_mod = _build_exact_mask_mod( + mask_mod = _build_interval_mask_mod( q_abs=q_abs, k_abs=k_abs, - q_group=q_group, - q_parent=q_parent, - k_group=k_group, + q_enter=q_enter, + k_enter=k_enter, + k_exit=k_exit, device=device, ) if not spec.slices: @@ -208,15 +403,11 @@ def _build_sparse_block_mask( if int(q_block_indices.size) == 0 or int(k_block_indices.size) == 0: continue q_block_start = q_block_indices * q_block - q_block_end = np.minimum( - (q_block_indices + 1) * q_block, - int(spec.q_len), - ) + q_block_end_raw = (q_block_indices + 1) * q_block + q_block_end = np.minimum(q_block_end_raw, int(spec.q_len)) k_block_start = k_block_indices * k_block - k_block_end = np.minimum( - (k_block_indices + 1) * k_block, - int(spec.k_len), - ) + k_block_end_raw = (k_block_indices + 1) * k_block + k_block_end = np.minimum(k_block_end_raw, int(spec.k_len)) q_overlap_start = np.maximum( q_block_start, q_start, @@ -233,12 +424,12 @@ def _build_sparse_block_mask( k_block_end, k_end, ) - q_min = q_abs[q_overlap_start] - q_max = q_abs[q_overlap_end - 1] - k_min = k_abs[k_overlap_start] - k_max = k_abs[k_overlap_end - 1] - q_is_full = (q_overlap_start == q_block_start) & (q_overlap_end == q_block_end) - k_is_full = (k_overlap_start == k_block_start) & (k_overlap_end == k_block_end) + q_is_full = (q_overlap_start == q_block_start) & ( + q_overlap_end == q_block_end_raw + ) + k_is_full = (k_overlap_start == k_block_start) & ( + k_overlap_end == k_block_end_raw + ) covers_block = q_is_full[:, None] & k_is_full[None, :] if slice_.mask_kind == AttnMaskKind.FULL: has_any = np.ones( @@ -246,6 +437,16 @@ def _build_sparse_block_mask( ) is_full = covers_block else: + q_min, q_max = ( + (q_abs[q_overlap_start], q_abs[q_overlap_end - 1]) + if q_abs_sorted + else _block_min_max(q_abs, q_overlap_start, q_overlap_end) + ) + k_min, k_max = ( + (k_abs[k_overlap_start], k_abs[k_overlap_end - 1]) + if k_abs_sorted + else _block_min_max(k_abs, k_overlap_start, k_overlap_end) + ) has_any = q_max[:, None] >= k_min[None, :] is_full = covers_block & (q_min[:, None] >= k_max[None, :]) @@ -255,41 +456,24 @@ def _build_sparse_block_mask( partial_blocks[q_slice, k_slice] |= has_any full_blocks[q_slice, k_slice] |= is_full - ambiguous = (touch_counts > 1) & partial_blocks & ~full_blocks - q_state_cache: dict[int, tuple[int, dict[int, int], frozenset[int]]] = {} - k_state_cache: dict[int, tuple[int, dict[int, int], tuple[int, ...]]] = {} - for q_idx, k_idx in np.argwhere(ambiguous): - q_state = q_state_cache.get(int(q_idx)) - if q_state is None: - q_state = _build_q_block_group_state( - q_abs=q_abs, - q_group=q_group, - q_parent=q_parent, - q_block=q_block, - block_idx=int(q_idx), - ) - q_state_cache[int(q_idx)] = q_state - k_state = k_state_cache.get(int(k_idx)) - if k_state is None: - k_state = _build_k_block_group_state( - k_abs=k_abs, - k_group=k_group, - k_block=k_block, - block_idx=int(k_idx), - ) - k_state_cache[int(k_idx)] = k_state - has_any, is_full = _exact_block_state( - q_state=q_state, - k_state=k_state, - ) - partial_blocks[q_idx, k_idx] = False - full_blocks[q_idx, k_idx] = False - if is_full: - full_blocks[q_idx, k_idx] = True - elif has_any: - partial_blocks[q_idx, k_idx] = True - partial_blocks &= ~full_blocks + needs_refine = full_blocks | ((touch_counts > 1) & partial_blocks) + if bool(needs_refine.any()): + refined_partial = partial_blocks & needs_refine + refined_full = full_blocks & needs_refine + _refine_interval_blocks( + partial_blocks=refined_partial, + full_blocks=refined_full, + q_abs=q_abs, + k_abs=k_abs, + q_enter=q_enter, + k_enter=k_enter, + k_exit=k_exit, + q_block=q_block, + k_block=k_block, + ) + partial_blocks = (partial_blocks & ~needs_refine) | refined_partial + full_blocks = (full_blocks & ~needs_refine) | refined_full kv_num_blocks, kv_indices = _dense_blocks_to_ordered( partial_blocks, device=device, @@ -321,7 +505,44 @@ def _build_sparse_block_mask( ) -def _valid_prefix(indices: torch.Tensor, *, name: str) -> torch.Tensor: +def prepare_block_mask_context( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> PreparedBlockMaskContext: + if group_ids.ndim != 1 or parent_ids.ndim != 1: + raise RuntimeError( + "Shared-prefix sparse block masks require rank-1 group_ids and parent_ids." + ) + if int(group_ids.numel()) != int(parent_ids.numel()): + raise RuntimeError( + "Shared-prefix sparse block masks require equal group_ids and parent_ids lengths." + ) + flat_group_ids = group_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) + flat_parent_ids = ( + parent_ids.detach().to(device="cpu", dtype=torch.int64).reshape(-1) + ) + row_tree = parse_shared_prefix_row( + group_ids=flat_group_ids, + parent_ids=flat_parent_ids, + ) + group_enter_np, group_exit_np = _build_group_interval_arrays( + row_tree=row_tree, + length=int(flat_group_ids.numel()), + ) + return PreparedBlockMaskContext( + source_len=int(flat_group_ids.numel()), + group_enter_np=group_enter_np, + group_exit_np=group_exit_np, + ) + + +def _validate_exact_indices( + indices: torch.Tensor, + *, + name: str, + source_len: int, +) -> int: if indices.ndim != 1: raise RuntimeError(f"{name} exact token indices must be rank 1.") if indices.dtype != torch.int64: @@ -334,52 +555,33 @@ def _valid_prefix(indices: torch.Tensor, *, name: str) -> torch.Tensor: raise RuntimeError( f"{name} exact token indices must use only contiguous tail padding." ) - return indices_cpu[:first_invalid] - return indices_cpu - - -def _validate_exact_indices( - indices: torch.Tensor, - *, - name: str, - source_len: int, -) -> int: - valid = _valid_prefix(indices, name=name) - if int(valid.numel()) == 0: + indices_cpu = indices_cpu[:first_invalid] + if int(indices_cpu.numel()) == 0: return 0 - if bool((valid[1:] <= valid[:-1]).any().item()): - raise RuntimeError(f"{name} exact token indices must be strictly increasing.") - max_index = int(valid[-1].item()) + if int(indices_cpu.unique().numel()) != int(indices_cpu.numel()): + raise RuntimeError(f"{name} exact token indices must not contain duplicates.") + max_index = int(indices_cpu.max().item()) if max_index >= int(source_len): raise RuntimeError( f"{name} exact token index {max_index} exceeds source metadata length {int(source_len)}." ) - return int(valid.numel()) + return int(indices_cpu.numel()) def _validate_supported_mask_spec( spec: FlexMaskSpec, *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + source_len: int, ) -> None: - if group_ids.ndim != 1 or parent_ids.ndim != 1: - raise RuntimeError( - "Shared-prefix sparse block masks require rank-1 group_ids and parent_ids." - ) - if int(group_ids.numel()) != int(parent_ids.numel()): - raise RuntimeError( - "Shared-prefix sparse block masks require equal group_ids and parent_ids lengths." - ) q_valid_len = _validate_exact_indices( spec.exact_mask.q_token_indices, name="q", - source_len=int(group_ids.numel()), + source_len=source_len, ) k_valid_len = _validate_exact_indices( spec.exact_mask.k_token_indices, name="k", - source_len=int(group_ids.numel()), + source_len=source_len, ) for slice_ in spec.slices: if int(slice_.row_index) != 0: @@ -404,12 +606,12 @@ def _validate_supported_mask_spec( ) -def build_block_mask( +def build_block_mask_from_context( spec: FlexMaskSpec, *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, + context: PreparedBlockMaskContext, device: torch.device, + validate: bool = True, ) -> BlockMask | None: if spec.q_len <= 0 or spec.k_len <= 0: return None @@ -423,12 +625,15 @@ def build_block_mask( "Exact stage k-token metadata length mismatch: " f"{int(spec.exact_mask.k_token_indices.numel())} != {int(spec.k_len)}" ) - _validate_supported_mask_spec(spec, group_ids=group_ids, parent_ids=parent_ids) + if validate: + _validate_supported_mask_spec( + spec, + source_len=context.source_len, + ) block_size = normalize_sparse_block_size(spec.block_size) return _build_sparse_block_mask( spec, device=device, - group_ids=group_ids, - parent_ids=parent_ids, + context=context, block_size=block_size, ) diff --git a/src/art/megatron/context_parallel/builder.py b/src/art/megatron/context_parallel/builder.py index 77ac1b623..5396873ab 100644 --- a/src/art/megatron/context_parallel/builder.py +++ b/src/art/megatron/context_parallel/builder.py @@ -2,110 +2,17 @@ import torch +from art.megatron.shared_prefix_tree import parse_shared_prefix_tree + from .types import ( AttnMaskKind, AttnSlice, PackedBatchAttentionSpec, PackedRowAttentionSpec, - SharedPrefixBuilderConfig, TokenRange, ) -def _valid_length( - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - *, - ignore_padding_group_id: int, -) -> int: - valid_mask = group_ids != ignore_padding_group_id - valid_count = int(valid_mask.sum().item()) - if valid_count == 0: - return 0 - if not bool(valid_mask[:valid_count].all().item()): - raise RuntimeError("Padding tokens must be a contiguous tail") - return _infer_terminal_padding_length( - group_ids[:valid_count], - parent_ids[:valid_count], - ) - - -def _infer_terminal_padding_length( - group_row: torch.Tensor, - parent_row: torch.Tensor, -) -> int: - if group_row.numel() == 0: - return 0 - runs = _scan_runs(group_row, parent_row) - if len(runs) < 2: - return int(group_row.numel()) - last_start, _last_end, last_group_id, last_parent_id = runs[-1] - if last_parent_id >= 0: - return int(group_row.numel()) - terminal_pair = (last_group_id, last_parent_id) - if any( - (group_id, parent_id) == terminal_pair - for _start, _end, group_id, parent_id in runs[:-1] - ): - return last_start - return int(group_row.numel()) - - -def _scan_runs( - group_row: torch.Tensor, - parent_row: torch.Tensor, -) -> list[tuple[int, int, int, int]]: - length = int(group_row.numel()) - if length == 0: - return [] - - group_changes = group_row[1:] != group_row[:-1] - parent_changes = parent_row[1:] != parent_row[:-1] - inconsistent_parent = torch.nonzero( - torch.logical_not(group_changes) & parent_changes, - as_tuple=False, - ).flatten() - if int(inconsistent_parent.numel()) > 0: - mismatch_index = int(inconsistent_parent[0].item()) + 1 - prior_boundaries = torch.nonzero( - group_changes[: mismatch_index - 1], - as_tuple=False, - ).flatten() - start = ( - 0 - if int(prior_boundaries.numel()) == 0 - else int(prior_boundaries[-1].item()) + 1 - ) - group_id = int(group_row[start].item()) - raise RuntimeError( - "Found one group run with inconsistent parent ids: " - f"group_id={group_id}, start={start}, end={mismatch_index}" - ) - - run_starts = torch.cat( - ( - torch.zeros(1, dtype=torch.int64, device=group_row.device), - torch.nonzero(group_changes, as_tuple=False).flatten() + 1, - ) - ) - run_ends = torch.cat( - ( - run_starts[1:], - torch.tensor([length], dtype=torch.int64, device=group_row.device), - ) - ) - starts = run_starts.to(device="cpu").tolist() - ends = run_ends.to(device="cpu").tolist() - group_ids = group_row.index_select(0, run_starts).to(device="cpu").tolist() - parent_ids = parent_row.index_select(0, run_starts).to(device="cpu").tolist() - return [ - (int(start), int(end), int(group_id), int(parent_id)) - for start, end, group_id, parent_id in zip( - starts, ends, group_ids, parent_ids, strict=True - ) - ] - - def _sort_and_dedupe_slices(slices: list[AttnSlice]) -> tuple[AttnSlice, ...]: sorted_slices = sorted( slices, @@ -138,23 +45,11 @@ def _sort_and_dedupe_slices(slices: list[AttnSlice]) -> tuple[AttnSlice, ...]: return tuple(deduped) -def _is_prompt_run( - *, - start: int, - group_id: int, - parent_id: int, - ignore_padding_group_id: int, -) -> bool: - return group_id == parent_id or ( - start == 0 and parent_id == ignore_padding_group_id - ) - - def build_shared_prefix_attention_spec( *, group_ids: torch.Tensor, parent_ids: torch.Tensor, - config: SharedPrefixBuilderConfig = SharedPrefixBuilderConfig(), + ignore_padding_group_id: int = -1, ) -> PackedBatchAttentionSpec: if group_ids.shape != parent_ids.shape: raise RuntimeError( @@ -166,127 +61,49 @@ def build_shared_prefix_attention_spec( "group_ids and parent_ids must be rank-2 packed tensors, got " f"{group_ids.ndim}" ) - if int(group_ids.shape[0]) != 1: - raise RuntimeError( - "ART shared-prefix attention spec currently supports exactly one packed sequence, " - f"got batch={int(group_ids.shape[0])}." - ) - rows: list[PackedRowAttentionSpec] = [] - for row_index in range(group_ids.shape[0]): - group_row = group_ids[row_index] - parent_row = parent_ids[row_index] - valid_tokens = _valid_length( - group_row, - parent_row, - ignore_padding_group_id=config.ignore_padding_group_id, - ) - if valid_tokens == 0: + for row in parse_shared_prefix_tree( + group_ids=group_ids, + parent_ids=parent_ids, + ignore_padding_group_id=ignore_padding_group_id, + ): + if row.valid_tokens == 0: rows.append( - PackedRowAttentionSpec(row_index=row_index, valid_tokens=0, slices=()) + PackedRowAttentionSpec( + row_index=row.row_index, valid_tokens=0, slices=() + ) ) continue - group_row = group_row[:valid_tokens] - parent_row = parent_row[:valid_tokens] - runs = _scan_runs(group_row, parent_row) - - group_run_count: dict[int, int] = {} - prompt_by_group_id: dict[int, tuple[tuple[int, int], int]] = {} - completion_ranges_by_prompt: dict[int, list[tuple[int, int]]] = {} - - for start, end, group_id, parent_id in runs: - group_run_count[group_id] = group_run_count.get(group_id, 0) + 1 - if _is_prompt_run( - start=start, - group_id=group_id, - parent_id=parent_id, - ignore_padding_group_id=config.ignore_padding_group_id, - ): - if group_id in prompt_by_group_id: - raise RuntimeError( - f"Prompt group_id {group_id} appears more than once in row {row_index}" - ) - family_index = len(prompt_by_group_id) - prompt_by_group_id[group_id] = ( - (start, end), - family_index, - ) - completion_ranges_by_prompt[group_id] = [] - - if config.require_contiguous_group_runs: - repeated_groups = { - group_id: count - for group_id, count in group_run_count.items() - if count > 1 and group_id != config.ignore_padding_group_id - } - if repeated_groups: - raise RuntimeError( - "Shared-prefix builder requires contiguous group runs per row, " - f"found repeats in row {row_index}: {repeated_groups}" - ) - - for start, end, group_id, parent_id in runs: - if _is_prompt_run( - start=start, - group_id=group_id, - parent_id=parent_id, - ignore_padding_group_id=config.ignore_padding_group_id, - ): - continue - prompt_entry = prompt_by_group_id.get(parent_id) - if prompt_entry is None: - raise RuntimeError( - "Completion run points to a missing prompt run: " - f"row={row_index}, group_id={group_id}, parent_id={parent_id}" - ) - completion_ranges_by_prompt[parent_id].append((start, end)) - + segment_by_group_id = {segment.group_id: segment for segment in row.segments} row_slices: list[AttnSlice] = [] - for prompt_group_id, ( - (prompt_start, prompt_end), - family_index, - ) in prompt_by_group_id.items(): - prompt_range = TokenRange(start=prompt_start, end=prompt_end) - row_slices.append( - AttnSlice( - q_range=prompt_range, - k_range=prompt_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=row_index, - family_index=family_index, - ) - ) - for completion_start, completion_end in completion_ranges_by_prompt[ - prompt_group_id - ]: - completion_range = TokenRange( - start=completion_start, - end=completion_end, - ) + for segment in row.segments: + q_range = TokenRange(start=segment.start, end=segment.end) + for ancestor_group_id in segment.ancestors: + ancestor = segment_by_group_id[ancestor_group_id] row_slices.append( AttnSlice( - q_range=completion_range, - k_range=prompt_range, + q_range=q_range, + k_range=TokenRange(start=ancestor.start, end=ancestor.end), mask_kind=AttnMaskKind.FULL, - row_index=row_index, - family_index=family_index, + row_index=row.row_index, + family_index=segment.family_index, ) ) - row_slices.append( - AttnSlice( - q_range=completion_range, - k_range=completion_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=row_index, - family_index=family_index, - ) + row_slices.append( + AttnSlice( + q_range=q_range, + k_range=q_range, + mask_kind=AttnMaskKind.CAUSAL, + row_index=row.row_index, + family_index=segment.family_index, ) + ) rows.append( PackedRowAttentionSpec( - row_index=row_index, - valid_tokens=valid_tokens, + row_index=row.row_index, + valid_tokens=row.valid_tokens, slices=_sort_and_dedupe_slices(row_slices), ) ) diff --git a/src/art/megatron/context_parallel/comm.py b/src/art/megatron/context_parallel/comm.py index c1767a4dc..8ea97067d 100644 --- a/src/art/megatron/context_parallel/comm.py +++ b/src/art/megatron/context_parallel/comm.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass from typing import Any, Protocol, cast @@ -141,37 +142,27 @@ def wait_post_process(self) -> tuple[torch.Tensor, torch.Tensor]: if range_.size() > 0 ) - def _apply_reduce() -> None: - dk_reduce = ( - dk_remote - if dk_remote.dtype == self.dk_local.dtype - else dk_remote.to(dtype=self.dk_local.dtype) - ) - dv_reduce = ( - dv_remote - if dv_remote.dtype == self.dv_local.dtype - else dv_remote.to(dtype=self.dv_local.dtype) - ) - reduce_fn = ( - range_reduce_sum_head_major_ - if self.input_layout == "head_major" - else range_reduce_sum_ - ) - reduce_fn( - dk_reduce, - output_tensor=self.dk_local, - ranges=flattened_ranges, - range_meta_cache=self.range_meta_cache, - ) - reduce_fn( - dv_reduce, - output_tensor=self.dv_local, - ranges=flattened_ranges, - range_meta_cache=self.range_meta_cache, - ) - return - - _apply_reduce() + reduce_fn = ( + range_reduce_sum_head_major_ + if self.input_layout == "head_major" + else range_reduce_sum_ + ) + reduce_fn( + dk_remote + if dk_remote.dtype == self.dk_local.dtype + else dk_remote.to(dtype=self.dk_local.dtype), + output_tensor=self.dk_local, + ranges=flattened_ranges, + range_meta_cache=self.range_meta_cache, + ) + reduce_fn( + dv_remote + if dv_remote.dtype == self.dv_local.dtype + else dv_remote.to(dtype=self.dv_local.dtype), + output_tensor=self.dv_local, + ranges=flattened_ranges, + range_meta_cache=self.range_meta_cache, + ) return self.dk_local, self.dv_local @@ -191,6 +182,60 @@ def _get_stream(self, tensor: torch.Tensor) -> torch.cuda.Stream | None: self._streams[device_index] = stream return stream + def _launch_exchange( + self, + *, + tensor: torch.Tensor, + recv_buffer: torch.Tensor, + total_send_rows: int, + make_send_buffer: Callable[[], torch.Tensor], + output_split_sizes: list[int], + input_split_sizes: list[int], + group: Any, + async_op: bool, + input_layout: str, + ) -> tuple[_Waitable | None, torch.Tensor, torch.cuda.Stream | None]: + stream = self._get_stream(tensor) if async_op else None + send_buffer = ( + tensor.new_empty( + _packed_peer_tensor_shape( + tensor=tensor, + total_rows=0, + input_layout=input_layout, + ) + ) + if total_send_rows <= 0 + else make_send_buffer() + ) + if stream is None: + return ( + _launch_peer_exchange( + recv_buffer=recv_buffer, + send_buffer=send_buffer, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ), + send_buffer, + None, + ) + + current_stream = torch.cuda.current_stream(tensor.device) + stream.wait_stream(current_stream) + send_buffer.record_stream(stream) + recv_buffer.record_stream(stream) + with torch.cuda.stream(stream): + handle = _launch_peer_exchange( + recv_buffer=recv_buffer, + send_buffer=send_buffer, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True, + ) + return handle, send_buffer, stream + def launch_kv_fetch( self, *, @@ -230,70 +275,23 @@ def launch_kv_fetch( ) input_split_sizes = [split * 2 for split in plan.send_splits] output_split_sizes = [split * 2 for split in plan.recv_splits] - stream = self._get_stream(k_local) if async_op else None - if stream is not None: - current_stream = torch.cuda.current_stream(k_local.device) - if total_send_rows <= 0: - send_buffer = k_local.new_empty( - _packed_peer_tensor_shape( - tensor=k_local, - total_rows=0, - input_layout=input_layout, - ) - ) - else: - send_buffer = _pack_gathered_tensors_per_peer( - left_tensor=k_local, - right_tensor=v_local, - ranges_by_peer=plan.send_ranges_by_peer, - range_meta_cache=range_meta_cache, - input_layout=input_layout, - ) - stream.wait_stream(current_stream) - send_buffer.record_stream(stream) - recv_packed.record_stream(stream) - with torch.cuda.stream(stream): - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=True, - ) - else: - if total_send_rows <= 0: - send_buffer = k_local.new_empty( - _packed_peer_tensor_shape( - tensor=k_local, - total_rows=0, - input_layout=input_layout, - ) - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - else: - send_buffer = _pack_gathered_tensors_per_peer( - left_tensor=k_local, - right_tensor=v_local, - ranges_by_peer=plan.send_ranges_by_peer, - range_meta_cache=range_meta_cache, - input_layout=input_layout, - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) + handle, send_buffer, stream = self._launch_exchange( + tensor=k_local, + recv_buffer=recv_packed, + total_send_rows=total_send_rows, + make_send_buffer=lambda: _pack_gathered_tensors_per_peer( + left_tensor=k_local, + right_tensor=v_local, + ranges_by_peer=plan.send_ranges_by_peer, + range_meta_cache=range_meta_cache, + input_layout=input_layout, + ), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + input_layout=input_layout, + ) return KvFetchWork( packed_buffer=recv_packed, recv_splits=plan.recv_splits, @@ -333,89 +331,33 @@ def launch_dkv_reduce( total_send_rows = int(sum(plan.send_splits)) recv_total = int(sum(plan.recv_splits)) - recv_packed = ( - dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=recv_total, - input_layout=input_layout, - ) - ) - if recv_total > 0 - else dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=0, - input_layout=input_layout, - ) + recv_packed = dk_remote.new_empty( + _packed_peer_tensor_shape( + tensor=dk_remote, + total_rows=recv_total, + input_layout=input_layout, ) ) input_split_sizes = [split * 2 for split in plan.send_splits] output_split_sizes = [split * 2 for split in plan.recv_splits] - stream = self._get_stream(dk_remote) if async_op else None - if stream is not None: - current_stream = torch.cuda.current_stream(dk_remote.device) - if total_send_rows <= 0: - send_buffer = dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=0, - input_layout=input_layout, - ) - ) - else: - send_buffer = _pack_split_tensors_by_peer( - left_tensor=dk_remote, - right_tensor=dv_remote, - splits=plan.send_splits, - input_layout=input_layout, - ) - stream.wait_stream(current_stream) - send_buffer.record_stream(stream) - recv_packed.record_stream(stream) - with torch.cuda.stream(stream): - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=True, - ) - else: - if total_send_rows <= 0: - send_buffer = dk_remote.new_empty( - _packed_peer_tensor_shape( - tensor=dk_remote, - total_rows=0, - input_layout=input_layout, - ) - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) - else: - send_buffer = _pack_split_tensors_by_peer( - left_tensor=dk_remote, - right_tensor=dv_remote, - splits=plan.send_splits, - input_layout=input_layout, - ) - handle = _launch_peer_exchange( - recv_buffer=recv_packed, - send_buffer=send_buffer, - output_split_sizes=output_split_sizes, - input_split_sizes=input_split_sizes, - group=group, - async_op=async_op, - ) + handle, send_buffer, stream = self._launch_exchange( + tensor=dk_remote, + recv_buffer=recv_packed, + total_send_rows=total_send_rows, + make_send_buffer=lambda: _pack_split_tensors_by_peer( + left_tensor=dk_remote, + right_tensor=dv_remote, + splits=plan.send_splits, + input_layout=input_layout, + ), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + input_layout=input_layout, + ) return DkvReduceWork( - packed_buffer=recv_packed if recv_total > 0 else None, + packed_buffer=recv_packed, handle=handle, send_buffer=send_buffer, stream=stream, @@ -449,29 +391,6 @@ def range_gather_per_peer( return torch.cat(chunks, dim=0).contiguous() -def _split_tensor_to_peer( - input_tensor: torch.Tensor, - splits: tuple[int, ...], -) -> torch.Tensor: - if int(sum(splits)) == 0: - return input_tensor.new_empty((0, *input_tensor.shape[1:])) - if int(input_tensor.shape[0]) == int(sum(splits)): - return input_tensor.contiguous() - if len([split for split in splits if split > 0]) > 1: - raise RuntimeError( - f"Expected at most one non-zero send split for dKV reduce, got {splits}" - ) - pieces: list[torch.Tensor] = [] - cursor = 0 - for split in splits: - if split == 0: - pieces.append(input_tensor.new_empty((0, *input_tensor.shape[1:]))) - continue - pieces.append(input_tensor[cursor : cursor + split]) - cursor += split - return torch.cat(pieces, dim=0).contiguous() - - def _pack_gathered_tensors_per_peer( *, left_tensor: torch.Tensor, @@ -480,56 +399,16 @@ def _pack_gathered_tensors_per_peer( range_meta_cache: dict[Any, Any] | None = None, input_layout: str = "token_major", ) -> torch.Tensor: - if input_layout == "head_major": - return _pack_gathered_tensors_per_peer_head_major( - left_tensor=left_tensor, - right_tensor=right_tensor, - ranges_by_peer=ranges_by_peer, - range_meta_cache=range_meta_cache, - ) - if input_layout != "token_major": - raise ValueError(f"Unsupported gathered-pack input layout: {input_layout}") - total_rows = sum( - range_.size() for peer_ranges in ranges_by_peer for range_ in peer_ranges - ) - if total_rows == 0: - return left_tensor.new_empty((0, *left_tensor.shape[1:])) - packed = left_tensor.new_empty((total_rows * 2, *left_tensor.shape[1:])) - cursor = 0 - for peer_ranges in ranges_by_peer: - split = sum(range_.size() for range_ in peer_ranges) - if split <= 0: - continue - range_gather( - left_tensor, - peer_ranges, - output=packed[cursor : cursor + split], - range_meta_cache=range_meta_cache, - ) - range_gather( - right_tensor, - peer_ranges, - output=packed[cursor + split : cursor + split * 2], - range_meta_cache=range_meta_cache, - ) - cursor += split * 2 - return packed - - -def _pack_gathered_tensors_per_peer_head_major( - *, - left_tensor: torch.Tensor, - right_tensor: torch.Tensor, - ranges_by_peer: tuple[tuple[TokenRange, ...], ...], - range_meta_cache: dict[Any, Any] | None = None, -) -> torch.Tensor: + _validate_peer_layout(input_layout, context="gathered-pack input") total_rows = sum( range_.size() for peer_ranges in ranges_by_peer for range_ in peer_ranges ) - if total_rows == 0: - return left_tensor.new_empty((0, left_tensor.shape[0], left_tensor.shape[2])) packed = left_tensor.new_empty( - (total_rows * 2, left_tensor.shape[0], left_tensor.shape[2]) + _packed_peer_tensor_shape( + tensor=left_tensor, + total_rows=total_rows, + input_layout=input_layout, + ) ) cursor = 0 for peer_ranges in ranges_by_peer: @@ -537,18 +416,20 @@ def _pack_gathered_tensors_per_peer_head_major( if split <= 0: continue packed[cursor : cursor + split].copy_( - range_gather_head_major( + _gather_peer_rows( left_tensor, peer_ranges, + input_layout=input_layout, range_meta_cache=range_meta_cache, - ).permute(1, 0, 2) + ) ) packed[cursor + split : cursor + split * 2].copy_( - range_gather_head_major( + _gather_peer_rows( right_tensor, peer_ranges, + input_layout=input_layout, range_meta_cache=range_meta_cache, - ).permute(1, 0, 2) + ) ) cursor += split * 2 return packed @@ -561,79 +442,83 @@ def _pack_split_tensors_by_peer( splits: tuple[int, ...], input_layout: str = "token_major", ) -> torch.Tensor: - if input_layout == "head_major": - return _pack_split_tensors_by_peer_head_major( - left_tensor=left_tensor, - right_tensor=right_tensor, - splits=splits, - ) - if input_layout != "token_major": - raise ValueError(f"Unsupported split-pack input layout: {input_layout}") + _validate_peer_layout(input_layout, context="split-pack input") total_rows = int(sum(splits)) - if total_rows == 0: - return left_tensor.new_empty((0, *left_tensor.shape[1:])) - packed = left_tensor.new_empty((total_rows * 2, *left_tensor.shape[1:])) + packed = left_tensor.new_empty( + _packed_peer_tensor_shape( + tensor=left_tensor, + total_rows=total_rows, + input_layout=input_layout, + ) + ) cursor = 0 for split in splits: if split <= 0: continue packed[cursor * 2 : cursor * 2 + split].copy_( - left_tensor[cursor : cursor + split] + _slice_peer_rows(left_tensor, cursor, cursor + split, layout=input_layout) ) packed[cursor * 2 + split : cursor * 2 + split * 2].copy_( - right_tensor[cursor : cursor + split] + _slice_peer_rows(right_tensor, cursor, cursor + split, layout=input_layout) ) cursor += split - if cursor != int(left_tensor.shape[0]) or cursor != int(right_tensor.shape[0]): + left_rows = _peer_row_count(left_tensor, layout=input_layout) + right_rows = _peer_row_count(right_tensor, layout=input_layout) + if cursor != left_rows or cursor != right_rows: raise RuntimeError( "Packed split consumed the wrong number of rows: " - f"consumed={cursor}, left={int(left_tensor.shape[0])}, right={int(right_tensor.shape[0])}" + f"consumed={cursor}, left={left_rows}, right={right_rows}" ) return packed +def _validate_peer_layout(layout: str, *, context: str) -> None: + if layout not in {"token_major", "head_major"}: + raise ValueError(f"Unsupported {context} layout: {layout}") + + def _packed_peer_tensor_shape( *, tensor: torch.Tensor, total_rows: int, input_layout: str, ) -> tuple[int, ...]: + _validate_peer_layout(input_layout, context="peer tensor input") if input_layout == "head_major": return (total_rows * 2, int(tensor.shape[0]), int(tensor.shape[2])) - if input_layout != "token_major": - raise ValueError(f"Unsupported split-pack input layout: {input_layout}") return (total_rows * 2, *tuple(int(dim) for dim in tensor.shape[1:])) -def _pack_split_tensors_by_peer_head_major( +def _peer_row_count(tensor: torch.Tensor, *, layout: str) -> int: + return int(tensor.shape[1] if layout == "head_major" else tensor.shape[0]) + + +def _slice_peer_rows( + tensor: torch.Tensor, + start: int, + end: int, *, - left_tensor: torch.Tensor, - right_tensor: torch.Tensor, - splits: tuple[int, ...], + layout: str, ) -> torch.Tensor: - total_rows = int(sum(splits)) - if total_rows == 0: - return left_tensor.new_empty((0, left_tensor.shape[0], left_tensor.shape[2])) - packed = left_tensor.new_empty( - (total_rows * 2, left_tensor.shape[0], left_tensor.shape[2]) - ) - cursor = 0 - for split in splits: - if split <= 0: - continue - packed[cursor * 2 : cursor * 2 + split].copy_( - left_tensor[:, cursor : cursor + split].permute(1, 0, 2) - ) - packed[cursor * 2 + split : cursor * 2 + split * 2].copy_( - right_tensor[:, cursor : cursor + split].permute(1, 0, 2) - ) - cursor += split - if cursor != int(left_tensor.shape[1]) or cursor != int(right_tensor.shape[1]): - raise RuntimeError( - "Head-major split pack consumed the wrong number of rows: " - f"consumed={cursor}, left={int(left_tensor.shape[1])}, right={int(right_tensor.shape[1])}" - ) - return packed + if layout == "head_major": + return tensor[:, start:end].movedim(1, 0) + return tensor[start:end] + + +def _gather_peer_rows( + tensor: torch.Tensor, + ranges: tuple[TokenRange, ...], + *, + input_layout: str, + range_meta_cache: dict[Any, Any] | None, +) -> torch.Tensor: + if input_layout == "head_major": + return range_gather_head_major( + tensor, + ranges, + range_meta_cache=range_meta_cache, + ).movedim(1, 0) + return range_gather(tensor, ranges, range_meta_cache=range_meta_cache) def _unpack_packed_tensor_per_peer( @@ -642,15 +527,13 @@ def _unpack_packed_tensor_per_peer( *, output_layout: str = "token_major", ) -> tuple[torch.Tensor, torch.Tensor]: - if output_layout == "head_major": - return _unpack_packed_tensor_per_peer_head_major( + _validate_peer_layout(output_layout, context="packed-tensor output") + if int(packed_tensor.shape[0]) == 0: + empty = _new_unpacked_peer_tensor( packed_tensor, - splits, + total_rows=0, + output_layout=output_layout, ) - if output_layout != "token_major": - raise ValueError(f"Unsupported packed-tensor output layout: {output_layout}") - if int(packed_tensor.shape[0]) == 0: - empty = packed_tensor.new_empty((0, *packed_tensor.shape[1:])) return empty, empty total_rows = 0 cursor = 0 @@ -664,62 +547,59 @@ def _unpack_packed_tensor_per_peer( "Packed tensor unpack consumed the wrong number of rows: " f"consumed={cursor}, input={int(packed_tensor.shape[0])}" ) - left = packed_tensor.new_empty((total_rows, *packed_tensor.shape[1:])) - right = packed_tensor.new_empty((total_rows, *packed_tensor.shape[1:])) + left = _new_unpacked_peer_tensor( + packed_tensor, + total_rows=total_rows, + output_layout=output_layout, + ) + right = _new_unpacked_peer_tensor( + packed_tensor, + total_rows=total_rows, + output_layout=output_layout, + ) in_cursor = 0 out_cursor = 0 for split in splits: if split <= 0: continue - left[out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor : in_cursor + split] + _copy_from_peer_rows( + left, + out_cursor, + packed_tensor[in_cursor : in_cursor + split], + output_layout=output_layout, ) - right[out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor + split : in_cursor + split * 2] + _copy_from_peer_rows( + right, + out_cursor, + packed_tensor[in_cursor + split : in_cursor + split * 2], + output_layout=output_layout, ) in_cursor += split * 2 out_cursor += split return left, right -def _unpack_packed_tensor_per_peer_head_major( +def _new_unpacked_peer_tensor( packed_tensor: torch.Tensor, - splits: tuple[int, ...], -) -> tuple[torch.Tensor, torch.Tensor]: - if int(packed_tensor.shape[0]) == 0: - empty = packed_tensor.new_empty( - (packed_tensor.shape[1], 0, packed_tensor.shape[2]) - ) - return empty, empty - total_rows = 0 - cursor = 0 - for split in splits: - if split <= 0: - continue - cursor += split * 2 - total_rows += split - if cursor != int(packed_tensor.shape[0]): - raise RuntimeError( - "Packed tensor unpack consumed the wrong number of rows: " - f"consumed={cursor}, input={int(packed_tensor.shape[0])}" - ) - left = packed_tensor.new_empty( - (packed_tensor.shape[1], total_rows, packed_tensor.shape[2]) - ) - right = packed_tensor.new_empty( - (packed_tensor.shape[1], total_rows, packed_tensor.shape[2]) - ) - in_cursor = 0 - out_cursor = 0 - for split in splits: - if split <= 0: - continue - left[:, out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor : in_cursor + split].permute(1, 0, 2) - ) - right[:, out_cursor : out_cursor + split].copy_( - packed_tensor[in_cursor + split : in_cursor + split * 2].permute(1, 0, 2) + *, + total_rows: int, + output_layout: str, +) -> torch.Tensor: + if output_layout == "head_major": + return packed_tensor.new_empty( + (packed_tensor.shape[1], total_rows, *packed_tensor.shape[2:]) ) - in_cursor += split * 2 - out_cursor += split - return left, right + return packed_tensor.new_empty((total_rows, *packed_tensor.shape[1:])) + + +def _copy_from_peer_rows( + output: torch.Tensor, + start: int, + rows: torch.Tensor, + *, + output_layout: str, +) -> None: + if output_layout == "head_major": + output[:, start : start + int(rows.shape[0])].copy_(rows.movedim(0, 1)) + else: + output[start : start + int(rows.shape[0])].copy_(rows) diff --git a/src/art/megatron/context_parallel/executor.py b/src/art/megatron/context_parallel/executor.py index e5e219e72..5beaec9f4 100644 --- a/src/art/megatron/context_parallel/executor.py +++ b/src/art/megatron/context_parallel/executor.py @@ -19,7 +19,7 @@ sparse_compiled_flex_attention, ) -from .block_mask import build_block_mask +from .block_mask import build_block_mask_from_context, prepare_block_mask_context from .comm import A2AVCommunicator from .range_ops import ( range_gather_head_major, @@ -684,17 +684,24 @@ def _build_stage_block_mask( raise RuntimeError( f"Stage {stage_plan.stage_index} is missing exact mask metadata" ) - mask = build_block_mask( + block_mask_context = state.execution_cache.block_mask_context + if block_mask_context is None: + block_mask_context = prepare_block_mask_context( + group_ids=state.group_ids, + parent_ids=state.parent_ids, + ) + state.execution_cache.block_mask_context = block_mask_context + mask = build_block_mask_from_context( FlexMaskSpec( q_len=int(execution_spec.q_len), k_len=int(execution_spec.k_len), block_size=resolved_block_size, slices=stage_plan.slices, - exact_mask=mask_metadata.model_dump(mode="python"), + exact_mask=mask_metadata, ), - group_ids=state.group_ids, - parent_ids=state.parent_ids, + context=block_mask_context, device=device, + validate=False, ) cache[cache_key] = mask return mask @@ -774,30 +781,6 @@ def prepare_context_parallel_execution_state( ) -def _causal_slice_pair_count(slice_: AttnSlice) -> int: - q_start = int(slice_.q_range.start) - q_end = int(slice_.q_range.end) - k_start = int(slice_.k_range.start) - k_end = int(slice_.k_range.end) - if q_end <= q_start or k_end <= k_start: - return 0 - - k_len = k_end - k_start - partial_q_start = max(q_start, k_start) - partial_q_end = min(q_end - 1, k_end - 2) - partial = 0 - if partial_q_start <= partial_q_end: - count = partial_q_end - partial_q_start + 1 - partial = count * (partial_q_start + partial_q_end + 2 - 2 * k_start) // 2 - - full_q_start = max(q_start, k_end - 1) - full_q_end = q_end - 1 - full = 0 - if full_q_start <= full_q_end: - full = (full_q_end - full_q_start + 1) * k_len - return int(partial + full) - - def _validate_stage_block_alignment( *, q_len: int, diff --git a/src/art/megatron/context_parallel/layout_index.py b/src/art/megatron/context_parallel/layout_index.py index 99fb2c35b..9f60550a0 100644 --- a/src/art/megatron/context_parallel/layout_index.py +++ b/src/art/megatron/context_parallel/layout_index.py @@ -1,10 +1,9 @@ from __future__ import annotations -from pydantic import BaseModel, ConfigDict +from dataclasses import dataclass -class TokenLayoutIndex(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class TokenLayoutIndex: ownership_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] token_counts_by_rank: tuple[int, ...] diff --git a/src/art/megatron/context_parallel/runtime.py b/src/art/megatron/context_parallel/runtime.py index c6eb9fddd..b89c42fd2 100644 --- a/src/art/megatron/context_parallel/runtime.py +++ b/src/art/megatron/context_parallel/runtime.py @@ -1,12 +1,11 @@ from __future__ import annotations from bisect import bisect_left, bisect_right +from dataclasses import dataclass, replace import hashlib import json from typing import Any, cast -import warnings -from pydantic import BaseModel, ConfigDict import torch from art.loss import shift_tensor @@ -19,8 +18,6 @@ AttnMaskKind, AttnSlice, ContextParallelConfig, - ContextParallelRuntimeKey, - ContextParallelRuntimePlan, DispatchedPackedTensors, DkvReducePlan, ExactMaskMetadata, @@ -28,17 +25,12 @@ PackedBatchAttentionSpec, PackedRowAttentionSpec, ParallelTopology, - PlannerProvenance, PreparedMegatronBatch, RankRuntimePlan, StagePlan, TokenRange, ) -_PLANNER_RUNTIME_BACKEND = "art_context_parallel" -_PLANNER_BEST_EFFORT_WARNING_KEYS: set[ - tuple[str, str, int, str, str, tuple[int, ...]] -] = set() _CHUNK_MASK_STATS_TORCH_THRESHOLD = 1024 _CP4_SEARCH_PROBE_CANDIDATE_LIMIT = 2 _CP4_SEARCH_PROBE_IMPROVEMENT_MS = 1.0 @@ -48,17 +40,15 @@ StageSliceKey = tuple[int, int, int, int, int, str, int] -class _PlanningBundle(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +@dataclass(frozen=True) +class _PlanningBundle: spec: PackedBatchAttentionSpec - runtime_key: ContextParallelRuntimeKey - runtime_plan: ContextParallelRuntimePlan + rank_plans: tuple[RankRuntimePlan, ...] gdn_execution_spec: Any | None = None _PLANNING_BUNDLE_CACHE: dict[str, _PlanningBundle] = {} -_RUNTIME_PLAN_CACHE: dict[tuple[str, int], ContextParallelRuntimePlan] = {} +_RUNTIME_PLAN_CACHE: dict[str, tuple[RankRuntimePlan, ...]] = {} _GDN_RANK_PLAN_CACHE: dict[tuple[str, str, int | None, int], Any] = {} @@ -111,174 +101,14 @@ def _planning_bundle_cache_key( { "group_ids": _metadata_tensor_digest(group_ids), "parent_ids": _metadata_tensor_digest(parent_ids), - "topology": topology.model_dump(mode="json"), - "config": config.model_dump(mode="json"), + "topology": _dataclass_payload(topology), + "config": _dataclass_payload(config), "original_seq_len": int(original_seq_len), "build_gdn_execution_spec": bool(build_gdn_execution_spec), } ) -def _rank_plan_cache_key( - *, - planning_key: str, - device: torch.device, - cp_rank: int, -) -> tuple[str, str, int | None, int]: - return (planning_key, device.type, device.index, int(cp_rank)) - - -def _config_for_runtime_cp( - *, - topology: ParallelTopology, - config: ContextParallelConfig, -) -> ContextParallelConfig: - cp_size = max(int(topology.cp), 1) - updates: dict[str, Any] = {} - applied_override = False - for override in config.planner_cp_overrides: - if int(override.cp_size) != cp_size: - continue - override_updates = override.model_dump(mode="python", exclude_none=True) - override_updates.pop("cp_size", None) - updates.update(override_updates) - applied_override = True - if not applied_override: - return config - updates.setdefault("planner_tuned_cp_sizes", (cp_size,)) - return config.model_copy(update=updates) - - -def _normalized_planner_metadata_value(value: str | None) -> str: - if value is None: - return "" - normalized = "".join( - character.lower() if character.isalnum() else " " - for character in str(value).strip() - ) - return " ".join(part for part in normalized.split() if part) - - -def _planner_metadata_matches( - expected: str | None, - actual: str | None, - *, - fuzzy: bool, -) -> bool: - normalized_expected = _normalized_planner_metadata_value(expected) - normalized_actual = _normalized_planner_metadata_value(actual) - if not normalized_expected or not normalized_actual: - return False - if normalized_expected == normalized_actual: - return True - return bool( - fuzzy - and ( - normalized_expected in normalized_actual - or normalized_actual in normalized_expected - ) - ) - - -def _planner_runtime_hardware() -> str | None: - if not torch.cuda.is_available(): - return None - try: - return str(torch.cuda.get_device_name(torch.cuda.current_device())) - except Exception: - return str(torch.cuda.get_device_name(0)) - - -def _planner_best_effort_warning_message(provenance: PlannerProvenance) -> str: - mismatch_reasons: list[str] = [] - if not provenance.backend_match: - mismatch_reasons.append( - f"backend runtime={provenance.runtime_backend!r} tuned={provenance.tuned_backend!r}" - ) - if not provenance.hardware_match: - mismatch_reasons.append( - f"hardware runtime={provenance.runtime_hardware!r} tuned={provenance.tuned_hardware!r}" - ) - if not provenance.cp_size_match: - mismatch_reasons.append( - f"cp_size runtime={int(provenance.runtime_cp_size)} tuned={list(provenance.tuned_cp_sizes)}" - ) - mismatch_text = ( - "; ".join(mismatch_reasons) if mismatch_reasons else "metadata missing" - ) - return ( - "ART context parallel planner coefficients are running in best-effort mode; " - f"{mismatch_text}. The runtime will continue with the configured coefficients." - ) - - -def _planner_provenance( - *, - topology: ParallelTopology, - config: ContextParallelConfig, - warn: bool = True, -) -> PlannerProvenance: - runtime_hardware = _planner_runtime_hardware() - tuned_cp_sizes = tuple( - sorted( - { - int(cp_size) - for cp_size in config.planner_tuned_cp_sizes - if int(cp_size) > 0 - } - ) - ) - provenance = PlannerProvenance( - runtime_backend=_PLANNER_RUNTIME_BACKEND, - runtime_hardware=runtime_hardware, - runtime_cp_size=max(int(topology.cp), 1), - tuned_backend=config.planner_tuned_backend, - tuned_hardware=config.planner_tuned_hardware, - tuned_cp_sizes=tuned_cp_sizes, - backend_match=_planner_metadata_matches( - config.planner_tuned_backend, - _PLANNER_RUNTIME_BACKEND, - fuzzy=False, - ), - hardware_match=_planner_metadata_matches( - config.planner_tuned_hardware, - runtime_hardware, - fuzzy=True, - ), - cp_size_match=bool(tuned_cp_sizes) - and max(int(topology.cp), 1) in tuned_cp_sizes, - using_best_effort=False, - ) - if ( - provenance.backend_match - and provenance.hardware_match - and provenance.cp_size_match - ): - return provenance - - warning_message = _planner_best_effort_warning_message(provenance) - warning_key = ( - _normalized_planner_metadata_value(provenance.runtime_backend), - _normalized_planner_metadata_value(provenance.runtime_hardware), - int(provenance.runtime_cp_size), - _normalized_planner_metadata_value(provenance.tuned_backend), - _normalized_planner_metadata_value(provenance.tuned_hardware), - provenance.tuned_cp_sizes, - ) - warning_emitted = False - if warn and warning_key not in _PLANNER_BEST_EFFORT_WARNING_KEYS: - _PLANNER_BEST_EFFORT_WARNING_KEYS.add(warning_key) - warnings.warn(warning_message, RuntimeWarning, stacklevel=3) - warning_emitted = True - return provenance.model_copy( - update={ - "using_best_effort": True, - "warning_message": warning_message, - "warning_emitted": warning_emitted, - } - ) - - def _normalized_chunk_size( *, valid_tokens: int, @@ -351,7 +181,7 @@ def _search_config_for_chunk_count( return config if all(int(getattr(config, key)) == int(value) for key, value in updates.items()): return config - return config.model_copy(update=updates) + return replace(config, **updates) def _best_improving_move( @@ -387,9 +217,9 @@ def _best_improving_move( candidate = list(current_owners) candidate[chunk_index] = dst_rank candidate_owners = tuple(candidate) - if not _assignment_uses_all_ranks( - candidate_owners, - cp_size=cp_size, + if ( + len(candidate_owners) >= cp_size + and len(set(candidate_owners)) != cp_size ): continue candidate_eval = evaluate_candidate( @@ -410,12 +240,10 @@ def _build_chunk_ranges( valid_tokens: int, chunk_size: int, ) -> tuple[TokenRange, ...]: - ranges: list[TokenRange] = [] - for start in range(0, valid_tokens, chunk_size): - ranges.append( - TokenRange(start=start, end=min(start + chunk_size, valid_tokens)) - ) - return tuple(ranges) + return tuple( + TokenRange(start=start, end=min(start + chunk_size, valid_tokens)) + for start in range(0, valid_tokens, chunk_size) + ) def _indexed_intersections( @@ -445,33 +273,6 @@ def _indexed_intersections( return intersections -def _slice_pair_count( - *, - mask_kind: AttnMaskKind, - q_range: TokenRange, - k_range: TokenRange, -) -> int: - if mask_kind is AttnMaskKind.FULL: - return int(q_range.size()) * int(k_range.size()) - return _causal_piece_pair_count( - q_range=q_range, - k_range=k_range, - ) - - -def _causal_piece_pair_count( - *, - q_range: TokenRange, - k_range: TokenRange, -) -> int: - return _causal_piece_pair_count_from_bounds( - q_start=int(q_range.start), - q_end=int(q_range.end), - k_start=int(k_range.start), - k_end=int(k_range.end), - ) - - def _causal_piece_pair_count_from_bounds( *, q_start: int, @@ -498,91 +299,15 @@ def _causal_piece_pair_count_from_bounds( return int(partial + full) -def _chunk_piece_decomposition( - *, - start: int, - end: int, - chunk_size: int, -) -> tuple[ - int, tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...], int -]: - first = start // chunk_size - last = (end - 1) // chunk_size - piece_starts: list[int] = [] - piece_ends: list[int] = [] - piece_lengths: list[int] = [] - piece_prefix_lengths: list[int] = [] - running_len = 0 - for chunk_index in range(first, last + 1): - piece_start = start if chunk_index == first else chunk_index * chunk_size - piece_end = end if chunk_index == last else (chunk_index + 1) * chunk_size - piece_len = piece_end - piece_start - if piece_len <= 0: - continue - running_len += piece_len - piece_starts.append(piece_start) - piece_ends.append(piece_end) - piece_lengths.append(piece_len) - piece_prefix_lengths.append(running_len) - return ( - first, - tuple(piece_starts), - tuple(piece_ends), - tuple(piece_lengths), - tuple(piece_prefix_lengths), - running_len, - ) - - -def _can_use_shared_prefix_chunk_pair_program( - row_spec: PackedRowAttentionSpec, -) -> bool: - slices = row_spec.slices - index = 0 - while index < len(slices): - prompt_slice = slices[index] - if ( - prompt_slice.family_index is None - or prompt_slice.mask_kind is not AttnMaskKind.CAUSAL - or prompt_slice.q_range != prompt_slice.k_range - ): - return False - prompt_family_index = prompt_slice.family_index - if prompt_family_index is None: - raise RuntimeError("shared-prefix prompt slices must carry family_index") - family_index = int(prompt_family_index) - prompt_start = int(prompt_slice.q_range.start) - prompt_end = int(prompt_slice.q_range.end) - index += 1 - while index < len(slices): - family_value = slices[index].family_index - if family_value is None or int(family_value) != family_index: - break - if index + 1 >= len(slices): - return False - full_slice = slices[index] - causal_slice = slices[index + 1] - if ( - full_slice.family_index != prompt_slice.family_index - or causal_slice.family_index != prompt_slice.family_index - or full_slice.mask_kind is not AttnMaskKind.FULL - or causal_slice.mask_kind is not AttnMaskKind.CAUSAL - or full_slice.q_range != causal_slice.q_range - or causal_slice.q_range != causal_slice.k_range - or int(full_slice.k_range.start) != prompt_start - or int(full_slice.k_range.end) != prompt_end - ): - return False - index += 2 - return True - - -def _build_chunk_pair_program_generic( +def _build_chunk_pair_program( row_spec: PackedRowAttentionSpec, *, - chunk_count: int, - chunk_size: int, + chunk_ranges: tuple[TokenRange, ...], ) -> tuple[torch.Tensor, list[float]]: + chunk_count = len(chunk_ranges) + if chunk_count == 0: + return torch.zeros((0, 0), dtype=torch.int64), [] + chunk_size = int(chunk_ranges[0].size()) pair_rows = [[0 for _ in range(chunk_count)] for _ in range(chunk_count)] q_weights = [0.0 for _ in range(chunk_count)] @@ -681,138 +406,6 @@ def _build_chunk_pair_program_generic( return torch.tensor(pair_rows, dtype=torch.int64), q_weights -def _build_chunk_pair_program( - row_spec: PackedRowAttentionSpec, - *, - chunk_ranges: tuple[TokenRange, ...], -) -> tuple[torch.Tensor, list[float]]: - chunk_count = len(chunk_ranges) - if chunk_count == 0: - return torch.zeros((0, 0), dtype=torch.int64), [] - chunk_size = int(chunk_ranges[0].size()) - if not _can_use_shared_prefix_chunk_pair_program(row_spec): - return _build_chunk_pair_program_generic( - row_spec, - chunk_count=chunk_count, - chunk_size=chunk_size, - ) - - pair_rows = [[0 for _ in range(chunk_count)] for _ in range(chunk_count)] - q_weights = [0.0 for _ in range(chunk_count)] - slices = row_spec.slices - index = 0 - while index < len(slices): - prompt_slice = slices[index] - ( - prompt_first, - prompt_starts, - prompt_ends, - prompt_lengths, - prompt_prefix, - prompt_total, - ) = _chunk_piece_decomposition( - start=int(prompt_slice.q_range.start), - end=int(prompt_slice.q_range.end), - chunk_size=chunk_size, - ) - for offset, q_chunk_index in enumerate( - range(prompt_first, prompt_first + len(prompt_lengths)) - ): - q_piece_len = prompt_lengths[offset] - row = pair_rows[q_chunk_index] - q_total = 0 - if offset > 0: - for k_offset in range(offset): - row[prompt_first + k_offset] += ( - q_piece_len * prompt_lengths[k_offset] - ) - q_total += q_piece_len * prompt_prefix[offset - 1] - pair_count = _causal_piece_pair_count_from_bounds( - q_start=prompt_starts[offset], - q_end=prompt_ends[offset], - k_start=prompt_starts[offset], - k_end=prompt_ends[offset], - ) - if pair_count > 0: - row[q_chunk_index] += pair_count - q_total += pair_count - if q_total > 0: - q_weights[q_chunk_index] += float(q_total) - - prompt_family_index = prompt_slice.family_index - if prompt_family_index is None: - raise RuntimeError("shared-prefix prompt slices must carry family_index") - family_index = int(prompt_family_index) - index += 1 - completion_chunk_indices: list[int] = [] - completion_chunk_totals: list[int] = [] - while index < len(slices): - family_value = slices[index].family_index - if family_value is None or int(family_value) != family_index: - break - full_slice = slices[index] - ( - completion_first, - completion_starts, - completion_ends, - completion_lengths, - completion_prefix, - _, - ) = _chunk_piece_decomposition( - start=int(full_slice.q_range.start), - end=int(full_slice.q_range.end), - chunk_size=chunk_size, - ) - for offset, q_chunk_index in enumerate( - range(completion_first, completion_first + len(completion_lengths)) - ): - q_piece_len = completion_lengths[offset] - if ( - completion_chunk_indices - and completion_chunk_indices[-1] == q_chunk_index - ): - completion_chunk_totals[-1] += q_piece_len - else: - completion_chunk_indices.append(q_chunk_index) - completion_chunk_totals.append(q_piece_len) - - for offset, q_chunk_index in enumerate( - range(completion_first, completion_first + len(completion_lengths)) - ): - q_piece_len = completion_lengths[offset] - row = pair_rows[q_chunk_index] - q_total = 0 - if offset > 0: - for k_offset in range(offset): - row[completion_first + k_offset] += ( - q_piece_len * completion_lengths[k_offset] - ) - q_total += q_piece_len * completion_prefix[offset - 1] - pair_count = _causal_piece_pair_count_from_bounds( - q_start=completion_starts[offset], - q_end=completion_ends[offset], - k_start=completion_starts[offset], - k_end=completion_ends[offset], - ) - if pair_count > 0: - row[q_chunk_index] += pair_count - q_total += pair_count - if q_total > 0: - q_weights[q_chunk_index] += float(q_total) - index += 2 - - for q_chunk_index, total_q_len in zip( - completion_chunk_indices, - completion_chunk_totals, - strict=True, - ): - row = pair_rows[q_chunk_index] - for k_offset, k_piece_len in enumerate(prompt_lengths): - row[prompt_first + k_offset] += total_q_len * k_piece_len - q_weights[q_chunk_index] += float(total_q_len * prompt_total) - return torch.tensor(pair_rows, dtype=torch.int64), q_weights - - def _collect_rank_stage_pieces( row_spec: PackedRowAttentionSpec, *, @@ -962,63 +555,6 @@ def _contiguous_chunk_assignment( return tuple(owners) -def _bucket_chunk_assignment( - *, - q_weights: list[float], - cp_size: int, -) -> tuple[int, ...]: - chunk_count = len(q_weights) - if chunk_count == 0: - return tuple() - if cp_size <= 1: - return tuple(0 for _ in range(chunk_count)) - rank_loads = [0.0 for _ in range(cp_size)] - rank_chunk_counts = [0 for _ in range(cp_size)] - owners = [-1 for _ in range(chunk_count)] - for chunk_index in sorted( - range(chunk_count), - key=lambda index: (-q_weights[index], index), - ): - rank = min( - range(cp_size), - key=lambda candidate: ( - rank_loads[candidate], - rank_chunk_counts[candidate], - candidate, - ), - ) - owners[chunk_index] = rank - rank_loads[rank] += q_weights[chunk_index] - rank_chunk_counts[rank] += 1 - return tuple(int(owner) for owner in owners) - - -def _striped_chunk_assignment( - *, - chunk_count: int, - cp_size: int, - group_size: int, -) -> tuple[int, ...]: - if chunk_count == 0: - return tuple() - if cp_size <= 1: - return tuple(0 for _ in range(chunk_count)) - group_size = max(1, int(group_size)) - return tuple( - ((chunk_index // group_size) % cp_size) for chunk_index in range(chunk_count) - ) - - -def _assignment_uses_all_ranks( - owners: tuple[int, ...], - *, - cp_size: int, -) -> bool: - if len(owners) < cp_size: - return True - return len({int(owner) for owner in owners}) == cp_size - - def _candidate_chunk_indices( *, owners: tuple[int, ...], @@ -1127,31 +663,6 @@ def _chunk_mask_stats( return token_count, range_count -def _merge_chunk_ranges_from_mask( - *, - chunk_ranges: tuple[TokenRange, ...], - chunk_mask: torch.Tensor, -) -> tuple[TokenRange, ...]: - chunk_indices = torch.nonzero(chunk_mask, as_tuple=False).flatten() - if int(chunk_indices.numel()) == 0: - return tuple() - ordered_chunk_indices = chunk_indices.tolist() - first_range = chunk_ranges[int(ordered_chunk_indices[0])] - current_start = int(first_range.start) - current_end = int(first_range.end) - merged: list[TokenRange] = [] - for chunk_index in ordered_chunk_indices[1:]: - range_ = chunk_ranges[int(chunk_index)] - if int(range_.start) <= current_end: - current_end = max(current_end, int(range_.end)) - continue - merged.append(TokenRange(start=current_start, end=current_end)) - current_start = int(range_.start) - current_end = int(range_.end) - merged.append(TokenRange(start=current_start, end=current_end)) - return tuple(merged) - - def _stage_cost_ms( *, pair_count: int, @@ -1521,31 +1032,6 @@ def _evaluate_plan( } -def _evaluate_plan_for_search( - *, - chunk_ranges: tuple[TokenRange, ...], - pair_matrix: list[list[int]] | torch.Tensor, - owners: tuple[int, ...], - wave_assignment: tuple[int, ...], - cp_size: int, - config: ContextParallelConfig, - pair_positive: torch.Tensor | None = None, - chunk_lengths: tuple[int, ...] | None = None, - chunk_lengths_tensor: torch.Tensor | None = None, -) -> dict[str, Any]: - return _evaluate_plan( - chunk_ranges=chunk_ranges, - pair_matrix=pair_matrix, - owners=owners, - wave_assignment=wave_assignment, - cp_size=cp_size, - config=config, - pair_positive=pair_positive, - chunk_lengths=chunk_lengths, - chunk_lengths_tensor=chunk_lengths_tensor, - ) - - def _search_chunk_assignment( *, chunk_ranges: tuple[TokenRange, ...], @@ -1554,7 +1040,6 @@ def _search_chunk_assignment( cp_size: int, config: ContextParallelConfig, ) -> tuple[tuple[int, ...], tuple[int, ...], dict[str, Any]]: - cp_size = int(cp_size) config = _search_config_for_chunk_count( config=config, chunk_count=len(chunk_ranges), @@ -1563,9 +1048,7 @@ def _search_chunk_assignment( 1, min(int(config.planner_max_remote_waves), len(chunk_ranges)) + 1, ) - best_owners: tuple[int, ...] = tuple() - best_waves: tuple[int, ...] = tuple() - best_eval: dict[str, Any] | None = None + best: tuple[tuple[int, ...], tuple[int, ...], dict[str, Any]] | None = None eval_cache: dict[tuple[tuple[int, ...], tuple[int, ...]], dict[str, Any]] = {} pair_counts = torch.as_tensor(pair_matrix, dtype=torch.int64) pair_positive = pair_counts > 0 @@ -1585,7 +1068,7 @@ def _evaluate_candidate( cached = eval_cache.get(cache_key) if cached is not None: return cached - cached = _evaluate_plan_for_search( + cached = _evaluate_plan( chunk_ranges=chunk_ranges, pair_matrix=pair_counts, owners=owners, @@ -1599,86 +1082,35 @@ def _evaluate_candidate( eval_cache[cache_key] = cached return cached - def _best_wave_assignment_for_owners( - owners: tuple[int, ...], - ) -> tuple[tuple[int, ...], dict[str, Any]]: - best_wave_assignment = tuple() - best_eval_local: dict[str, Any] | None = None - for wave_count in wave_count_candidates: - wave_assignment = _wave_assignment( - chunk_count=len(chunk_ranges), - wave_count=wave_count, - ) - candidate_eval = _evaluate_candidate( - owners=owners, - wave_assignment=wave_assignment, - ) - if best_eval_local is None or float(candidate_eval["score"]) + 1e-9 < float( - best_eval_local["score"] - ): - best_wave_assignment = wave_assignment - best_eval_local = candidate_eval - if best_eval_local is None: - raise RuntimeError("Failed to evaluate any wave assignment candidate.") - return best_wave_assignment, best_eval_local - - strategy = str(config.planner_assignment_strategy).strip().lower() - striped_owners = _striped_chunk_assignment( - chunk_count=len(chunk_ranges), - cp_size=cp_size, - group_size=int(config.planner_stripe_group_size), - ) - fixed_owners_by_strategy = { - "contiguous": _contiguous_chunk_assignment( - q_weights=q_weights, cp_size=cp_size - ), - "bucket": _bucket_chunk_assignment(q_weights=q_weights, cp_size=cp_size), - "striped": striped_owners, - } - if strategy in fixed_owners_by_strategy: - owners = fixed_owners_by_strategy[strategy] - best_waves, best_eval = _best_wave_assignment_for_owners(owners) - return owners, best_waves, best_eval - if strategy not in {"search", "search_with_striped_seed"}: - raise ValueError( - "Unsupported planner_assignment_strategy=" - f"{config.planner_assignment_strategy!r}." - ) - contiguous_owners = _contiguous_chunk_assignment( q_weights=q_weights, cp_size=cp_size, ) + if not contiguous_owners: + wave_assignment = _wave_assignment(chunk_count=len(chunk_ranges), wave_count=1) + return ( + contiguous_owners, + wave_assignment, + _evaluate_candidate( + owners=contiguous_owners, + wave_assignment=wave_assignment, + ), + ) + for wave_count in wave_count_candidates: wave_assignment = _wave_assignment( chunk_count=len(chunk_ranges), wave_count=wave_count, ) - initial_candidates = [ - initial_owners - for initial_owners in (contiguous_owners,) - if initial_owners - if _assignment_uses_all_ranks(initial_owners, cp_size=cp_size) - ] - if not initial_candidates: - continue - current_owners = min( - initial_candidates, - key=lambda owners: float( - _evaluate_candidate(owners=owners, wave_assignment=wave_assignment)[ - "score" - ] - ), - ) + current_owners = contiguous_owners current_eval = _evaluate_candidate( owners=current_owners, wave_assignment=wave_assignment, ) - if cp_size >= 8: - search_steps_remaining = 0 - else: - search_steps_remaining = int(config.planner_max_search_steps) + search_steps_remaining = ( + 0 if cp_size >= 8 else int(config.planner_max_search_steps) + ) if cp_size == 4 and search_steps_remaining > 0: probe_move = _best_improving_move( current_owners=current_owners, @@ -1716,27 +1148,13 @@ def _best_wave_assignment_for_owners( break current_owners, current_eval = best_move - if best_eval is None or float(current_eval["score"]) + 1e-9 < float( - best_eval["score"] + if best is None or float(current_eval["score"]) + 1e-9 < float( + best[2]["score"] ): - best_owners = current_owners - best_waves = wave_assignment - best_eval = current_eval - - if best_eval is None: - best_owners = _contiguous_chunk_assignment(q_weights=q_weights, cp_size=cp_size) - best_waves = _wave_assignment(chunk_count=len(chunk_ranges), wave_count=1) - best_eval = _evaluate_candidate( - owners=best_owners, - wave_assignment=best_waves, - ) - return best_owners, best_waves, best_eval - - -def _concatenate_peer_ranges( - ranges_by_peer: list[tuple[TokenRange, ...]] | tuple[tuple[TokenRange, ...], ...], -) -> tuple[tuple[TokenRange, ...], ...]: - return tuple(tuple(ranges) for ranges in ranges_by_peer) + best = (current_owners, wave_assignment, current_eval) + if best is None: + raise RuntimeError("Failed to evaluate any CP planner wave assignment.") + return best def _flatten_ranges_by_peer( @@ -1956,16 +1374,8 @@ def _build_rank_runtime_plan( _remap_subrange(range_, host_local_ranges) for range_ in local_global_k_ranges ), - kv_fetch_plan=KvFetchPlan( - send_splits=tuple(0 for _ in range(cp_size)), - recv_splits=tuple(0 for _ in range(cp_size)), - send_ranges_by_peer=tuple(tuple() for _ in range(cp_size)), - ), - dkv_reduce_plan=DkvReducePlan( - send_splits=tuple(0 for _ in range(cp_size)), - recv_splits=tuple(0 for _ in range(cp_size)), - recv_ranges_by_peer=tuple(tuple() for _ in range(cp_size)), - ), + kv_fetch_plan=None, + dkv_reduce_plan=None, remote_buffer_range=None, block_size=block_size, ) @@ -2063,14 +1473,8 @@ def _build_rank_runtime_plan( token_layout_index=token_layout_index, local_valid_lengths=(local_token_count,), local_row_ranges=local_row_ranges, - local_token_count=local_token_count, stage_plans=tuple(stage_plans), backward_stage_indices=tuple(backward_stage_indices + [0]), - remote_kv_fetch_plan=KvFetchPlan( - send_splits=aggregate_send_splits, - recv_splits=tuple(aggregate_recv_splits), - send_ranges_by_peer=aggregate_send_ranges, - ), remote_dkv_reduce_plan=DkvReducePlan( send_splits=tuple(aggregate_recv_splits), recv_splits=aggregate_send_splits, @@ -2079,58 +1483,6 @@ def _build_rank_runtime_plan( ) -def make_runtime_key( - spec: PackedBatchAttentionSpec, - *, - topology: ParallelTopology, - config: ContextParallelConfig, -) -> ContextParallelRuntimeKey: - if len(spec.rows) != 1: - raise RuntimeError( - "ART context parallel runtime keys expect exactly one packed sequence, " - f"got {len(spec.rows)} rows." - ) - row_signatures = tuple(_row_signature(row) for row in spec.rows) - return ContextParallelRuntimeKey( - topology=topology, - config=config, - row_signatures=row_signatures, - ) - - -def build_context_parallel_token_layout_index( - *, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - topology: ParallelTopology, - config: ContextParallelConfig, - original_seq_len: int, -) -> TokenLayoutIndex: - """Return the token ownership chosen by the real CP attention planner.""" - - spec = build_shared_prefix_attention_spec( - group_ids=group_ids, parent_ids=parent_ids - ) - if int(topology.cp) <= 1: - valid_tokens = int(spec.rows[0].valid_tokens) if spec.rows else 0 - return TokenLayoutIndex( - ownership_ranges_by_rank=(((0, valid_tokens, 0),) if valid_tokens else (),), - token_counts_by_rank=(valid_tokens,), - ) - runtime_config = _config_for_runtime_cp(topology=topology, config=config) - _row_spec, chunk_ranges, owners, _wave_assignment = _runtime_plan_assignment( - spec, - topology=topology, - config=runtime_config, - ) - del original_seq_len - return _build_runtime_token_layout_index( - chunk_ranges=chunk_ranges, - owners=owners, - cp_size=max(int(topology.cp), 1), - ) - - def prepare_cp_micro( *, micro: PackedTensors, @@ -2172,7 +1524,7 @@ def prepare_cp_micro( ref_logprobs=ref_logprobs, ) if tensors.token_uids is not None: - state = state.model_copy(update={"trace_token_uids": tensors.token_uids}) + state = replace(state, trace_token_uids=tensors.token_uids) if prepare_execution_state: from .executor import prepare_context_parallel_execution_state @@ -2222,12 +1574,11 @@ def prepare_megatron_context_parallel_state( ) group_ids_cpu = _planning_metadata_cpu(micro["group_ids"]) parent_ids_cpu = _planning_metadata_cpu(micro["parent_ids"]) - runtime_config = _config_for_runtime_cp(topology=topology, config=config) planning_key = _planning_bundle_cache_key( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, topology=topology, - config=runtime_config, + config=config, original_seq_len=int(micro["tokens"].shape[1]), build_gdn_execution_spec=build_gdn_execution_spec, ) @@ -2237,12 +1588,10 @@ def prepare_megatron_context_parallel_state( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, ) - runtime_key = make_runtime_key(spec, topology=topology, config=runtime_config) runtime_plan = get_or_build_runtime_plan( spec, topology=topology, - config=runtime_config, - runtime_key=runtime_key, + config=config, original_seq_len=int(micro["tokens"].shape[1]), ) gdn_execution_spec = None @@ -2252,18 +1601,15 @@ def prepare_megatron_context_parallel_state( ) gdn_execution_spec = parse_gdn_shared_prefix_segments( - group_ids_cpu, - parent_ids_cpu, - min_completions_per_family=0, + group_ids_cpu, parent_ids_cpu ) bundle = _PlanningBundle( spec=spec, - runtime_key=runtime_key, - runtime_plan=runtime_plan, + rank_plans=runtime_plan, gdn_execution_spec=gdn_execution_spec, ) _cache_put(_PLANNING_BUNDLE_CACHE, planning_key, bundle) - rank_plan = bundle.runtime_plan.rank_plans[int(cp_rank)] + rank_plan = bundle.rank_plans[int(cp_rank)] gdn_execution_plan = None if build_gdn_execution_spec: if bundle.gdn_execution_spec is None: @@ -2271,10 +1617,11 @@ def prepare_megatron_context_parallel_state( gdn_plan_device = ( target_device if target_device is not None else micro["tokens"].device ) - rank_gdn_key = _rank_plan_cache_key( - planning_key=planning_key, - device=gdn_plan_device, - cp_rank=int(cp_rank), + rank_gdn_key = ( + planning_key, + gdn_plan_device.type, + gdn_plan_device.index, + int(cp_rank), ) gdn_execution_plan = _GDN_RANK_PLAN_CACHE.get(rank_gdn_key) if gdn_execution_plan is None: @@ -2290,22 +1637,15 @@ def prepare_megatron_context_parallel_state( attention_token_layout_index=rank_plan.token_layout_index, ) _cache_put(_GDN_RANK_PLAN_CACHE, rank_gdn_key, gdn_execution_plan) - planner_provenance = _planner_provenance( - topology=topology, - config=runtime_config, - warn=int(cp_rank) == 0, - ) pad_multiple = int(topology.tp) if bool(topology.sp) and int(topology.tp) > 1 else 1 state = ArtContextParallelState( - runtime_key=bundle.runtime_key, rank_plan=rank_plan, cp_group=cp_group, - config=runtime_config, + config=config, group_ids=group_ids_cpu[0].contiguous(), parent_ids=parent_ids_cpu[0].contiguous(), gdn_execution_spec=bundle.gdn_execution_spec, gdn_execution_plan=gdn_execution_plan, - planner_provenance=planner_provenance, trace_token_uids=None, ) return state, rank_plan, bundle.spec, pad_multiple @@ -2353,111 +1693,43 @@ def dispatch_megatron_context_parallel_training_tensors( if trace_token_uids else None ) - local_tokens = _dispatch_tensor( - micro["tokens"], - rank_plan=rank_plan, - pad_value=0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_labels = _dispatch_tensor( - labels, - rank_plan=rank_plan, - pad_value=-100, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_input_pos = _dispatch_tensor( - micro["input_pos"], - rank_plan=rank_plan, - pad_value=0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_assistant_mask = _dispatch_tensor( - assistant_mask, - rank_plan=rank_plan, - pad_value=False, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ).to(dtype=torch.bool) - local_group_ids = _dispatch_tensor( - shifted_group_ids, - rank_plan=rank_plan, - pad_value=0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_old_logprobs = _dispatch_tensor( - old_logprobs, - rank_plan=rank_plan, - pad_value=float("nan"), - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_original_logprobs = ( - None - if original_logprobs is None - else _dispatch_tensor( - original_logprobs, - rank_plan=rank_plan, - pad_value=0.0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - ) - local_ref_logprobs = ( - None - if ref_logprobs is None - else _dispatch_tensor( - ref_logprobs, + + def dispatch( + tensor: torch.Tensor, + pad_value: int | float | bool, + *, + move_to_target: bool = True, + ) -> torch.Tensor: + local = _dispatch_tensor( + tensor, rank_plan=rank_plan, - pad_value=float("nan"), + pad_value=pad_value, pad_multiple=pad_multiple, dispatch_meta_cache=dispatch_meta_cache, ) - ) - local_advantages = _dispatch_tensor( - advantages, - rank_plan=rank_plan, - pad_value=0.0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) - local_weights = _dispatch_tensor( - weights, - rank_plan=rank_plan, - pad_value=0.0, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) + return _to_target_device(local, target_device) if move_to_target else local + + def maybe_dispatch( + tensor: torch.Tensor | None, + pad_value: int | float | bool, + ) -> torch.Tensor | None: + return None if tensor is None else dispatch(tensor, pad_value) + local_token_uids = ( - None - if token_uids is None - else _dispatch_tensor( - token_uids, - rank_plan=rank_plan, - pad_value=-1, - pad_multiple=pad_multiple, - dispatch_meta_cache=dispatch_meta_cache, - ) + None if token_uids is None else dispatch(token_uids, -1, move_to_target=False) ) return DispatchedPackedTensors( - tokens=_to_target_device(local_tokens, target_device), - labels=_to_target_device(local_labels, target_device), - input_pos=_to_target_device(local_input_pos, target_device), - assistant_mask=_to_target_device(local_assistant_mask, target_device), - group_ids=_to_target_device(local_group_ids, target_device), - old_logprobs=_to_target_device(local_old_logprobs, target_device), - advantages=_to_target_device(local_advantages, target_device), - weights=_to_target_device(local_weights, target_device), + tokens=dispatch(micro["tokens"], 0), + labels=dispatch(labels, -100), + input_pos=dispatch(micro["input_pos"], 0), + assistant_mask=dispatch(assistant_mask, False).to(dtype=torch.bool), + group_ids=dispatch(shifted_group_ids, 0), + old_logprobs=dispatch(old_logprobs, float("nan")), + advantages=dispatch(advantages, 0.0), + weights=dispatch(weights, 0.0), valid_lengths=rank_plan.local_valid_lengths, - original_logprobs=None - if local_original_logprobs is None - else _to_target_device(local_original_logprobs, target_device), - ref_logprobs=None - if local_ref_logprobs is None - else _to_target_device(local_ref_logprobs, target_device), + original_logprobs=maybe_dispatch(original_logprobs, 0.0), + ref_logprobs=maybe_dispatch(ref_logprobs, float("nan")), loss_all_reduce_group=cp_group, token_uids=None if local_token_uids is None else local_token_uids.contiguous(), ) @@ -2468,12 +1740,13 @@ def get_or_build_runtime_plan( *, topology: ParallelTopology, config: ContextParallelConfig, - runtime_key: ContextParallelRuntimeKey, original_seq_len: int, -) -> ContextParallelRuntimePlan: - key = ( - _json_cache_key(runtime_key.model_dump(mode="json")), - int(original_seq_len), +) -> tuple[RankRuntimePlan, ...]: + key = _runtime_plan_cache_key( + spec, + topology=topology, + config=config, + original_seq_len=original_seq_len, ) cached = _RUNTIME_PLAN_CACHE.get(key) if cached is not None: @@ -2488,25 +1761,6 @@ def get_or_build_runtime_plan( return plan -def get_or_build_rank_runtime_plan( - spec: PackedBatchAttentionSpec, - *, - topology: ParallelTopology, - config: ContextParallelConfig, - runtime_key: ContextParallelRuntimeKey, - original_seq_len: int, - target_rank: int, -) -> RankRuntimePlan: - del runtime_key - return _build_rank_runtime_plan_for_spec( - spec, - topology=topology, - config=config, - original_seq_len=original_seq_len, - target_rank=target_rank, - ) - - def _runtime_plan_assignment( spec: PackedBatchAttentionSpec, *, @@ -2552,45 +1806,13 @@ def _runtime_plan_assignment( return row_spec, chunk_ranges, owners, wave_assignment -def _build_rank_runtime_plan_for_spec( - spec: PackedBatchAttentionSpec, - *, - topology: ParallelTopology, - config: ContextParallelConfig, - original_seq_len: int, - target_rank: int, -) -> RankRuntimePlan: - row_spec, chunk_ranges, owners, wave_assignment = _runtime_plan_assignment( - spec, - topology=topology, - config=config, - ) - cp_size = max(int(topology.cp), 1) - token_layout_index = _build_runtime_token_layout_index( - chunk_ranges=chunk_ranges, - owners=owners, - cp_size=cp_size, - ) - return _build_rank_runtime_plan( - row_spec=row_spec, - chunk_ranges=chunk_ranges, - owners=owners, - wave_assignment=wave_assignment, - token_layout_index=token_layout_index, - cp_size=cp_size, - original_seq_len=original_seq_len, - target_rank=int(target_rank), - block_size=int(config.block_size), - ) - - def _build_runtime_plan( spec: PackedBatchAttentionSpec, *, topology: ParallelTopology, config: ContextParallelConfig, original_seq_len: int, -) -> ContextParallelRuntimePlan: +) -> tuple[RankRuntimePlan, ...]: row_spec, chunk_ranges, owners, wave_assignment = _runtime_plan_assignment( spec, topology=topology, @@ -2602,7 +1824,7 @@ def _build_runtime_plan( owners=owners, cp_size=cp_size, ) - rank_plans = [ + return tuple( _build_rank_runtime_plan( row_spec=row_spec, chunk_ranges=chunk_ranges, @@ -2615,12 +1837,6 @@ def _build_runtime_plan( block_size=int(config.block_size), ) for rank in range(cp_size) - ] - return ContextParallelRuntimePlan( - topology=topology, - config=config, - token_layout_index=token_layout_index, - rank_plans=tuple(rank_plans), ) @@ -2648,11 +1864,42 @@ def _build_runtime_token_layout_index( def _row_signature(row_spec: PackedRowAttentionSpec) -> str: payload = { "valid_tokens": row_spec.valid_tokens, - "slices": [slice_.model_dump(mode="json") for slice_ in row_spec.slices], + "slices": [_attn_slice_payload(slice_) for slice_ in row_spec.slices], } return json.dumps(payload, sort_keys=True) +def _runtime_plan_cache_key( + spec: PackedBatchAttentionSpec, + *, + topology: ParallelTopology, + config: ContextParallelConfig, + original_seq_len: int, +) -> str: + return _json_cache_key( + { + "topology": _dataclass_payload(topology), + "config": _dataclass_payload(config), + "row_signatures": tuple(_row_signature(row) for row in spec.rows), + "original_seq_len": int(original_seq_len), + } + ) + + +def _dataclass_payload(value: Any) -> dict[str, Any]: + return dict(value.__dict__) + + +def _attn_slice_payload(slice_: AttnSlice) -> dict[str, Any]: + return { + "q_range": _dataclass_payload(slice_.q_range), + "k_range": _dataclass_payload(slice_.k_range), + "mask_kind": slice_.mask_kind.value, + "row_index": slice_.row_index, + "family_index": slice_.family_index, + } + + def _range_key(range_: TokenRange) -> tuple[int, int]: return (int(range_.start), int(range_.end)) @@ -2694,101 +1941,6 @@ def _set_stage_token_indices( current_indices.copy_(source_indices) -def _token_costs(row_spec: PackedRowAttentionSpec) -> list[float]: - costs = [0.0] * row_spec.valid_tokens - for slice_ in row_spec.slices: - q_range = slice_.q_range - k_range = slice_.k_range - if slice_.mask_kind is AttnMaskKind.FULL: - cost = float(k_range.size()) - for q_idx in range(q_range.start, q_range.end): - costs[q_idx] += cost - continue - if q_range.size() != k_range.size(): - raise RuntimeError( - "The current planner only supports causal slices with matched q/k sizes, got " - f"{q_range} vs {k_range}" - ) - for q_idx in range(q_range.start, q_range.end): - costs[q_idx] += float(q_idx - q_range.start + 1) - return costs - - -def _split_row_by_cost( - row_spec: PackedRowAttentionSpec, - *, - cp_size: int, - block_size: int, -) -> tuple[TokenRange | None, ...]: - if cp_size == 1: - return (TokenRange(start=0, end=row_spec.valid_tokens),) - if row_spec.valid_tokens == 0: - return tuple(None for _ in range(cp_size)) - - costs = _token_costs(row_spec) - prefix = [0.0] - for cost in costs: - prefix.append(prefix[-1] + cost) - total_cost = prefix[-1] - boundaries = [0] - block_aligned_split = int(block_size) > 1 and row_spec.valid_tokens >= ( - cp_size * int(block_size) - ) - for split_index in range(1, cp_size): - remaining_ranks = cp_size - split_index - min_boundary = boundaries[-1] - max_boundary = row_spec.valid_tokens - remaining_ranks - if max_boundary <= min_boundary: - boundaries.append(min_boundary) - continue - target = ( - total_cost * split_index / cp_size - if total_cost > 0.0 - else row_spec.valid_tokens * split_index / cp_size - ) - best_boundary = min_boundary + 1 - best_error = float("inf") - candidate_boundaries = range(min_boundary + 1, max_boundary + 1) - if block_aligned_split: - aligned_start = ( - (min_boundary + 1 + block_size - 1) // block_size - ) * block_size - aligned_end = (max_boundary // block_size) * block_size - if aligned_start <= aligned_end: - candidate_boundaries = range(aligned_start, aligned_end + 1, block_size) - for boundary in candidate_boundaries: - current = prefix[boundary] if total_cost > 0.0 else float(boundary) - error = abs(current - target) - if error < best_error: - best_error = error - best_boundary = boundary - boundaries.append(best_boundary) - boundaries.append(row_spec.valid_tokens) - - ranges: list[TokenRange | None] = [] - for start, end in zip(boundaries[:-1], boundaries[1:]): - if end <= start: - ranges.append(None) - else: - ranges.append(TokenRange(start=start, end=end)) - return tuple(ranges) - - -def _intersections( - base_range: TokenRange, - owner_ranges: tuple[TokenRange | None, ...], -) -> list[tuple[int, TokenRange]]: - intersections: list[tuple[int, TokenRange]] = [] - for rank, owner_range in enumerate(owner_ranges): - if owner_range is None: - continue - start = max(base_range.start, owner_range.start) - end = min(base_range.end, owner_range.end) - if end > start: - intersections.append((rank, TokenRange(start=start, end=end))) - return intersections - - def _resolve_stage_mask_kind( *, mask_kind: AttnMaskKind, diff --git a/src/art/megatron/context_parallel/types.py b/src/art/megatron/context_parallel/types.py index 5cc874d09..bf52ffddc 100644 --- a/src/art/megatron/context_parallel/types.py +++ b/src/art/megatron/context_parallel/types.py @@ -1,10 +1,11 @@ from __future__ import annotations +from dataclasses import dataclass, field from enum import Enum from typing import Any from megatron.core.packed_seq_params import PackedSeqParams -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict import torch from .layout_index import TokenLayoutIndex @@ -16,22 +17,17 @@ class AttnMaskKind(str, Enum): CAUSAL = "causal" -class TokenRange(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class TokenRange: start: int end: int def size(self) -> int: return self.end - self.start - def is_empty(self) -> bool: - return self.end <= self.start - - -class AttnSlice(BaseModel): - model_config = ConfigDict(frozen=True) +@dataclass(frozen=True) +class AttnSlice: q_range: TokenRange k_range: TokenRange mask_kind: AttnMaskKind @@ -39,68 +35,25 @@ class AttnSlice(BaseModel): family_index: int | None = None -class PackedRowAttentionSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class PackedRowAttentionSpec: row_index: int valid_tokens: int slices: tuple[AttnSlice, ...] -class PackedBatchAttentionSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class PackedBatchAttentionSpec: rows: tuple[PackedRowAttentionSpec, ...] -class SharedPrefixBuilderConfig(BaseModel): - model_config = ConfigDict(frozen=True) - - ignore_padding_group_id: int = -1 - require_contiguous_group_runs: bool = True - - -class PlannerCpOverride(BaseModel): - model_config = ConfigDict(frozen=True) - - cp_size: int - block_size: int | None = None - planner_chunk_size: int | None = None - planner_chunk_budget_base: int | None = None - planner_chunk_budget_per_cp_rank: int | None = None - planner_assignment_strategy: str | None = None - planner_stripe_group_size: int | None = None - planner_max_search_steps: int | None = None - planner_candidate_chunk_limit: int | None = None - planner_max_remote_waves: int | None = None - planner_stage_overhead_ms: float | None = None - planner_comm_stage_overhead_ms: float | None = None - planner_interval_overhead_ms: float | None = None - planner_merge_q_token_ms: float | None = None - planner_fetch_token_ms: float | None = None - planner_reduce_token_ms: float | None = None - planner_local_pair_ms: float | None = None - planner_remote_pair_ms: float | None = None - planner_local_backward_pair_ms: float | None = None - planner_remote_backward_pair_ms: float | None = None - planner_remote_stage_token_floor: int | None = None - planner_remote_stage_pair_floor: int | None = None - planner_remote_stage_underfill_ms: float | None = None - planner_tuned_backend: str | None = None - planner_tuned_hardware: str | None = None - planner_tuned_cp_sizes: tuple[int, ...] | None = None - - -class ContextParallelConfig(BaseModel): - model_config = ConfigDict(frozen=True, extra="forbid") - +@dataclass(frozen=True) +class ContextParallelConfig: block_size: int = 128 attention_sparse_block_size: tuple[int, int] | None = None planner_chunk_size: int = 512 planner_chunk_budget_base: int = 128 planner_chunk_budget_per_cp_rank: int = 16 - planner_assignment_strategy: str = "search" - planner_stripe_group_size: int = 16 planner_max_search_steps: int = 8 planner_candidate_chunk_limit: int = 8 planner_max_remote_waves: int = 4 @@ -117,15 +70,10 @@ class ContextParallelConfig(BaseModel): planner_remote_stage_token_floor: int = 4096 planner_remote_stage_pair_floor: int = 4_000_000 planner_remote_stage_underfill_ms: float = 0.287151 - planner_tuned_backend: str | None = "art_context_parallel" - planner_tuned_hardware: str | None = "NVIDIA H200" - planner_tuned_cp_sizes: tuple[int, ...] = (2,) - planner_cp_overrides: tuple[PlannerCpOverride, ...] = () -class ParallelTopology(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class ParallelTopology: tp: int = 1 cp: int = 1 dp: int = 1 @@ -133,73 +81,50 @@ class ParallelTopology(BaseModel): sp: bool = False -class ContextParallelRuntimeKey(BaseModel): - model_config = ConfigDict(frozen=True) - - topology: ParallelTopology - config: ContextParallelConfig - row_signatures: tuple[str, ...] - - -class KvFetchPlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class KvFetchPlan: send_splits: tuple[int, ...] recv_splits: tuple[int, ...] send_ranges_by_peer: tuple[tuple[TokenRange, ...], ...] -class DkvReducePlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class DkvReducePlan: send_splits: tuple[int, ...] recv_splits: tuple[int, ...] recv_ranges_by_peer: tuple[tuple[TokenRange, ...], ...] -class StagePlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class StagePlan: stage_index: int source_rank: int - source_ranks: tuple[int, ...] = () is_local_stage: bool - wave_index: int | None = None slices: tuple[AttnSlice, ...] - global_q_ranges: tuple[TokenRange, ...] = () - global_k_ranges: tuple[TokenRange, ...] = () owner_local_q_ranges: tuple[TokenRange, ...] owner_local_k_ranges: tuple[TokenRange, ...] - mask_metadata: "ExactMaskMetadata | None" = None - remote_buffer_range: TokenRange | None = None q_len: int k_len: int + source_ranks: tuple[int, ...] = () + wave_index: int | None = None + global_q_ranges: tuple[TokenRange, ...] = () + global_k_ranges: tuple[TokenRange, ...] = () + mask_metadata: "ExactMaskMetadata | None" = None + remote_buffer_range: TokenRange | None = None kv_fetch_plan: KvFetchPlan | None = None dkv_reduce_plan: DkvReducePlan | None = None -class RankRuntimePlan(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class RankRuntimePlan: rank: int original_seq_len: int token_layout_index: TokenLayoutIndex local_valid_lengths: tuple[int, ...] local_row_ranges: tuple[TokenRange | None, ...] - local_token_count: int stage_plans: tuple[StagePlan, ...] - backward_stage_indices: tuple[int, ...] = () - remote_kv_fetch_plan: KvFetchPlan remote_dkv_reduce_plan: DkvReducePlan - - -class ContextParallelRuntimePlan(BaseModel): - model_config = ConfigDict(frozen=True) - - topology: ParallelTopology - config: ContextParallelConfig - token_layout_index: TokenLayoutIndex - rank_plans: tuple[RankRuntimePlan, ...] + backward_stage_indices: tuple[int, ...] = () class DispatchedPackedTensors(ContextParallelLossInputs): @@ -220,47 +145,27 @@ class DispatchedPackedTensors(ContextParallelLossInputs): token_uids: torch.Tensor | None = None -class ContextParallelExecutionCache(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - block_masks: dict[Any, Any] = Field(default_factory=dict) - range_indices: dict[Any, torch.Tensor] = Field(default_factory=dict) - range_meta: dict[Any, tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = Field( +@dataclass +class ContextParallelExecutionCache: + block_mask_context: Any | None = None + block_masks: dict[Any, Any] = field(default_factory=dict) + range_indices: dict[Any, torch.Tensor] = field(default_factory=dict) + range_meta: dict[Any, tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]] = field( default_factory=dict ) - stage_execution_specs: dict[Any, "StageExecutionSpec"] = Field(default_factory=dict) + stage_execution_specs: dict[Any, "StageExecutionSpec"] = field(default_factory=dict) -class StageExecutionSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class StageExecutionSpec: q_len: int k_len: int compile_key: str mask_metadata: "ExactMaskMetadata | None" = None -class PlannerProvenance(BaseModel): - model_config = ConfigDict(frozen=True) - - runtime_backend: str - runtime_hardware: str | None = None - runtime_cp_size: int - tuned_backend: str | None = None - tuned_hardware: str | None = None - tuned_cp_sizes: tuple[int, ...] = () - backend_match: bool - hardware_match: bool - cp_size_match: bool - using_best_effort: bool - warning_message: str | None = None - warning_emitted: bool = False - - -class ArtContextParallelState(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - runtime_key: ContextParallelRuntimeKey +@dataclass +class ArtContextParallelState: rank_plan: RankRuntimePlan cp_group: Any config: ContextParallelConfig @@ -272,31 +177,28 @@ class ArtContextParallelState(BaseModel): gdn_input_layout: str | None = None gdn_output_layout: str | None = None gdn_attention_original_shape: tuple[int, int, int] | None = None - gdn_attention_original_shapes: dict[int, tuple[int, int, int]] = Field( + gdn_attention_original_shapes: dict[int, tuple[int, int, int]] = field( default_factory=dict ) gdn_attention_token_uids: torch.Tensor | None = None gdn_active_module: Any | None = None - planner_provenance: PlannerProvenance trace_token_uids: torch.Tensor | None = None - execution_cache: ContextParallelExecutionCache = Field( + execution_cache: ContextParallelExecutionCache = field( default_factory=ContextParallelExecutionCache ) -class PreparedMegatronBatch(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - +@dataclass +class PreparedMegatronBatch: tensors: DispatchedPackedTensors - packed_seq_params: PackedSeqParams | None = None attention_state: Any + packed_seq_params: PackedSeqParams | None = None rank_plan: RankRuntimePlan | None = None pad_multiple: int = 1 -class FlexMaskSpec(BaseModel): - model_config = ConfigDict(frozen=True) - +@dataclass(frozen=True) +class FlexMaskSpec: q_len: int k_len: int block_size: int | tuple[int, int] @@ -304,9 +206,8 @@ class FlexMaskSpec(BaseModel): exact_mask: "ExactMaskMetadata" -class ExactMaskMetadata(BaseModel): - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - +@dataclass(frozen=True) +class ExactMaskMetadata: q_token_indices: torch.Tensor k_token_indices: torch.Tensor cache_key: str diff --git a/src/art/megatron/gdn/__init__.py b/src/art/megatron/gdn/__init__.py index cd3a0873a..1dc629403 100644 --- a/src/art/megatron/gdn/__init__.py +++ b/src/art/megatron/gdn/__init__.py @@ -3,12 +3,10 @@ from .fla_cp import chunk_gated_delta_rule_native_cp from .gdn_shared_prefix import ( GdnPackedExecutionSpec, - GdnPackedFamilySpec, GdnPlannerConfig, GdnRankExecutionPlan, GdnSegmentBucketPlan, GdnSegmentSpec, - build_gdn_cp_segment_schedule, build_gdn_rank_execution_plan, move_gdn_rank_execution_plan_to_device, parse_gdn_shared_prefix_segments, @@ -19,12 +17,10 @@ __all__ = [ "chunk_gated_delta_rule_native_cp", "GdnPackedExecutionSpec", - "GdnPackedFamilySpec", "GdnPlannerConfig", "GdnRankExecutionPlan", "GdnSegmentSpec", "GdnSegmentBucketPlan", - "build_gdn_cp_segment_schedule", "build_gdn_rank_execution_plan", "exchange_rank_tensor_all_to_all", "move_gdn_rank_execution_plan_to_device", diff --git a/src/art/megatron/gdn/gdn_shared_prefix.py b/src/art/megatron/gdn/gdn_shared_prefix.py index 3fb693891..f4bc02ba7 100644 --- a/src/art/megatron/gdn/gdn_shared_prefix.py +++ b/src/art/megatron/gdn/gdn_shared_prefix.py @@ -1,152 +1,76 @@ from __future__ import annotations from bisect import bisect_left -from typing import Any, Literal, TypeVar +from dataclasses import dataclass, replace +from typing import Any, Literal -from pydantic import BaseModel, ConfigDict, Field import torch from art.megatron.context_parallel.layout_index import TokenLayoutIndex +from art.megatron.shared_prefix_tree import parse_shared_prefix_tree GdnSegmentKind = Literal["prefix", "completion"] -GdnSegmentDecisionKey = tuple[int, int, int] # FLA's public chunk_gated_delta_rule hard-codes 64-token WY chunks. FLA_CHUNK_SIZE = 64 -_PydanticModelT = TypeVar("_PydanticModelT", bound=BaseModel) -class GdnSegmentSpec(BaseModel): +@dataclass(frozen=True) +class GdnSegmentSpec: """Contiguous logical GDN segment in one packed row.""" - model_config = ConfigDict(frozen=True) - - row_index: int = Field(ge=0) - family_index: int = Field(ge=0) + row_index: int + family_index: int group_id: int parent_id: int - start: int = Field(ge=0) - end: int = Field(ge=1) + start: int + end: int kind: GdnSegmentKind - child_index: int | None = Field(default=None, ge=0) + child_index: int | None = None @property def length(self) -> int: return self.end - self.start - def linear_indices(self, sequence_length: int) -> tuple[int, ...]: - base = self.row_index * sequence_length - return tuple(range(base + self.start, base + self.end)) - - -class GdnPackedFamilySpec(BaseModel): - """One shared-prefix family plus child completion segments.""" - - model_config = ConfigDict(frozen=True) - - row_index: int = Field(ge=0) - family_index: int = Field(ge=0) - prefix: GdnSegmentSpec - completions: tuple[GdnSegmentSpec, ...] - - @property - def completion_count(self) -> int: - return len(self.completions) - - @property - def token_count(self) -> int: - return self.prefix.length + sum(segment.length for segment in self.completions) - -class GdnPackedExecutionSpec(BaseModel): +@dataclass(frozen=True) +class GdnPackedExecutionSpec: """Parsed shared-prefix GDN execution metadata for a packed batch.""" - model_config = ConfigDict(frozen=True) - - batch_size: int = Field(ge=1) - sequence_length: int = Field(ge=1) + batch_size: int + sequence_length: int valid_lengths: tuple[int, ...] - families: tuple[GdnPackedFamilySpec, ...] + tree_segments: tuple[GdnSegmentSpec, ...] + tree_parent_indices: tuple[int, ...] + tree_depths: tuple[int, ...] @property def family_count(self) -> int: - return len(self.families) - - @property - def completion_count(self) -> int: - return sum(family.completion_count for family in self.families) + return len(self.tree_segments) @property def real_token_count(self) -> int: return sum(self.valid_lengths) - @property - def max_segment_length(self) -> int: - lengths = [ - segment.length - for family in self.families - for segment in (family.prefix, *family.completions) - ] - return max(lengths, default=0) - - def segments(self) -> tuple[GdnSegmentSpec, ...]: - return tuple( - segment - for family in self.families - for segment in (family.prefix, *family.completions) - ) - - -_GDN_SEGMENT_SPEC_FIELDS = frozenset( - { - "row_index", - "family_index", - "group_id", - "parent_id", - "start", - "end", - "kind", - "child_index", - } -) -_GDN_PACKED_FAMILY_SPEC_FIELDS = frozenset( - { - "row_index", - "family_index", - "prefix", - "completions", - } -) - - -def _trusted_pydantic_construct( - model_type: type[_PydanticModelT], - fields_set: frozenset[str], - **values: Any, -) -> _PydanticModelT: - model = model_type.__new__(model_type) - object.__setattr__(model, "__dict__", values) - object.__setattr__(model, "__pydantic_fields_set__", fields_set) - object.__setattr__(model, "__pydantic_extra__", None) - object.__setattr__(model, "__pydantic_private__", None) - return model - -class GdnSegmentBucketPlan(BaseModel): +@dataclass(frozen=True) +class GdnSegmentBucketPlan: """Device-local index tensors for a variable-length GDN segment batch.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - length: int = Field(ge=1) + length: int lengths: torch.Tensor lengths_cpu: torch.Tensor - lengths_by_rank_cpu: torch.Tensor | None = None real_mask: torch.Tensor cu_seqlens: torch.Tensor cu_seqlens_cpu: torch.Tensor row_indices: torch.Tensor position_indices: torch.Tensor family_indices: torch.Tensor - real_token_count_static: int = Field(ge=0) + real_token_count_static: int + lengths_by_rank_cpu: torch.Tensor | None = None + family_indices_cpu: torch.Tensor | None = None + parent_indices: torch.Tensor | None = None + parent_indices_cpu: torch.Tensor | None = None + needs_final_state: bool = True output_mask: torch.Tensor | None = None @property @@ -158,89 +82,54 @@ def real_token_count(self) -> int: return self.real_token_count_static -class GdnParentStateTransferPlan(BaseModel): - """Prefix-state rows transferred from one CP rank to another.""" - - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) +@dataclass(frozen=True) +class GdnStateExchangePlan: + """Sparse CP exchange for tree parent states needed by remote children.""" - source_rank: int = Field(ge=0) - dest_rank: int = Field(ge=0) - family_indices: tuple[int, ...] - family_indices_tensor: torch.Tensor | None = None + source_family_indices: tuple[int, ...] + dest_family_indices: tuple[int, ...] + exchange: Any + reverse_exchange: Any -class GdnPlannerConfig(BaseModel): +@dataclass(frozen=True) +class GdnPlannerConfig: """Tunable cost coefficients for one packed-row GDN execution plan.""" - model_config = ConfigDict(frozen=True) - - max_padding_ratio: float = Field(default=2.0, gt=1.0) - max_segments_per_batch: int = Field(default=4096, ge=1) - cp_chain_min_tokens_per_rank: int = Field(default=32, ge=1) - cp_chain_min_total_tokens: int = Field(default=32768, ge=1) - cp_chain_min_prefix_only_tokens: int = Field(default=32768, ge=1) - local_fork_launch_penalty_tokens: int = Field(default=256, ge=0) - cp_collective_latency_tokens: int = Field(default=512, ge=0) - parent_state_exchange_penalty_tokens: int = Field(default=16384, ge=0) - layout_cross_rank_token_cost: float = Field(default=6.0, ge=0.0) - rank_idle_token_cost: float = Field(default=1.0, ge=0.0) - empty_rank_penalty_tokens: int = Field(default=65536, ge=0) - max_zero_exchange_load_imbalance: float = Field(default=1.5, ge=1.0) - local_completion_rebalance_min_imbalance: float = Field(default=1.08, ge=1.0) - cp_chain_beam_width: int = Field(default=2, ge=1) - cp_chain_beam_branch_factor: int = Field(default=4, ge=1) - cp_chain_beam_candidate_limit: int = Field(default=16, ge=1) - cp_chain_beam_max_steps: int = Field(default=4, ge=0) - cp_chain_beam_min_score_delta_tokens: float = Field(default=512.0, ge=0.0) - cp_chain_min_score_delta_ms: float = Field(default=0.25, ge=0.0) - planner_local_token_ms: float = Field(default=0.00065, ge=0.0) - planner_chain_token_ms: float = Field(default=0.00055, ge=0.0) - planner_local_bucket_ms: float = Field(default=0.25, ge=0.0) - planner_chain_bucket_ms: float = Field(default=22.0, ge=0.0) - planner_local_segment_ms: float = Field(default=0.010, ge=0.0) - planner_layout_cross_rank_token_ms: float = Field(default=0.00008, ge=0.0) - planner_parent_state_exchange_base_ms: float = Field(default=40.0, ge=0.0) - planner_parent_state_exchange_ms: float = Field(default=0.5, ge=0.0) - planner_empty_rank_ms: float = Field(default=32.0, ge=0.0) - - -class GdnRankExecutionPlan(BaseModel): + max_padding_ratio: float = 2.0 + max_segments_per_batch: int = 4096 + cp_chain_min_tokens_per_rank: int = 32 + cp_chain_min_total_tokens: int = 32768 + cp_chain_min_prefix_only_tokens: int = 32768 + cp_tree_chain_min_total_tokens: int = 8192 + cp_tree_chain_min_prefix_only_tokens: int = 8192 + rank_idle_token_cost: float = 1.0 + max_zero_exchange_load_imbalance: float = 1.5 + planner_local_token_ms: float = 0.00065 + planner_layout_cross_rank_token_ms: float = 0.00008 + planner_empty_rank_ms: float = 32.0 + + +@dataclass(frozen=True) +class GdnRankExecutionPlan: """Rank-local planned execution metadata for shared-prefix GDN.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - cp_rank: int = Field(ge=0) - cp_size: int = Field(ge=1) - batch_size: int = Field(ge=1) - sequence_length: int = Field(ge=0) - packed_batch_size: int | None = Field(default=None, ge=1) - packed_sequence_length: int | None = Field(default=None, ge=1) + cp_rank: int + cp_size: int + batch_size: int + sequence_length: int real_token_mask: torch.Tensor - family_count: int = Field(ge=0) - completion_count: int = Field(ge=0) - local_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () - local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - ready_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_local_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - chain_prefix_buckets: tuple[GdnSegmentBucketPlan, ...] = () - chain_completion_buckets: tuple[GdnSegmentBucketPlan, ...] = () - prefix_table_is_dense_ordered: bool + packed_batch_size: int | None = None + packed_sequence_length: int | None = None attention_to_gdn: Any | None = None gdn_to_attention: Any | None = None attention_token_ranges: tuple[tuple[int, int, int], ...] = () gdn_token_ranges: tuple[tuple[int, int, int], ...] = () - attention_token_count: int = Field(default=0, ge=0) - gdn_token_count: int = Field(default=0, ge=0) - parent_state_exchange_family_indices: tuple[int, ...] = () - parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () - prefix_boundary_buckets: tuple[GdnSegmentBucketPlan, ...] = () - prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - completion_with_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_completion_with_prefix_tail_buckets: tuple[GdnSegmentBucketPlan, ...] = () - remote_prefix_tail_exchange: Any | None = None - remote_prefix_tail_backward_exchange: Any | None = None - remote_prefix_tail_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () + attention_token_count: int = 0 + gdn_token_count: int = 0 + tree_segment_buckets_by_depth: tuple[tuple[GdnSegmentBucketPlan, ...], ...] = () + tree_chain_buckets_by_depth: tuple[tuple[GdnSegmentBucketPlan, ...], ...] = () + tree_state_exchanges_by_depth: tuple[GdnStateExchangePlan | None, ...] = () @property def attention_token_indices(self) -> tuple[int, ...]: @@ -251,74 +140,12 @@ def gdn_token_indices(self) -> tuple[int, ...]: return _tokens_from_rank_ranges(self.gdn_token_ranges) -class GdnCpSegmentSchedule(BaseModel): - """CPU-side ownership and bucket schedule for one CP GDN plan.""" - - model_config = ConfigDict(frozen=True) - - gdn_token_counts_by_rank: tuple[int, ...] - gdn_token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] = () - cross_rank_token_count: int = Field(ge=0) - chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] - chain_completion_buckets: tuple[tuple[GdnSegmentSpec, ...], ...] - local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] - local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...] - parent_state_exchange_family_indices: tuple[int, ...] = () - parent_state_transfers: tuple[GdnParentStateTransferPlan, ...] = () - - -class _GdnCpSegmentSearchDecision(BaseModel): - model_config = ConfigDict(frozen=True) - - chain_segment_keys: frozenset[GdnSegmentDecisionKey] - co_locate_local_families: bool - score: float - - -class _ExplicitBucketColumn(BaseModel): - model_config = ConfigDict(frozen=True) - - row_index: int - family_index: int - positions: tuple[int, ...] - output_mask: tuple[bool, ...] - - @property - def length(self) -> int: - return len(self.positions) - - -def _explicit_bucket_column( - *, - row_index: int, - family_index: int, - positions: tuple[int, ...], - output_mask: tuple[bool, ...], -) -> _ExplicitBucketColumn: - return _ExplicitBucketColumn.model_construct( - row_index=row_index, - family_index=family_index, - positions=positions, - output_mask=output_mask, - ) - - -class _AttentionLayoutIndex(BaseModel): +@dataclass(frozen=True) +class _AttentionLayoutIndex: """Counting index for CP attention token ownership.""" - model_config = ConfigDict(frozen=True) - token_ranges_by_rank: tuple[tuple[tuple[int, int], ...], ...] token_range_ends_by_rank: tuple[tuple[int, ...], ...] - range_count: int = Field(ge=0) - - -def _layout_cp_size(layout: TokenLayoutIndex) -> int: - return len(layout.token_counts_by_rank) - - -def _layout_token_count(layout: TokenLayoutIndex) -> int: - return sum(int(count) for count in layout.token_counts_by_rank) def _tokens_from_rank_ranges( @@ -327,21 +154,6 @@ def _tokens_from_rank_ranges( return tuple(token for start, end, _ in ranges for token in range(start, end)) -def _token_layout_from_rank_ranges( - ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...], -) -> TokenLayoutIndex: - return TokenLayoutIndex( - ownership_ranges_by_rank=ranges_by_rank, - token_counts_by_rank=tuple( - _ranges_token_count(ranges) for ranges in ranges_by_rank - ), - ) - - -def _ranges_token_count(ranges: tuple[tuple[int, int, int], ...]) -> int: - return sum(int(end) - int(start) for start, end, _ in ranges) - - def build_gdn_rank_execution_plan( spec: GdnPackedExecutionSpec, *, @@ -349,7 +161,6 @@ def build_gdn_rank_execution_plan( cp_rank: int = 0, cp_size: int = 1, attention_token_layout_index: TokenLayoutIndex | None = None, - cp_segment_schedule: GdnCpSegmentSchedule | None = None, planner_config: GdnPlannerConfig | None = None, ) -> GdnRankExecutionPlan: """Build rank-local tensor metadata from a parsed shared-prefix DAG. @@ -368,185 +179,20 @@ def build_gdn_rank_execution_plan( cp_rank=cp_rank, cp_size=cp_size, attention_token_layout_index=attention_token_layout_index, - cp_segment_schedule=cp_segment_schedule, planner_config=planner_config, ) return move_gdn_rank_execution_plan_to_device(cpu_plan, target_device) - if cp_size != 1 or cp_rank != 0: - return _build_cp_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - cp_segment_schedule=cp_segment_schedule, - planner_config=planner_config, - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_cp1_bucket_plans( + return _build_tree_rank_execution_plan( spec, device=device, - planner_config=planner_config, - ) - valid_lengths = torch.tensor( - spec.valid_lengths, - device=device, - dtype=torch.long, - ) - positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) - local_range_list: list[tuple[int, int, int]] = [] - local_position = 0 - for row_index, length in enumerate(spec.valid_lengths): - if length: - start = row_index * spec.sequence_length - local_range_list.append((start, start + length, local_position)) - local_position += length - local_ranges = tuple(local_range_list) - return GdnRankExecutionPlan.model_construct( cp_rank=cp_rank, cp_size=cp_size, - batch_size=spec.batch_size, - sequence_length=spec.sequence_length, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=positions.unsqueeze(0) < valid_lengths.unsqueeze(1), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=(), - ready_local_completion_buckets=(), - remote_local_completion_buckets=(), - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=False, - attention_token_ranges=local_ranges, - gdn_token_ranges=local_ranges, - attention_token_count=spec.real_token_count, - gdn_token_count=spec.real_token_count, - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - ) - - -def move_gdn_rank_execution_plan_to_device( - plan: GdnRankExecutionPlan, - device: torch.device | str, -) -> GdnRankExecutionPlan: - """Move planner tensors to the execution device after CPU planning.""" - - from art.megatron.gdn.layout import move_cp_exchange_plan_to_device - - return GdnRankExecutionPlan.model_construct( - cp_rank=plan.cp_rank, - cp_size=plan.cp_size, - batch_size=plan.batch_size, - sequence_length=plan.sequence_length, - packed_batch_size=plan.packed_batch_size, - packed_sequence_length=plan.packed_sequence_length, - real_token_mask=_move_planner_tensor(plan.real_token_mask, device), - family_count=plan.family_count, - completion_count=plan.completion_count, - local_prefix_buckets=_move_bucket_plans(plan.local_prefix_buckets, device), - local_completion_buckets=_move_bucket_plans( - plan.local_completion_buckets, device - ), - ready_local_completion_buckets=_move_bucket_plans( - plan.ready_local_completion_buckets, device - ), - remote_local_completion_buckets=_move_bucket_plans( - plan.remote_local_completion_buckets, device - ), - chain_prefix_buckets=_move_bucket_plans(plan.chain_prefix_buckets, device), - chain_completion_buckets=_move_bucket_plans( - plan.chain_completion_buckets, device - ), - prefix_table_is_dense_ordered=plan.prefix_table_is_dense_ordered, - attention_to_gdn=move_cp_exchange_plan_to_device(plan.attention_to_gdn, device), - gdn_to_attention=move_cp_exchange_plan_to_device(plan.gdn_to_attention, device), - attention_token_ranges=plan.attention_token_ranges, - gdn_token_ranges=plan.gdn_token_ranges, - attention_token_count=plan.attention_token_count, - gdn_token_count=plan.gdn_token_count, - parent_state_exchange_family_indices=plan.parent_state_exchange_family_indices, - parent_state_transfers=_move_parent_state_transfers( - plan.parent_state_transfers, device - ), - prefix_boundary_buckets=_move_bucket_plans( - plan.prefix_boundary_buckets, device - ), - prefix_tail_buckets=_move_bucket_plans(plan.prefix_tail_buckets, device), - completion_with_prefix_tail_buckets=_move_bucket_plans( - plan.completion_with_prefix_tail_buckets, device - ), - remote_prefix_tail_buckets=_move_bucket_plans( - plan.remote_prefix_tail_buckets, device - ), - remote_completion_with_prefix_tail_buckets=_move_bucket_plans( - plan.remote_completion_with_prefix_tail_buckets, device - ), - remote_prefix_tail_exchange=move_cp_exchange_plan_to_device( - plan.remote_prefix_tail_exchange, device - ), - remote_prefix_tail_backward_exchange=move_cp_exchange_plan_to_device( - plan.remote_prefix_tail_backward_exchange, device - ), - remote_prefix_tail_state_transfers=_move_parent_state_transfers( - plan.remote_prefix_tail_state_transfers, device - ), - ) - - -def _move_bucket_plans( - buckets: tuple[GdnSegmentBucketPlan, ...], - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - GdnSegmentBucketPlan.model_construct( - length=bucket.length, - lengths=_move_planner_tensor(bucket.lengths, device), - lengths_cpu=bucket.lengths_cpu, - lengths_by_rank_cpu=bucket.lengths_by_rank_cpu, - real_mask=_move_planner_tensor(bucket.real_mask, device), - cu_seqlens=_move_planner_tensor(bucket.cu_seqlens, device), - cu_seqlens_cpu=bucket.cu_seqlens_cpu, - row_indices=_move_planner_tensor(bucket.row_indices, device), - position_indices=_move_planner_tensor(bucket.position_indices, device), - family_indices=_move_planner_tensor(bucket.family_indices, device), - real_token_count_static=bucket.real_token_count, - output_mask=( - _move_planner_tensor(bucket.output_mask, device) - if bucket.output_mask is not None - else None - ), - ) - for bucket in buckets - ) - - -def _move_parent_state_transfers( - transfers: tuple[GdnParentStateTransferPlan, ...], - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - return tuple( - GdnParentStateTransferPlan.model_construct( - source_rank=transfer.source_rank, - dest_rank=transfer.dest_rank, - family_indices=transfer.family_indices, - family_indices_tensor=( - _move_planner_tensor(transfer.family_indices_tensor, device) - if transfer.family_indices_tensor is not None - else None - ), - ) - for transfer in transfers + attention_token_layout_index=attention_token_layout_index, + planner_config=planner_config, ) -def _build_local_attention_layout_rank_execution_plan( +def _build_tree_rank_execution_plan( spec: GdnPackedExecutionSpec, *, device: torch.device | str, @@ -554,14 +200,17 @@ def _build_local_attention_layout_rank_execution_plan( cp_size: int, attention_token_layout_index: TokenLayoutIndex | None, planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan | None: - if cp_size <= 1 or not spec.families: - return None - if any( - _has_chainable_segment(family, cp_size=cp_size, planner_config=planner_config) - for family in spec.families - ): - return None +) -> GdnRankExecutionPlan: + if cp_size < 1: + raise ValueError(f"cp_size must be >= 1, got {cp_size}") + if cp_rank < 0 or cp_rank >= cp_size: + raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") + if not spec.tree_segments: + raise ValueError("tree GDN planning requires tree segments") + if len(spec.tree_parent_indices) != len(spec.tree_segments): + raise ValueError("tree parent metadata length must match tree segments") + if len(spec.tree_depths) != len(spec.tree_segments): + raise ValueError("tree depth metadata length must match tree segments") from art.megatron.gdn.layout import ( _reverse_exchange_plan, @@ -575,2746 +224,397 @@ def _build_local_attention_layout_rank_execution_plan( planner_config=planner_config, ) attention_layout_index = _build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max(1, 2 * spec.real_token_count // len(tuple(spec.segments()))), + source_layout ) segment_attention_counts = _segment_attention_rank_counts( spec, cp_size=cp_size, attention_layout_index=attention_layout_index, ) - best = _assign_local_attention_segments( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - co_locate_local_families=False, - planner_config=planner_config, - ) - co_located = _assign_local_attention_segments( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - co_locate_local_families=True, - planner_config=planner_config, - ) - if co_located[4] < best[4]: - best = co_located - ( - prefix_owner_by_family, - completion_owners_by_family, - _, - cross_rank_token_count, - _, - ) = best - - local_prefix_segments: list[GdnSegmentSpec] = [] - local_completion_segments: list[GdnSegmentSpec] = [] - prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [[] for _ in range(cp_size)] - completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] - rank_loads = [0] * cp_size - parent_state_exchange_families: set[int] = set() - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - - def append_segment(rank: int, segment: GdnSegmentSpec) -> None: - token_start = _segment_token_start(segment, spec.sequence_length) - position_start = rank_loads[rank] - gdn_ranges_by_rank[rank].append( - (token_start, token_start + segment.length, position_start) - ) - rank_loads[rank] += segment.length - - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - if prefix_owner == cp_rank: - local_prefix_segments.append(family.prefix) - prefix_segments_by_rank[prefix_owner].append(family.prefix) - append_segment(prefix_owner, family.prefix) - completion_owners = completion_owners_by_family[family.family_index] - for completion, completion_owner in zip( - family.completions, completion_owners, strict=True - ): - if completion_owner == cp_rank: - local_completion_segments.append(completion) - completion_segments_by_rank[completion_owner].append(completion) - append_segment(completion_owner, completion) - if completion_owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - parent_state_transfer_families.setdefault( - (prefix_owner, completion_owner), set() - ).add(family.family_index) - - local_token_ranges = tuple(gdn_ranges_by_rank[cp_rank]) - local_token_count = rank_loads[cp_rank] - schedule = GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=tuple(rank_loads), - gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - cross_rank_token_count=cross_rank_token_count, - chain_prefix_buckets=(), - chain_completion_buckets=(), - local_prefix_segments_by_rank=tuple( - tuple(segments) for segments in prefix_segments_by_rank - ), - local_completion_segments_by_rank=tuple( - tuple(segments) for segments in completion_segments_by_rank - ), - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families) - ), - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - if parent_state_transfer_families: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( - source_layout=source_layout, - device=device, - dest_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - local_rank=cp_rank, - cross_rank_token_count=cross_rank_token_count, - ) - gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) - local_prefix_family_indices = { - segment.family_index for segment in local_prefix_segments - } - local_prefix_buckets = _batch_segments_by_padded_work( - (), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - chunk_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index in local_prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - plain_local_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=(), - ) - ) - ready_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - prefix_family_order = tuple( - segment.family_index for bucket in local_prefix_buckets for segment in bucket - ) - ready_completion_bucket_plans = _build_position_bucket_plans( - ready_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - remote_completion_bucket_plans = _build_position_bucket_plans( - remote_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - tuple(local_prefix_segments), - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=_build_position_bucket_plans( - local_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - local_completion_buckets=( - ready_completion_bucket_plans + remote_completion_bucket_plans - ), - ready_local_completion_buckets=ready_completion_bucket_plans, - remote_local_completion_buckets=remote_completion_bucket_plans, - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=( - not local_prefix_segments - and prefix_family_order == tuple(range(spec.family_count)) - ), - attention_to_gdn=attention_to_gdn, - gdn_to_attention=gdn_to_attention, - attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], - gdn_token_ranges=local_token_ranges, - attention_token_count=source_layout.token_counts_by_rank[cp_rank], - gdn_token_count=local_token_count, - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families - remote_prefix_tail_families) - ), - parent_state_transfers=_filter_parent_state_transfers( - _build_parent_state_transfer_plans(parent_state_transfer_families), - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - -def _assign_local_attention_segments( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[int, ...], - tuple[tuple[int, ...], ...], - tuple[int, ...], - int, - float, -]: + depth_count = max(spec.tree_depths, default=0) + 1 rank_loads = [0] * cp_size - has_prefix = [False] * cp_size - has_completion = [False] * cp_size - prefix_owner_by_family: list[int] = [] - completion_owners_by_family: list[tuple[int, ...]] = [] - parent_state_exchange_families: set[int] = set() + owner_by_node = [-1] * len(spec.tree_segments) + chained_nodes = [False] * len(spec.tree_segments) + tree_has_children = [False] * len(spec.tree_segments) + for parent_index in spec.tree_parent_indices: + if parent_index >= 0: + tree_has_children[parent_index] = True + gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] + segments_by_rank_depth: list[list[list[GdnSegmentSpec]]] = [ + [[] for _ in range(depth_count)] for _ in range(cp_size) + ] + chain_segments_by_depth: list[list[GdnSegmentSpec]] = [ + [] for _ in range(depth_count) + ] cross_rank_token_count = 0 - def append_owner(rank: int, segment: GdnSegmentSpec) -> None: + children_by_node: list[list[int]] = [[] for _ in spec.tree_segments] + root_indices: list[int] = [] + for node_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0: + root_indices.append(node_index) + else: + children_by_node[parent_index].append(node_index) + + def subtree_indices(root_index: int) -> tuple[int, ...]: + ordered: list[int] = [] + stack = [root_index] + while stack: + node_index = stack.pop() + ordered.append(node_index) + stack.extend(reversed(children_by_node[node_index])) + return tuple(ordered) + + def assign_local_group(node_indices: tuple[int, ...]) -> None: nonlocal cross_rank_token_count - rank_loads[rank] += segment.length - cross_rank_token_count += ( - segment.length - segment_attention_counts[_segment_key(segment)][rank] - ) - - for family in spec.families: - if co_locate_local_families: - owner = _best_segment_owner( - (family.prefix, *family.completions), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - prefix_owner_by_family.append(owner) - completion_owners = tuple(owner for _ in family.completions) - completion_owners_by_family.append(completion_owners) - has_prefix[owner] = True - for segment in (family.prefix, *family.completions): - append_owner(owner, segment) - if family.completions: - has_completion[owner] = True - continue - - prefix_owner = _best_segment_owner( - (family.prefix,), + segments = tuple(spec.tree_segments[index] for index in node_indices) + owner = _best_segment_owner( + segments, rank_loads, segment_attention_counts=segment_attention_counts, planner_config=planner_config, ) - prefix_owner_by_family.append(prefix_owner) - has_prefix[prefix_owner] = True - append_owner(prefix_owner, family.prefix) - completion_owners = [] - for completion in family.completions: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - completion_owners.append(owner) - has_completion[owner] = True - append_owner(owner, completion) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - completion_owners_by_family.append(tuple(completion_owners)) - - del has_prefix, has_completion - score = _score_local_segment_assignment( - spec, - cp_size=cp_size, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - rank_loads=tuple(rank_loads), - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=len(parent_state_exchange_families), - planner_config=planner_config, - ) - return ( - tuple(prefix_owner_by_family), - tuple(completion_owners_by_family), - tuple(sorted(parent_state_exchange_families)), - cross_rank_token_count, - score, - ) - - -def _score_local_segment_assignment( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], - rank_loads: tuple[int, ...], - cross_rank_token_count: int, - parent_state_exchange_family_count: int, - planner_config: GdnPlannerConfig, -) -> float: - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - completion_owners = completion_owners_by_family[family.family_index] - for completion, completion_owner in zip( - family.completions, completion_owners, strict=True - ): - local_completion_segments_by_rank[completion_owner].append(completion) - ( - local_work_by_rank, - local_bucket_count, - local_segment_count, - ) = _estimate_local_rank_kernel_work( - tuple(tuple(segments) for segments in local_prefix_segments_by_rank), - tuple(tuple(segments) for segments in local_completion_segments_by_rank), - planner_config=planner_config, - ) - return _score_cp_segment_stats( - rank_local_work=local_work_by_rank, - rank_chain_work=tuple(0 for _ in range(cp_size)), - rank_real_tokens=rank_loads, - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=parent_state_exchange_family_count, - local_bucket_count=local_bucket_count, - local_segment_count=local_segment_count, - chain_bucket_count=0, - planner_config=planner_config, - ) - - -def _can_zero_exchange_colocate_families( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], -) -> bool: - for family in spec.families: - family_rank_counts = [0] * cp_size - for segment in (family.prefix, *family.completions): - segment_counts = segment_attention_counts[_segment_key(segment)] - for rank in range(cp_size): - family_rank_counts[rank] += segment_counts[rank] - if max(family_rank_counts, default=0) != family.token_count: - return False - return True - - -def parse_gdn_shared_prefix_segments( - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - *, - min_completions_per_family: int = 0, -) -> GdnPackedExecutionSpec: - """Parse ART packed shared-prefix metadata into a GDN segment DAG. - - The parser is intentionally strict: GDN state routing depends on prompt-family - boundaries, so malformed metadata should fail before execution can silently - leak recurrent or conv state across siblings or independent families. - """ - - groups = _rank2_long_cpu("group_ids", group_ids) - parents = _rank2_long_cpu("parent_ids", parent_ids) - if tuple(groups.shape) != tuple(parents.shape): - raise ValueError( - "group_ids and parent_ids must have the same shape, got " - f"{tuple(groups.shape)} and {tuple(parents.shape)}" - ) - - batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) - valid_lengths: list[int] = [] - families: list[GdnPackedFamilySpec] = [] - for row_index in range(batch_size): - row_group_ids = groups[row_index] - row_parent_ids = parents[row_index] - valid_length = _validate_padding_tensor( - row_index, row_group_ids, row_parent_ids - ) - valid_lengths.append(valid_length) - if valid_length == 0: - continue - families.extend( - _parse_row_tensor( - row_index=row_index, - group_ids=row_group_ids, - parent_ids=row_parent_ids, - valid_length=valid_length, - first_family_index=len(families), - min_completions_per_family=min_completions_per_family, - ) - ) - - return GdnPackedExecutionSpec( - batch_size=batch_size, - sequence_length=sequence_length, - valid_lengths=tuple(valid_lengths), - families=tuple(families), - ) - - -def _build_segment_bucket_plans( - segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], - *, - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - _build_segment_bucket_plan(bucket[0].length, bucket, device=device) - for bucket in segment_buckets - ) - - -def _build_chunk_aligned_cp1_bucket_plans( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], -]: - boundary_segments: list[GdnSegmentSpec] = [] - tail_segments: list[GdnSegmentSpec] = [] - completion_columns: list[_ExplicitBucketColumn] = [] - for family in spec.families: - prefix = family.prefix - boundary_end = _prefix_chunk_boundary_end(prefix) - if boundary_end > prefix.start: - boundary_segments.append( - _segment_with_bounds(prefix, prefix.start, boundary_end) - ) - prefix_tail_positions = tuple(range(boundary_end, prefix.end)) - if prefix_tail_positions and not family.completions: - tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) - for child_offset, completion in enumerate(family.completions): - completion_positions = prefix_tail_positions + tuple( - range(completion.start, completion.end) - ) - completion_columns.append( - _explicit_bucket_column( - row_index=completion.row_index, - family_index=completion.family_index, - positions=completion_positions, - output_mask=( - ((child_offset == 0),) * len(prefix_tail_positions) - + (True,) * completion.length - ), - ) - ) - boundary_buckets = _batch_segments_by_padded_work( - tuple(boundary_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_buckets = _batch_segments_by_padded_work( - tuple(tail_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_segment_bucket_plans(boundary_buckets, device=device), - _build_segment_bucket_plans(tail_buckets, device=device), - _build_explicit_bucket_plans(completion_column_batches, device=device), - ) - - -def _build_chunk_aligned_position_bucket_plans( - prefix_segments: tuple[GdnSegmentSpec, ...], - completion_segments: tuple[GdnSegmentSpec, ...], - local_token_ranges: tuple[tuple[int, int, int], ...], - *, - sequence_length: int, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], -]: - local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) - local_range_positions = { - (token_start, token_end): position_start - for token_start, token_end, position_start in local_token_ranges - } - completions_by_family: dict[int, list[GdnSegmentSpec]] = {} - for completion in completion_segments: - completions_by_family.setdefault(completion.family_index, []).append(completion) - boundary_segments: list[GdnSegmentSpec] = [] - tail_segments: list[GdnSegmentSpec] = [] - completion_columns: list[_ExplicitBucketColumn] = [] - for prefix in prefix_segments: - boundary_end = _prefix_chunk_boundary_end(prefix) - if boundary_end > prefix.start: - boundary_segments.append( - _segment_with_bounds(prefix, prefix.start, boundary_end) - ) - family_completions = tuple(completions_by_family.get(prefix.family_index, ())) - prefix_tail_positions = _local_positions_for_span( - prefix.row_index, - boundary_end, - prefix.end, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - local_range_positions=local_range_positions, - ) - if prefix_tail_positions and not family_completions: - tail_segments.append(_segment_with_bounds(prefix, boundary_end, prefix.end)) - for child_offset, completion in enumerate(family_completions): - completion_positions = _local_positions_for_span( - completion.row_index, - completion.start, - completion.end, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - local_range_positions=local_range_positions, - ) - positions = prefix_tail_positions + completion_positions - completion_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=completion.family_index, - positions=positions, - output_mask=( - ((child_offset == 0),) * len(prefix_tail_positions) - + (True,) * len(completion_positions) - ), - ) - ) - boundary_buckets = _batch_segments_by_padded_work( - tuple(boundary_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_buckets = _batch_segments_by_padded_work( - tuple(tail_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_position_bucket_plans( - boundary_buckets, - local_token_ranges, - sequence_length=sequence_length, - device=device, - ), - _build_position_bucket_plans( - tail_buckets, - local_token_ranges, - sequence_length=sequence_length, - device=device, - ), - _build_explicit_bucket_plans(completion_column_batches, device=device), - ) - - -def _build_remote_prefix_tail_plans( - spec: GdnPackedExecutionSpec, - schedule: GdnCpSegmentSchedule, - *, - cp_rank: int, - device: torch.device | str, - planner_config: GdnPlannerConfig, -) -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - Any | None, - Any | None, - tuple[GdnParentStateTransferPlan, ...], - frozenset[int], -]: - from art.megatron.gdn.layout import ( - GdnCpExchangePlan, - GdnCpPeerTransfer, - _reverse_exchange_plan, - ) - - family_by_index = {family.family_index: family for family in spec.families} - prefix_owner_by_family = _prefix_owner_by_family(schedule) - source_positions_by_pair: dict[tuple[int, int], list[int]] = {} - dest_positions_by_pair: dict[tuple[int, int], list[int]] = {} - dest_counts = [0 for _ in schedule.gdn_token_counts_by_rank] - state_transfer_families: dict[tuple[int, int], set[int]] = {} - remote_tail_family_indices: set[int] = set() - local_tail_columns: list[_ExplicitBucketColumn] = [] - local_completion_columns: list[_ExplicitBucketColumn] = [] - tail_positions_by_dest_family: dict[tuple[int, int], tuple[int, ...]] = {} - local_tail_column_families: set[int] = set() - rank_ranges = schedule.gdn_token_ranges_by_rank - rank_range_ends = tuple( - tuple(end for _, end, _ in ranges) for ranges in rank_ranges - ) - rank_range_positions = tuple( - { - (token_start, token_end): position_start - for token_start, token_end, position_start in ranges - } - for ranges in rank_ranges - ) - - for dest_rank, completions in enumerate(schedule.local_completion_segments_by_rank): - for completion in completions: - source_rank = prefix_owner_by_family.get(completion.family_index) - if source_rank is None or source_rank == dest_rank: - continue - family = family_by_index[completion.family_index] - boundary_end = _prefix_chunk_boundary_end(family.prefix) - if boundary_end == family.prefix.end: - continue - dest_family = (dest_rank, family.family_index) - dest_positions = tail_positions_by_dest_family.get(dest_family) - if dest_positions is None: - source_positions = _local_positions_for_span( - family.prefix.row_index, - boundary_end, - family.prefix.end, - sequence_length=spec.sequence_length, - local_token_ranges=rank_ranges[source_rank], - local_range_ends=rank_range_ends[source_rank], - local_range_positions=rank_range_positions[source_rank], - ) - if len(source_positions) != family.prefix.end - boundary_end: - raise ValueError( - "remote prefix-tail exchange could not locate all source tokens " - f"for family {family.family_index}" - ) - dest_start = dest_counts[dest_rank] - dest_positions = tuple( - range(dest_start, dest_start + len(source_positions)) - ) - tail_positions_by_dest_family[dest_family] = dest_positions - dest_counts[dest_rank] += len(source_positions) - pair = (source_rank, dest_rank) - source_positions_by_pair.setdefault(pair, []).extend(source_positions) - dest_positions_by_pair.setdefault(pair, []).extend(dest_positions) - state_transfer_families.setdefault(pair, set()).add(family.family_index) - remote_tail_family_indices.add(family.family_index) - - if dest_rank != cp_rank: - continue - completion_positions = _local_positions_for_span( - completion.row_index, - completion.start, - completion.end, - sequence_length=spec.sequence_length, - local_token_ranges=rank_ranges[dest_rank], - local_range_ends=rank_range_ends[dest_rank], - local_range_positions=rank_range_positions[dest_rank], - ) - if len(completion_positions) != completion.length: - raise ValueError( - "remote prefix-tail bucket could not locate all completion tokens " - f"for family {family.family_index}" - ) - remote_base = int(schedule.gdn_token_counts_by_rank[dest_rank]) - if ( - len(dest_positions) > 0 - and family.family_index not in local_tail_column_families - ): - local_tail_column_families.add(family.family_index) - local_tail_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=family.family_index, - positions=tuple(remote_base + pos for pos in dest_positions), - output_mask=(False,) * len(dest_positions), - ) - ) - local_completion_columns.append( - _explicit_bucket_column( - row_index=0, - family_index=family.family_index, - positions=completion_positions, - output_mask=(True,) * len(completion_positions), - ) - ) - - if not source_positions_by_pair: - return (), (), None, None, (), frozenset() - - transfers = tuple( - GdnCpPeerTransfer.model_construct( - source_rank=source_rank, - dest_rank=dest_rank, - token_count=len(source_positions), - source_positions_tensor=_move_planner_tensor( - torch.tensor(source_positions, dtype=torch.long), device - ), - dest_positions_tensor=_move_planner_tensor( - torch.tensor( - dest_positions_by_pair[(source_rank, dest_rank)], - dtype=torch.long, - ), - device, - ), - ) - for (source_rank, dest_rank), source_positions in sorted( - source_positions_by_pair.items() - ) - ) - exchange = GdnCpExchangePlan.model_construct( - cp_size=len(schedule.gdn_token_counts_by_rank), - source_token_counts_by_rank=schedule.gdn_token_counts_by_rank, - dest_token_counts_by_rank=tuple(dest_counts), - transfers=transfers, - cross_rank_token_count_override=sum(dest_counts), - ) - tail_column_batches = _batch_explicit_bucket_columns( - tuple(local_tail_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_column_batches = _batch_explicit_bucket_columns( - tuple(local_completion_columns), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - _build_explicit_bucket_plans(tail_column_batches, device=device), - _build_explicit_bucket_plans(completion_column_batches, device=device), - exchange, - _reverse_exchange_plan(exchange), - _transfer_plans_to_device( - _build_parent_state_transfer_plans(state_transfer_families), - device=device, - ), - frozenset(remote_tail_family_indices), - ) - - -def _empty_remote_prefix_tail_plans() -> tuple[ - tuple[GdnSegmentBucketPlan, ...], - tuple[GdnSegmentBucketPlan, ...], - Any | None, - Any | None, - tuple[GdnParentStateTransferPlan, ...], - frozenset[int], -]: - return (), (), None, None, (), frozenset() - - -def _prefix_owner_by_family(schedule: GdnCpSegmentSchedule) -> dict[int, int]: - owners: dict[int, int] = {} - for rank, segments in enumerate(schedule.local_prefix_segments_by_rank): - for segment in segments: - owners[segment.family_index] = rank - return owners - - -def _filter_parent_state_transfers( - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - excluded_families: frozenset[int], - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - if not excluded_families: - return _transfer_plans_to_device(transfers, device=device) - kept: dict[tuple[int, int], set[int]] = {} - for transfer in transfers: - families = set(transfer.family_indices) - excluded_families - if families: - kept.setdefault((transfer.source_rank, transfer.dest_rank), set()).update( - families - ) - return _transfer_plans_to_device( - _build_parent_state_transfer_plans(kept), device=device - ) - - -def _local_positions_for_span( - row_index: int, - start: int, - end: int, - *, - sequence_length: int, - local_token_ranges: tuple[tuple[int, int, int], ...], - local_range_ends: tuple[int, ...], - local_range_positions: dict[tuple[int, int], int] | None = None, -) -> tuple[int, ...]: - if start == end: - return () - token_start = row_index * sequence_length + start - token_end = row_index * sequence_length + end - if local_range_positions is not None: - position_start = local_range_positions.get((token_start, token_end)) - if position_start is not None: - return tuple(range(position_start, position_start + end - start)) - range_index = bisect_left(local_range_ends, token_start + 1) - if range_index < len(local_token_ranges): - range_start, range_end, position_start = local_token_ranges[range_index] - if range_start <= token_start and token_end <= range_end: - local_start = position_start + token_start - range_start - return tuple(range(local_start, local_start + end - start)) - segment = _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=0, - group_id=0, - parent_id=0, - start=start, - end=end, - kind="prefix", - child_index=None, - ) - return tuple( - int(position) - for position in _local_positions_for_segment( - segment, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - ).tolist() - ) - - -def _prefix_chunk_boundary_end(prefix: GdnSegmentSpec) -> int: - aligned_length = (prefix.length // FLA_CHUNK_SIZE) * FLA_CHUNK_SIZE - return prefix.start + aligned_length - - -def _segment_with_bounds( - segment: GdnSegmentSpec, start: int, end: int -) -> GdnSegmentSpec: - return _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=segment.row_index, - family_index=segment.family_index, - group_id=segment.group_id, - parent_id=segment.parent_id, - start=start, - end=end, - kind=segment.kind, - child_index=segment.child_index, - ) - - -def _batch_explicit_bucket_columns( - columns: tuple[_ExplicitBucketColumn, ...], - *, - max_padding_ratio: float = 1.25, - max_segments_per_batch: int = 128, -) -> tuple[tuple[_ExplicitBucketColumn, ...], ...]: - if not columns: - return () - ordered = sorted( - columns, - key=lambda column: (column.length, column.family_index, column.row_index), - ) - batches: list[list[_ExplicitBucketColumn]] = [] - current: list[_ExplicitBucketColumn] = [] - current_tokens = 0 - current_max = 0 - for column in ordered: - next_count = len(current) + 1 - next_tokens = current_tokens + column.length - next_max = max(current_max, column.length) - padded = next_max * next_count - can_extend = not current or ( - next_count <= max_segments_per_batch - and padded <= max_padding_ratio * next_tokens - ) - if not can_extend: - batches.append(current) - current = [] - current_tokens = 0 - current_max = 0 - current.append(column) - current_tokens += column.length - current_max = max(current_max, column.length) - if current: - batches.append(current) - return tuple(tuple(batch) for batch in batches) - - -def _build_explicit_bucket_plans( - bucket_columns: tuple[tuple[_ExplicitBucketColumn, ...], ...], - *, - device: torch.device | str, -) -> tuple[GdnSegmentBucketPlan, ...]: - return tuple( - _build_explicit_bucket_plan(columns, device=device) - for columns in bucket_columns - ) - - -def _build_explicit_bucket_plan( - columns: tuple[_ExplicitBucketColumn, ...], - *, - device: torch.device | str, -) -> GdnSegmentBucketPlan: - max_length = max(column.length for column in columns) - column_count = len(columns) - lengths = [column.length for column in columns] - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - padded_element_count = max_length * column_count - row_indices = [0] * padded_element_count - position_indices = [0] * padded_element_count - output_mask = [False] * padded_element_count - for column_index, column in enumerate(columns): - length = column.length - column_slice = slice(column_index, length * column_count, column_count) - row_indices[column_slice] = [column.row_index] * length - position_indices[column_slice] = column.positions - output_mask[column_slice] = column.output_mask - row_indices_cpu = torch.tensor(row_indices, dtype=torch.long).reshape( - max_length, column_count - ) - position_indices_cpu = torch.tensor(position_indices, dtype=torch.long).reshape( - max_length, column_count - ) - output_mask_cpu = torch.tensor(output_mask, dtype=torch.bool).reshape( - max_length, column_count - ) - family_indices_cpu = torch.tensor( - [column.family_index for column in columns], dtype=torch.long - ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), - lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=None, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=int(lengths_cpu.sum().item()), - output_mask=_move_planner_tensor(output_mask_cpu, device), - ) - - -def _attention_source_layout( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None, - planner_config: GdnPlannerConfig, -) -> TokenLayoutIndex: - if attention_token_layout_index is not None: - if _layout_cp_size(attention_token_layout_index) != cp_size: - raise ValueError( - "attention token layout index cp_size must match GDN cp_size, got " - f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" - ) - if _layout_token_count(attention_token_layout_index) != spec.real_token_count: - raise ValueError( - "attention token layout index token count must match GDN real token " - f"count, got {_layout_token_count(attention_token_layout_index)} and " - f"{spec.real_token_count}" - ) - return attention_token_layout_index - return _token_layout_from_rank_ranges( - _default_attention_layout_ranges( - spec, - cp_size=cp_size, - planner_config=planner_config, - ) - ) - - -def _build_cp_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None, - cp_segment_schedule: GdnCpSegmentSchedule | None, - planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan: - if cp_size < 1: - raise ValueError(f"cp_size must be >= 1, got {cp_size}") - if cp_rank < 0 or cp_rank >= cp_size: - raise ValueError(f"cp_rank must be in [0, {cp_size}), got {cp_rank}") - if ( - attention_token_layout_index is not None - and _layout_cp_size(attention_token_layout_index) != cp_size - ): - raise ValueError( - "attention token layout index cp_size must match GDN cp_size, got " - f"{_layout_cp_size(attention_token_layout_index)} and {cp_size}" - ) - - from art.megatron.gdn.layout import ( - _reverse_exchange_plan, - build_local_rank_cp_exchange_plan_from_dest_ranges, - ) - - has_explicit_attention_layout = attention_token_layout_index is not None - if cp_segment_schedule is None and not has_explicit_attention_layout: - local_family_plan = _build_local_family_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - planner_config=planner_config, - ) - if local_family_plan is not None: - return local_family_plan - if cp_segment_schedule is None and has_explicit_attention_layout: - local_layout_plan = _build_local_attention_layout_rank_execution_plan( - spec, - device=device, - cp_rank=cp_rank, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - if local_layout_plan is not None: - return local_layout_plan - - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - if cp_segment_schedule is None: - schedule = _build_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=_build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max( - 1, - (2 * spec.real_token_count) // max(1, len(spec.segments())), - ), - ), - planner_config=planner_config, - ) - else: - schedule = cp_segment_schedule - if len(schedule.gdn_token_counts_by_rank) != cp_size: - raise ValueError(f"CP GDN schedule must contain {cp_size} ranks") - attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( - source_layout=source_layout, - device=device, - local_rank=cp_rank, - dest_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - cross_rank_token_count=schedule.cross_rank_token_count, - ) - gdn_to_attention = _reverse_exchange_plan(attention_to_gdn) - local_token_ranges = schedule.gdn_token_ranges_by_rank[cp_rank] - local_gdn_token_count = schedule.gdn_token_counts_by_rank[cp_rank] - if schedule.parent_state_exchange_family_indices: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, - device=device, - planner_config=planner_config, - ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - - chain_prefix_buckets = tuple( - bucket for bucket in schedule.chain_prefix_buckets if bucket - ) - chain_completion_buckets = tuple( - bucket for bucket in schedule.chain_completion_buckets if bucket - ) - local_prefix_segments = tuple(schedule.local_prefix_segments_by_rank[cp_rank]) - local_prefix_family_indices = { - segment.family_index for segment in local_prefix_segments - } - local_prefix_buckets = _batch_segments_by_padded_work( - () if local_prefix_segments else (), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - local_completion_segments = tuple( - schedule.local_completion_segments_by_rank[cp_rank] - ) - chunk_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index in local_prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in local_completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families - ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - plain_local_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=chain_prefix_buckets, - ) - ) - ready_local_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_local_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - local_completion_buckets = ( - ready_local_completion_buckets + remote_local_completion_buckets - ) - prefix_family_order = tuple( - segment.family_index - for bucket in ( - *chain_prefix_buckets, - *local_prefix_buckets, - ) - for segment in bucket - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - local_prefix_segments, - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - planner_config=planner_config, - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=local_gdn_token_count, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_gdn_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=_build_position_bucket_plans( - local_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - local_completion_buckets=_build_position_bucket_plans( - local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - ready_local_completion_buckets=_build_position_bucket_plans( - ready_local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - remote_local_completion_buckets=_build_position_bucket_plans( - remote_local_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ), - chain_prefix_buckets=_build_position_bucket_plans( - chain_prefix_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - token_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - ), - chain_completion_buckets=_build_position_bucket_plans( - chain_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - token_ranges_by_rank=schedule.gdn_token_ranges_by_rank, - ), - prefix_table_is_dense_ordered=( - not local_prefix_segments - and prefix_family_order == tuple(range(spec.family_count)) - ), - attention_to_gdn=attention_to_gdn, - gdn_to_attention=gdn_to_attention, - attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], - gdn_token_ranges=local_token_ranges, - attention_token_count=source_layout.token_counts_by_rank[cp_rank], - gdn_token_count=local_gdn_token_count, - parent_state_exchange_family_indices=( - tuple( - family_index - for family_index in schedule.parent_state_exchange_family_indices - if family_index not in remote_prefix_tail_families - ) - ), - parent_state_transfers=_filter_parent_state_transfers( - schedule.parent_state_transfers, - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, - ) - - -def build_gdn_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_token_layout_index: TokenLayoutIndex | None = None, - planner_config: GdnPlannerConfig | None = None, -) -> GdnCpSegmentSchedule: - planner_config = planner_config or GdnPlannerConfig() - source_layout = _attention_source_layout( - spec, - cp_size=cp_size, - attention_token_layout_index=attention_token_layout_index, - planner_config=planner_config, - ) - return _build_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=_build_attention_layout_index_from_token_layout( - source_layout, - max_ranges=max( - 1, (2 * spec.real_token_count) // max(1, len(spec.segments())) - ), - ), - planner_config=planner_config, - ) - - -def _build_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - planner_config: GdnPlannerConfig, -) -> GdnCpSegmentSchedule: - segment_attention_counts = _segment_attention_rank_counts( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - legal_chain_segments = tuple( - segment - for family in spec.families - for segment in (family.prefix, *family.completions) - if ( - _can_chain_prefix_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - if segment.kind == "prefix" - else _can_chain_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - ) - ) - decision = _beam_search_cp_segment_schedule_decision( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - segment_attention_counts=segment_attention_counts, - legal_chain_segments=legal_chain_segments, - planner_config=planner_config, - ) - return _materialize_cp_segment_schedule( - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - segment_attention_counts=segment_attention_counts, - chain_segment_keys=decision.chain_segment_keys, - co_locate_local_families=decision.co_locate_local_families, - planner_config=planner_config, - ) - - -def _beam_search_cp_segment_schedule_decision( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - legal_chain_segments: tuple[GdnSegmentSpec, ...], - planner_config: GdnPlannerConfig, -) -> _GdnCpSegmentSearchDecision: - legal_chain_keys = frozenset( - _segment_key(segment) for segment in legal_chain_segments - ) - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]] = {} - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int] = {} - for segment in legal_chain_segments: - key = _segment_key(segment) - ( - chain_rank_counts_by_key[key], - chain_cross_rank_tokens_by_key[key], - ) = _chain_segment_rank_counts_and_cross_rank_tokens( - segment, - spec, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - - score_cache: dict[ - frozenset[GdnSegmentDecisionKey], _GdnCpSegmentSearchDecision - ] = {} - - def decision_for( - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - ) -> _GdnCpSegmentSearchDecision: - cached = score_cache.get(chain_segment_keys) - if cached is not None: - return cached - non_colocated_score = _score_cp_segment_decisions( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - chain_segment_keys=chain_segment_keys, - co_locate_local_families=False, - planner_config=planner_config, - ) - colocated_score = _score_cp_segment_decisions( - spec, - cp_size=cp_size, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - chain_segment_keys=chain_segment_keys, - co_locate_local_families=True, - planner_config=planner_config, - ) - co_locate = colocated_score < non_colocated_score - decision = _GdnCpSegmentSearchDecision.model_construct( - chain_segment_keys=chain_segment_keys, - co_locate_local_families=co_locate, - score=colocated_score if co_locate else non_colocated_score, - ) - score_cache[chain_segment_keys] = decision - return decision - - best = decision_for(frozenset()) - beam_by_keys = {best.chain_segment_keys: best} - if legal_chain_keys: - all_chain = decision_for(legal_chain_keys) - beam_by_keys[all_chain.chain_segment_keys] = all_chain - if best.score - all_chain.score > planner_config.cp_chain_min_score_delta_ms: - best = all_chain - candidate_groups = _bounded_chain_candidate_groups( - spec, - legal_chain_segments, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ) - beam = _best_cp_segment_search_decisions( - beam_by_keys.values(), - limit=planner_config.cp_chain_beam_width, - ) - stale_steps = 0 - for _ in range(planner_config.cp_chain_beam_max_steps): - if not candidate_groups: - break - expanded: dict[ - frozenset[GdnSegmentDecisionKey], _GdnCpSegmentSearchDecision - ] = {} - for decision in beam: - neighbors = [] - for segment_keys in _chain_beam_neighbor_groups( - decision.chain_segment_keys, - candidate_groups=candidate_groups, - branch_factor=planner_config.cp_chain_beam_branch_factor, - ): - if segment_keys.issubset(decision.chain_segment_keys): - next_keys = decision.chain_segment_keys - segment_keys - else: - next_keys = decision.chain_segment_keys | segment_keys - neighbors.append(decision_for(frozenset(next_keys))) - for neighbor in _best_cp_segment_search_decisions( - neighbors, - limit=planner_config.cp_chain_beam_branch_factor, - ): - expanded[neighbor.chain_segment_keys] = neighbor - if not expanded: - break - beam = _best_cp_segment_search_decisions( - (*beam, *expanded.values()), - limit=planner_config.cp_chain_beam_width, - ) - step_best = beam[0] - if best.score - step_best.score > planner_config.cp_chain_min_score_delta_ms: - best = step_best - stale_steps = 0 - else: - stale_steps += 1 - if stale_steps >= 2: - break - return best - - -def _chain_beam_neighbor_groups( - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - *, - candidate_groups: tuple[frozenset[GdnSegmentDecisionKey], ...], - branch_factor: int, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - selected: list[frozenset[GdnSegmentDecisionKey]] = [] - for group in candidate_groups: - if group and not group.issubset(chain_segment_keys): - selected.append(group) - if len(selected) >= branch_factor: - return tuple(selected) - for group in reversed(candidate_groups): - if group and group.intersection(chain_segment_keys) and group not in selected: - selected.append(group) - if len(selected) >= branch_factor: - break - return tuple(selected) - - -def _best_cp_segment_search_decisions( - decisions: Any, - *, - limit: int, -) -> tuple[_GdnCpSegmentSearchDecision, ...]: - return tuple( - sorted( - decisions, - key=lambda decision: ( - decision.score, - len(decision.chain_segment_keys), - tuple(sorted(decision.chain_segment_keys)), - ), - )[:limit] - ) - - -def _bounded_chain_candidate_groups( - spec: GdnPackedExecutionSpec, - legal_chain_segments: tuple[GdnSegmentSpec, ...], - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - legal_key_set = frozenset(_segment_key(segment) for segment in legal_chain_segments) - if not legal_key_set: - return () - prefix_keys = frozenset( - _segment_key(family.prefix) - for family in spec.families - if _segment_key(family.prefix) in legal_key_set - ) - completion_keys = legal_key_set - prefix_keys - groups: list[frozenset[GdnSegmentDecisionKey]] = [] - for group in (legal_key_set, prefix_keys, completion_keys): - if group and group not in groups: - groups.append(group) - for group in _ranked_chain_beam_groups( - spec, - legal_chain_segments, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ): - if group and group not in groups: - groups.append(group) - return tuple(groups[: planner_config.cp_chain_beam_candidate_limit]) - - -def _ranked_chain_beam_groups( - spec: GdnPackedExecutionSpec, - legal_chain_segments: tuple[GdnSegmentSpec, ...], - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[frozenset[GdnSegmentDecisionKey], ...]: - if not legal_chain_segments: - return () - priority_by_key = { - _segment_key(segment): _chain_beam_segment_priority( - segment, - segment_attention_counts=segment_attention_counts, - chain_rank_counts_by_key=chain_rank_counts_by_key, - ) - for segment in legal_chain_segments - } - legal_key_set = frozenset(priority_by_key) - groups: set[frozenset[GdnSegmentDecisionKey]] = { - frozenset((key,)) for key in legal_key_set - } - for family in spec.families: - completion_keys = frozenset( - _segment_key(completion) - for completion in family.completions - if _segment_key(completion) in legal_key_set - ) - if len(completion_keys) > 1: - groups.add(completion_keys) - family_keys = completion_keys - prefix_key = _segment_key(family.prefix) - if prefix_key in legal_key_set: - family_keys = family_keys | frozenset((prefix_key,)) - if len(family_keys) > 1: - groups.add(family_keys) - ranked = tuple( - sorted( - groups, - key=lambda group: _chain_beam_group_priority( - group, priority_by_key=priority_by_key - ), - reverse=True, - ) - ) - limit = planner_config.cp_chain_beam_candidate_limit - if len(ranked) <= limit: - return ranked - high_count = (limit + 1) // 2 - low_count = limit - high_count - selected = [*ranked[:high_count]] - for group in ranked[-low_count:]: - if group not in selected: - selected.append(group) - return tuple(selected) - - -def _chain_beam_group_priority( - group: frozenset[GdnSegmentDecisionKey], - *, - priority_by_key: dict[GdnSegmentDecisionKey, tuple[int, int, int, int]], -) -> tuple[int, int, int, int, int]: - priorities = tuple(priority_by_key[key] for key in group) - return ( - sum(priority[0] for priority in priorities), - sum(priority[1] for priority in priorities), - max((priority[2] for priority in priorities), default=0), - sum(priority[3] for priority in priorities), - len(group), - ) - - -def _chain_beam_segment_priority( - segment: GdnSegmentSpec, - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], -) -> tuple[int, int, int, int]: - key = _segment_key(segment) - chain_max_load = max(chain_rank_counts_by_key[key], default=0) - best_attention_locality = max(segment_attention_counts[key], default=0) - chain_load_relief = segment.length - chain_max_load - minimum_local_exchange = segment.length - best_attention_locality - return ( - chain_load_relief, - segment.length, - best_attention_locality, - -minimum_local_exchange, - ) - - -def _score_cp_segment_decisions( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int], - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> float: - rank_loads = [0] * cp_size - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - chain_prefix_segments: list[GdnSegmentSpec] = [] - chain_completion_segments: list[GdnSegmentSpec] = [] - parent_state_exchange_families: set[int] = set() - cross_rank_token_count = 0 - - for family in spec.families: - prefix_key = _segment_key(family.prefix) - chain_prefix = prefix_key in chain_segment_keys - local_completions = tuple( - completion - for completion in family.completions - if _segment_key(completion) not in chain_segment_keys - ) - prefix_owner: int | None = None - if chain_prefix: - chain_prefix_segments.append(family.prefix) - cross_rank_token_count += _add_chain_search_load( - rank_loads, - family.prefix, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - ) - else: - owner_segments = ( - (family.prefix, *local_completions) - if co_locate_local_families - else (family.prefix,) - ) - prefix_owner = _best_segment_owner( - owner_segments, - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - cross_rank_token_count += _add_local_search_load( - rank_loads, - prefix_owner, - family.prefix, - segment_attention_counts=segment_attention_counts, - ) - for completion in family.completions: - completion_key = _segment_key(completion) - if completion_key in chain_segment_keys: - chain_completion_segments.append(completion) - cross_rank_token_count += _add_chain_search_load( - rank_loads, - completion, - chain_rank_counts_by_key=chain_rank_counts_by_key, - chain_cross_rank_tokens_by_key=chain_cross_rank_tokens_by_key, - ) - if not chain_prefix: - parent_state_exchange_families.add(family.family_index) - continue - if co_locate_local_families and not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "co-located local completion planning lost the prefix owner" - ) - owner = prefix_owner - else: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local completion planning lost the prefix owner" - ) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - local_completion_segments_by_rank[owner].append(completion) - cross_rank_token_count += _add_local_search_load( - rank_loads, - owner, - completion, - segment_attention_counts=segment_attention_counts, - ) - ( - local_work_by_rank, - local_bucket_count, - local_segment_count, - ) = _estimate_local_rank_kernel_work( - tuple(tuple(segments) for segments in local_prefix_segments_by_rank), - tuple(tuple(segments) for segments in local_completion_segments_by_rank), - planner_config=planner_config, - ) - chain_work_by_rank, chain_bucket_count = _estimate_chain_rank_kernel_work( - cp_size=cp_size, - chain_prefix_segments=tuple(chain_prefix_segments), - chain_completion_segments=tuple(chain_completion_segments), - chain_rank_counts_by_key=chain_rank_counts_by_key, - planner_config=planner_config, - ) - return _score_cp_segment_stats( - rank_local_work=local_work_by_rank, - rank_chain_work=chain_work_by_rank, - rank_real_tokens=tuple(rank_loads), - cross_rank_token_count=cross_rank_token_count, - parent_state_exchange_family_count=len(parent_state_exchange_families), - local_bucket_count=local_bucket_count, - local_segment_count=local_segment_count, - chain_bucket_count=chain_bucket_count, - planner_config=planner_config, - ) - - -def _estimate_local_rank_kernel_work( - local_prefix_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...], - local_completion_segments_by_rank: tuple[tuple[GdnSegmentSpec, ...], ...], - *, - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], int, int]: - rank_work: list[int] = [] - rank_bucket_counts: list[int] = [] - rank_segment_counts: list[int] = [] - for prefix_segments, completion_segments in zip( - local_prefix_segments_by_rank, - local_completion_segments_by_rank, - strict=True, - ): - prefix_family_indices = {segment.family_index for segment in prefix_segments} - chunk_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index in prefix_family_indices - ) - plain_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index not in prefix_family_indices - ) - chunk_work, chunk_bucket_count = _estimate_chunk_aligned_local_work( - prefix_segments, - chunk_local_completion_segments, - planner_config=planner_config, - ) - completion_work, completion_bucket_count = _padded_work_from_lengths( - tuple(segment.length for segment in plain_local_completion_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - rank_work.append(chunk_work + completion_work) - rank_bucket_counts.append(chunk_bucket_count + completion_bucket_count) - rank_segment_counts.append(len(prefix_segments) + len(completion_segments)) - return ( - tuple(rank_work), - max(rank_bucket_counts, default=0), - max(rank_segment_counts, default=0), - ) - - -def _estimate_chunk_aligned_local_work( - prefix_segments: tuple[GdnSegmentSpec, ...], - completion_segments: tuple[GdnSegmentSpec, ...], - *, - planner_config: GdnPlannerConfig, -) -> tuple[int, int]: - completions_by_family: dict[int, list[GdnSegmentSpec]] = {} - for completion in completion_segments: - completions_by_family.setdefault(completion.family_index, []).append(completion) - boundary_lengths: list[int] = [] - tail_lengths: list[int] = [] - completion_column_lengths: list[int] = [] - for prefix in prefix_segments: - boundary_end = _prefix_chunk_boundary_end(prefix) - boundary_length = boundary_end - prefix.start - if boundary_length > 0: - boundary_lengths.append(boundary_length) - tail_length = prefix.end - boundary_end - family_completions = tuple(completions_by_family.get(prefix.family_index, ())) - if tail_length > 0 and not family_completions: - tail_lengths.append(tail_length) - for completion in family_completions: - completion_column_lengths.append(tail_length + completion.length) - boundary_work, boundary_bucket_count = _padded_work_from_lengths( - tuple(boundary_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - tail_work, tail_bucket_count = _padded_work_from_lengths( - tuple(tail_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - completion_work, completion_bucket_count = _padded_work_from_lengths( - tuple(completion_column_lengths), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - return ( - boundary_work + tail_work + completion_work, - boundary_bucket_count + tail_bucket_count + completion_bucket_count, - ) - - -def _estimate_chain_rank_kernel_work( - *, - cp_size: int, - chain_prefix_segments: tuple[GdnSegmentSpec, ...], - chain_completion_segments: tuple[GdnSegmentSpec, ...], - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], int]: - rank_work = [0] * cp_size - bucket_count = 0 - for segments in (chain_prefix_segments, chain_completion_segments): - buckets = _batch_segments_by_padded_work( - segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - bucket_count += len(buckets) - for bucket in buckets: - for rank in range(cp_size): - lengths = tuple( - chain_rank_counts_by_key[_segment_key(segment)][rank] - for segment in bucket - ) - rank_work[rank] += max(lengths, default=0) * len(lengths) - return tuple(rank_work), bucket_count - - -def _padded_work_from_lengths( - lengths: tuple[int, ...], - *, - max_padding_ratio: float, - max_segments_per_batch: int, -) -> tuple[int, int]: - if not lengths: - return 0, 0 - ordered = sorted(length for length in lengths if length > 0) - if not ordered: - return 0, 0 - bucket_count = 0 - padded_work = 0 - current_count = 0 - current_tokens = 0 - current_max = 0 - for length in ordered: - next_count = current_count + 1 - next_tokens = current_tokens + length - next_max = max(current_max, length) - next_padded = next_max * next_count - can_extend = current_count == 0 or ( - next_count <= max_segments_per_batch - and next_padded <= max_padding_ratio * next_tokens - ) - if not can_extend: - bucket_count += 1 - padded_work += current_max * current_count - current_count = 0 - current_tokens = 0 - current_max = 0 - current_count += 1 - current_tokens += length - current_max = max(current_max, length) - if current_count: - bucket_count += 1 - padded_work += current_max * current_count - return padded_work, bucket_count - - -def _add_chain_search_load( - rank_loads: list[int], - segment: GdnSegmentSpec, - *, - chain_rank_counts_by_key: dict[GdnSegmentDecisionKey, tuple[int, ...]], - chain_cross_rank_tokens_by_key: dict[GdnSegmentDecisionKey, int], -) -> int: - key = _segment_key(segment) - for rank, token_count in enumerate(chain_rank_counts_by_key[key]): - rank_loads[rank] += token_count - return chain_cross_rank_tokens_by_key[key] - - -def _add_local_search_load( - rank_loads: list[int], - rank: int, - segment: GdnSegmentSpec, - *, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], -) -> int: - rank_loads[rank] += segment.length - return segment.length - segment_attention_counts[_segment_key(segment)][rank] - - -def _chain_segment_rank_counts_and_cross_rank_tokens( - segment: GdnSegmentSpec, - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, -) -> tuple[tuple[int, ...], int]: - token_start = _segment_token_start(segment, spec.sequence_length) - attention_shards = _attention_contiguous_chain_shards( - token_start, - segment.length, - cp_size=cp_size, - attention_layout_index=attention_layout_index, - ) - if attention_shards is not None: - return tuple(len(shard) for shard in attention_shards), 0 - shard_lengths = _fla_aligned_chain_shard_lengths(segment.length, cp_size=cp_size) - cross_rank_tokens = 0 - start = 0 - for rank, shard_length in enumerate(shard_lengths): - end = start + shard_length - shard_start = token_start + start - cross_rank_tokens += shard_length - _attention_overlap_count( - attention_layout_index, - rank, - shard_start, - shard_start + shard_length, - ) - start = end - return shard_lengths, cross_rank_tokens - - -def _materialize_cp_segment_schedule( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - attention_layout_index: _AttentionLayoutIndex, - segment_attention_counts: dict[tuple[int, int, int], tuple[int, ...]], - chain_segment_keys: frozenset[GdnSegmentDecisionKey], - co_locate_local_families: bool, - planner_config: GdnPlannerConfig, -) -> GdnCpSegmentSchedule: - gdn_ranges_by_rank: list[list[tuple[int, int, int]]] = [[] for _ in range(cp_size)] - rank_loads = [0] * cp_size - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - local_completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - chain_prefix_segments: list[GdnSegmentSpec] = [] - chain_completion_segments: list[GdnSegmentSpec] = [] - parent_state_exchange_families: set[int] = set() - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - cross_rank_token_count = 0 - - for family in spec.families: - prefix_key = _segment_key(family.prefix) - chain_prefix = prefix_key in chain_segment_keys - local_completions = tuple( - completion - for completion in family.completions - if _segment_key(completion) not in chain_segment_keys - ) - prefix_owner: int | None = None - if chain_prefix: - chain_prefix_segments.append(family.prefix) - cross_rank_token_count += _append_chain_segment( - gdn_ranges_by_rank, - rank_loads, - family.prefix, - spec, - attention_layout_index=attention_layout_index, - ) - else: - owner_segments = ( - (family.prefix, *local_completions) - if co_locate_local_families - else (family.prefix,) - ) - prefix_owner = _best_segment_owner( - owner_segments, - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - local_prefix_segments_by_rank[prefix_owner].append(family.prefix) - cross_rank_token_count += _append_local_segment( - gdn_ranges_by_rank, - rank_loads, - prefix_owner, - family.prefix, - spec, - segment_attention_counts=segment_attention_counts, - ) - for completion in family.completions: - if _segment_key(completion) in chain_segment_keys: - chain_completion_segments.append(completion) - cross_rank_token_count += _append_chain_segment( - gdn_ranges_by_rank, - rank_loads, - completion, - spec, - attention_layout_index=attention_layout_index, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local-prefix/chained-completion planning lost the prefix owner" - ) - parent_state_exchange_families.add(family.family_index) - for dest_rank in range(cp_size): - if dest_rank == prefix_owner: - continue - parent_state_transfer_families.setdefault( - (prefix_owner, dest_rank), set() - ).add(family.family_index) - continue - if co_locate_local_families and not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "co-located local completion planning lost the prefix owner" - ) - owner = prefix_owner - else: - owner = _best_segment_owner( - (completion,), - rank_loads, - segment_attention_counts=segment_attention_counts, - planner_config=planner_config, - ) - if not chain_prefix: - if prefix_owner is None: - raise RuntimeError( - "local completion planning lost the prefix owner" - ) - if owner != prefix_owner: - parent_state_exchange_families.add(family.family_index) - parent_state_transfer_families.setdefault( - (prefix_owner, owner), set() - ).add(family.family_index) - local_completion_segments_by_rank[owner].append(completion) - cross_rank_token_count += _append_local_segment( - gdn_ranges_by_rank, + for segment in segments: + owner_by_node[segment.family_index] = owner + segments_by_rank_depth[owner][ + spec.tree_depths[segment.family_index] + ].append(segment) + cross_rank_token_count += _append_local_segment( + gdn_ranges_by_rank, rank_loads, owner, - completion, - spec, - segment_attention_counts=segment_attention_counts, - ) - - return GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=tuple(rank_loads), - gdn_token_ranges_by_rank=tuple(tuple(ranges) for ranges in gdn_ranges_by_rank), - cross_rank_token_count=cross_rank_token_count, - chain_prefix_buckets=_batch_segments_by_padded_work( - tuple(chain_prefix_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ), - chain_completion_buckets=_batch_segments_by_padded_work( - tuple(chain_completion_segments), - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ), - local_prefix_segments_by_rank=tuple( - tuple(segments) for segments in local_prefix_segments_by_rank - ), - local_completion_segments_by_rank=tuple( - tuple(segments) for segments in local_completion_segments_by_rank - ), - parent_state_exchange_family_indices=tuple( - sorted(parent_state_exchange_families) - ), - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), - ) - - -def _build_local_family_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, - device: torch.device | str, - cp_rank: int, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> GdnRankExecutionPlan | None: - if cp_size <= 1 or not spec.families: - return None - target_rank_load = spec.real_token_count / cp_size - loads = [0] * cp_size - prefix_owner_by_family: list[int] = [] - completion_owners_by_family: list[tuple[int, ...]] = [] - for family in spec.families: - if _has_chainable_segment( - family, cp_size=cp_size, planner_config=planner_config - ): - return None - prefix_locality_limit = max( - planner_config.max_zero_exchange_load_imbalance * target_rank_load, - min(64.0, float(spec.real_token_count)), - ) - if family.prefix.length > prefix_locality_limit: - return None - owner = _least_loaded_rank(loads) - prefix_owner_by_family.append(owner) - completion_owners_by_family.append(tuple(owner for _ in family.completions)) - loads[owner] += family.token_count - - if max(loads, default=0) > ( - planner_config.local_completion_rebalance_min_imbalance * target_rank_load - ): - completion_owners_by_family = list( - _rebalance_local_completion_segments( + segment, spec, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - initial_loads=tuple(loads), + segment_attention_counts=segment_attention_counts, + ) + + subtree_token_counts = [segment.length for segment in spec.tree_segments] + for node_index in reversed(range(len(spec.tree_segments))): + for child_index in children_by_node[node_index]: + subtree_token_counts[node_index] += subtree_token_counts[child_index] + target_rank_load = spec.real_token_count / max(1, cp_size) + max_local_group_tokens = max(1, int(target_rank_load)) + + def assign_tree(root_index: int) -> None: + nonlocal cross_rank_token_count + root = spec.tree_segments[root_index] + if ( + spec.tree_parent_indices[root_index] < 0 + and cp_size > 1 + and _can_chain_tree_segment( + root, + cp_size=cp_size, planner_config=planner_config, ) - ) - rank_assignments = _materialize_local_family_rank_assignments( - spec, - cp_size=cp_size, - prefix_owner_by_family=tuple(prefix_owner_by_family), - completion_owners_by_family=tuple(completion_owners_by_family), - ) - local_token_count, local_token_ranges, prefix_segments, completion_segments = ( - rank_assignments[cp_rank] - ) - parent_state_transfer_families: dict[tuple[int, int], set[int]] = {} - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - completion_owners = completion_owners_by_family[family.family_index] - for completion_owner in sorted(set(completion_owners)): - if completion_owner == prefix_owner: - continue - parent_state_transfer_families.setdefault( - (prefix_owner, completion_owner), set() - ).add(family.family_index) - - from art.megatron.gdn.layout import GdnCpExchangePlan, GdnCpPeerTransfer - - token_counts_by_rank = tuple(assignment[0] for assignment in rank_assignments) - identity_exchange = GdnCpExchangePlan.model_construct( - cp_size=cp_size, - source_token_counts_by_rank=token_counts_by_rank, - dest_token_counts_by_rank=token_counts_by_rank, - transfers=tuple( - GdnCpPeerTransfer.model_construct( - source_rank=rank, - dest_rank=rank, - token_count=token_count, - source_positions_tensor=None, - dest_positions_tensor=None, + ): + chained_nodes[root.family_index] = True + chain_segments_by_depth[spec.tree_depths[root.family_index]].append(root) + cross_rank_token_count += _append_chain_segment( + gdn_ranges_by_rank, + rank_loads, + root, + spec, + attention_layout_index=attention_layout_index, ) - for rank, token_count in enumerate(token_counts_by_rank) - if token_count - ), + for child_index in children_by_node[root_index]: + assign_tree(child_index) + return + + if subtree_token_counts[root_index] <= max_local_group_tokens: + assign_local_group(subtree_indices(root_index)) + return + + assign_local_group((root_index,)) + for child_index in children_by_node[root_index]: + assign_tree(child_index) + + for root_index in root_indices: + assign_tree(root_index) + + gdn_ranges_by_rank_by_position = tuple( + tuple(ranges) for ranges in gdn_ranges_by_rank ) - parent_state_exchange_family_indices = tuple( - sorted( - family_index - for family_indices in parent_state_transfer_families.values() - for family_index in family_indices - ) + gdn_ranges_by_rank_by_source = tuple( + tuple(sorted(ranges)) for ranges in gdn_ranges_by_rank ) - schedule = GdnCpSegmentSchedule.model_construct( - gdn_token_counts_by_rank=token_counts_by_rank, - gdn_token_ranges_by_rank=tuple( - assignment[1] for assignment in rank_assignments - ), - cross_rank_token_count=0, - chain_prefix_buckets=(), - chain_completion_buckets=(), - local_prefix_segments_by_rank=tuple( - assignment[2] for assignment in rank_assignments - ), - local_completion_segments_by_rank=tuple( - assignment[3] for assignment in rank_assignments - ), - parent_state_exchange_family_indices=parent_state_exchange_family_indices, - parent_state_transfers=_build_parent_state_transfer_plans( - parent_state_transfer_families - ), + + attention_to_gdn = build_local_rank_cp_exchange_plan_from_dest_ranges( + source_layout=source_layout, + device=device, + local_rank=cp_rank, + dest_ranges_by_rank=gdn_ranges_by_rank_by_position, + cross_rank_token_count=cross_rank_token_count, ) - if parent_state_exchange_family_indices: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _build_remote_prefix_tail_plans( - spec, - schedule, - cp_rank=cp_rank, + local_token_ranges = gdn_ranges_by_rank_by_source[cp_rank] + tree_segment_buckets_by_depth = tuple( + _build_tree_bucket_plans( + tuple(segments_by_rank_depth[cp_rank][depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges=None if cp_size == 1 else local_token_ranges, + sequence_length=spec.sequence_length, device=device, planner_config=planner_config, ) - else: - ( - remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers, - remote_prefix_tail_families, - ) = _empty_remote_prefix_tail_plans() - local_prefix_family_indices = {segment.family_index for segment in prefix_segments} - chunk_local_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index in local_prefix_family_indices - ) - suffix_only_completion_segments = tuple( - segment - for segment in completion_segments - if segment.family_index not in local_prefix_family_indices - and segment.family_index not in remote_prefix_tail_families + for depth in range(depth_count) ) - ready_completion_segments, remote_completion_segments = ( - _split_ready_and_remote_completion_segments( - suffix_only_completion_segments, - local_prefix_segments=(), - chain_prefix_buckets=(), + tree_chain_buckets_by_depth = ( + tuple( + _build_tree_bucket_plans( + tuple(chain_segments_by_depth[depth]), + spec.tree_parent_indices, + tuple(tree_has_children), + local_token_ranges=local_token_ranges, + sequence_length=spec.sequence_length, + device=device, + planner_config=planner_config, + token_ranges_by_rank=tuple( + tuple(ranges) for ranges in gdn_ranges_by_rank_by_source + ), + split_by_final_state=False, + ) + for depth in range(depth_count) ) + if cp_size > 1 + else tuple(() for _ in range(depth_count)) ) - ready_completion_buckets = _batch_segments_by_padded_work( - ready_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - remote_completion_buckets = _batch_segments_by_padded_work( - remote_completion_segments, - max_padding_ratio=planner_config.max_padding_ratio, - max_segments_per_batch=planner_config.max_segments_per_batch, - ) - ready_completion_bucket_plans = _build_position_bucket_plans( - ready_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - remote_completion_bucket_plans = _build_position_bucket_plans( - remote_completion_buckets, - local_token_ranges, - sequence_length=spec.sequence_length, - device=device, - ) - local_completion_bucket_plans = ( - ready_completion_bucket_plans + remote_completion_bucket_plans - ) - ( - prefix_boundary_buckets, - prefix_tail_buckets, - completion_with_prefix_tail_buckets, - ) = _build_chunk_aligned_position_bucket_plans( - prefix_segments, - chunk_local_completion_segments, - local_token_ranges, - sequence_length=spec.sequence_length, + tree_state_exchanges_by_depth = _build_tree_state_exchanges_by_depth( + spec, + owner_by_node=tuple(owner_by_node), + chained_nodes=tuple(chained_nodes), + cp_rank=cp_rank, + cp_size=cp_size, + depth_count=depth_count, device=device, - planner_config=planner_config, ) - return GdnRankExecutionPlan.model_construct( + if cp_size == 1: + valid_lengths = torch.tensor( + spec.valid_lengths, device=device, dtype=torch.long + ) + positions = torch.arange(spec.sequence_length, device=device, dtype=torch.long) + real_token_mask = positions.unsqueeze(0) < valid_lengths.unsqueeze(1) + else: + real_token_mask = torch.ones( + 1, + rank_loads[cp_rank], + device=device, + dtype=torch.bool, + ) + + return GdnRankExecutionPlan( cp_rank=cp_rank, cp_size=cp_size, - batch_size=1, - sequence_length=local_token_count, + batch_size=1 if cp_size > 1 else spec.batch_size, + sequence_length=rank_loads[cp_rank] if cp_size > 1 else spec.sequence_length, packed_batch_size=spec.batch_size, packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones( - 1, local_token_count, device=device, dtype=torch.bool - ), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=local_completion_bucket_plans, - ready_local_completion_buckets=ready_completion_bucket_plans, - remote_local_completion_buckets=remote_completion_bucket_plans, - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=( - tuple(segment.family_index for segment in prefix_segments) - == tuple(range(spec.family_count)) - ), - attention_to_gdn=identity_exchange, - gdn_to_attention=identity_exchange, - attention_token_ranges=local_token_ranges, - gdn_token_ranges=local_token_ranges, - attention_token_count=local_token_count, - gdn_token_count=local_token_count, - parent_state_exchange_family_indices=tuple( - family_index - for family_index in parent_state_exchange_family_indices - if family_index not in remote_prefix_tail_families - ), - parent_state_transfers=_filter_parent_state_transfers( - _build_parent_state_transfer_plans(parent_state_transfer_families), - excluded_families=remote_prefix_tail_families, - device=device, - ), - prefix_boundary_buckets=prefix_boundary_buckets, - prefix_tail_buckets=prefix_tail_buckets, - completion_with_prefix_tail_buckets=completion_with_prefix_tail_buckets, - remote_prefix_tail_buckets=remote_prefix_tail_buckets, - remote_completion_with_prefix_tail_buckets=remote_completion_with_prefix_tail_buckets, - remote_prefix_tail_exchange=remote_prefix_tail_exchange, - remote_prefix_tail_backward_exchange=remote_prefix_tail_backward_exchange, - remote_prefix_tail_state_transfers=remote_prefix_tail_state_transfers, + real_token_mask=real_token_mask, + attention_to_gdn=attention_to_gdn, + gdn_to_attention=_reverse_exchange_plan(attention_to_gdn), + attention_token_ranges=source_layout.ownership_ranges_by_rank[cp_rank], + gdn_token_ranges=gdn_ranges_by_rank_by_position[cp_rank], + attention_token_count=source_layout.token_counts_by_rank[cp_rank], + gdn_token_count=rank_loads[cp_rank], + tree_segment_buckets_by_depth=tree_segment_buckets_by_depth, + tree_chain_buckets_by_depth=tree_chain_buckets_by_depth, + tree_state_exchanges_by_depth=tree_state_exchanges_by_depth, ) -def _rebalance_local_completion_segments( - spec: GdnPackedExecutionSpec, - *, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], - initial_loads: tuple[int, ...], - planner_config: GdnPlannerConfig, -) -> tuple[tuple[int, ...], ...]: - owners = [list(family_owners) for family_owners in completion_owners_by_family] - loads = list(initial_loads) - remote_owners_by_family = [ - { - owner - for owner in family_owners - if owner != prefix_owner_by_family[family_index] - } - for family_index, family_owners in enumerate(owners) - ] - transfer_count = sum( - len(remote_owners) for remote_owners in remote_owners_by_family - ) +def move_gdn_rank_execution_plan_to_device( + plan: GdnRankExecutionPlan, + device: torch.device | str, +) -> GdnRankExecutionPlan: + """Move planner tensors to the execution device after CPU planning.""" - def score(candidate_loads: list[int], candidate_transfer_count: int) -> float: - max_load = max(candidate_loads, default=0) - idle_tokens = sum(max_load - load for load in candidate_loads) - return ( - max_load - + planner_config.rank_idle_token_cost * idle_tokens - + planner_config.parent_state_exchange_penalty_tokens - * candidate_transfer_count - ) + from art.megatron.gdn.layout import move_cp_exchange_plan_to_device - best_score = score(loads, transfer_count) - while True: - best_move: ( - tuple[int, int, int, tuple[int, ...], list[int], int, float] | None - ) = None - for family in spec.families: - family_owners = owners[family.family_index] - prefix_owner = prefix_owner_by_family[family.family_index] - original_remote_owners = remote_owners_by_family[family.family_index] - for source in sorted(set(family_owners)): - source_children = [ - child_index - for child_index, owner in enumerate(family_owners) - if owner == source - ] - ordered_children = sorted( - source_children, - key=lambda child_index: family.completions[child_index].length, - reverse=True, - ) - for dest in range(len(loads)): - if dest == source: - continue - moved_tokens = 0 - moved_children = [] - for child_index in ordered_children: - moved_tokens += family.completions[child_index].length - moved_children.append(child_index) - candidate_loads = list(loads) - candidate_loads[source] -= moved_tokens - candidate_loads[dest] += moved_tokens - candidate_remote_owners = set(original_remote_owners) - if source != prefix_owner and len(moved_children) == len( - source_children - ): - candidate_remote_owners.discard(source) - if dest != prefix_owner: - candidate_remote_owners.add(dest) - candidate_transfer_count = ( - transfer_count - - len(original_remote_owners) - + len(candidate_remote_owners) - ) - candidate_score = score( - candidate_loads, candidate_transfer_count - ) - if candidate_score >= best_score: - continue - if best_move is None or candidate_score < best_move[-1]: - best_move = ( - family.family_index, - source, - dest, - tuple(moved_children), - candidate_loads, - candidate_transfer_count, - candidate_score, - ) - if best_move is None: - return tuple(tuple(item) for item in owners) - ( - family_index, - _source, - dest, - moved_children, - loads, - transfer_count, - best_score, - ) = best_move - for child_index in moved_children: - owners[family_index][child_index] = dest - prefix_owner = prefix_owner_by_family[family_index] - remote_owners_by_family[family_index] = { - owner for owner in set(owners[family_index]) if owner != prefix_owner - } - - -def _materialize_local_family_rank_assignments( - spec: GdnPackedExecutionSpec, - *, - cp_size: int, - prefix_owner_by_family: tuple[int, ...], - completion_owners_by_family: tuple[tuple[int, ...], ...], -) -> tuple[ - tuple[ - int, - tuple[tuple[int, int, int], ...], - tuple[GdnSegmentSpec, ...], - tuple[GdnSegmentSpec, ...], - ], - ..., -]: - token_ranges_by_rank: list[list[tuple[int, int, int]]] = [ - [] for _ in range(cp_size) - ] - token_counts_by_rank = [0] * cp_size - prefix_segments_by_rank: list[list[GdnSegmentSpec]] = [[] for _ in range(cp_size)] - completion_segments_by_rank: list[list[GdnSegmentSpec]] = [ - [] for _ in range(cp_size) - ] - sequence_length = spec.sequence_length - for family in spec.families: - prefix_owner = prefix_owner_by_family[family.family_index] - prefix_segments_by_rank[prefix_owner].append(family.prefix) - prefix_token_start = ( - family.prefix.row_index * sequence_length + family.prefix.start - ) - prefix_position_start = token_counts_by_rank[prefix_owner] - token_ranges_by_rank[prefix_owner].append( - ( - prefix_token_start, - prefix_token_start + family.prefix.length, - prefix_position_start, - ) - ) - token_counts_by_rank[prefix_owner] = ( - prefix_position_start + family.prefix.length - ) - for completion, completion_owner in zip( - family.completions, - completion_owners_by_family[family.family_index], - strict=True, - ): - completion_segments_by_rank[completion_owner].append(completion) - completion_token_start = ( - completion.row_index * sequence_length + completion.start - ) - completion_position_start = token_counts_by_rank[completion_owner] - token_ranges_by_rank[completion_owner].append( - ( - completion_token_start, - completion_token_start + completion.length, - completion_position_start, - ) - ) - token_counts_by_rank[completion_owner] = ( - completion_position_start + completion.length - ) - return tuple( - ( - token_counts_by_rank[rank], - tuple(token_ranges_by_rank[rank]), - tuple(prefix_segments_by_rank[rank]), - tuple(completion_segments_by_rank[rank]), - ) - for rank in range(cp_size) + return replace( + plan, + real_token_mask=_move_planner_tensor(plan.real_token_mask, device), + attention_to_gdn=move_cp_exchange_plan_to_device(plan.attention_to_gdn, device), + gdn_to_attention=move_cp_exchange_plan_to_device(plan.gdn_to_attention, device), + tree_segment_buckets_by_depth=tuple( + _move_bucket_plans(buckets, device) + for buckets in plan.tree_segment_buckets_by_depth + ), + tree_chain_buckets_by_depth=tuple( + _move_bucket_plans(buckets, device) + for buckets in plan.tree_chain_buckets_by_depth + ), + tree_state_exchanges_by_depth=tuple( + _move_state_exchange_plan(exchange, device) + for exchange in plan.tree_state_exchanges_by_depth + ), ) -def _empty_local_family_rank_execution_plan( - spec: GdnPackedExecutionSpec, - *, +def _move_state_exchange_plan( + exchange: GdnStateExchangePlan | None, device: torch.device | str, - cp_rank: int, - cp_size: int, -) -> GdnRankExecutionPlan: - from art.megatron.gdn.layout import GdnCpExchangePlan +) -> GdnStateExchangePlan | None: + if exchange is None: + return None + from art.megatron.gdn.layout import move_cp_exchange_plan_to_device - identity_exchange = GdnCpExchangePlan.model_construct( - cp_size=cp_size, - source_token_counts_by_rank=tuple(0 for _ in range(cp_size)), - dest_token_counts_by_rank=tuple(0 for _ in range(cp_size)), - transfers=(), - ) - return GdnRankExecutionPlan.model_construct( - cp_rank=cp_rank, - cp_size=cp_size, - batch_size=1, - sequence_length=0, - packed_batch_size=spec.batch_size, - packed_sequence_length=spec.sequence_length, - real_token_mask=torch.ones(1, 0, device=device, dtype=torch.bool), - family_count=spec.family_count, - completion_count=spec.completion_count, - local_prefix_buckets=(), - local_completion_buckets=(), - ready_local_completion_buckets=(), - remote_local_completion_buckets=(), - chain_prefix_buckets=(), - chain_completion_buckets=(), - prefix_table_is_dense_ordered=False, - attention_to_gdn=identity_exchange, - gdn_to_attention=identity_exchange, - attention_token_ranges=(), - gdn_token_ranges=(), - attention_token_count=0, - gdn_token_count=0, - parent_state_exchange_family_indices=(), - parent_state_transfers=(), + return replace( + exchange, + exchange=move_cp_exchange_plan_to_device(exchange.exchange, device), + reverse_exchange=move_cp_exchange_plan_to_device( + exchange.reverse_exchange, device + ), ) -def _can_chain_segment( - segment: GdnSegmentSpec, - *, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> bool: - min_tokens = ( - planner_config.cp_chain_min_prefix_only_tokens - if segment.kind == "prefix" - else planner_config.cp_chain_min_total_tokens - ) - if segment.length < min_tokens: - return False - if segment.length < cp_size: - return False - if segment.length // FLA_CHUNK_SIZE < cp_size: - return False - per_rank = segment.length / cp_size - if per_rank < planner_config.cp_chain_min_tokens_per_rank: - return False - return True - - -def _build_parent_state_transfer_plans( - families_by_peer: dict[tuple[int, int], set[int]], -) -> tuple[GdnParentStateTransferPlan, ...]: +def _move_bucket_plans( + buckets: tuple[GdnSegmentBucketPlan, ...], + device: torch.device | str, +) -> tuple[GdnSegmentBucketPlan, ...]: return tuple( - GdnParentStateTransferPlan( - source_rank=source_rank, - dest_rank=dest_rank, - family_indices=tuple(sorted(family_indices)), + replace( + bucket, + lengths=_move_planner_tensor(bucket.lengths, device), + real_mask=_move_planner_tensor(bucket.real_mask, device), + cu_seqlens=_move_planner_tensor(bucket.cu_seqlens, device), + row_indices=_move_planner_tensor(bucket.row_indices, device), + position_indices=_move_planner_tensor(bucket.position_indices, device), + family_indices=_move_planner_tensor(bucket.family_indices, device), + parent_indices=( + _move_planner_tensor(bucket.parent_indices, device) + if bucket.parent_indices is not None + else None + ), + output_mask=( + _move_planner_tensor(bucket.output_mask, device) + if bucket.output_mask is not None + else None + ), ) - for (source_rank, dest_rank), family_indices in sorted(families_by_peer.items()) - if source_rank != dest_rank and family_indices + for bucket in buckets ) -def _split_ready_and_remote_completion_segments( - completion_segments: tuple[GdnSegmentSpec, ...], - *, - local_prefix_segments: tuple[GdnSegmentSpec, ...], - chain_prefix_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], -) -> tuple[tuple[GdnSegmentSpec, ...], tuple[GdnSegmentSpec, ...]]: - ready_family_indices = { - segment.family_index for segment in local_prefix_segments - } | {segment.family_index for bucket in chain_prefix_buckets for segment in bucket} - ready = [] - remote = [] - for segment in completion_segments: - if segment.family_index in ready_family_indices: - ready.append(segment) - else: - remote.append(segment) - return tuple(ready), tuple(remote) +def parse_gdn_shared_prefix_segments( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, +) -> GdnPackedExecutionSpec: + """Parse ART packed shared-prefix metadata into generic GDN tree nodes.""" + groups = _rank2_long_cpu("group_ids", group_ids) + parents = _rank2_long_cpu("parent_ids", parent_ids) + if tuple(groups.shape) != tuple(parents.shape): + raise ValueError( + "group_ids and parent_ids must have the same shape, got " + f"{tuple(groups.shape)} and {tuple(parents.shape)}" + ) -def _transfer_plans_to_device( - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - device: torch.device | str, -) -> tuple[GdnParentStateTransferPlan, ...]: - return tuple( - transfer.model_copy( - update={ - "family_indices_tensor": _move_planner_tensor( - torch.tensor(transfer.family_indices, dtype=torch.long), - device, + batch_size, sequence_length = (int(groups.shape[0]), int(groups.shape[1])) + rows = parse_shared_prefix_tree(group_ids=groups, parent_ids=parents) + tree_segments: list[GdnSegmentSpec] = [] + tree_parent_indices: list[int] = [] + tree_depths: list[int] = [] + valid_lengths: list[int] = [] + node_by_row_group: dict[tuple[int, int], int] = {} + child_counts_by_parent: dict[int, int] = {} + + for row in rows: + valid_lengths.append(row.valid_tokens) + for segment in row.segments: + node_index = len(tree_segments) + is_root = segment.depth == 0 + parent_node_index = ( + -1 if is_root else node_by_row_group[(row.row_index, segment.parent_id)] + ) + child_index = None + if not is_root: + child_index = child_counts_by_parent.get(parent_node_index, 0) + child_counts_by_parent[parent_node_index] = child_index + 1 + tree_segments.append( + GdnSegmentSpec( + row_index=row.row_index, + family_index=node_index, + group_id=segment.group_id, + parent_id=segment.parent_id, + start=segment.start, + end=segment.end, + kind="prefix" if is_root else "completion", + child_index=child_index, ) - } - ) - for transfer in transfers + ) + tree_parent_indices.append(parent_node_index) + tree_depths.append(segment.depth) + node_by_row_group[(row.row_index, segment.group_id)] = node_index + + return GdnPackedExecutionSpec( + batch_size=batch_size, + sequence_length=sequence_length, + valid_lengths=tuple(valid_lengths), + tree_segments=tuple(tree_segments), + tree_parent_indices=tuple(tree_parent_indices), + tree_depths=tuple(tree_depths), ) -def _has_chainable_segment( - family: GdnPackedFamilySpec, +def _attention_source_layout( + spec: GdnPackedExecutionSpec, *, cp_size: int, + attention_token_layout_index: TokenLayoutIndex | None, planner_config: GdnPlannerConfig, -) -> bool: - return _can_chain_prefix_segment( - family.prefix, cp_size=cp_size, planner_config=planner_config - ) or any( - _can_chain_segment(completion, cp_size=cp_size, planner_config=planner_config) - for completion in family.completions +) -> TokenLayoutIndex: + if attention_token_layout_index is not None: + layout_cp_size = len(attention_token_layout_index.token_counts_by_rank) + layout_token_count = sum( + int(count) for count in attention_token_layout_index.token_counts_by_rank + ) + if layout_cp_size != cp_size: + raise ValueError( + "attention token layout index cp_size must match GDN cp_size, got " + f"{layout_cp_size} and {cp_size}" + ) + if layout_token_count != spec.real_token_count: + raise ValueError( + "attention token layout index token count must match GDN real token " + f"count, got {layout_token_count} and {spec.real_token_count}" + ) + return attention_token_layout_index + ranges_by_rank = _default_attention_layout_ranges( + spec, + cp_size=cp_size, + planner_config=planner_config, + ) + return TokenLayoutIndex( + ownership_ranges_by_rank=ranges_by_rank, + token_counts_by_rank=tuple( + sum(int(end) - int(start) for start, end, _ in ranges) + for ranges in ranges_by_rank + ), ) -def _can_chain_prefix_segment( +def _can_chain_tree_segment( segment: GdnSegmentSpec, *, cp_size: int, planner_config: GdnPlannerConfig, ) -> bool: - return _can_chain_segment(segment, cp_size=cp_size, planner_config=planner_config) - - -def _score_cp_segment_stats( - *, - rank_local_work: tuple[int, ...], - rank_chain_work: tuple[int, ...], - rank_real_tokens: tuple[int, ...], - cross_rank_token_count: int, - parent_state_exchange_family_count: int, - local_bucket_count: int, - local_segment_count: int, - chain_bucket_count: int, - planner_config: GdnPlannerConfig, -) -> float: - empty_rank_count = sum(1 for token_count in rank_real_tokens if token_count == 0) - return ( - _rank_kernel_ms( - rank_local_work, - rank_chain_work, - local_token_ms=planner_config.planner_local_token_ms, - chain_token_ms=planner_config.planner_chain_token_ms, + min_total_tokens = ( + min( + planner_config.cp_tree_chain_min_prefix_only_tokens, + planner_config.cp_chain_min_prefix_only_tokens, ) - + planner_config.planner_local_bucket_ms * local_bucket_count - + planner_config.planner_chain_bucket_ms * chain_bucket_count - + planner_config.planner_local_segment_ms * local_segment_count - + planner_config.planner_layout_cross_rank_token_ms * cross_rank_token_count - + ( - planner_config.planner_parent_state_exchange_base_ms - + planner_config.planner_parent_state_exchange_ms - * parent_state_exchange_family_count - if parent_state_exchange_family_count - else 0.0 + if segment.kind == "prefix" + else min( + planner_config.cp_tree_chain_min_total_tokens, + planner_config.cp_chain_min_total_tokens, ) - + planner_config.planner_empty_rank_ms * empty_rank_count ) - - -def _rank_kernel_ms( - rank_local_work: tuple[int, ...], - rank_chain_work: tuple[int, ...], - *, - local_token_ms: float, - chain_token_ms: float, -) -> float: - return max( - ( - local_work * local_token_ms + chain_work * chain_token_ms - for local_work, chain_work in zip( - rank_local_work, rank_chain_work, strict=True - ) - ), - default=0.0, + return ( + segment.length >= min_total_tokens + and segment.length >= cp_size + and segment.length // FLA_CHUNK_SIZE >= cp_size + and segment.length / cp_size >= planner_config.cp_chain_min_tokens_per_rank ) @@ -3336,11 +636,16 @@ def _best_segment_owner( for rank in range(rank_count): counts_by_rank[rank] += segment_counts[rank] on_rank_tokens = tuple(counts_by_rank) - best: tuple[float, int, int, int, int] | None = None + best: tuple[float, float, int, int, int, int] | None = None for rank, tokens in enumerate(on_rank_tokens): projected_loads = list(rank_loads) projected_loads[rank] += segment_length max_load = max(projected_loads, default=0) + target_load = sum(projected_loads) / max(1, len(projected_loads)) + overload = max( + 0.0, + max_load - planner_config.max_zero_exchange_load_imbalance * target_load, + ) idle_tokens = sum(max_load - load for load in projected_loads) cross_rank_tokens = segment_length - int(tokens) empty_rank_count = sum(1 for load in projected_loads if load == 0) @@ -3353,6 +658,7 @@ def _best_segment_owner( + empty_rank_count * planner_config.planner_empty_rank_ms ) candidate = ( + overload, score, max_load, cross_rank_tokens, @@ -3366,23 +672,130 @@ def _best_segment_owner( return best[-1] +def _build_tree_state_exchanges_by_depth( + spec: GdnPackedExecutionSpec, + *, + owner_by_node: tuple[int, ...], + chained_nodes: tuple[bool, ...], + cp_rank: int, + cp_size: int, + depth_count: int, + device: torch.device | str, +) -> tuple[GdnStateExchangePlan | None, ...]: + if cp_size <= 1: + return tuple(None for _ in range(depth_count)) + + from art.megatron.gdn.layout import ( + GdnCpExchangePlan, + _make_peer_transfer, + _reverse_exchange_plan, + ) + + families_by_depth_pair: list[dict[tuple[int, int], set[int]]] = [ + {} for _ in range(depth_count) + ] + for child_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0 or chained_nodes[parent_index]: + continue + source_rank = owner_by_node[parent_index] + dest_rank = owner_by_node[child_index] + if source_rank < 0 or dest_rank < 0: + raise ValueError("tree state exchange requires every node to have an owner") + if source_rank == dest_rank: + continue + depth = spec.tree_depths[child_index] + families_by_depth_pair[depth].setdefault((source_rank, dest_rank), set()).add( + parent_index + ) + + state_exchanges: list[GdnStateExchangePlan | None] = [] + for pair_families in families_by_depth_pair: + if not pair_families: + state_exchanges.append(None) + continue + source_families_by_rank = [set[int]() for _ in range(cp_size)] + dest_families_by_rank = [set[int]() for _ in range(cp_size)] + for (source_rank, dest_rank), parent_indices in pair_families.items(): + source_families_by_rank[source_rank].update(parent_indices) + dest_families_by_rank[dest_rank].update(parent_indices) + source_families = tuple( + tuple(sorted(families)) for families in source_families_by_rank + ) + dest_families = tuple( + tuple(sorted(families)) for families in dest_families_by_rank + ) + source_positions = ( + {family: index for index, family in enumerate(families)} + for families in source_families + ) + dest_positions = ( + {family: index for index, family in enumerate(families)} + for families in dest_families + ) + source_position_by_rank = tuple(source_positions) + dest_position_by_rank = tuple(dest_positions) + transfers = [] + transfer_count = 0 + for (source_rank, dest_rank), parent_indices in sorted(pair_families.items()): + ordered = tuple(sorted(parent_indices)) + transfer_count += len(ordered) + transfers.append( + _make_peer_transfer( + source_rank=source_rank, + dest_rank=dest_rank, + source_positions=torch.tensor( + [ + source_position_by_rank[source_rank][family] + for family in ordered + ], + dtype=torch.long, + ), + dest_positions=torch.tensor( + [ + dest_position_by_rank[dest_rank][family] + for family in ordered + ], + dtype=torch.long, + ), + source_count=len(source_families[source_rank]), + dest_count=len(dest_families[dest_rank]), + device=device, + ) + ) + exchange = GdnCpExchangePlan( + cp_size=cp_size, + source_token_counts_by_rank=tuple( + len(families) for families in source_families + ), + dest_token_counts_by_rank=tuple( + len(families) for families in dest_families + ), + transfers=tuple(transfers), + cross_rank_token_count_override=transfer_count, + ) + state_exchanges.append( + GdnStateExchangePlan( + source_family_indices=source_families[cp_rank], + dest_family_indices=dest_families[cp_rank], + exchange=exchange, + reverse_exchange=_reverse_exchange_plan(exchange), + ) + ) + return tuple(state_exchanges) + + def _build_attention_layout_index_from_token_layout( layout: TokenLayoutIndex, - *, - max_ranges: int, ) -> _AttentionLayoutIndex: - del max_ranges ranges_by_rank = tuple( tuple(sorted((int(start), int(end)) for start, end, _ in rank_ranges)) for rank_ranges in layout.ownership_ranges_by_rank ) - range_count = sum(len(ranges) for ranges in ranges_by_rank) - return _AttentionLayoutIndex.model_construct( + return _AttentionLayoutIndex( token_ranges_by_rank=ranges_by_rank, token_range_ends_by_rank=tuple( tuple(end for _, end in ranges) for ranges in ranges_by_rank ), - range_count=range_count, ) @@ -3393,7 +806,7 @@ def _segment_attention_rank_counts( attention_layout_index: _AttentionLayoutIndex, ) -> dict[tuple[int, int, int], tuple[int, ...]]: del cp_size - segments = tuple(spec.segments()) + segments = spec.tree_segments if not segments: return {} starts = torch.tensor( @@ -3471,74 +884,25 @@ def should_split_segment(segment: GdnSegmentSpec) -> bool: if segment.length <= planner_config.max_zero_exchange_load_imbalance * ( target_rank_load ): - return False - if segment.kind == "prefix": - return _can_chain_prefix_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - return _can_chain_segment( - segment, cp_size=cp_size, planner_config=planner_config - ) - - for family in spec.families: - has_split_segment = any( - should_split_segment(segment) - for segment in (family.prefix, *family.completions) - ) - if not has_split_segment: - if _should_co_locate_non_chain_family( - family, - total_real_tokens=spec.real_token_count, - cp_size=cp_size, - planner_config=planner_config, - ): - owner = _least_loaded_rank(loads) - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - append_segment(owner, token_start, segment.length) - continue - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - owner = _least_loaded_rank(loads) - append_segment(owner, token_start, segment.length) - continue - for segment in (family.prefix, *family.completions): - token_start = _segment_token_start(segment, spec.sequence_length) - if should_split_segment(segment): - _append_split_default_attention_segment( - ranks, loads, token_start, segment.length - ) - continue - owner = _least_loaded_rank(loads) - append_segment(owner, token_start, segment.length) - return tuple(tuple(ranges) for ranges in ranks) - - -def _should_co_locate_non_chain_family( - family: GdnPackedFamilySpec, - *, - total_real_tokens: int, - cp_size: int, - planner_config: GdnPlannerConfig, -) -> bool: - target_rank_load = total_real_tokens / cp_size - return family.token_count <= ( - planner_config.max_zero_exchange_load_imbalance * target_rank_load - ) - + return False + return _can_chain_tree_segment( + segment, cp_size=cp_size, planner_config=planner_config + ) -def _append_split_default_attention_segment( - ranks: list[list[tuple[int, int, int]]], - loads: list[int], - token_start: int, - token_count: int, -) -> None: - cp_size = len(ranks) - for rank in range(cp_size): - start = (token_count * rank) // cp_size - end = (token_count * (rank + 1)) // cp_size - ranks[rank].append((token_start + start, token_start + end, loads[rank])) - loads[rank] += end - start + for segment in spec.tree_segments: + token_start = _segment_token_start(segment, spec.sequence_length) + if should_split_segment(segment): + for rank in range(cp_size): + start = (segment.length * rank) // cp_size + end = (segment.length * (rank + 1)) // cp_size + ranks[rank].append( + (token_start + start, token_start + end, loads[rank]) + ) + loads[rank] += end - start + continue + owner = _least_loaded_rank(loads) + append_segment(owner, token_start, segment.length) + return tuple(tuple(ranges) for ranges in ranks) def _append_chain_segment( @@ -3581,36 +945,16 @@ def _append_chain_segment( ) rank_loads[rank] += shard_length if attention_layout_index is not None: - cross_rank_tokens += shard_length - _attention_overlap_count( - attention_layout_index, - rank, + cross_rank_tokens += shard_length - _range_overlap_count( shard_start, shard_start + shard_length, + attention_layout_index.token_ranges_by_rank[rank], + attention_layout_index.token_range_ends_by_rank[rank], ) start = end return cross_rank_tokens -def _chain_rank_token_indices( - segment: GdnSegmentSpec, - spec: GdnPackedExecutionSpec, - *, - cp_rank: int, - cp_size: int, -) -> range: - token_start = _segment_token_start(segment, spec.sequence_length) - lengths = _fla_aligned_chain_shard_lengths(segment.length, cp_size=cp_size) - start = sum(lengths[:cp_rank]) - end = start + lengths[cp_rank] - if start >= end: - raise ValueError( - "CP chain planning requires non-empty shards; " - f"segment={segment.kind}:{segment.family_index} " - f"length={segment.length} cp_size={cp_size}" - ) - return range(token_start + start, token_start + end) - - def _fla_aligned_chain_shard_lengths(length: int, *, cp_size: int) -> tuple[int, ...]: full_chunks = int(length) // FLA_CHUNK_SIZE if full_chunks < int(cp_size): @@ -3641,15 +985,14 @@ def _attention_contiguous_chain_shards( shards: list[range] = [] cursor = token_start for rank in range(cp_size): - overlap = _attention_single_contiguous_overlap( - attention_layout_index, - rank, + overlaps = _range_overlaps( token_start, segment_end, + attention_layout_index.token_ranges_by_rank[rank], ) - if overlap is None: + if len(overlaps) != 1: return None - start, end = overlap + start, end = overlaps[0] if start != cursor or end <= start: return None shards.append(range(start, end)) @@ -3661,18 +1004,6 @@ def _attention_contiguous_chain_shards( return tuple(shards) -def _attention_single_contiguous_overlap( - index: _AttentionLayoutIndex, - rank: int, - start: int, - end: int, -) -> tuple[int, int] | None: - overlaps = _range_overlaps(start, end, index.token_ranges_by_rank[rank]) - if len(overlaps) != 1: - return None - return overlaps[0] - - def _append_local_segment( gdn_ranges_by_rank: list[list[tuple[int, int, int]]], rank_loads: list[int], @@ -3695,162 +1026,154 @@ def _least_loaded_rank(rank_loads: list[int]) -> int: return min(range(len(rank_loads)), key=lambda rank: (rank_loads[rank], rank)) -def _owner_rank( - local_prefix_segments_by_rank: list[list[GdnSegmentSpec]], - prefix: GdnSegmentSpec, -) -> int: - for rank, segments in enumerate(local_prefix_segments_by_rank): - if prefix in segments: - return rank - raise RuntimeError("local prefix owner was not recorded") - - -def _build_position_bucket_plans( - segment_buckets: tuple[tuple[GdnSegmentSpec, ...], ...], - local_token_ranges: tuple[tuple[int, int, int], ...], +def _build_tree_bucket_plans( + segments: tuple[GdnSegmentSpec, ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], *, + local_token_ranges: tuple[tuple[int, int, int], ...] | None, sequence_length: int, device: torch.device | str, + planner_config: GdnPlannerConfig, token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, + split_by_final_state: bool = True, ) -> tuple[GdnSegmentBucketPlan, ...]: + segment_buckets = ( + _batch_tree_segments_by_padded_work( + segments, + tree_has_children, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + if split_by_final_state + else _batch_segments_by_padded_work( + segments, + max_padding_ratio=planner_config.max_padding_ratio, + max_segments_per_batch=planner_config.max_segments_per_batch, + ) + ) return tuple( - _build_position_bucket_plan( + _bucket_with_tree_parent_indices( + ( + _build_segment_bucket_plan(bucket, device=device) + if local_token_ranges is None + else _build_position_bucket_plan( + bucket, + local_token_ranges, + sequence_length=sequence_length, + device=device, + token_ranges_by_rank=token_ranges_by_rank, + ) + ), bucket, - local_token_ranges, - sequence_length=sequence_length, + tree_parent_indices, + tree_has_children, device=device, - token_ranges_by_rank=token_ranges_by_rank, ) for bucket in segment_buckets ) -def _build_position_bucket_plan( +def _bucket_with_tree_parent_indices( + plan: GdnSegmentBucketPlan, segments: tuple[GdnSegmentSpec, ...], - local_token_ranges: tuple[tuple[int, int, int], ...], + tree_parent_indices: tuple[int, ...], + tree_has_children: tuple[bool, ...], *, - sequence_length: int, device: torch.device | str, - token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, ) -> GdnSegmentBucketPlan: - exact_plan = _build_exact_range_position_bucket_plan( - segments, - local_token_ranges, - sequence_length=sequence_length, - device=device, - token_ranges_by_rank=token_ranges_by_rank, - ) - if exact_plan is not None: - return exact_plan - local_positions_by_segment = [] - lengths = [] - local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) - for segment in segments: - positions = _local_positions_for_segment( - segment, - sequence_length=sequence_length, - local_token_ranges=local_token_ranges, - local_range_ends=local_range_ends, - ) - length = int(positions.numel()) - if not length: - raise ValueError( - "planned GDN bucket contains a segment with no local tokens; " - f"family={segment.family_index} kind={segment.kind}" - ) - local_positions_by_segment.append(positions) - lengths.append(length) - max_length = max(lengths) - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - position_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) - for column, positions in enumerate(local_positions_by_segment): - position_indices_cpu[: int(positions.numel()), column] = positions - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - lengths_by_rank_cpu = _bucket_lengths_by_rank_cpu( - segments, - token_ranges_by_rank, - sequence_length=sequence_length, - ) - row_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) - family_indices_cpu = torch.tensor( - [segment.family_index for segment in segments], + parent_indices = torch.tensor( + [tree_parent_indices[segment.family_index] for segment in segments], dtype=torch.long, ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), - lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=lengths_by_rank_cpu, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=sum(lengths), + return replace( + plan, + parent_indices=_move_planner_tensor(parent_indices, device), + parent_indices_cpu=parent_indices, + needs_final_state=any( + tree_has_children[segment.family_index] for segment in segments + ), ) -def _build_exact_range_position_bucket_plan( +def _build_position_bucket_plan( segments: tuple[GdnSegmentSpec, ...], local_token_ranges: tuple[tuple[int, int, int], ...], *, sequence_length: int, device: torch.device | str, token_ranges_by_rank: tuple[tuple[tuple[int, int, int], ...], ...] | None = None, -) -> GdnSegmentBucketPlan | None: +) -> GdnSegmentBucketPlan: range_positions = { (start, end): position for start, end, position in local_token_ranges } - starts = [] - lengths = [] + starts: list[int] = [] + lengths: list[int] = [] for segment in segments: token_start = _segment_token_start(segment, sequence_length) token_end = token_start + segment.length position_start = range_positions.get((token_start, token_end)) if position_start is None: - return None + break starts.append(position_start) lengths.append(segment.length) - max_length = max(lengths) - starts_cpu = torch.tensor(starts, dtype=torch.long) - lengths_cpu = torch.tensor(lengths, dtype=torch.long) - offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) - real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - position_indices_cpu = torch.where( - real_mask_cpu, - starts_cpu.unsqueeze(0) + offsets_cpu, - torch.zeros_like(offsets_cpu), - ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - lengths_by_rank_cpu = _bucket_lengths_by_rank_cpu( - segments, - token_ranges_by_rank, - sequence_length=sequence_length, - ) - row_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) - family_indices_cpu = torch.tensor( - [segment.family_index for segment in segments], + else: + starts_cpu = torch.tensor(starts, dtype=torch.long) + lengths_cpu = torch.tensor(lengths, dtype=torch.long) + offsets_cpu = torch.arange(max(lengths), dtype=torch.long).unsqueeze(1) + position_indices_cpu = torch.where( + offsets_cpu < lengths_cpu.unsqueeze(0), + starts_cpu.unsqueeze(0) + offsets_cpu, + torch.zeros_like(offsets_cpu), + ) + return _build_bucket_plan( + segments, + lengths_cpu=lengths_cpu, + row_indices_cpu=torch.zeros_like(position_indices_cpu), + position_indices_cpu=position_indices_cpu, + lengths_by_rank_cpu=_bucket_lengths_by_rank_cpu( + segments, + token_ranges_by_rank, + sequence_length=sequence_length, + ), + device=device, + ) + + local_positions_by_segment: list[torch.Tensor] = [] + local_range_ends = tuple(token_end for _, token_end, _ in local_token_ranges) + for segment in segments: + positions = _local_positions_for_segment( + segment, + sequence_length=sequence_length, + local_token_ranges=local_token_ranges, + local_range_ends=local_range_ends, + ) + if not int(positions.numel()): + raise ValueError( + "planned GDN bucket contains a segment with no local tokens; " + f"family={segment.family_index} kind={segment.kind}" + ) + local_positions_by_segment.append(positions) + + lengths_cpu = torch.tensor( + [int(positions.numel()) for positions in local_positions_by_segment], dtype=torch.long, ) - return GdnSegmentBucketPlan.model_construct( - length=max_length, - lengths=_move_planner_tensor(lengths_cpu, device), + max_length = int(lengths_cpu.max().item()) + position_indices_cpu = torch.zeros(max_length, len(segments), dtype=torch.long) + for column, positions in enumerate(local_positions_by_segment): + position_indices_cpu[: int(positions.numel()), column] = positions + return _build_bucket_plan( + segments, lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=lengths_by_rank_cpu, - real_mask=_move_planner_tensor(real_mask_cpu, device), - cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), - cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor(row_indices_cpu, device), - position_indices=_move_planner_tensor(position_indices_cpu, device), - family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=sum(lengths), + row_indices_cpu=torch.zeros_like(position_indices_cpu), + position_indices_cpu=position_indices_cpu, + lengths_by_rank_cpu=_bucket_lengths_by_rank_cpu( + segments, + token_ranges_by_rank, + sequence_length=sequence_length, + ), + device=device, ) @@ -3927,41 +1250,86 @@ def _batch_segments_by_padded_work( return tuple(tuple(batch) for batch in batches) +def _batch_tree_segments_by_padded_work( + segments: tuple[GdnSegmentSpec, ...], + tree_has_children: tuple[bool, ...], + *, + max_padding_ratio: float = 1.25, + max_segments_per_batch: int = 128, +) -> tuple[tuple[GdnSegmentSpec, ...], ...]: + stateful = tuple( + segment for segment in segments if tree_has_children[segment.family_index] + ) + stateless = tuple( + segment for segment in segments if not tree_has_children[segment.family_index] + ) + return ( + *_batch_segments_by_padded_work( + stateful, + max_padding_ratio=max_padding_ratio, + max_segments_per_batch=max_segments_per_batch, + ), + *_batch_segments_by_padded_work( + stateless, + max_padding_ratio=max_padding_ratio, + max_segments_per_batch=max_segments_per_batch, + ), + ) + + def _build_segment_bucket_plan( - length: int, segments: tuple[GdnSegmentSpec, ...], *, device: torch.device | str + segments: tuple[GdnSegmentSpec, ...], *, device: torch.device | str ) -> GdnSegmentBucketPlan: - max_length = max(segment.length for segment in segments) lengths_cpu = torch.tensor( [segment.length for segment in segments], dtype=torch.long ) + max_length = int(lengths_cpu.max().item()) starts_cpu = torch.tensor([segment.start for segment in segments], dtype=torch.long) rows_cpu = torch.tensor( [segment.row_index for segment in segments], dtype=torch.long ) offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) + return _build_bucket_plan( + segments, + lengths_cpu=lengths_cpu, + row_indices_cpu=rows_cpu.unsqueeze(0).expand(max_length, -1).contiguous(), + position_indices_cpu=starts_cpu.unsqueeze(0) + offsets_cpu, + device=device, + ) + + +def _build_bucket_plan( + segments: tuple[GdnSegmentSpec, ...], + *, + lengths_cpu: torch.Tensor, + row_indices_cpu: torch.Tensor, + position_indices_cpu: torch.Tensor, + device: torch.device | str, + lengths_by_rank_cpu: torch.Tensor | None = None, +) -> GdnSegmentBucketPlan: + max_length = int(lengths_cpu.max().item()) + offsets_cpu = torch.arange(max_length, dtype=torch.long).unsqueeze(1) real_mask_cpu = offsets_cpu < lengths_cpu.unsqueeze(0) - positions_cpu = starts_cpu.unsqueeze(0) + offsets_cpu + cu_seqlens_cpu = torch.cat( + [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] + ) family_indices_cpu = torch.tensor( [segment.family_index for segment in segments], dtype=torch.long, ) - cu_seqlens_cpu = torch.cat( - [lengths_cpu.new_zeros(1), torch.cumsum(lengths_cpu, dim=0)] - ) - return GdnSegmentBucketPlan.model_construct( + return GdnSegmentBucketPlan( length=max_length, lengths=_move_planner_tensor(lengths_cpu, device), lengths_cpu=lengths_cpu, - lengths_by_rank_cpu=None, + lengths_by_rank_cpu=lengths_by_rank_cpu, real_mask=_move_planner_tensor(real_mask_cpu, device), cu_seqlens=_move_planner_tensor(cu_seqlens_cpu, device), cu_seqlens_cpu=cu_seqlens_cpu, - row_indices=_move_planner_tensor( - rows_cpu.unsqueeze(0).expand(max_length, -1).contiguous(), device - ), - position_indices=_move_planner_tensor(positions_cpu, device), + row_indices=_move_planner_tensor(row_indices_cpu, device), + position_indices=_move_planner_tensor(position_indices_cpu, device), family_indices=_move_planner_tensor(family_indices_cpu, device), - real_token_count_static=sum(segment.length for segment in segments), + family_indices_cpu=family_indices_cpu, + real_token_count_static=int(lengths_cpu.sum().item()), ) @@ -3969,20 +1337,6 @@ def _segment_token_start(segment: GdnSegmentSpec, sequence_length: int) -> int: return segment.row_index * sequence_length + segment.start -def _attention_overlap_count( - index: _AttentionLayoutIndex, - rank: int, - start: int, - end: int, -) -> int: - return _range_overlap_count( - start, - end, - index.token_ranges_by_rank[rank], - index.token_range_ends_by_rank[rank], - ) - - def _range_overlap_count( start: int, end: int, @@ -4012,27 +1366,6 @@ def _range_overlaps( return overlaps -def _local_token_ranges( - local_gdn_tokens: tuple[int, ...], -) -> tuple[tuple[int, int, int], ...]: - if not local_gdn_tokens: - return () - ranges = [] - token_start = local_gdn_tokens[0] - token_end = token_start + 1 - position_start = 0 - for position, token in enumerate(local_gdn_tokens[1:], start=1): - if token == token_end: - token_end += 1 - continue - ranges.append((token_start, token_end, position_start)) - token_start = token - token_end = token + 1 - position_start = position - ranges.append((token_start, token_end, position_start)) - return tuple(ranges) - - def _local_positions_for_segment( segment: GdnSegmentSpec, *, @@ -4079,285 +1412,3 @@ def _rank2_long_cpu(name: str, tensor: torch.Tensor) -> torch.Tensor: ): raise TypeError(f"{name} must contain integer ids, got dtype={tensor.dtype}") return tensor.detach().to(device="cpu", dtype=torch.long) - - -def _validate_padding_tensor( - row_index: int, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, -) -> int: - padding_positions = torch.nonzero(group_ids == -1, as_tuple=False) - valid_length = ( - int(padding_positions[0].item()) - if int(padding_positions.numel()) > 0 - else int(group_ids.numel()) - ) - if valid_length == 0: - if bool(torch.any(parent_ids != -1).item()): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return 0 - if bool(torch.any(group_ids[valid_length:] != -1).item()): - raise ValueError( - f"row {row_index}: valid tokens must be contiguous before padding" - ) - if bool(torch.any(parent_ids[:valid_length] == -1).item()): - raise ValueError( - f"row {row_index}: valid tokens must have non-padding parent_ids" - ) - if bool(torch.any(parent_ids[valid_length:] != -1).item()): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return valid_length - - -def _validate_padding( - row_index: int, - group_ids: list[int], - parent_ids: list[int], -) -> int: - valid_length = 0 - for group_id in group_ids: - if group_id == -1: - break - valid_length += 1 - if valid_length == 0: - if any(parent_id != -1 for parent_id in parent_ids): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return 0 - if any(group_id != -1 for group_id in group_ids[valid_length:]): - raise ValueError( - f"row {row_index}: valid tokens must be contiguous before padding" - ) - if any(parent_id == -1 for parent_id in parent_ids[:valid_length]): - raise ValueError( - f"row {row_index}: valid tokens must have non-padding parent_ids" - ) - if any(parent_id != -1 for parent_id in parent_ids[valid_length:]): - raise ValueError(f"row {row_index}: padding parent_ids must be -1") - return valid_length - - -def _parse_row_tensor( - *, - row_index: int, - group_ids: torch.Tensor, - parent_ids: torch.Tensor, - valid_length: int, - first_family_index: int, - min_completions_per_family: int, -) -> list[GdnPackedFamilySpec]: - valid_groups = group_ids[:valid_length] - valid_parents = parent_ids[:valid_length] - if valid_length > 1: - same_group = valid_groups[1:] == valid_groups[:-1] - parent_changed = same_group & (valid_parents[1:] != valid_parents[:-1]) - if bool(torch.any(parent_changed).item()): - position = int(torch.nonzero(parent_changed, as_tuple=False)[0].item()) + 1 - group_id = int(valid_groups[position].item()) - previous_parent = int(valid_parents[position - 1].item()) - current_parent = int(valid_parents[position].item()) - raise ValueError( - f"row {row_index}: group {group_id} changes parent from " - f"{previous_parent} to {current_parent}" - ) - boundaries = torch.nonzero(~same_group, as_tuple=False).flatten() + 1 - starts_tensor = torch.cat( - (valid_groups.new_zeros(1), boundaries.to(valid_groups.dtype)) - ) - ends_tensor = torch.cat( - ( - boundaries.to(valid_groups.dtype), - valid_groups.new_tensor([valid_length]), - ) - ) - else: - starts_tensor = valid_groups.new_zeros(1) - ends_tensor = valid_groups.new_tensor([valid_length]) - - starts = tuple(int(value) for value in starts_tensor.tolist()) - ends = tuple(int(value) for value in ends_tensor.tolist()) - segment_group_ids = tuple(int(valid_groups[start].item()) for start in starts) - segment_parent_ids = tuple(int(valid_parents[start].item()) for start in starts) - families: list[GdnPackedFamilySpec] = [] - seen_groups: set[int] = set() - segment_cursor = 0 - while segment_cursor < len(starts): - group_id = segment_group_ids[segment_cursor] - parent_id = segment_parent_ids[segment_cursor] - start = starts[segment_cursor] - end = ends[segment_cursor] - if group_id in seen_groups: - raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") - if group_id != parent_id: - raise ValueError( - f"row {row_index}: completion group {group_id} appears before " - f"its prefix parent {parent_id}" - ) - seen_groups.add(group_id) - family_index = first_family_index + len(families) - prefix = _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - group_id=group_id, - parent_id=parent_id, - start=start, - end=end, - kind="prefix", - child_index=None, - ) - segment_cursor += 1 - completions: list[GdnSegmentSpec] = [] - while segment_cursor < len(starts): - child_group_id = segment_group_ids[segment_cursor] - child_parent_id = segment_parent_ids[segment_cursor] - child_start = starts[segment_cursor] - child_end = ends[segment_cursor] - if child_group_id == child_parent_id: - break - if child_parent_id != group_id: - raise ValueError( - f"row {row_index}: completion group {child_group_id} has " - f"parent {child_parent_id}, expected active prefix {group_id}" - ) - if child_group_id in seen_groups: - raise ValueError( - f"row {row_index}: group_id {child_group_id} is non-contiguous" - ) - seen_groups.add(child_group_id) - completions.append( - _trusted_pydantic_construct( - GdnSegmentSpec, - _GDN_SEGMENT_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - group_id=child_group_id, - parent_id=child_parent_id, - start=child_start, - end=child_end, - kind="completion", - child_index=len(completions), - ) - ) - segment_cursor += 1 - if len(completions) < min_completions_per_family: - raise ValueError( - f"row {row_index}: prefix group {group_id} has {len(completions)} " - f"completion(s), expected at least {min_completions_per_family}" - ) - families.append( - _trusted_pydantic_construct( - GdnPackedFamilySpec, - _GDN_PACKED_FAMILY_SPEC_FIELDS, - row_index=row_index, - family_index=family_index, - prefix=prefix, - completions=tuple(completions), - ) - ) - return families - - -def _parse_row( - *, - row_index: int, - group_ids: list[int], - parent_ids: list[int], - valid_length: int, - first_family_index: int, - min_completions_per_family: int, -) -> list[GdnPackedFamilySpec]: - families: list[GdnPackedFamilySpec] = [] - seen_groups: set[int] = set() - cursor = 0 - while cursor < valid_length: - group_id, parent_id, start, end = _read_segment( - row_index, group_ids, parent_ids, valid_length, cursor - ) - if group_id in seen_groups: - raise ValueError(f"row {row_index}: group_id {group_id} is non-contiguous") - if group_id != parent_id: - raise ValueError( - f"row {row_index}: completion group {group_id} appears before " - f"its prefix parent {parent_id}" - ) - seen_groups.add(group_id) - family_index = first_family_index + len(families) - prefix = GdnSegmentSpec( - row_index=row_index, - family_index=family_index, - group_id=group_id, - parent_id=parent_id, - start=start, - end=end, - kind="prefix", - ) - cursor = end - completions: list[GdnSegmentSpec] = [] - while cursor < valid_length: - child_group_id, child_parent_id, child_start, child_end = _read_segment( - row_index, group_ids, parent_ids, valid_length, cursor - ) - if child_group_id == child_parent_id: - break - if child_parent_id != group_id: - raise ValueError( - f"row {row_index}: completion group {child_group_id} has " - f"parent {child_parent_id}, expected active prefix {group_id}" - ) - if child_group_id in seen_groups: - raise ValueError( - f"row {row_index}: group_id {child_group_id} is non-contiguous" - ) - seen_groups.add(child_group_id) - completions.append( - GdnSegmentSpec( - row_index=row_index, - family_index=family_index, - group_id=child_group_id, - parent_id=child_parent_id, - start=child_start, - end=child_end, - kind="completion", - child_index=len(completions), - ) - ) - cursor = child_end - if len(completions) < min_completions_per_family: - raise ValueError( - f"row {row_index}: prefix group {group_id} has {len(completions)} " - f"completion(s), expected at least {min_completions_per_family}" - ) - families.append( - GdnPackedFamilySpec( - row_index=row_index, - family_index=family_index, - prefix=prefix, - completions=tuple(completions), - ) - ) - return families - - -def _read_segment( - row_index: int, - group_ids: list[int], - parent_ids: list[int], - valid_length: int, - cursor: int, -) -> tuple[int, int, int, int]: - group_id = int(group_ids[cursor]) - parent_id = int(parent_ids[cursor]) - if group_id < 0 or parent_id < 0: - raise ValueError(f"row {row_index}: segment ids must be non-negative") - start = cursor - cursor += 1 - while cursor < valid_length and int(group_ids[cursor]) == group_id: - current_parent = int(parent_ids[cursor]) - if current_parent != parent_id: - raise ValueError( - f"row {row_index}: group {group_id} changes parent from " - f"{parent_id} to {current_parent}" - ) - cursor += 1 - return group_id, parent_id, start, cursor diff --git a/src/art/megatron/gdn/layout.py b/src/art/megatron/gdn/layout.py index c3469a451..7119218f6 100644 --- a/src/art/megatron/gdn/layout.py +++ b/src/art/megatron/gdn/layout.py @@ -1,9 +1,9 @@ from __future__ import annotations from collections.abc import Sequence +from dataclasses import dataclass, replace from typing import Any -from pydantic import BaseModel, ConfigDict, Field, model_validator import torch from torch import Tensor from torch.distributed import ( @@ -20,47 +20,47 @@ from art.megatron.context_parallel.layout_index import TokenLayoutIndex -class GdnCpPeerTransfer(BaseModel): +@dataclass(frozen=True) +class GdnCpPeerTransfer: """Token rows sent from one source rank to one destination rank.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - - source_rank: int = Field(ge=0) - dest_rank: int = Field(ge=0) - token_count: int = Field(ge=0) + source_rank: int + dest_rank: int + token_count: int + source_positions_cpu: tuple[int, ...] | None = None + dest_positions_cpu: tuple[int, ...] | None = None source_positions_tensor: Tensor | None = None dest_positions_tensor: Tensor | None = None - @model_validator(mode="after") - def _same_lengths(self) -> "GdnCpPeerTransfer": + def __post_init__(self) -> None: lengths = {int(self.token_count)} + if self.source_positions_cpu is not None: + lengths.add(len(self.source_positions_cpu)) + if self.dest_positions_cpu is not None: + lengths.add(len(self.dest_positions_cpu)) if self.source_positions_tensor is not None: lengths.add(int(self.source_positions_tensor.numel())) if self.dest_positions_tensor is not None: lengths.add(int(self.dest_positions_tensor.numel())) if len(lengths) != 1: raise ValueError("token, source, and destination position counts differ") - return self -class GdnCpExchangePlan(BaseModel): +@dataclass(frozen=True) +class GdnCpExchangePlan: """Permutation/all-to-all metadata between two distributed token layouts.""" - model_config = ConfigDict(frozen=True) - - cp_size: int = Field(ge=1) + cp_size: int source_token_counts_by_rank: tuple[int, ...] dest_token_counts_by_rank: tuple[int, ...] transfers: tuple[GdnCpPeerTransfer, ...] - cross_rank_token_count_override: int | None = Field(default=None, ge=0) + cross_rank_token_count_override: int | None = None - @model_validator(mode="after") - def _rank_counts(self) -> "GdnCpExchangePlan": + def __post_init__(self) -> None: if len(self.source_token_counts_by_rank) != self.cp_size: raise ValueError("source token count length must equal cp_size") if len(self.dest_token_counts_by_rank) != self.cp_size: raise ValueError("destination token count length must equal cp_size") - return self @property def cross_rank_token_count(self) -> int: @@ -73,11 +73,10 @@ def cross_rank_token_count(self) -> int: ) -class GdnSpExchangePlan(BaseModel): +@dataclass(frozen=True) +class GdnSpExchangePlan: """Sequence-parallel view of an existing CP exchange plan.""" - model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) - plan: GdnCpExchangePlan rank: int @@ -206,7 +205,7 @@ def build_local_rank_cp_exchange_plan_from_dest_ranges( device=device, ) ) - return GdnCpExchangePlan.model_construct( + return GdnCpExchangePlan( cp_size=cp_size, source_token_counts_by_rank=source_layout.token_counts_by_rank, dest_token_counts_by_rank=dest_counts, @@ -238,23 +237,33 @@ def _make_peer_transfer( source_count=source_count, dest_count=dest_count, ): + source_cpu = None + dest_cpu = None source_tensor = None dest_tensor = None else: + source_cpu = _tensor_positions_tuple(source_positions) + dest_cpu = _tensor_positions_tuple(dest_positions) target = torch.device(device) if device is not None else torch.device("cpu") source_tensor = source_positions.to( device=target, dtype=torch.long ).contiguous() dest_tensor = dest_positions.to(device=target, dtype=torch.long).contiguous() - return GdnCpPeerTransfer.model_construct( + return GdnCpPeerTransfer( source_rank=source_rank, dest_rank=dest_rank, token_count=token_count, + source_positions_cpu=source_cpu, + dest_positions_cpu=dest_cpu, source_positions_tensor=source_tensor, dest_positions_tensor=dest_tensor, ) +def _tensor_positions_tuple(tensor: Tensor) -> tuple[int, ...]: + return tuple(int(value) for value in tensor.detach().cpu().tolist()) + + def _is_full_identity_transfer( *, source_rank: int, @@ -277,16 +286,18 @@ def _is_full_identity_transfer( def _reverse_exchange_plan(plan: GdnCpExchangePlan) -> GdnCpExchangePlan: - return GdnCpExchangePlan.model_construct( + return GdnCpExchangePlan( cp_size=plan.cp_size, source_token_counts_by_rank=_dest_counts_by_rank(plan), dest_token_counts_by_rank=_source_counts_by_rank(plan), cross_rank_token_count_override=plan.cross_rank_token_count_override, transfers=tuple( - GdnCpPeerTransfer.model_construct( + GdnCpPeerTransfer( source_rank=transfer.dest_rank, dest_rank=transfer.source_rank, token_count=_transfer_token_count(transfer), + source_positions_cpu=transfer.dest_positions_cpu, + dest_positions_cpu=transfer.source_positions_cpu, source_positions_tensor=transfer.dest_positions_tensor, dest_positions_tensor=transfer.source_positions_tensor, ) @@ -485,15 +496,11 @@ def move_cp_exchange_plan_to_device( if plan is None: return None target = torch.device(device) - return GdnCpExchangePlan.model_construct( - cp_size=plan.cp_size, - source_token_counts_by_rank=_source_counts_by_rank(plan), - dest_token_counts_by_rank=_dest_counts_by_rank(plan), + return replace( + plan, transfers=tuple( - GdnCpPeerTransfer.model_construct( - source_rank=transfer.source_rank, - dest_rank=transfer.dest_rank, - token_count=transfer.token_count, + replace( + transfer, source_positions_tensor=_move_optional_index_tensor( transfer.source_positions_tensor, target ), @@ -503,7 +510,6 @@ def move_cp_exchange_plan_to_device( ) for transfer in plan.transfers ), - cross_rank_token_count_override=plan.cross_rank_token_count_override, ) @@ -532,7 +538,7 @@ def shard_cp_exchange_plan_for_sequence_parallel( """ if tp_size <= 1: - return GdnSpExchangePlan.model_construct(plan=plan, rank=cp_rank) + return GdnSpExchangePlan(plan=plan, rank=cp_rank) _check_rank(plan, cp_rank) if tp_rank < 0 or tp_rank >= tp_size: raise ValueError(f"tp_rank must be in [0, {tp_size}), got {tp_rank}") @@ -603,7 +609,7 @@ def shard_cp_exchange_plan_for_sequence_parallel( # A CP-local reorder can still move rows between TP ranks, and local CP plans do # not contain enough global TP information for every rank to independently # prove that no peer exchange is needed. - sp_plan = GdnCpExchangePlan.model_construct( + sp_plan = GdnCpExchangePlan( cp_size=world_size, source_token_counts_by_rank=source_counts, dest_token_counts_by_rank=dest_counts, @@ -612,7 +618,7 @@ def shard_cp_exchange_plan_for_sequence_parallel( ), cross_rank_token_count_override=1, ) - return GdnSpExchangePlan.model_construct(plan=sp_plan, rank=composite_rank) + return GdnSpExchangePlan(plan=sp_plan, rank=composite_rank) def recv_split_sizes_for_rank(plan: GdnCpExchangePlan, rank: int) -> tuple[int, ...]: @@ -750,10 +756,15 @@ def _is_implicit_full_identity_transfer( ) -def _transfer_positions_tuple(tensor: Tensor | None) -> tuple[int, ...]: +def _transfer_positions_tuple( + positions: tuple[int, ...] | None, + tensor: Tensor | None, +) -> tuple[int, ...]: + if positions is not None: + return positions if tensor is None: return () - return tuple(int(value) for value in tensor.detach().cpu().tolist()) + return _tensor_positions_tuple(tensor) def _transfer_index_tensor( @@ -895,17 +906,6 @@ def _exchange_rank_tensor_local( ) -def _copy_rank_self_transfers( - local_tensor: Tensor, - plan: GdnCpExchangePlan, - *, - rank: int, -) -> Tensor: - return _init_rank_exchange_output( - local_tensor, plan, rank=rank, accumulate=False, zero_init=False - ) - - def _init_rank_exchange_output( local_tensor: Tensor, plan: GdnCpExchangePlan, @@ -1028,7 +1028,10 @@ def _transfer_dest_positions_for_duplicate_check( dest_count=_dest_count_for_rank(plan, transfer.dest_rank), ): return tuple(range(token_count)) - positions = _transfer_positions_tuple(transfer.dest_positions_tensor) + positions = _transfer_positions_tuple( + transfer.dest_positions_cpu, + transfer.dest_positions_tensor, + ) if len(positions) != token_count: raise ValueError("GDN CP transfer destination positions must match token_count") return positions diff --git a/src/art/megatron/gdn/operator.py b/src/art/megatron/gdn/operator.py index e8a122f5c..3dbed83d9 100644 --- a/src/art/megatron/gdn/operator.py +++ b/src/art/megatron/gdn/operator.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import MethodType -from typing import Any, Callable, Literal, NamedTuple, Sequence, cast +from typing import Any, Callable, Iterable, Literal, NamedTuple, Sequence, cast import torch from torch import Tensor @@ -12,9 +12,9 @@ from .fla_cp import chunk_gated_delta_rule_native_cp from .gdn_shared_prefix import ( GdnPackedExecutionSpec, - GdnParentStateTransferPlan, GdnRankExecutionPlan, GdnSegmentBucketPlan, + GdnStateExchangePlan, build_gdn_rank_execution_plan, parse_gdn_shared_prefix_segments, ) @@ -301,6 +301,7 @@ def _empty_safe_norm_forward( return original_forward(input_, *args, **kwargs) +@torch.compiler.disable def _shared_prefix_forward( self: Any, hidden_states: Tensor, @@ -463,9 +464,7 @@ def run_gdn_layer( ) if execution_spec is None and execution_plan is None: - execution_spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + execution_spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) if ( execution_spec is not None and requested_cp_size == 1 @@ -510,136 +509,396 @@ def run_gdn_layer( ) if input_layout != "attention" or output_layout != "attention": raise ValueError("GDN layout controls require a CP execution plan") - return _run_planned_prefixes_and_completions(gdn, hidden_states, execution_plan) + return _run_tree_prefixes(gdn, hidden_states, execution_plan) -def _run_planned_prefixes_and_completions( +def _run_tree_prefixes( gdn: Any, hidden_states: Tensor, plan: GdnRankExecutionPlan, ) -> tuple[Tensor, Tensor | None]: - if _has_chunk_aligned_local_plan(plan): - return _run_chunk_aligned_prefixes_and_completions(gdn, hidden_states, plan) - raise ValueError( - "shared-prefix GDN requires a chunk-aligned execution plan; " - "prefix/completion bucket execution has been removed" + qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, hidden_states) + gate = gate.clone() + recurrent_output = torch.zeros_like(gate) + recurrent_output, _cp_dependency = _run_tree_depth_buckets( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + plan, + state_reference=hidden_states, ) + return _project_gdn_output(gdn, recurrent_output, gate, plan) -def _has_chunk_aligned_local_plan(plan: GdnRankExecutionPlan) -> bool: - return bool( - plan.prefix_boundary_buckets - or plan.prefix_tail_buckets - or plan.completion_with_prefix_tail_buckets +def _run_tree_depth_buckets( + gdn: Any, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + recurrent_output: Tensor, + plan: GdnRankExecutionPlan, + *, + state_reference: Tensor, + group: Any | None = None, + cp_dependency: Tensor | None = None, +) -> tuple[Tensor, Tensor | None]: + state_cache = _TreeStateChunkCache( + device=state_reference.device, ) + for depth, buckets in enumerate(plan.tree_segment_buckets_by_depth): + if depth < len(plan.tree_state_exchanges_by_depth): + cp_dependency = state_cache.exchange_remote_parent_states( + gdn, + plan.tree_state_exchanges_by_depth[depth], + state_reference=state_reference, + rank=plan.cp_rank, + group=group, + cp_dependency=cp_dependency, + ) + if depth < len(plan.tree_chain_buckets_by_depth): + for bucket in plan.tree_chain_buckets_by_depth[depth]: + recurrent_output, cp_dependency = _run_tree_bucket( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + state_cache, + bucket, + state_reference=state_reference, + group=group, + cp_dependency=cp_dependency, + recurrent_cp=True, + scale_parent_state_gradient=1.0 / plan.cp_size, + ) + + for bucket in buckets: + recurrent_output, cp_dependency = _run_tree_bucket( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + state_cache, + bucket, + state_reference=state_reference, + cp_dependency=cp_dependency, + ) + + return recurrent_output, cp_dependency -def _run_chunk_aligned_prefixes_and_completions( + +def _run_tree_bucket( gdn: Any, - hidden_states: Tensor, - plan: GdnRankExecutionPlan, + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + recurrent_output: Tensor, + state_cache: "_TreeStateChunkCache", + bucket: GdnSegmentBucketPlan, + *, + state_reference: Tensor, + group: Any | None = None, + cp_dependency: Tensor | None = None, + recurrent_cp: bool = False, + scale_parent_state_gradient: float | None = None, ) -> tuple[Tensor, Tensor | None]: - qkv, gate, beta, recurrent_g = _project_gdn_inputs(gdn, hidden_states) - gate = gate.clone() - recurrent_output = torch.zeros_like(gate) - boundary_family_chunks: list[Tensor] = [] - boundary_conv_chunks: list[Tensor] = [] - boundary_rec_chunks: list[Tensor] = [] - - for bucket in plan.prefix_boundary_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket + parent_conv, parent_rec = state_cache.parent_states( + gdn, + bucket, + state_reference=state_reference, + ) + if _bucket_has_parent_state(bucket): + parent_conv, parent_rec = _couple_parent_states(parent_conv, parent_rec) + if scale_parent_state_gradient is not None: + parent_conv = _scale_state_gradient( + parent_conv, + scale_parent_state_gradient, + ) + parent_rec = _scale_state_gradient(parent_rec, scale_parent_state_gradient) + segment_qkv, segment_beta, segment_g = _gather_bucket_streams( + qkv, + beta, + recurrent_g, + bucket, + ) + if cp_dependency is not None: + segment_qkv = _add_autograd_dependency(segment_qkv, cp_dependency) + segment_beta = _add_autograd_dependency(segment_beta, cp_dependency) + segment_g = _add_autograd_dependency(segment_g, cp_dependency) + parent_conv = _add_autograd_dependency(parent_conv, cp_dependency) + parent_rec = _add_autograd_dependency(parent_rec, cp_dependency) + segment_out, segment_conv, segment_rec = run_gdn_bucket( + bucket, + (segment_qkv, segment_beta, segment_g), + (parent_conv, parent_rec), + gdn=gdn, + group=group, + recurrent_cp=recurrent_cp, + output_final_state=bucket.needs_final_state or recurrent_cp, + ) + if bucket.needs_final_state and (segment_conv is None or segment_rec is None): + raise RuntimeError("tree GDN execution must return final states") + if ( + bucket.needs_final_state + and segment_conv is not None + and segment_rec is not None + ): + cp_dependency = _make_autograd_dependency( + segment_out, segment_conv, segment_rec ) - zero_conv = _zero_conv_state( - gdn, hidden_states, batch_size=bucket.segment_count + else: + cp_dependency = _make_autograd_dependency(segment_out) + recurrent_output = _scatter_bucket_recurrent_output( + recurrent_output, + bucket, + segment_out, + ) + if bucket.needs_final_state: + state_cache.append( + bucket, + cast(Tensor, segment_conv), + cast(Tensor, segment_rec), + ) + return recurrent_output, cp_dependency + + +class _TreeStateChunkCache: + def __init__(self, *, device: torch.device) -> None: + self._device = device + self._conv_chunks: list[Tensor] = [] + self._rec_chunks: list[Tensor] = [] + self._source_by_family: list[tuple[int, int] | None] = [] + + def append(self, bucket: GdnSegmentBucketPlan, conv: Tensor, rec: Tensor) -> None: + self.append_families(_bucket_family_indices_cpu(bucket), conv, rec) + + def append_families( + self, family_indices: Sequence[int], conv: Tensor, rec: Tensor + ) -> None: + if len(family_indices) == 0: + return + if int(conv.shape[0]) != len(family_indices): + raise ValueError( + "tree GDN state cache conv batch must match family count, got " + f"{tuple(conv.shape)} and {len(family_indices)} families" + ) + if int(rec.shape[0]) != len(family_indices): + raise ValueError( + "tree GDN state cache recurrent batch must match family count, got " + f"{tuple(rec.shape)} and {len(family_indices)} families" + ) + chunk_index = len(self._conv_chunks) + self._conv_chunks.append(conv) + self._rec_chunks.append(rec) + max_family = max(int(index) for index in family_indices) + if max_family >= len(self._source_by_family): + self._source_by_family.extend( + None for _ in range(max_family + 1 - len(self._source_by_family)) + ) + for source_row, family_index in enumerate(family_indices): + self._source_by_family[int(family_index)] = (chunk_index, source_row) + + def exchange_remote_parent_states( + self, + gdn: Any, + exchange: GdnStateExchangePlan | None, + *, + state_reference: Tensor, + rank: int, + group: Any | None, + cp_dependency: Tensor | None, + ) -> Tensor | None: + if exchange is None: + return cp_dependency + from .layout import exchange_rank_tensor_all_to_all + + source_conv, source_rec = self.states_for_families( + gdn, + exchange.source_family_indices, + state_reference=state_reference, + ) + if cp_dependency is not None: + source_conv = _add_autograd_dependency(source_conv, cp_dependency) + source_rec = _add_autograd_dependency(source_rec, cp_dependency) + remote_conv = exchange_rank_tensor_all_to_all( + source_conv, + exchange.exchange, + rank=rank, + group=group, + backward_plan=exchange.reverse_exchange, ) - zero_rec = _zero_recurrent_state( - gdn, hidden_states, batch_size=bucket.segment_count + remote_rec = exchange_rank_tensor_all_to_all( + source_rec, + exchange.exchange, + rank=rank, + group=group, + backward_plan=exchange.reverse_exchange, ) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("prefix boundary GDN execution must return final states") - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - boundary_family_chunks.append(bucket.family_indices) - boundary_conv_chunks.append(prefix_conv) - boundary_rec_chunks.append(prefix_rec) - - boundary_conv_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_conv_chunks, - zero_state=_zero_conv_state(gdn, hidden_states, batch_size=plan.family_count), - ) - boundary_rec_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_rec_chunks, - zero_state=_zero_recurrent_state( - gdn, hidden_states, batch_size=plan.family_count - ), - ) - - tail_family_chunks: list[Tensor] = [] - tail_conv_chunks: list[Tensor] = [] - tail_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) - tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, + self.append_families(exchange.dest_family_indices, remote_conv, remote_rec) + dependency = _make_zero_autograd_dependency( + source_conv, source_rec, remote_conv, remote_rec ) - if tail_conv is None or tail_rec is None: - raise RuntimeError("prefix tail GDN execution must return final states") - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, tail_out + return dependency if cp_dependency is None else dependency + cp_dependency + + def states_for_families( + self, + gdn: Any, + family_indices: Sequence[int], + *, + state_reference: Tensor, + ) -> tuple[Tensor, Tensor]: + if len(family_indices) == 0: + conv = _zero_conv_state(gdn, state_reference, batch_size=0) + rec = _zero_recurrent_state(gdn, state_reference, batch_size=0) + return conv.requires_grad_(True), rec.requires_grad_(True) + return self._mixed_parent_states( + gdn, + tuple(int(index) for index in family_indices), + state_reference=state_reference, + batch_size=len(family_indices), + roots_allowed=False, ) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_conv_table = _replace_indexed_family_states( - boundary_conv_table, - family_chunks=tail_family_chunks, - state_chunks=tail_conv_chunks, - ) - prefix_rec_table = _replace_indexed_family_states( - boundary_rec_table, - family_chunks=tail_family_chunks, - state_chunks=tail_rec_chunks, - ) + def parent_states( + self, + gdn: Any, + bucket: GdnSegmentBucketPlan, + *, + state_reference: Tensor, + ) -> tuple[Tensor, Tensor]: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = _bucket_parent_indices_cpu(bucket) + batch_size = bucket.segment_count + if all(parent_index < 0 for parent_index in parent_indices_cpu): + return ( + _zero_conv_state(gdn, state_reference, batch_size=batch_size), + _zero_recurrent_state(gdn, state_reference, batch_size=batch_size), + ) - for bucket in plan.completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, + return self._mixed_parent_states( + gdn, + parent_indices_cpu, + state_reference=state_reference, + batch_size=batch_size, ) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out + + def _mixed_parent_states( + self, + gdn: Any, + parent_indices_cpu: tuple[int, ...], + *, + state_reference: Tensor, + batch_size: int, + roots_allowed: bool = True, + ) -> tuple[Tensor, Tensor]: + sources_by_chunk: dict[int, list[tuple[int, int]]] = {} + missing_parents: list[int] = [] + for dest_row, parent_index in enumerate(parent_indices_cpu): + if parent_index < 0: + if roots_allowed: + continue + missing_parents.append(parent_index) + continue + source = ( + self._source_by_family[parent_index] + if parent_index < len(self._source_by_family) + else None + ) + if source is None: + missing_parents.append(parent_index) + continue + chunk_index, source_row = source + sources_by_chunk.setdefault(chunk_index, []).append((dest_row, source_row)) + if missing_parents: + raise RuntimeError( + "tree GDN append-only execution is missing parent state for " + f"families {tuple(missing_parents)}" + ) + + single_source_chunk = next(iter(sources_by_chunk.values())) + if len(sources_by_chunk) == 1 and len(single_source_chunk) == batch_size: + chunk_index, pairs = next(iter(sources_by_chunk.items())) + return ( + _select_state_rows(self._conv_chunks[chunk_index], pairs), + _select_state_rows(self._rec_chunks[chunk_index], pairs), + ) + + conv = _zero_conv_state(gdn, state_reference, batch_size=batch_size) + rec = _zero_recurrent_state(gdn, state_reference, batch_size=batch_size) + for chunk_index, pairs in sources_by_chunk.items(): + dest_rows = _long_tensor( + (dest_row for dest_row, _ in pairs), + device=self._device, + ) + source_rows = _long_tensor( + (source_row for _, source_row in pairs), + device=self._device, + ) + conv = conv.index_copy( + 0, + dest_rows, + self._conv_chunks[chunk_index].index_select(0, source_rows), + ) + rec = rec.index_copy( + 0, + dest_rows, + self._rec_chunks[chunk_index].index_select(0, source_rows), + ) + return conv, rec + + +def _select_state_rows(chunk: Tensor, pairs: Sequence[tuple[int, int]]) -> Tensor: + source_rows = tuple(source_row for _, source_row in pairs) + if len(set(source_rows)) == 1: + return chunk.narrow(0, source_rows[0], 1).expand( + len(source_rows), + *tuple(chunk.shape[1:]), ) - return _project_gdn_output(gdn, recurrent_output, gate, plan) + first_row = source_rows[0] + if source_rows == tuple(range(first_row, first_row + len(source_rows))): + return chunk.narrow(0, first_row, len(source_rows)) + return chunk.index_select( + 0, + _long_tensor(source_rows, device=chunk.device), + ) + + +def _bucket_family_indices_cpu(bucket: GdnSegmentBucketPlan) -> tuple[int, ...]: + family_indices = bucket.family_indices_cpu + if family_indices is None: + family_indices = bucket.family_indices.detach().cpu() + return tuple(int(index) for index in family_indices.tolist()) + + +def _bucket_parent_indices_cpu(bucket: GdnSegmentBucketPlan) -> tuple[int, ...]: + parent_indices = bucket.parent_indices + if parent_indices is None: + raise RuntimeError("tree GDN bucket is missing parent indices") + parent_indices_cpu = bucket.parent_indices_cpu + if parent_indices_cpu is None: + parent_indices_cpu = parent_indices.detach().cpu() + return tuple(int(index) for index in parent_indices_cpu.tolist()) + + +def _long_tensor(values: Iterable[int], *, device: torch.device) -> Tensor: + return torch.tensor(tuple(values), dtype=torch.long, device=device) + + +def _bucket_has_parent_state(bucket: GdnSegmentBucketPlan) -> bool: + return any(parent_index >= 0 for parent_index in _bucket_parent_indices_cpu(bucket)) + + +def _bucket_has_uniform_lengths(bucket: GdnSegmentBucketPlan) -> bool: + lengths_cpu = bucket.lengths_cpu + if lengths_cpu is None: + lengths_cpu = bucket.lengths.detach().cpu() + return all(int(length) == int(bucket.length) for length in lengths_cpu.tolist()) def _run_cp_planned_prefixes_and_completions( @@ -679,385 +938,21 @@ def _run_cp_planned_prefixes_and_completions( if empty_gdn_rank else _empty_autograd_dependency(qkv) ) - qkv_with_remote_tail = qkv - beta_with_remote_tail = beta - recurrent_g_with_remote_tail = recurrent_g - if plan.remote_prefix_tail_exchange is not None: - remote_qkv, remote_beta, remote_g = _exchange_remote_prefix_tail_streams( - qkv, - beta, - recurrent_g, - plan=plan, - group=group, - ) - qkv_with_remote_tail = torch.cat([qkv, remote_qkv.unsqueeze(0)], dim=1) - beta_with_remote_tail = torch.cat([beta, remote_beta.unsqueeze(0)], dim=1) - recurrent_g_with_remote_tail = torch.cat( - [recurrent_g, remote_g.unsqueeze(0)], dim=1 - ) - cp_dependency = cp_dependency + _make_zero_autograd_dependency( - remote_qkv, remote_beta, remote_g - ) + if not plan.tree_segment_buckets_by_depth: + raise ValueError("CP shared-prefix GDN requires a tree execution plan") gate = gate.clone() recurrent_output = torch.zeros_like(gate) - prefix_family_chunks: list[Tensor] = [] - prefix_conv_chunks: list[Tensor] = [] - prefix_rec_chunks: list[Tensor] = [] - - for bucket in plan.chain_prefix_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - group=group, - recurrent_cp=True, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("CP prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - cp_dependency = _make_autograd_dependency(prefix_out, prefix_conv, prefix_rec) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - boundary_family_chunks: list[Tensor] = [] - boundary_conv_chunks: list[Tensor] = [] - boundary_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_boundary_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("local prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - boundary_family_chunks.append(bucket.family_indices) - boundary_conv_chunks.append(prefix_conv) - boundary_rec_chunks.append(prefix_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - if ( - plan.prefix_tail_buckets - or plan.remote_prefix_tail_buckets - or plan.completion_with_prefix_tail_buckets - or plan.remote_completion_with_prefix_tail_buckets - or plan.remote_prefix_tail_state_transfers - ): - boundary_conv_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_conv_chunks, - zero_state=_zero_conv_state(gdn, qkv, batch_size=plan.family_count), - ) - boundary_rec_table = _materialize_indexed_family_state_table( - plan=plan, - family_chunks=boundary_family_chunks, - state_chunks=boundary_rec_chunks, - zero_state=_zero_recurrent_state(gdn, qkv, batch_size=plan.family_count), - ) - remote_boundary_conv_table = boundary_conv_table - remote_boundary_rec_table = boundary_rec_table - if plan.remote_prefix_tail_state_transfers: - ( - remote_boundary_conv_table, - remote_boundary_rec_table, - remote_boundary_dependency, - ) = _exchange_parent_state_rows( - boundary_conv_table, - boundary_rec_table, - transfers=plan.remote_prefix_tail_state_transfers, - group=group, - ) - cp_dependency = cp_dependency + remote_boundary_dependency - tail_family_chunks: list[Tensor] = [] - tail_conv_chunks: list[Tensor] = [] - tail_rec_chunks: list[Tensor] = [] - for bucket in plan.prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - tail_conv = boundary_conv_table.index_select(0, bucket.family_indices) - tail_rec = boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, - ) - if tail_conv is None or tail_rec is None: - raise RuntimeError("local prefix tail GDN execution must return states") - tail_out = _add_autograd_dependency(tail_out, cp_dependency) - tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) - tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, tail_out - ) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(tail_conv) - prefix_rec_chunks.append(tail_rec) - for bucket in plan.remote_prefix_tail_buckets: - tail_qkv, tail_beta, tail_g = _gather_bucket_streams( - qkv_with_remote_tail, - beta_with_remote_tail, - recurrent_g_with_remote_tail, - bucket, - ) - tail_conv = remote_boundary_conv_table.index_select( - 0, bucket.family_indices - ) - tail_rec = remote_boundary_rec_table.index_select(0, bucket.family_indices) - tail_out, tail_conv, tail_rec = run_gdn_bucket( - bucket, - (tail_qkv, tail_beta, tail_g), - (tail_conv, tail_rec), - gdn=gdn, - output_final_state=True, - ) - if tail_conv is None or tail_rec is None: - raise RuntimeError( - "remote prefix tail GDN execution must return states" - ) - tail_out = _add_autograd_dependency(tail_out, cp_dependency) - tail_conv = _add_autograd_dependency(tail_conv, cp_dependency) - tail_rec = _add_autograd_dependency(tail_rec, cp_dependency) - tail_family_chunks.append(bucket.family_indices) - tail_conv_chunks.append(tail_conv) - tail_rec_chunks.append(tail_rec) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(tail_conv) - prefix_rec_chunks.append(tail_rec) - prefix_conv_table = _replace_indexed_family_states( - boundary_conv_table, - family_chunks=tail_family_chunks, - state_chunks=tail_conv_chunks, - ) - prefix_rec_table = _replace_indexed_family_states( - boundary_rec_table, - family_chunks=tail_family_chunks, - state_chunks=tail_rec_chunks, - ) - for bucket in plan.completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - for bucket in plan.remote_completion_with_prefix_tail_buckets: - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, - beta, - recurrent_g, - bucket, - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - for bucket in plan.local_prefix_buckets: - prefix_qkv, prefix_beta, prefix_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - zero_conv = _zero_conv_state(gdn, qkv, batch_size=bucket.segment_count) - zero_rec = _zero_recurrent_state(gdn, qkv, batch_size=bucket.segment_count) - prefix_out, prefix_conv, prefix_rec = run_gdn_bucket( - bucket, - (prefix_qkv, prefix_beta, prefix_g), - (zero_conv, zero_rec), - gdn=gdn, - output_final_state=True, - ) - if prefix_conv is None or prefix_rec is None: - raise RuntimeError("local prefix GDN execution must return final states") - prefix_out = _add_autograd_dependency(prefix_out, cp_dependency) - prefix_conv = _add_autograd_dependency(prefix_conv, cp_dependency) - prefix_rec = _add_autograd_dependency(prefix_rec, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, prefix_out - ) - prefix_family_chunks.append(bucket.family_indices) - prefix_conv_chunks.append(prefix_conv) - prefix_rec_chunks.append(prefix_rec) - - if not prefix_conv_chunks and not plan.parent_state_exchange_family_indices: - projected, out_bias = _project_cp_gdn_output( - gdn, - recurrent_output, - gate, - plan, - group=group, - output_layout=output_layout, - ) - projected = _add_autograd_dependency(projected, cp_dependency) - return projected, out_bias - - prefix_conv_table = _materialize_ordered_family_state_table( - family_chunks=prefix_family_chunks, - state_chunks=prefix_conv_chunks, - zero_state=_zero_conv_state(gdn, qkv, batch_size=plan.family_count), - ) - prefix_rec_table = _materialize_ordered_family_state_table( - family_chunks=prefix_family_chunks, - state_chunks=prefix_rec_chunks, - zero_state=_zero_recurrent_state(gdn, qkv, batch_size=plan.family_count), - ) - parent_state_exchanged = False - if plan.chain_completion_buckets and plan.parent_state_exchange_family_indices: - if not plan.parent_state_transfers: - raise ValueError("CP parent-state exchange requires planned transfers") - prefix_conv_table, prefix_rec_table, exchange_dependency = ( - _exchange_parent_state_rows( - prefix_conv_table, - prefix_rec_table, - transfers=plan.parent_state_transfers, - group=group, - ) - ) - cp_dependency = cp_dependency + exchange_dependency - parent_state_exchanged = True - for bucket in plan.chain_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_conv = _scale_state_gradient(completion_conv, 1.0 / plan.cp_size) - completion_rec = _scale_state_gradient(completion_rec, 1.0 / plan.cp_size) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - group=group, - recurrent_cp=True, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - cp_dependency = _make_autograd_dependency(completion_out) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - ready_completion_buckets = ( - plan.ready_local_completion_buckets - if plan.ready_local_completion_buckets or plan.remote_local_completion_buckets - else plan.local_completion_buckets + recurrent_output, cp_dependency = _run_tree_depth_buckets( + gdn, + qkv, + beta, + recurrent_g, + recurrent_output, + plan, + state_reference=qkv, + group=group, + cp_dependency=cp_dependency, ) - for bucket in ready_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - - if plan.parent_state_exchange_family_indices and not parent_state_exchanged: - if not plan.parent_state_transfers: - raise ValueError("CP parent-state exchange requires planned transfers") - prefix_conv_table, prefix_rec_table, exchange_dependency = ( - _exchange_parent_state_rows( - prefix_conv_table, - prefix_rec_table, - transfers=plan.parent_state_transfers, - group=group, - ) - ) - cp_dependency = cp_dependency + exchange_dependency - - for bucket in plan.remote_local_completion_buckets: - completion_qkv, completion_beta, completion_g = _gather_bucket_streams( - qkv, beta, recurrent_g, bucket - ) - completion_conv = prefix_conv_table.index_select(0, bucket.family_indices) - completion_rec = prefix_rec_table.index_select(0, bucket.family_indices) - completion_conv, completion_rec = _couple_parent_states( - completion_conv, completion_rec - ) - completion_out, _, _ = run_gdn_bucket( - bucket, - (completion_qkv, completion_beta, completion_g), - (completion_conv, completion_rec), - gdn=gdn, - output_final_state=False, - ) - completion_out = _add_autograd_dependency(completion_out, cp_dependency) - recurrent_output = _scatter_bucket_recurrent_output( - recurrent_output, bucket, completion_out - ) - projected, out_bias = _project_cp_gdn_output( gdn, recurrent_output, @@ -1065,8 +960,8 @@ def _run_cp_planned_prefixes_and_completions( plan, group=group, output_layout=output_layout, + dependency=cp_dependency, ) - projected = _add_autograd_dependency(projected, cp_dependency) return projected, out_bias @@ -1659,12 +1554,6 @@ def _local_layout_token_count_for_hidden( return (real_count + _tp_world_size(projection) - 1) // _tp_world_size(projection) -def _attention_original_shape_from_plan( - hidden_states: Tensor, plan: GdnRankExecutionPlan -) -> tuple[int, int, int]: - return (int(plan.attention_token_count), 1, int(hidden_states.shape[-1])) - - def _restore_hidden_from_cp_flat( flat: Tensor, original_shape: tuple[int, int, int] ) -> Tensor: @@ -1922,6 +1811,7 @@ def _project_cp_gdn_output( *, group: Any, output_layout: Literal["attention", "gdn"], + dependency: Tensor | None = None, ) -> tuple[Tensor, Tensor | None]: batch_size, seq_len, _, _ = recurrent_output.shape token_uids = ( @@ -1933,6 +1823,8 @@ def _project_cp_gdn_output( norm_out = _apply_gated_rms_norm(gdn, recurrent_output, gate) norm_out = norm_out.reshape(batch_size, seq_len, _local_value_dim(gdn)) norm_out = norm_out.transpose(0, 1).contiguous() + if dependency is not None: + norm_out = _add_autograd_dependency(norm_out, dependency) if token_uids is not None: token_uids = _replicated_layout_token_uids(plan, "gdn", hidden_states=norm_out) _attach_trace_token_uids(norm_out, token_uids) @@ -2271,6 +2163,36 @@ def _local_value_dim(gdn: Any) -> int: return _local_value_heads(gdn) * int(gdn.value_head_dim) +def _prepare_dense_recurrent_inputs( + qkv: Tensor, + beta: Tensor, + recurrent_g: Tensor, + *, + key_heads: int, + value_heads: int, + key_dim: int, + value_dim: int, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + key_channels = int(key_heads) * int(key_dim) + value_channels = int(value_heads) * int(value_dim) + query = qkv[..., :key_channels].reshape(*qkv.shape[:2], key_heads, key_dim) + key = qkv[..., key_channels : 2 * key_channels].reshape( + *qkv.shape[:2], + key_heads, + key_dim, + ) + value = qkv[..., 2 * key_channels : 2 * key_channels + value_channels].reshape( + *qkv.shape[:2], + value_heads, + value_dim, + ) + repeat = int(value_heads) // int(key_heads) + if repeat != 1: + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + return query, key, value, beta, recurrent_g + + def _scatter_bucket_recurrent_output( output: Tensor, bucket: GdnSegmentBucketPlan, bucket_output: Tensor ) -> Tensor: @@ -2289,269 +2211,6 @@ def _bucket_output_mask(bucket: GdnSegmentBucketPlan) -> Tensor: return bucket.real_mask if output_mask is None else output_mask -def _materialize_indexed_family_state_table( - *, - plan: GdnRankExecutionPlan, - family_chunks: list[Tensor], - state_chunks: list[Tensor], - zero_state: Tensor, -) -> Tensor: - table = zero_state.detach() - if not state_chunks: - return table.requires_grad_(True) - values = torch.cat(state_chunks, dim=0) - family_indices = torch.cat(family_chunks, dim=0) - return table.index_copy(0, family_indices, values) - - -def _materialize_ordered_family_state_table( - *, - family_chunks: list[Tensor], - state_chunks: list[Tensor], - zero_state: Tensor, -) -> Tensor: - if len(family_chunks) != len(state_chunks): - raise RuntimeError("family and state chunk counts must match") - table = zero_state.detach().requires_grad_(True) - for family_indices, states in zip(family_chunks, state_chunks, strict=True): - table = table.index_copy(0, family_indices, states) - return table - - -def _replace_indexed_family_states( - table: Tensor, - *, - family_chunks: list[Tensor], - state_chunks: list[Tensor], -) -> Tensor: - if not state_chunks: - return table - return table.index_copy( - 0, - torch.cat(family_chunks, dim=0), - torch.cat(state_chunks, dim=0), - ) - - -def _exchange_parent_state_rows( - conv_table: Tensor, - rec_table: Tensor, - *, - transfers: tuple[GdnParentStateTransferPlan, ...], - group: Any, -) -> tuple[Tensor, Tensor, Tensor]: - if not transfers: - return conv_table, rec_table, _empty_autograd_dependency(conv_table) - conv_table, rec_table = _ParentStateExchange.apply( - conv_table, rec_table, transfers, group - ) - return conv_table, rec_table, _make_autograd_dependency(conv_table, rec_table) - - -def _exchange_remote_prefix_tail_streams( - qkv: Tensor, - beta: Tensor, - recurrent_g: Tensor, - *, - plan: GdnRankExecutionPlan, - group: Any, -) -> tuple[Tensor, Tensor, Tensor]: - from .layout import exchange_rank_tensor_all_to_all - - if plan.remote_prefix_tail_exchange is None: - return ( - qkv.new_empty((0, int(qkv.shape[-1]))), - beta.new_empty((0, int(beta.shape[-1]))), - recurrent_g.new_empty((0, int(recurrent_g.shape[-1]))), - ) - if plan.remote_prefix_tail_backward_exchange is None: - raise ValueError("remote prefix-tail exchange requires a backward plan") - qkv_flat = qkv.reshape(-1, int(qkv.shape[-1])) - beta_flat = beta.reshape(-1, int(beta.shape[-1])) - g_flat = recurrent_g.reshape(-1, int(recurrent_g.shape[-1])) - kwargs = { - "plan": plan.remote_prefix_tail_exchange, - "rank": plan.cp_rank, - "group": group, - "backward_plan": plan.remote_prefix_tail_backward_exchange, - } - return ( - exchange_rank_tensor_all_to_all(qkv_flat, **kwargs), - exchange_rank_tensor_all_to_all(beta_flat, **kwargs), - exchange_rank_tensor_all_to_all(g_flat, **kwargs), - ) - - -class _ParentStateExchange(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - conv_table: Tensor, - rec_table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - group: Any, - ) -> tuple[Tensor, Tensor]: - ctx.group = group - ctx.transfers = transfers - ctx.save_for_backward(conv_table, rec_table) - return ( - _exchange_parent_state_tensor_forward( - conv_table, - transfers, - group=group, - ), - _exchange_parent_state_tensor_forward( - rec_table, - transfers, - group=group, - ), - ) - - @staticmethod - def backward( - ctx: Any, *grad_outputs: Tensor | None - ) -> tuple[Tensor | None, Tensor | None, None, None]: - grad_conv, grad_rec = grad_outputs - conv_ref, rec_ref = ctx.saved_tensors - return ( - _exchange_parent_state_tensor_backward( - _zero_if_none(grad_conv, conv_ref), - ctx.transfers, - group=ctx.group, - ), - _exchange_parent_state_tensor_backward( - _zero_if_none(grad_rec, rec_ref), - ctx.transfers, - group=ctx.group, - ), - None, - None, - ) - - -def _exchange_parent_state_tensor_forward( - table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - group: Any, -) -> Tensor: - rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] - output = table.clone() - recvs = _exchange_parent_state_rows_all_to_all( - table, transfers, rank=rank, reverse=False, group=group - ) - for transfer, rows in recvs: - index = _parent_state_index_tensor(transfer, device=table.device) - output.index_copy_(0, index, rows) - return output - - -def _exchange_parent_state_tensor_backward( - grad_output: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - group: Any, -) -> Tensor: - rank = torch.distributed.get_rank(group) # ty: ignore[possibly-missing-attribute] - grad_input = grad_output.clone() - for transfer in transfers: - if transfer.dest_rank != rank: - continue - index = _parent_state_index_tensor(transfer, device=grad_output.device) - grad_input.index_fill_(0, index, 0) - recvs = _exchange_parent_state_rows_all_to_all( - grad_output, transfers, rank=rank, reverse=True, group=group - ) - for transfer, rows in recvs: - index = _parent_state_index_tensor(transfer, device=grad_output.device) - grad_input.index_add_(0, index, rows) - return grad_input - - -def _zero_if_none(grad: Tensor | None, reference: Tensor) -> Tensor: - if grad is None: - return reference.new_zeros(reference.shape) - return grad.contiguous() - - -def _exchange_parent_state_rows_all_to_all( - table: Tensor, - transfers: tuple[GdnParentStateTransferPlan, ...], - *, - rank: int, - reverse: bool, - group: Any, -) -> list[tuple[GdnParentStateTransferPlan, Tensor]]: - world_size = torch.distributed.get_world_size(group) # ty: ignore[possibly-missing-attribute] - send_counts = [0 for _ in range(world_size)] - recv_counts = [0 for _ in range(world_size)] - send_pieces: list[Tensor] = [] - for peer_rank in range(world_size): - for transfer in transfers: - send_rank = transfer.dest_rank if reverse else transfer.source_rank - recv_rank = transfer.source_rank if reverse else transfer.dest_rank - if send_rank == recv_rank: - continue - row_count = len(transfer.family_indices) - if rank == send_rank and peer_rank == recv_rank: - index = _parent_state_index_tensor(transfer, device=table.device) - send_pieces.append(table.index_select(0, index).contiguous()) - send_counts[peer_rank] += row_count - if rank == recv_rank and peer_rank == send_rank: - recv_counts[peer_rank] += row_count - - trailing_shape = tuple(table.shape[1:]) - send_buffer = ( - torch.cat(send_pieces, dim=0) - if send_pieces - else table.new_empty((0, *trailing_shape)) - ) - recv_buffer = table.new_empty((sum(recv_counts), *trailing_shape)) - work = torch.distributed.all_to_all_single( # ty: ignore[possibly-missing-attribute] - recv_buffer, - send_buffer, - output_split_sizes=recv_counts, - input_split_sizes=send_counts, - group=group, - async_op=True, - ) - work.wait() - - recvs: list[tuple[GdnParentStateTransferPlan, Tensor]] = [] - offset = 0 - for peer_rank, count in enumerate(recv_counts): - peer_end = offset + count - for transfer in transfers: - send_rank = transfer.dest_rank if reverse else transfer.source_rank - recv_rank = transfer.source_rank if reverse else transfer.dest_rank - if send_rank == recv_rank: - continue - if rank != recv_rank or peer_rank != send_rank: - continue - rows = len(transfer.family_indices) - recvs.append((transfer, recv_buffer[offset : offset + rows])) - offset += rows - if offset != peer_end: - raise RuntimeError( - "parent-state exchange unpack mismatch: " - f"rank={rank} peer={peer_rank} consumed={offset} expected={peer_end}" - ) - return recvs - - -def _parent_state_index_tensor( - transfer: GdnParentStateTransferPlan, - *, - device: torch.device, -) -> Tensor: - if ( - transfer.family_indices_tensor is not None - and transfer.family_indices_tensor.device == device - ): - return transfer.family_indices_tensor - return torch.tensor(transfer.family_indices, device=device, dtype=torch.long) - - def run_gdn_bucket( bucket: GdnSegmentBucketPlan, projected_streams: tuple[Tensor, Tensor, Tensor], @@ -2597,14 +2256,17 @@ def run_gdn_bucket( conv_output_final_state = output_final_state chain_conv_final: Tensor | None = None + chain_gradient_dependency: Tensor | None = None if recurrent_cp: - conv_initial, chain_conv_final = _chain_conv_initial_and_final( - qkv, - bucket.cu_seqlens_cpu, - bucket.lengths_by_rank_cpu, - conv_initial, - group=group, - output_final_state=output_final_state, + conv_initial, chain_conv_final, chain_gradient_dependency = ( + _chain_conv_initial_and_final( + qkv, + bucket.cu_seqlens_cpu, + bucket.lengths_by_rank_cpu, + conv_initial, + group=group, + output_final_state=output_final_state, + ) ) conv_output_final_state = False @@ -2618,15 +2280,31 @@ def run_gdn_bucket( if recurrent_cp: conv_final = chain_conv_final - query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( - qkv, - beta, - recurrent_g, - key_heads=_local_key_heads(gdn), - value_heads=_local_value_heads(gdn), - key_dim=int(gdn.key_head_dim), - value_dim=int(gdn.value_head_dim), - ) + dense_local_bucket = not recurrent_cp and _bucket_has_uniform_lengths(bucket) + if dense_local_bucket: + query, key, value, beta, recurrent_g = _prepare_dense_recurrent_inputs( + qkv.reshape(batch_size, int(bucket.length), int(qkv.shape[-1])), + beta.reshape(batch_size, int(bucket.length), int(beta.shape[-1])), + recurrent_g.reshape( + batch_size, + int(bucket.length), + int(recurrent_g.shape[-1]), + ), + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) + else: + query, key, value, beta, recurrent_g = _prepare_packed_recurrent_inputs_fused( + qkv, + beta, + recurrent_g, + key_heads=_local_key_heads(gdn), + value_heads=_local_value_heads(gdn), + key_dim=int(gdn.key_head_dim), + value_dim=int(gdn.value_head_dim), + ) if gdn.use_qk_l2norm: query = _l2norm(query.contiguous()) key = _l2norm(key.contiguous()) @@ -2657,8 +2335,27 @@ def run_gdn_bucket( initial_state=recurrent_initial, output_final_state=output_final_state, use_qk_l2norm_in_kernel=False, - cu_seqlens=bucket.cu_seqlens, - ) + cu_seqlens=None if dense_local_bucket else bucket.cu_seqlens, + ) + if dense_local_bucket: + recurrent_out = recurrent_out.reshape( + 1, + token_count, + int(recurrent_out.shape[-2]), + int(recurrent_out.shape[-1]), + ) + if chain_gradient_dependency is not None: + recurrent_out = _add_autograd_dependency( + recurrent_out, + chain_gradient_dependency, + ) + if conv_final is not None: + conv_final = _add_autograd_dependency(conv_final, chain_gradient_dependency) + if recurrent_final is not None: + recurrent_final = _add_autograd_dependency( + recurrent_final, + chain_gradient_dependency, + ) return recurrent_out, conv_final, recurrent_final @@ -2670,15 +2367,22 @@ def _chain_conv_initial_and_final( *, group: Any, output_final_state: bool, -) -> tuple[Tensor, Tensor | None]: +) -> tuple[Tensor, Tensor | None, Tensor]: if group is None: raise ValueError("CP chain conv state requires a process group") if not dist.is_available() or not dist.is_initialized(): # ty: ignore[possibly-missing-attribute] raise RuntimeError("torch.distributed must be initialized for CP chain conv") - parent_initial = _AllReduceGradient.apply(parent_initial, group) + parent_initial, gradient_dependency = _AllReduceGradient.apply( + parent_initial, + group, + ) tail_width = int(parent_initial.shape[-1]) if tail_width <= 0: - return parent_initial, parent_initial if output_final_state else None + return ( + parent_initial, + parent_initial if output_final_state else None, + gradient_dependency, + ) if lengths_by_rank_cpu is None: raise ValueError("CP chain conv requires static all-rank bucket lengths") if cu_seqlens_cpu.device.type != "cpu" or lengths_by_rank_cpu.device.type != "cpu": @@ -2705,7 +2409,7 @@ def _chain_conv_initial_and_final( if output_final_state else None ) - return conv_initial, conv_final + return conv_initial, conv_final, gradient_dependency def _local_packed_conv_tail( @@ -2782,14 +2486,20 @@ def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, None]: class _AllReduceGradient(torch.autograd.Function): @staticmethod - def forward(ctx: Any, tensor: Tensor, group: Any) -> Tensor: + def forward(ctx: Any, tensor: Tensor, group: Any) -> tuple[Tensor, Tensor]: ctx.group = group - return tensor + ctx.save_for_backward(tensor) + return tensor, tensor.new_zeros(()) @staticmethod - def backward(ctx: Any, *grad_outputs: Tensor) -> tuple[Tensor, None]: - (grad_output,) = grad_outputs - grad_input = grad_output.contiguous() + def backward(ctx: Any, *grad_outputs: Tensor | None) -> tuple[Tensor, None]: + grad_output, _grad_dependency = grad_outputs + (reference,) = ctx.saved_tensors + grad_input = ( + reference.new_zeros(reference.shape) + if grad_output is None + else grad_output.contiguous() + ) dist.all_reduce( # ty: ignore[possibly-missing-attribute] grad_input, op=dist.ReduceOp.SUM, # ty: ignore[possibly-missing-attribute] diff --git a/src/art/megatron/lora.py b/src/art/megatron/lora.py index 4cea46b2a..62e7c8435 100644 --- a/src/art/megatron/lora.py +++ b/src/art/megatron/lora.py @@ -1,9 +1,14 @@ -from collections.abc import Sequence +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +import contextvars +from dataclasses import dataclass, replace +import functools +import importlib import json import math import os import re -from typing import Any, Literal, NamedTuple, cast +from typing import Any, Callable, Literal, NamedTuple, TypeVar, cast from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.core import parallel_state as ps @@ -22,9 +27,7 @@ ) from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.moe.experts import TEGroupedMLP -from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_layer import TransformerLayer -from pydantic import BaseModel, ConfigDict import torch from .kernels.cute_grouped_lora_quack import ( @@ -42,6 +45,8 @@ ShardDomain = Literal["tp", "expert_tp"] GradSyncDomain = Literal["tp_default", "expert_tp"] GradSyncOp = Literal["none", "sum", "avg"] +LoraSlotKind = Literal["checkpoint", "lora"] +_F = TypeVar("_F", bound=Callable[..., Any]) TP_DEFAULT_GRAD_SYNC_DOMAIN: GradSyncDomain = "tp_default" EXPERT_TP_GRAD_SYNC_DOMAIN: GradSyncDomain = "expert_tp" @@ -50,11 +55,114 @@ GRAD_SYNC_OP_AVG: GradSyncOp = "avg" -class LoRAParallelSpec(BaseModel): - # This spec only describes TP / expert-TP behavior. - # DP/CP vs expert-DP behavior is selected separately via `allreduce`. - model_config = ConfigDict(frozen=True) +@dataclass(frozen=True) +class LoRASlotRef: + kind: LoraSlotKind + name: str | None + +_CURRENT_LORA_SLOT: contextvars.ContextVar[LoRASlotRef | None] = contextvars.ContextVar( + "art_megatron_current_lora_slot", default=None +) + + +@contextmanager +def use_lora_slot(ref: LoRASlotRef | None) -> Iterator[None]: + token = _CURRENT_LORA_SLOT.set(ref) + try: + yield + finally: + _CURRENT_LORA_SLOT.reset(token) + + +def _with_captured_lora_slot(function: _F) -> _F: + context = _CURRENT_LORA_SLOT.get() + + @functools.wraps(function) + def wrapped(*args: Any, **kwargs: Any) -> Any: + token = _CURRENT_LORA_SLOT.set(context) + try: + return function(*args, **kwargs) + finally: + _CURRENT_LORA_SLOT.reset(token) + + return cast(_F, wrapped) + + +def _patch_function_once(module: Any, name: str, wrapper: Callable[[_F], _F]) -> None: + original = getattr(module, name, None) + if original is None or getattr(original, "_art_lora_slot_context_patch", False): + return + patched = wrapper(original) + setattr(patched, "_art_lora_slot_context_patch", True) + setattr(module, name, patched) + + +def install_lora_checkpoint_context_hooks() -> None: + """Preserve the selected dynamic LoRA slot across activation recompute.""" + + def wrap_checkpoint(original: _F, function_index: int) -> _F: + @functools.wraps(original) + def checkpoint(*args: Any, **kwargs: Any) -> Any: + if len(args) > function_index: + args = ( + *args[:function_index], + _with_captured_lora_slot(args[function_index]), + *args[function_index + 1 :], + ) + elif "function" in kwargs: + kwargs = { + **kwargs, + "function": _with_captured_lora_slot(kwargs["function"]), + } + elif "forward_func" in kwargs: + kwargs = { + **kwargs, + "forward_func": _with_captured_lora_slot(kwargs["forward_func"]), + } + else: + raise TypeError("checkpoint wrapper could not find callable argument") + return original(*args, **kwargs) + + return cast(_F, checkpoint) + + def patch(target: str, name: str, function_index: int) -> None: + try: + module_name, _, attr_path = target.partition(":") + target_obj = importlib.import_module(module_name) + for attr in attr_path.split(".") if attr_path else (): + target_obj = getattr(target_obj, attr, None) + if target_obj is None: + return + _patch_function_once( + target_obj, + name, + lambda original: wrap_checkpoint(original, function_index), + ) + except Exception: + pass + + for target, name, function_index in ( + ("torch.utils.checkpoint", "checkpoint", 0), + ("megatron.core.tensor_parallel", "checkpoint", 0), + ("megatron.core.tensor_parallel.random", "checkpoint", 0), + ( + "megatron.core.tensor_parallel.random:CheckpointWithoutOutput", + "checkpoint", + 1, + ), + ("megatron.core.transformer.transformer_block", "te_checkpoint", 0), + ("transformer_engine.pytorch.distributed", "checkpoint", 0), + ): + patch(target, name, function_index) + + +install_lora_checkpoint_context_hooks() + + +@dataclass(frozen=True) +class LoRAParallelSpec: + # This only describes TP / expert-TP; DP/CP vs expert-DP is selected by `allreduce`. shard_domain: ShardDomain = "tp" sharded: bool = False shard_dim: int | None = None @@ -72,10 +180,7 @@ class LoraShardMeta(NamedTuple): @property def numel(self) -> int: - total = 1 - for dim in self.shape: - total *= dim - return total + return math.prod(self.shape) class _LoraPublishTemplate(NamedTuple): @@ -123,14 +228,6 @@ def _get_shard_rank(domain: ShardDomain) -> int: return group.rank() -def _get_shard_group(domain: ShardDomain) -> Any | None: - if not _distributed_initialized(): - return None - if domain == "tp": - return ps.get_tensor_model_parallel_group() - return ps.get_expert_tensor_parallel_group(check_initialized=False) - - def _dtype_name(dtype: torch.dtype) -> str: return str(dtype).removeprefix("torch.") @@ -242,13 +339,8 @@ def _set_lora_parallel_metadata( setattr(param, "lora_tp_shard_dim", parallel_spec.shard_dim) setattr(param, "grad_sync_domain", parallel_spec.grad_sync_domain) setattr(param, "grad_sync_op", parallel_spec.grad_sync_op) - # Megatron DDP routing flag: - # - allreduce=True: sync with regular DP/CP replicas. - # - allreduce=False: sync with expert-DP replicas. - # TP / expert-TP replica handling is controlled by grad_sync_* metadata. setattr(param, "allreduce", allreduce) - # Megatron's native TP finalize path consumes this attr. setattr( param, "average_gradients_across_tp_domain", @@ -259,16 +351,12 @@ def _set_lora_parallel_metadata( ), ) - # Megatron optimizer and checkpoint logic rely on tensor model-parallel metadata - # to distinguish true shards from TP-duplicate params. if parallel_spec.sharded: shard_dim = parallel_spec.shard_dim if shard_dim is None: raise ValueError("LoRAParallelSpec.shard_dim must be set when sharded=True") setattr(param, "tensor_model_parallel", True) setattr(param, "partition_dim", _normalize_axis(shard_dim, param.ndim)) - # stride > 1 means the dim is split into blocks and each tp rank holds a shard of the block - # this might happen for fused e.g. gate_(up|proj), but loras are individual per module setattr(param, "partition_stride", 1) else: setattr(param, "tensor_model_parallel", False) @@ -307,6 +395,59 @@ def _exported_shard_dim(param: torch.nn.Parameter) -> int: return 1 - axis +def _copy_lora_param_metadata( + source: torch.nn.Parameter, + target: torch.nn.Parameter, +) -> None: + for name in ( + "lora_shard_domain", + "lora_tp_sharded", + "lora_tp_replicated", + "lora_tp_shard_dim", + "grad_sync_domain", + "grad_sync_op", + "allreduce", + "average_gradients_across_tp_domain", + "tensor_model_parallel", + "partition_dim", + "partition_stride", + "lora_tp_shard_strategy", + "lora_tp_component_sizes", + ): + if hasattr(source, name): + setattr(target, name, getattr(source, name)) + setattr(target, "_art_dynamic_lora_slot", True) + + +class LoRASlot(torch.nn.Module): + def __init__( + self, + *, + ref: LoRASlotRef, + a_t: torch.Tensor, + b_t: torch.Tensor, + alpha: float, + a_template: torch.nn.Parameter, + b_template: torch.nn.Parameter, + requires_grad: bool, + ) -> None: + super().__init__() + self.ref = ref + self.alpha = float(alpha) + self.A_T = torch.nn.Parameter(a_t.detach().clone(), requires_grad=requires_grad) + self.B_T = torch.nn.Parameter(b_t.detach().clone(), requires_grad=requires_grad) + _copy_lora_param_metadata(a_template, self.A_T) + _copy_lora_param_metadata(b_template, self.B_T) + + @property + def rank(self) -> int: + return int(self.A_T.shape[-1]) + + @property + def scale(self) -> float: + return self.alpha / self.rank + + class LoRA(torch.nn.Module): def __init__( self, @@ -327,7 +468,12 @@ def __init__( "adapter_model_prefix must contain the '{expert}' format placeholder if num_local_experts > 1" ) self.adapter_model_prefix = adapter_model_prefix + self.alpha = float(alpha) + self.in_features = int(in_features) + self.out_features = int(out_features) self.scale = alpha / rank + self._slot_modules = torch.nn.ModuleDict() + self._slot_keys: dict[LoRASlotRef, str] = {} self.A_T = torch.nn.Parameter( torch.zeros( num_local_experts, in_features, rank, dtype=dtype, device=device @@ -362,7 +508,11 @@ def _broadcast_if_replicated(self, param: torch.nn.Parameter) -> None: world_size = _get_shard_world_size(domain) if world_size <= 1: return - group = _get_shard_group(domain) + group = ( + ps.get_tensor_model_parallel_group() + if domain == "tp" + else ps.get_expert_tensor_parallel_group(check_initialized=False) + ) if group is None: raise RuntimeError( f"{self.adapter_model_prefix}: missing process group for replicated parameter domain={domain}" @@ -395,43 +545,104 @@ def _expected_weight_keys(self, suffix: str) -> list[str]: ] return [f"{self.adapter_model_prefix}.{suffix}.weight"] + def load_lora_slot( + self, + ref: LoRASlotRef, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float = LORA_ALPHA, + requires_grad: bool, + ) -> bool: + if ref.name is None: + raise ValueError("base-model slot refs do not own LoRA tensors") + weights = self._adapter_weights(adapter_model, require=False) + if weights is None: + return False + a_t = self._localized_weight(weights[0], into=self.A_T) + b_t = self._localized_weight(weights[1], into=self.B_T) + slot_key = self._slot_keys.get(ref) + if slot_key is None: + slot_key = f"slot_{len(self._slot_keys)}" + self._slot_keys[ref] = slot_key + elif self._has_live_slot_grads(ref): + raise RuntimeError( + f"Cannot overwrite live LoRA slot {ref.kind}:{ref.name} for " + f"{self.adapter_model_prefix}; clear grads/backward graph first." + ) + self._slot_modules[slot_key] = LoRASlot( + ref=ref, + a_t=a_t, + b_t=b_t, + alpha=alpha, + a_template=self.A_T, + b_template=self.B_T, + requires_grad=requires_grad, + ) + return True + + def lora_slot_params(self, ref: LoRASlotRef) -> list[torch.nn.Parameter]: + slot = self._slot(ref) + if slot is None: + return [] + return [slot.A_T, slot.B_T] + + def _slot(self, ref: LoRASlotRef) -> LoRASlot | None: + key = self._slot_keys.get(ref) + if key is None: + return None + return cast(LoRASlot, self._slot_modules[key]) + + def _has_live_slot_grads(self, ref: LoRASlotRef) -> bool: + slot = self._slot(ref) + return slot is not None and any( + param.grad is not None for param in (slot.A_T, slot.B_T) + ) + def load_lora(self, adapter_model: dict[str, torch.Tensor]) -> None: - missing_keys = [ + weights = self._adapter_weights(adapter_model, require=True) + assert weights is not None + self._load_weight(weights[0], into=self.A_T) + self._load_weight(weights[1], into=self.B_T) + + def _adapter_weights( + self, + adapter_model: dict[str, torch.Tensor], + *, + require: bool, + ) -> tuple[torch.Tensor, torch.Tensor] | None: + all_keys = [ key for suffix in ("lora_A", "lora_B") for key in self._expected_weight_keys(suffix) - if key not in adapter_model ] - if missing_keys: + missing = [key for key in all_keys if key not in adapter_model] + if len(missing) == len(all_keys) and not require: + return None + if missing: + state = "Missing" if require else "Incomplete" raise KeyError( - f"Missing LoRA adapter keys for {self.adapter_model_prefix}: {sorted(missing_keys)}" + f"{state} LoRA adapter keys for {self.adapter_model_prefix}: " + f"{sorted(missing)}" ) - self.load_weights( - adapter_model, - suffix="lora_A", - into=self.A_T, - ) - self.load_weights( - adapter_model, - suffix="lora_B", - into=self.B_T, + return ( + self._adapter_weight(adapter_model, suffix="lora_A"), + self._adapter_weight(adapter_model, suffix="lora_B"), ) - def load_weights( + def _adapter_weight( self, adapter_model: dict[str, torch.Tensor], *, suffix: str, - into: torch.nn.Parameter, - ) -> None: + ) -> torch.Tensor: keys = self._expected_weight_keys(suffix) if self.num_local_experts > 1: - weight = torch.stack([adapter_model[key].T for key in keys]) - else: - weight = adapter_model[keys[0]].T - self.load_weight(weight, into=into) + return torch.stack([adapter_model[key].T for key in keys]) + return adapter_model[keys[0]].T - def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + def _localized_weight( + self, weight: torch.Tensor, *, into: torch.nn.Parameter + ) -> torch.Tensor: domain = into.lora_shard_domain # ty: ignore[unresolved-attribute] if into.lora_tp_sharded: # ty: ignore[unresolved-attribute] axis = into.lora_tp_shard_dim # ty: ignore[unresolved-attribute] @@ -470,11 +681,10 @@ def load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None raise ValueError( f"{self.adapter_model_prefix}: unsupported shard strategy={strategy}" ) - elif tuple(weight.shape) != tuple(into.shape): - raise ValueError( - f"{self.adapter_model_prefix}: unsharded load shape mismatch, got {tuple(weight.shape)} " - f"expected {tuple(into.shape)}" - ) + return weight.contiguous() + + def _load_weight(self, weight: torch.Tensor, *, into: torch.nn.Parameter) -> None: + weight = self._localized_weight(weight, into=into) if tuple(weight.shape) != tuple(into.shape): raise ValueError( f"{self.adapter_model_prefix}: sharded load shape mismatch, got {tuple(weight.shape)} " @@ -575,9 +785,26 @@ def sharded_lora_grad_dict(self) -> dict[str, torch.Tensor]: grads[key] = local_grad.T return grads + def active_lora_tensors( + self, + ) -> tuple[torch.Tensor, torch.Tensor, float] | None: + ref = _CURRENT_LORA_SLOT.get() + if ref is None: + return self.A_T, self.B_T, self.scale + if ref.name is None: + return None + slot = self._slot(ref) + if slot is None: + return None + return slot.A_T, slot.B_T, slot.scale + def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor | None = None ) -> torch.Tensor: + active = self.active_lora_tensors() + if active is None: + return x.new_zeros((*x.shape[:-1], self.out_features)) + a_t, b_t, scale = active if tokens_per_expert is not None: assert self.num_local_experts > 1, ( "tokens_per_expert is only supported if num_local_experts > 1" @@ -586,12 +813,10 @@ def forward( if isinstance(bsz, list): bsz = torch.tensor(bsz, dtype=torch.int64, device="cpu") if x.shape[0] == 0: - return x.new_zeros((x.shape[0], self.B_T.shape[-1])) - return quack_grouped_lora(x, self.A_T, self.B_T, bsz, scale=self.scale) - out = (x @ self.A_T) @ self.B_T - if self.scale == 1.0: - return out - return out * self.scale + return x.new_zeros((*x.shape[:-1], self.out_features)) + return quack_grouped_lora(x, a_t, b_t, bsz, scale=scale) + out = (x @ a_t) @ b_t + return out if scale == 1.0 else out * scale class LoRAPublishPlanner: @@ -667,52 +892,47 @@ def _metadata_for_template( template: _LoraPublishTemplate, adapter_model: dict[str, torch.Tensor], ) -> list[LoraShardMeta]: - if template.num_local_experts > 1: - return self._expert_metadata_for_template(template, adapter_model) - return self._dense_metadata_for_template(template, adapter_model) - - def _dense_metadata_for_template( - self, - template: _LoraPublishTemplate, - adapter_model: dict[str, torch.Tensor], - ) -> list[LoraShardMeta]: - tp_ranks = self._dense_tp_ranks() shard_ranks = range(template.shard_world_size) if template.sharded else (0,) + if template.num_local_experts <= 1: + tp_ranks = ( + _process_group_ranks(ps.get_tensor_model_parallel_group()) + if _distributed_initialized() + else (0,) + ) + owners = [ + ( + f"{template.adapter_model_prefix}.{template.suffix}", + tp_ranks[shard_rank], + shard_rank, + ) + for shard_rank in shard_ranks + ] + else: + ep_world_size = 1 + if _distributed_initialized(): + ep_world_size = ps.get_expert_model_parallel_world_size() + owners = [ + ( + f"{template.adapter_model_prefix.format(expert=expert)}.{template.suffix}", + self._expert_owner_rank(ep_rank, shard_rank), + shard_rank, + ) + for ep_rank in range(ep_world_size) + for local_expert in range(template.num_local_experts) + for expert in [ep_rank * template.num_local_experts + local_expert] + for shard_rank in shard_ranks + ] return [ self._make_metadata( template, - key=f"{template.adapter_model_prefix}.{template.suffix}", - owner_rank=tp_ranks[shard_rank], + key=key, + owner_rank=owner_rank, shard_rank=shard_rank, adapter_model=adapter_model, ) - for shard_rank in shard_ranks + for key, owner_rank, shard_rank in owners ] - def _expert_metadata_for_template( - self, - template: _LoraPublishTemplate, - adapter_model: dict[str, torch.Tensor], - ) -> list[LoraShardMeta]: - ep_world_size = self._expert_model_world_size() - shard_ranks = range(template.shard_world_size) if template.sharded else (0,) - metadata: list[LoraShardMeta] = [] - for ep_rank in range(ep_world_size): - for local_expert in range(template.num_local_experts): - expert = ep_rank * template.num_local_experts + local_expert - key = f"{template.adapter_model_prefix.format(expert=expert)}.{template.suffix}" - for shard_rank in shard_ranks: - metadata.append( - self._make_metadata( - template, - key=key, - owner_rank=self._expert_owner_rank(ep_rank, shard_rank), - shard_rank=shard_rank, - adapter_model=adapter_model, - ) - ) - return metadata - @staticmethod def _make_metadata( template: _LoraPublishTemplate, @@ -722,6 +942,18 @@ def _make_metadata( shard_rank: int, adapter_model: dict[str, torch.Tensor], ) -> LoraShardMeta: + manifest: dict[str, Any] = { + "sharded": template.sharded, + "shard_world_size": template.shard_world_size if template.sharded else 1, + "shard_rank": shard_rank if template.sharded else 0, + } + if template.sharded: + manifest["export_shard_dim"] = template.export_shard_dim + manifest["export_shard_strategy"] = ( + template.export_shard_strategy or "uniform" + ) + if template.component_sizes: + manifest["component_sizes"] = list(template.component_sizes) return LoraShardMeta( key=key, owner_rank=owner_rank, @@ -731,22 +963,10 @@ def _make_metadata( if key in adapter_model else template.dtype_name ), - manifest=_publish_manifest(template, shard_rank=shard_rank), + manifest=manifest, block=_block_for_key(key), ) - @staticmethod - def _dense_tp_ranks() -> tuple[int, ...]: - if not _distributed_initialized(): - return (0,) - return _process_group_ranks(ps.get_tensor_model_parallel_group()) - - @staticmethod - def _expert_model_world_size() -> int: - if not _distributed_initialized(): - return 1 - return ps.get_expert_model_parallel_world_size() - @staticmethod def _expert_owner_rank(ep_rank: int, shard_rank: int) -> int: if not _distributed_initialized(): @@ -793,24 +1013,6 @@ def _exported_param_shape(module: LoRA, param: torch.nn.Parameter) -> tuple[int, return tuple(int(dim) for dim in param.T.shape) -def _publish_manifest( - template: _LoraPublishTemplate, - *, - shard_rank: int, -) -> dict[str, Any]: - manifest: dict[str, Any] = { - "sharded": template.sharded, - "shard_world_size": template.shard_world_size if template.sharded else 1, - "shard_rank": shard_rank if template.sharded else 0, - } - if template.sharded: - manifest["export_shard_dim"] = template.export_shard_dim - manifest["export_shard_strategy"] = template.export_shard_strategy or "uniform" - if template.component_sizes: - manifest["component_sizes"] = list(template.component_sizes) - return manifest - - @torch.compiler.disable def _expert_grouped_lora_forward( lora: LoRA, @@ -834,15 +1036,110 @@ def _expert_grouped_lora_dual_forward( counts = torch.tensor(counts, dtype=torch.int64, device="cpu") if x.shape[0] == 0: return x.new_zeros((x.shape[0], module.linear_fc1.out_features)) + gate = module.gate_lora.active_lora_tensors() + up = module.up_lora.active_lora_tensors() + if gate is None or up is None: + return torch.cat( + [ + module.gate_lora(x, tokens_per_expert=counts), + module.up_lora(x, tokens_per_expert=counts), + ], + dim=-1, + ) + gate_a_t, gate_b_t, gate_scale = gate + up_a_t, up_b_t, up_scale = up return quack_grouped_lora_dual( x, - module.gate_lora.A_T, - module.gate_lora.B_T, - module.up_lora.A_T, - module.up_lora.B_T, + gate_a_t, + gate_b_t, + up_a_t, + up_b_t, counts, - scale_gate=module.gate_lora.scale, - scale_up=module.up_lora.scale, + scale_gate=gate_scale, + scale_up=up_scale, + ) + + +def _parallel_lora( + *, + adapter_model_prefix: str, + linear: Any, + out_features: int, + rank: int, + alpha: float, + layout: Literal["column", "row"], + shard_domain: ShardDomain = "tp", + grad_sync_domain: GradSyncDomain = TP_DEFAULT_GRAD_SYNC_DOMAIN, + allreduce: bool = True, + num_local_experts: int = 1, +) -> LoRA: + weight = getattr(linear, "weight0", None) + if weight is None: + weight = getattr(linear, "weight", None) + assert isinstance(weight, torch.Tensor) + row_layout = layout == "row" + a_parallel_spec = LoRAParallelSpec( + shard_domain=shard_domain, + sharded=row_layout, + shard_dim=-2 if row_layout else None, + grad_sync_domain=grad_sync_domain, + grad_sync_op=GRAD_SYNC_OP_NONE if row_layout else GRAD_SYNC_OP_SUM, + ) + b_parallel_spec = replace( + a_parallel_spec, + sharded=not row_layout, + shard_dim=None if row_layout else -1, + grad_sync_domain=grad_sync_domain, + grad_sync_op=GRAD_SYNC_OP_SUM if row_layout else GRAD_SYNC_OP_NONE, + ) + return LoRA( + adapter_model_prefix=adapter_model_prefix, + in_features=linear.in_features, + out_features=out_features, + rank=rank, + alpha=alpha, + dtype=weight.dtype, + device=weight.device, + num_local_experts=num_local_experts, + a_parallel_spec=a_parallel_spec, + b_parallel_spec=b_parallel_spec, + allreduce=allreduce, + ) + + +def _parallel_lora_pair( + *, + adapter_model_prefix: str, + linear: Any, + out_features: int, + rank: int, + alpha: float, + layout: Literal["column", "row"], + suffixes: tuple[str, str], + num_local_experts: int = 1, +) -> tuple[LoRA, LoRA]: + expert_parallel = num_local_experts > 1 + return cast( + tuple[LoRA, LoRA], + tuple( + _parallel_lora( + adapter_model_prefix=f"{adapter_model_prefix}.{suffix}", + linear=linear, + out_features=out_features, + rank=rank, + alpha=alpha, + layout=layout, + shard_domain="expert_tp" if expert_parallel else "tp", + grad_sync_domain=( + EXPERT_TP_GRAD_SYNC_DOMAIN + if expert_parallel + else TP_DEFAULT_GRAD_SYNC_DOMAIN + ), + allreduce=not expert_parallel, + num_local_experts=num_local_experts, + ) + for suffix in suffixes + ), ) @@ -860,33 +1157,13 @@ def __init__( self.provider = provider self.linear_proj = linear_proj self.reduce_output = reduce_output - assert isinstance(linear_proj.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=True, - shard_dim=-2, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": False, - "shard_dim": None, - "grad_sync_op": GRAD_SYNC_OP_SUM, # sum replicated TP contributions - } - ) - self.lora = LoRA( + self.lora = _parallel_lora( adapter_model_prefix=adapter_model_prefix, - in_features=linear_proj.in_features, + linear=linear_proj, out_features=linear_proj.out_features, rank=rank, alpha=alpha, - dtype=linear_proj.weight.dtype, - device=linear_proj.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. - allreduce=True, + layout="row", ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -957,64 +1234,29 @@ def __init__( self.provider.num_attention_heads // self.provider.num_query_groups ) self.hidden_size_per_attention_head = self.provider.kv_channels - self.q_proj_lora = self._build_qkv_lora( + self.q_proj_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.q_proj", - linear_qkv=linear_qkv, + linear=linear_qkv, + out_features=q_and_gate_out_features_per_rank, rank=rank, alpha=alpha, - out_features=q_and_gate_out_features_per_rank, + layout="column", ) - self.k_proj_lora = self._build_qkv_lora( + self.k_proj_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.k_proj", - linear_qkv=linear_qkv, + linear=linear_qkv, + out_features=kv_out_features_per_rank, rank=rank, alpha=alpha, - out_features=kv_out_features_per_rank, + layout="column", ) - self.v_proj_lora = self._build_qkv_lora( + self.v_proj_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.v_proj", - linear_qkv=linear_qkv, - rank=rank, - alpha=alpha, + linear=linear_qkv, out_features=kv_out_features_per_rank, - ) - - @staticmethod - def _build_qkv_lora( - *, - adapter_model_prefix: str, - linear_qkv: TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - out_features: int, - ) -> LoRA: - assert isinstance(linear_qkv.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=False, - shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, # sum replicated TP contributions - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions - } - ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, - in_features=linear_qkv.in_features, - out_features=out_features, rank=rank, alpha=alpha, - dtype=linear_qkv.weight.dtype, - device=linear_qkv.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Non-expert LoRA params use Megatron's dense DP/CP gradient buckets. - allreduce=True, + layout="column", ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -1080,13 +1322,13 @@ def __init__( z_out_features_per_partition = ( gated_delta_net.v_dim // ps.get_tensor_model_parallel_world_size() ) - assert isinstance(in_proj.weight, torch.Tensor) - self.qkv_lora = self._build_in_proj_lora( + self.qkv_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.in_proj_qkv", - in_proj=in_proj, + linear=in_proj, + out_features=qkv_out_features_per_partition, rank=rank, alpha=alpha, - out_features=qkv_out_features_per_partition, + layout="column", ) _set_lora_shard_strategy_metadata( self.qkv_lora.B_T, @@ -1097,49 +1339,13 @@ def __init__( gated_delta_net.v_dim, ), ) - self.z_lora = self._build_in_proj_lora( + self.z_lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.in_proj_z", - in_proj=in_proj, - rank=rank, - alpha=alpha, + linear=in_proj, out_features=z_out_features_per_partition, - ) - - @staticmethod - def _build_in_proj_lora( - *, - adapter_model_prefix: str, - in_proj: TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - out_features: int, - ) -> LoRA: - assert isinstance(in_proj.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=False, - shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_op": GRAD_SYNC_OP_NONE, - } - ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, - in_features=in_proj.in_features, - out_features=out_features, rank=rank, alpha=alpha, - dtype=in_proj.weight.dtype, - device=in_proj.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - allreduce=True, + layout="column", ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -1169,133 +1375,62 @@ def __init__( rank: int, alpha: float, num_local_experts: int, + fused_gate_up: bool = False, ) -> None: super().__init__() - assert linear_fc1 is not None self.linear_fc1 = linear_fc1 - self.gate_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - num_local_experts=num_local_experts, - ) - self.up_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.up_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - num_local_experts=num_local_experts, - ) - self.uses_direct_quack_grouped_lora_dual = True - - @staticmethod - def _build_fc1_lora( - *, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelGroupedLinear, - rank: int, - alpha: float, - num_local_experts: int, - ) -> LoRA: - assert linear_fc1 is not None - assert isinstance(linear_fc1.weight0, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="expert_tp", - sharded=False, - shard_dim=None, - grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_NONE, # only need DP-type reductions - } - ) - return LoRA( - adapter_model_prefix=adapter_model_prefix, - in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features // 2, - rank=rank, - alpha=alpha, - dtype=linear_fc1.weight0.dtype, - device=linear_fc1.weight0.device, - num_local_experts=num_local_experts, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Expert LoRA params use Megatron's expert-DP gradient buckets. - allreduce=False, - ) - - def forward( - self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc1(x, tokens_per_expert) - adapter_out = _expert_grouped_lora_dual_forward(self, x, tokens_per_expert) - return base_out + adapter_out, bias_out - - -class MLPExpertsLinearFC1FusedLoRA(torch.nn.Module): - def __init__( - self, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelGroupedLinear, - rank: int, - alpha: float, - num_local_experts: int, - ) -> None: - super().__init__() - assert linear_fc1 is not None - assert isinstance(linear_fc1.weight0, torch.Tensor) - self.linear_fc1 = linear_fc1 - a_parallel_spec = LoRAParallelSpec( - shard_domain="expert_tp", - sharded=False, - shard_dim=None, - grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_NONE, - } - ) - self.lora = LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_up_proj", - in_features=linear_fc1.in_features, - out_features=linear_fc1.out_features, - rank=rank, - alpha=alpha, - dtype=linear_fc1.weight0.dtype, - device=linear_fc1.weight0.device, - num_local_experts=num_local_experts, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - allreduce=False, - ) - gate_out_features = linear_fc1.out_features // 2 - expert_tp_world_size = _get_shard_world_size("expert_tp") - _set_lora_shard_strategy_metadata( - self.lora.B_T, - strategy="componentwise", - component_sizes=( - gate_out_features * expert_tp_world_size, - gate_out_features * expert_tp_world_size, - ), - ) + self.fused_gate_up = bool(fused_gate_up) + if self.fused_gate_up: + self.lora = _parallel_lora( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.gate_up_proj", + linear=linear_fc1, + out_features=linear_fc1.out_features, + rank=rank, + alpha=alpha, + layout="column", + shard_domain="expert_tp", + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, + allreduce=False, + num_local_experts=num_local_experts, + ) + gate_out_features = linear_fc1.out_features // 2 + expert_tp_world_size = _get_shard_world_size("expert_tp") + _set_lora_shard_strategy_metadata( + self.lora.B_T, + strategy="componentwise", + component_sizes=( + gate_out_features * expert_tp_world_size, + gate_out_features * expert_tp_world_size, + ), + ) + else: + self.gate_lora, self.up_lora = _parallel_lora_pair( + adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}", + linear=linear_fc1, + out_features=linear_fc1.out_features // 2, + rank=rank, + alpha=alpha, + layout="column", + suffixes=("gate_proj", "up_proj"), + num_local_experts=num_local_experts, + ) def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc1(x, tokens_per_expert) - adapter_out = _expert_grouped_lora_forward( - self.lora, x, tokens_per_expert, self.linear_fc1.out_features + base_out, bias_out = cast( + Callable[ + [torch.Tensor, list[int] | torch.Tensor], + tuple[torch.Tensor, torch.Tensor | None], + ], + self.linear_fc1, + )(x, tokens_per_expert) + adapter_out = ( + _expert_grouped_lora_forward( + self.lora, x, tokens_per_expert, self.linear_fc1.out_features + ) + if self.fused_gate_up + else _expert_grouped_lora_dual_forward(self, x, tokens_per_expert) ) return base_out + adapter_out, bias_out @@ -1310,43 +1445,30 @@ def __init__( num_local_experts: int, ) -> None: super().__init__() - assert linear_fc2 is not None - assert isinstance(linear_fc2.weight0, torch.Tensor) self.linear_fc2 = linear_fc2 - a_parallel_spec = LoRAParallelSpec( - shard_domain="expert_tp", - sharded=True, - shard_dim=-2, - grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_NONE, # only need DP-type reductions - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": False, - "shard_dim": None, - "grad_sync_domain": EXPERT_TP_GRAD_SYNC_DOMAIN, - "grad_sync_op": GRAD_SYNC_OP_SUM, # we handle this with extended finalize_grads - } - ) - self.lora = LoRA( + self.lora = _parallel_lora( adapter_model_prefix=f"{adapter_model_prefix}.{{expert}}.down_proj", - in_features=linear_fc2.in_features, + linear=linear_fc2, out_features=linear_fc2.out_features, rank=rank, alpha=alpha, - dtype=linear_fc2.weight0.dtype, - device=linear_fc2.weight0.device, - num_local_experts=num_local_experts, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - # Expert LoRA params use Megatron's expert-DP gradient buckets. + layout="row", + shard_domain="expert_tp", + grad_sync_domain=EXPERT_TP_GRAD_SYNC_DOMAIN, allreduce=False, + num_local_experts=num_local_experts, ) def forward( self, x: torch.Tensor, tokens_per_expert: list[int] | torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor | None]: - base_out, bias_out = self.linear_fc2(x, tokens_per_expert) + base_out, bias_out = cast( + Callable[ + [torch.Tensor, list[int] | torch.Tensor], + tuple[torch.Tensor, torch.Tensor | None], + ], + self.linear_fc2, + )(x, tokens_per_expert) adapter_out = _expert_grouped_lora_forward( self.lora, x, tokens_per_expert, self.linear_fc2.out_features ) @@ -1368,53 +1490,14 @@ def __init__( linear_fc1.return_layernorm_output = True linear_fc1.return_layernorm_output_gathered = True self.linear_fc1 = linear_fc1 - self.gate_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.gate_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - ) - self.up_lora = self._build_fc1_lora( - adapter_model_prefix=f"{adapter_model_prefix}.up_proj", - linear_fc1=linear_fc1, - rank=rank, - alpha=alpha, - ) - - @staticmethod - def _build_fc1_lora( - *, - adapter_model_prefix: str, - linear_fc1: TEColumnParallelLinear | TELayerNormColumnParallelLinear, - rank: int, - alpha: float, - ) -> LoRA: - assert isinstance(linear_fc1.weight, torch.Tensor) - a_parallel_spec = LoRAParallelSpec( - shard_domain="tp", - sharded=False, - shard_dim=None, - grad_sync_domain=TP_DEFAULT_GRAD_SYNC_DOMAIN, - grad_sync_op=GRAD_SYNC_OP_SUM, - ) - b_parallel_spec = a_parallel_spec.model_copy( - update={ - "sharded": True, - "shard_dim": -1, - "grad_sync_op": GRAD_SYNC_OP_NONE, - } - ) - return LoRA( + self.gate_lora, self.up_lora = _parallel_lora_pair( adapter_model_prefix=adapter_model_prefix, - in_features=linear_fc1.in_features, + linear=linear_fc1, out_features=linear_fc1.out_features // 2, rank=rank, alpha=alpha, - dtype=linear_fc1.weight.dtype, - device=linear_fc1.weight.device, - a_parallel_spec=a_parallel_spec, - b_parallel_spec=b_parallel_spec, - allreduce=True, + layout="column", + suffixes=("gate_proj", "up_proj"), ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -1438,29 +1521,6 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: return base_out + adapter_out, bias_out -class SharedExpertsLinearFC2LoRA(torch.nn.Module): - def __init__( - self, - adapter_model_prefix: str, - linear_fc2: TERowParallelLinear, - rank: int, - alpha: float, - provider: GPTModelProvider, - ) -> None: - super().__init__() - self.row_parallel_lora = SelfAttentionLinearProjLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.down_proj", - linear_proj=linear_fc2, - rank=rank, - alpha=alpha, - provider=provider, - reduce_output=not _linear_disables_tensor_parallel_comm(linear_fc2), - ) - - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.row_parallel_lora(x) - - def _unwrap_attr( value: Any, attr_name: str, @@ -1568,8 +1628,14 @@ def wrap_grouped_moe_experts( target_modules: set[str], rank: int, alpha: int, + fused_gate_up: bool = False, ) -> None: - if _targets_include(target_modules, "gate_proj", "up_proj"): + wrap_fc1 = ( + _targets_include(target_modules, "experts") + if fused_gate_up + else _targets_include(target_modules, "gate_proj", "up_proj") + ) + if wrap_fc1: mlp_experts_linear_fc1 = _unwrap_attr( experts.linear_fc1, "linear_fc1", @@ -1581,58 +1647,27 @@ def wrap_grouped_moe_experts( rank=rank, alpha=alpha, num_local_experts=experts.num_local_experts, + fused_gate_up=fused_gate_up, ) - if _targets_include(target_modules, "down_proj"): - mlp_experts_linear_fc2 = _unwrap_attr( - experts.linear_fc2, - "linear_fc2", - TERowParallelGroupedLinear, # type: ignore[arg-type] - ) - experts.linear_fc2 = MLPExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, - rank=rank, - alpha=alpha, - num_local_experts=experts.num_local_experts, - ) - - -def wrap_grouped_moe_experts_3d( - experts: TEGroupedMLP, - *, - adapter_model_prefix: str, - target_modules: set[str], - rank: int, - alpha: int, -) -> None: - if _targets_include(target_modules, "experts"): - mlp_experts_linear_fc1 = _unwrap_attr( - experts.linear_fc1, - "linear_fc1", - TEColumnParallelGroupedLinear, # type: ignore[arg-type] - ) - experts.linear_fc1 = MLPExpertsLinearFC1FusedLoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc1=mlp_experts_linear_fc1, - rank=rank, - alpha=alpha, - num_local_experts=experts.num_local_experts, - ) - mlp_experts_linear_fc2 = _unwrap_attr( + wrap_fc2 = ( + wrap_fc1 if fused_gate_up else _targets_include(target_modules, "down_proj") + ) + if wrap_fc2: + linear_fc2 = _unwrap_attr( experts.linear_fc2, "linear_fc2", TERowParallelGroupedLinear, # type: ignore[arg-type] ) experts.linear_fc2 = MLPExpertsLinearFC2LoRA( adapter_model_prefix=f"{adapter_model_prefix}.mlp.experts", - linear_fc2=mlp_experts_linear_fc2, + linear_fc2=linear_fc2, rank=rank, alpha=alpha, num_local_experts=experts.num_local_experts, ) -def wrap_dense_mlp( +def wrap_split_mlp_lora( mlp: Any, *, adapter_model_prefix: str, @@ -1642,65 +1677,30 @@ def wrap_dense_mlp( alpha: int, ) -> None: if _targets_include(target_modules, "gate_proj", "up_proj"): - mlp_linear_fc1 = _unwrap_attr( + linear_fc1 = _unwrap_attr( mlp.linear_fc1, "linear_fc1", (TEColumnParallelLinear, TELayerNormColumnParallelLinear), ) mlp.linear_fc1 = SharedExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp", - linear_fc1=mlp_linear_fc1, + adapter_model_prefix=adapter_model_prefix, + linear_fc1=linear_fc1, rank=rank, alpha=alpha, ) if _targets_include(target_modules, "down_proj"): - mlp_linear_fc2 = _unwrap_attr( + linear_fc2 = _unwrap_attr( mlp.linear_fc2, "linear_fc2", TERowParallelLinear, ) - mlp.linear_fc2 = SharedExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp", - linear_fc2=mlp_linear_fc2, - rank=rank, - alpha=alpha, - provider=provider, - ) - - -def wrap_shared_experts_mlp( - shared_experts: SharedExpertMLP, - *, - adapter_model_prefix: str, - provider: GPTModelProvider, - target_modules: set[str], - rank: int, - alpha: int, -) -> None: - if _targets_include(target_modules, "gate_proj", "up_proj"): - shared_experts_linear_fc1 = _unwrap_attr( - shared_experts.linear_fc1, - "linear_fc1", - (TEColumnParallelLinear, TELayerNormColumnParallelLinear), - ) - shared_experts.linear_fc1 = SharedExpertsLinearFC1LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", - linear_fc1=shared_experts_linear_fc1, - rank=rank, - alpha=alpha, - ) - if _targets_include(target_modules, "down_proj"): - shared_experts_linear_fc2 = _unwrap_attr( - shared_experts.linear_fc2, - "linear_fc2", - TERowParallelLinear, - ) - shared_experts.linear_fc2 = SharedExpertsLinearFC2LoRA( - adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", - linear_fc2=shared_experts_linear_fc2, + mlp.linear_fc2 = SelfAttentionLinearProjLoRA( + adapter_model_prefix=f"{adapter_model_prefix}.down_proj", + linear_proj=linear_fc2, rank=rank, alpha=alpha, provider=provider, + reduce_output=not _linear_disables_tensor_parallel_comm(linear_fc2), ) @@ -1721,3 +1721,43 @@ def apply_lora_adapters( alpha=LORA_ALPHA, ) return list(model) + + +def load_lora_slot_into_model( + model: Sequence[torch.nn.Module], + ref: LoRASlotRef, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float = LORA_ALPHA, + requires_grad: bool, +) -> int: + loaded = 0 + for chunk in model: + for module in chunk.modules(): + if isinstance(module, LoRA) and module.load_lora_slot( + ref, + adapter_model, + alpha=alpha, + requires_grad=requires_grad, + ): + loaded += 1 + if loaded == 0 and ref.name is not None: + raise RuntimeError(f"LoRA slot {ref.kind}:{ref.name} loaded no adapter sites") + return loaded + + +def iter_lora_slot_parameters( + model: Sequence[torch.nn.Module], + ref: LoRASlotRef, +) -> Iterator[torch.nn.Parameter]: + seen: set[int] = set() + for chunk in model: + for module in chunk.modules(): + if not isinstance(module, LoRA): + continue + for param in module.lora_slot_params(ref): + param_id = id(param) + if param_id in seen: + continue + seen.add(param_id) + yield param diff --git a/src/art/megatron/model_support/handlers/default_dense.py b/src/art/megatron/model_support/handlers/default_dense.py index bd79332ae..d3f7d2416 100644 --- a/src/art/megatron/model_support/handlers/default_dense.py +++ b/src/art/megatron/model_support/handlers/default_dense.py @@ -137,7 +137,7 @@ def apply_lora_adapters( from art.megatron.lora import ( _adapter_model_prefix, - wrap_dense_mlp, + wrap_split_mlp_lora, wrap_standard_self_attention, ) @@ -146,18 +146,19 @@ def apply_lora_adapters( for module in chunk.modules(): if not isinstance(module, TransformerLayer): continue + adapter_model_prefix = _adapter_model_prefix(module) wrap_standard_self_attention( module.self_attention, - adapter_model_prefix=_adapter_model_prefix(module), + adapter_model_prefix=adapter_model_prefix, provider=provider, target_modules=target_set, rank=rank, alpha=alpha, ) _require_dense_mlp(module) - wrap_dense_mlp( + wrap_split_mlp_lora( module.mlp, - adapter_model_prefix=_adapter_model_prefix(module), + adapter_model_prefix=f"{adapter_model_prefix}.mlp", provider=provider, target_modules=target_set, rank=rank, @@ -168,32 +169,9 @@ def build_adapter_weights_by_base( self, model_chunks: Sequence[Any], ) -> dict[str, list[Any]]: - from megatron.core.transformer.transformer_layer import TransformerLayer + from art.megatron.weights import adapter_export - from art.megatron.weights.adapter_export import ( - add_dense_mlp_adapter_weights, - add_standard_self_attention_adapter_weights, - layer_base_prefix, - ) - - adapter_weights_by_base: dict[str, list[Any]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - layer_prefix = layer_base_prefix(module, module_name=module_name) - _require_dense_mlp(module) - add_standard_self_attention_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - add_dense_mlp_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - mlp=module.mlp, - ) - return adapter_weights_by_base + return adapter_export.build_transformer_layer_adapter_weights(model_chunks) def compile_workaround_config( self, @@ -236,7 +214,7 @@ def apply_lora_adapters( from art.megatron.lora import ( _adapter_model_prefix, wrap_grouped_moe_experts, - wrap_shared_experts_mlp, + wrap_split_mlp_lora, wrap_standard_self_attention, ) @@ -263,9 +241,9 @@ def apply_lora_adapters( ) shared_experts = getattr(module.mlp, "shared_experts", None) if shared_experts is not None: - wrap_shared_experts_mlp( + wrap_split_mlp_lora( shared_experts, - adapter_model_prefix=adapter_model_prefix, + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", provider=provider, target_modules=target_set, rank=rank, @@ -276,40 +254,13 @@ def build_adapter_weights_by_base( self, model_chunks: Sequence[Any], ) -> dict[str, list[Any]]: - from megatron.core.transformer.transformer_layer import TransformerLayer + from art.megatron.weights import adapter_export - from art.megatron.weights.adapter_export import ( - add_grouped_moe_adapter_weights, - add_shared_experts_adapter_weights, - add_standard_self_attention_adapter_weights, - layer_base_prefix, + return adapter_export.build_transformer_layer_adapter_weights( + model_chunks, + grouped_moe=True, ) - adapter_weights_by_base: dict[str, list[Any]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - layer_prefix = layer_base_prefix(module, module_name=module_name) - add_standard_self_attention_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - add_grouped_moe_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - experts=_require_moe_experts(module), - ) - shared_experts = getattr(module.mlp, "shared_experts", None) - if shared_experts is not None: - add_shared_experts_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - shared_experts=shared_experts, - ) - return adapter_weights_by_base - def _require_dense_mlp(module: Any) -> None: if getattr(module.mlp, "experts", None) is not None: diff --git a/src/art/megatron/model_support/handlers/qwen3_5.py b/src/art/megatron/model_support/handlers/qwen3_5.py index ad200499a..3d4ea98d8 100644 --- a/src/art/megatron/model_support/handlers/qwen3_5.py +++ b/src/art/megatron/model_support/handlers/qwen3_5.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Callable from copy import copy from functools import lru_cache import re @@ -50,6 +51,9 @@ r"^(?P.*\.mlp\.experts)\.(?P\d+)\." r"(?Pgate_proj|up_proj|down_proj)\.(?Plora_[AB])\.weight$" ) +_ART_MOE_MODULES = ("gate_up_proj", "down_proj") +_VLLM_EXPERT_MODULES = ("gate_proj", "up_proj", "down_proj") +_LORA_NAMES = ("lora_A", "lora_B") class Qwen35BaseHandler(DefaultDenseHandler): @@ -80,7 +84,7 @@ def to_vllm_lora_tensors( *, adapter_config: dict[str, Any], ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: - if _group_art_moe_tensors(tensors): + if _group_expert_lora_tensors(tensors, _ART_MOE_EXPERT_KEY_RE): raise TypeError("Dense Qwen3.5 handler received MoE LoRA tensors") transformed: dict[str, torch.Tensor] = {} for key, tensor in tensors.items(): @@ -313,44 +317,14 @@ def build_adapter_weights_by_base( self, model_chunks: Sequence[Any], ) -> dict[str, list[Any]]: - from megatron.core.ssm.gated_delta_net import GatedDeltaNet - from megatron.core.transformer.attention import SelfAttention - from megatron.core.transformer.transformer_layer import TransformerLayer - - from art.megatron.lora import _is_language_transformer_layer_name - from art.megatron.weights.adapter_export import ( - add_gated_delta_net_adapter_weights, - add_standard_self_attention_adapter_weights, - layer_base_prefix, - ) + from art.megatron.weights import adapter_export _ensure_bridge_qwen35_adapter_name_map() - adapter_weights_by_base: dict[str, list[Any]] = {} - for chunk in model_chunks: - for module_name, module in chunk.named_modules(): - if not isinstance(module, TransformerLayer): - continue - if not _is_language_transformer_layer_name(module_name): - continue - layer_prefix = layer_base_prefix(module, module_name=module_name) - if isinstance(module.self_attention, SelfAttention): - add_standard_self_attention_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - elif isinstance(module.self_attention, GatedDeltaNet): - add_gated_delta_net_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - self_attention=module.self_attention, - ) - self._add_mlp_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - module=module, - ) - return adapter_weights_by_base + return adapter_export.build_transformer_layer_adapter_weights( + model_chunks, + grouped_moe=self.is_moe, + language_layers_only=True, + ) def _wrap_mlp_lora( self, @@ -362,34 +336,18 @@ def _wrap_mlp_lora( rank: int, alpha: int, ) -> None: - from art.megatron.lora import wrap_dense_mlp + from art.megatron.lora import wrap_split_mlp_lora _require_dense_mlp(module) - wrap_dense_mlp( + wrap_split_mlp_lora( module.mlp, - adapter_model_prefix=adapter_model_prefix, + adapter_model_prefix=f"{adapter_model_prefix}.mlp", provider=provider, target_modules=target_modules, rank=rank, alpha=alpha, ) - def _add_mlp_adapter_weights( - self, - adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, - module: Any, - ) -> None: - from art.megatron.weights.adapter_export import add_dense_mlp_adapter_weights - - _require_dense_mlp(module) - add_dense_mlp_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - mlp=module.mlp, - ) - def get_forward_kwargs(self, model: Any, **kwargs: Any) -> dict[str, Any]: unwrapped = model while hasattr(unwrapped, "module"): @@ -483,54 +441,27 @@ def _wrap_mlp_lora( rank: int, alpha: int, ) -> None: - from art.megatron.lora import ( - wrap_grouped_moe_experts_3d, - wrap_shared_experts_mlp, - ) + from art.megatron.lora import wrap_grouped_moe_experts, wrap_split_mlp_lora - wrap_grouped_moe_experts_3d( + wrap_grouped_moe_experts( _require_moe_experts(module), adapter_model_prefix=adapter_model_prefix, target_modules=target_modules, rank=rank, alpha=alpha, + fused_gate_up=True, ) shared_experts = getattr(module.mlp, "shared_experts", None) if shared_experts is not None: - wrap_shared_experts_mlp( + wrap_split_mlp_lora( shared_experts, - adapter_model_prefix=adapter_model_prefix, + adapter_model_prefix=f"{adapter_model_prefix}.mlp.shared_expert", provider=provider, target_modules=target_modules, rank=rank, alpha=alpha, ) - def _add_mlp_adapter_weights( - self, - adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, - module: Any, - ) -> None: - from art.megatron.weights.adapter_export import ( - add_grouped_moe_adapter_weights, - add_shared_experts_adapter_weights, - ) - - add_grouped_moe_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - experts=_require_moe_experts(module), - ) - shared_experts = getattr(module.mlp, "shared_experts", None) - if shared_experts is not None: - add_shared_experts_adapter_weights( - adapter_weights_by_base, - layer_prefix=layer_prefix, - shared_experts=shared_experts, - ) - def compile_workaround_config( self, provider: Any, @@ -573,10 +504,6 @@ def _from_vllm_key(key: str) -> str: ) -def _is_lora_weight_key(key: str) -> bool: - return key.endswith((".lora_A.weight", ".lora_B.weight")) - - def _is_self_attn_q_proj_lora_b(key: str) -> bool: return key.endswith(".self_attn.q_proj.lora_B.weight") @@ -725,12 +652,16 @@ def _vllm_moe_config( return config -def _group_art_moe_tensors( +type _ExpertLoraGroups = dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] + + +def _group_expert_lora_tensors( tensors: dict[str, torch.Tensor], -) -> dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]]: - grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} + pattern: re.Pattern[str], +) -> _ExpertLoraGroups: + grouped: _ExpertLoraGroups = {} for key, tensor in tensors.items(): - match = _ART_MOE_EXPERT_KEY_RE.match(key) + match = pattern.match(key) if match is None: continue grouped.setdefault(match.group("prefix"), {}).setdefault( @@ -740,27 +671,57 @@ def _group_art_moe_tensors( return grouped +def _expert_lora_key(prefix: str, expert: int, module: str, lora_name: str) -> str: + return f"{prefix}.{expert}.{module}.{lora_name}.weight" + + +def _convert_remaining_lora_tensors( + transformed: dict[str, torch.Tensor], + tensors: dict[str, torch.Tensor], + *, + used_keys: set[str], + convert: Callable[ + [str, torch.Tensor], + tuple[str, torch.Tensor], + ], + reject_fused_moe: bool = False, +) -> None: + for key, tensor in tensors.items(): + if key in used_keys: + continue + if reject_fused_moe and _VLLM_MOE_KEY_RE.match(key) is not None: + raise RuntimeError( + "Mixed fused and per-expert Qwen3.5 vLLM MoE LoRA tensors" + ) + converted_key, converted = convert(key, tensor) + if converted_key in transformed: + raise RuntimeError( + f"Duplicate Qwen3.5 LoRA tensor after conversion: {converted_key}" + ) + transformed[converted_key] = converted + + def _to_vllm_lora_tensors( tensors: dict[str, torch.Tensor], *, adapter_config: dict[str, Any], ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: - grouped = _group_art_moe_tensors(tensors) + grouped = _group_expert_lora_tensors(tensors, _ART_MOE_EXPERT_KEY_RE) has_shared_experts = _has_shared_expert_lora_tensors(tensors) transformed: dict[str, torch.Tensor] = {} + convert = lambda key, tensor: _to_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) if not grouped: has_fused_experts = any(_VLLM_MOE_KEY_RE.match(key) for key in tensors) - for key, tensor in tensors.items(): - vllm_key, tensor = _to_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - if vllm_key in transformed: - raise RuntimeError( - f"Duplicate Qwen3.5 LoRA tensor after conversion: {vllm_key}" - ) - transformed[vllm_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=set(), + convert=convert, + ) return transformed, ( _vllm_moe_config( adapter_config, @@ -772,17 +733,21 @@ def _to_vllm_lora_tensors( used_keys: set[str] = set() for prefix, experts in grouped.items(): vllm_prefix = _to_vllm_key(prefix) - gate_up_a: list[torch.Tensor] = [] - gate_up_b: list[torch.Tensor] = [] - down_a: list[torch.Tensor] = [] - down_b: list[torch.Tensor] = [] + blocks = { + ("gate_up_proj", "lora_A"): [], + ("gate_up_proj", "lora_B"): [], + ("down_proj", "lora_A"): [], + ("down_proj", "lora_B"): [], + } for expert in sorted(experts): modules = experts[expert] try: - gate_up_a_tensor = modules["gate_up_proj"]["lora_A"] gate_up_b_tensor = modules["gate_up_proj"]["lora_B"] - d_a = modules["down_proj"]["lora_A"] - d_b = modules["down_proj"]["lora_B"] + expert_tensors = { + (module_name, lora_name): modules[module_name][lora_name] + for module_name in _ART_MOE_MODULES + for lora_name in _LORA_NAMES + } except KeyError as exc: raise RuntimeError( f"Incomplete Qwen3.5 MoE LoRA block for {prefix}.{expert}" @@ -792,34 +757,29 @@ def _to_vllm_lora_tensors( f"{prefix}.{expert}: gate/up lora_B rows " f"{gate_up_b_tensor.shape[0]} are not even" ) - gate_up_a.append(gate_up_a_tensor.contiguous()) - gate_up_b.append(gate_up_b_tensor.contiguous()) - down_a.append(d_a.contiguous()) - down_b.append(d_b.contiguous()) - for module_name in ("gate_up_proj", "down_proj"): - for lora_name in ("lora_A", "lora_B"): - used_keys.add(f"{prefix}.{expert}.{module_name}.{lora_name}.weight") + for slot, tensor in expert_tensors.items(): + blocks[slot].append(tensor.contiguous()) + used_keys.add(_expert_lora_key(prefix, expert, *slot)) transformed[f"{vllm_prefix}.base_layer.lora_A.weight"] = torch.cat( - gate_up_a, + blocks[("gate_up_proj", "lora_A")], dim=0, ).contiguous() transformed[f"{vllm_prefix}.base_layer.lora_B.weight"] = _pack_vllm_3d_lora_b( - gate_up_b + blocks[("gate_up_proj", "lora_B")] ) transformed[f"{vllm_prefix}.lora_A.weight"] = torch.cat( - down_a, + blocks[("down_proj", "lora_A")], dim=0, ).contiguous() - transformed[f"{vllm_prefix}.lora_B.weight"] = _pack_vllm_3d_lora_b(down_b) - for key, tensor in tensors.items(): - if key in used_keys: - continue - vllm_key, tensor = _to_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, + transformed[f"{vllm_prefix}.lora_B.weight"] = _pack_vllm_3d_lora_b( + blocks[("down_proj", "lora_B")] ) - transformed[vllm_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=used_keys, + convert=convert, + ) return transformed, _vllm_moe_config( adapter_config, has_shared_experts=has_shared_experts, @@ -831,15 +791,12 @@ def _from_vllm_lora_tensors( *, adapter_config: dict[str, Any], ) -> dict[str, torch.Tensor]: - expert_grouped: dict[str, dict[int, dict[str, dict[str, torch.Tensor]]]] = {} - for key, tensor in tensors.items(): - match = _VLLM_MOE_EXPERT_KEY_RE.match(key) - if match is None: - continue - expert_grouped.setdefault(match.group("prefix"), {}).setdefault( - int(match.group("expert")), - {}, - ).setdefault(match.group("module"), {})[match.group("lora")] = tensor + convert = lambda key, tensor: _from_vllm_lora_tensor( + key, + tensor, + adapter_config=adapter_config, + ) + expert_grouped = _group_expert_lora_tensors(tensors, _VLLM_MOE_EXPERT_KEY_RE) if expert_grouped: transformed: dict[str, torch.Tensor] = {} used_keys: set[str] = set() @@ -847,16 +804,17 @@ def _from_vllm_lora_tensors( art_prefix = _from_vllm_key(prefix) for expert, modules in experts.items(): try: - gate_a = modules["gate_proj"]["lora_A"] - gate_b = modules["gate_proj"]["lora_B"] - up_a = modules["up_proj"]["lora_A"] - up_b = modules["up_proj"]["lora_B"] - down_a = modules["down_proj"]["lora_A"] - down_b = modules["down_proj"]["lora_B"] + expert_tensors = { + (module_name, lora_name): modules[module_name][lora_name] + for module_name in _VLLM_EXPERT_MODULES + for lora_name in _LORA_NAMES + } except KeyError as exc: raise RuntimeError( f"Incomplete Qwen3.5 vLLM MoE LoRA block for {prefix}.{expert}" ) from exc + gate_a = expert_tensors[("gate_proj", "lora_A")] + up_a = expert_tensors[("up_proj", "lora_A")] if not torch.equal(gate_a, up_a): raise RuntimeError( "Qwen3.5 Megatron gate_up_proj requires gate/up " @@ -866,32 +824,29 @@ def _from_vllm_lora_tensors( _clone(gate_a) ) transformed[f"{art_prefix}.{expert}.gate_up_proj.lora_B.weight"] = ( - torch.cat([gate_b, up_b], dim=0).contiguous() + torch.cat( + [ + expert_tensors[("gate_proj", "lora_B")], + expert_tensors[("up_proj", "lora_B")], + ], + dim=0, + ).contiguous() ) transformed[f"{art_prefix}.{expert}.down_proj.lora_A.weight"] = _clone( - down_a + expert_tensors[("down_proj", "lora_A")] ) transformed[f"{art_prefix}.{expert}.down_proj.lora_B.weight"] = _clone( - down_b - ) - for module_name in ("gate_proj", "up_proj", "down_proj"): - for lora_name in ("lora_A", "lora_B"): - used_keys.add( - f"{prefix}.{expert}.{module_name}.{lora_name}.weight" - ) - for key, tensor in tensors.items(): - if key in used_keys: - continue - if _VLLM_MOE_KEY_RE.match(key) is not None: - raise RuntimeError( - "Mixed fused and per-expert Qwen3.5 vLLM MoE LoRA tensors" + expert_tensors[("down_proj", "lora_B")] ) - art_key, tensor = _from_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - transformed[art_key] = tensor + for slot in expert_tensors: + used_keys.add(_expert_lora_key(prefix, expert, *slot)) + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=used_keys, + convert=convert, + reject_fused_moe=True, + ) return transformed grouped: dict[str, dict[str, torch.Tensor]] = {} @@ -905,13 +860,12 @@ def _from_vllm_lora_tensors( grouped.setdefault(match.group("prefix"), {})[slot] = tensor if not grouped: transformed: dict[str, torch.Tensor] = {} - for key, tensor in tensors.items(): - art_key, tensor = _from_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - transformed[art_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=set(), + convert=convert, + ) return transformed rank = int(adapter_config["r"]) @@ -970,15 +924,12 @@ def _from_vllm_lora_tensors( f"{prefix}.lora_B.weight", } ) - for key, tensor in tensors.items(): - if key in used_keys: - continue - art_key, tensor = _from_vllm_lora_tensor( - key, - tensor, - adapter_config=adapter_config, - ) - transformed[art_key] = tensor + _convert_remaining_lora_tensors( + transformed, + tensors, + used_keys=used_keys, + convert=convert, + ) return transformed diff --git a/src/art/megatron/model_support/spec.py b/src/art/megatron/model_support/spec.py index 15c6f8d96..92c1368a2 100644 --- a/src/art/megatron/model_support/spec.py +++ b/src/art/megatron/model_support/spec.py @@ -75,6 +75,7 @@ class ModelSupportSpec(BaseModel): class ModelSupportHandler(Protocol): key: str is_moe: bool + build_gdn_execution_spec: bool native_vllm_lora_status: NativeVllmLoraStatus def identity_lora_model_config(self, base_config: Any) -> Any: ... diff --git a/src/art/megatron/service.py b/src/art/megatron/service.py index 884188a8d..f8cc0d311 100644 --- a/src/art/megatron/service.py +++ b/src/art/megatron/service.py @@ -241,10 +241,6 @@ def rollout_weights_mode(self) -> Literal["lora", "merged"]: def _vllm_base_url(self) -> str: return self._vllm_runtime.base_url - @property - def _vllm_host(self) -> str: - return self._vllm_runtime.host - @property def _vllm_port(self) -> int: return self._vllm_runtime.port diff --git a/src/art/megatron/setup.sh b/src/art/megatron/setup.sh index 6d3a5548c..3e5a1cb51 100755 --- a/src/art/megatron/setup.sh +++ b/src/art/megatron/setup.sh @@ -36,3 +36,7 @@ if [ -x "${HOME}/.local/bin/uv" ]; then uv_bin="${HOME}/.local/bin/uv" fi "${uv_bin}" sync --extra backend --extra megatron --frozen --active + +if [ "${INSTALL_VLLM_RUNTIME:-true}" = "true" ]; then + "${uv_bin}" sync --project vllm_runtime --frozen --no-dev +fi diff --git a/src/art/megatron/shared_prefix_packing.py b/src/art/megatron/shared_prefix_packing.py new file mode 100644 index 000000000..b5a34ba60 --- /dev/null +++ b/src/art/megatron/shared_prefix_packing.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class SharedPrefixPack: + tokens: torch.Tensor + group_ids: torch.Tensor + parent_ids: torch.Tensor + position_ids: torch.Tensor + positions_by_sequence: tuple[torch.Tensor, ...] + + +@dataclass(frozen=True) +class _PrefixSegment: + sequence_indices: tuple[int, ...] + start: int + end: int + group_id: int + parent_id: int + + +def pack_shared_prefixes( + sequences: Iterable[torch.Tensor], + *, + max_depth: int, +) -> SharedPrefixPack: + """Pack token sequences by storing shared prefixes once. + + This is the small packing step that lets `TrainerRank.dp_rank_forward()` run one + model pass over a compact prefix tree instead of replaying the same prompt + tokens for every request. Think of each input sequence as a path through a + tree: when several paths start with the same tokens, this function writes + that shared segment once, then writes each branch after it. + + Args: + sequences: 1-D token tensors to pack. + max_depth: How many nested shared-prefix levels to emit. `0` disables + prefix sharing and writes each sequence as its own root segment. `1` + shares the first common segment in each branch; larger values allow + branches to contain shared sub-branches. + + Returns: + `tokens` is the compact model input, shaped `[1, packed_length]`. + `group_ids` and `parent_ids` describe the prefix tree to shared-prefix + attention. Positions in the same emitted segment share a group, and each + group points at the parent segment it continues from. Root groups point + to themselves. + `position_ids` keeps each token's original sequence position for + positional embeddings/rotary attention. + `positions_by_sequence` is the reverse index used after the model call + to unpack logits, logprobs, or hidden states back into one tensor per + original request. + + The implementation is a tiny radix-tree walk. It finds the longest prefix + shared by the active sequences, emits that segment once, then partitions the + remaining sequences by their next token while preserving first-seen order. + Single sequences, empty branches, and branches past `max_depth` are emitted + as ordinary unshared tails. + """ + if max_depth < 0: + raise ValueError("max_depth must be >= 0") + + tensors = tuple(_sequence_tensor(sequence) for sequence in sequences) + if not tensors: + return _empty_pack() + + device = tensors[0].device + rows = tuple(tensor.detach().cpu().tolist() for tensor in tensors) + segments = _prefix_segments(rows, max_depth=max_depth) + if not segments: + return _empty_pack(len(tensors), device=device) + + token_chunks: list[torch.Tensor] = [] + group_chunks: list[torch.Tensor] = [] + parent_chunks: list[torch.Tensor] = [] + position_chunks: list[torch.Tensor] = [] + positions_by_sequence: list[list[torch.Tensor]] = [[] for _ in tensors] + cursor = 0 + + for planned in segments: + segment = tensors[planned.sequence_indices[0]][planned.start : planned.end] + packed_positions = torch.arange(cursor, cursor + len(segment), device=device) + token_chunks.append(segment) + group_chunks.append(torch.full_like(segment, planned.group_id)) + parent_chunks.append(torch.full_like(segment, planned.parent_id)) + position_chunks.append(torch.arange(planned.start, planned.end, device=device)) + for sequence_index in planned.sequence_indices: + positions_by_sequence[sequence_index].append(packed_positions) + cursor += len(segment) + + return SharedPrefixPack( + tokens=torch.cat(token_chunks).unsqueeze(0), + group_ids=torch.cat(group_chunks).unsqueeze(0), + parent_ids=torch.cat(parent_chunks).unsqueeze(0), + position_ids=torch.cat(position_chunks).unsqueeze(0), + positions_by_sequence=tuple( + torch.cat(chunks) + if chunks + else torch.empty(0, dtype=torch.long, device=device) + for chunks in positions_by_sequence + ), + ) + + +def estimate_shared_prefix_packed_tokens( + sequences: Iterable[torch.Tensor], + *, + max_depth: int, +) -> int | None: + """Return the exact packed token count without building a packed batch. + + The estimator intentionally only handles CPU tensors. For CUDA tensors, many + tiny prefix probes would launch many tiny kernels, so callers should fall + back to full packing instead. + """ + if max_depth < 0: + raise ValueError("max_depth must be >= 0") + + rows: list[list[int]] = [] + for sequence in sequences: + tensor = _sequence_tensor(sequence) + if tensor.device.type != "cpu": + return None + rows.append(tensor.tolist()) + + return sum( + segment.end - segment.start + for segment in _prefix_segments(tuple(rows), max_depth=max_depth) + ) + + +def _prefix_segments( + rows: tuple[list[int], ...], + *, + max_depth: int, +) -> tuple[_PrefixSegment, ...]: + lengths = tuple(len(row) for row in rows) + segments: list[_PrefixSegment] = [] + next_group_id = 1 + + def emit( + indices: tuple[int, ...], + start: int, + end: int, + parent_group_id: int | None, + ) -> int: + nonlocal next_group_id + group_id = next_group_id + next_group_id += 1 + segments.append( + _PrefixSegment( + sequence_indices=indices, + start=start, + end=end, + group_id=group_id, + parent_id=group_id if parent_group_id is None else parent_group_id, + ) + ) + return group_id + + def shared_end(indices: tuple[int, ...], start: int) -> int: + end = min(lengths[index] for index in indices) + low = high = rows[indices[0]] + for index in indices[1:]: + row = rows[index] + if row < low: + low = row + elif row > high: + high = row + while start < end: + if low[start] != high[start]: + break + start += 1 + return start + + def branch_groups(indices: tuple[int, ...], start: int) -> list[tuple[int, ...]]: + groups: dict[int, list[int]] = {} + order: list[int] = [] + for index in indices: + token = rows[index][start] + if token not in groups: + groups[token] = [] + order.append(token) + groups[token].append(index) + return [tuple(groups[token]) for token in order] + + def walk( + indices: tuple[int, ...], + start: int, + parent_group_id: int | None, + depth: int, + ) -> None: + active = tuple(index for index in indices if lengths[index] > start) + if not active: + return + if ( + max_depth == 0 + or len(active) == 1 + or (parent_group_id is not None and depth >= max_depth) + ): + for index in active: + emit((index,), start, lengths[index], parent_group_id) + return + + end = shared_end(active, start) + if end > start: + walk(active, end, emit(active, start, end, parent_group_id), depth + 1) + return + for group in branch_groups(active, start): + walk(group, start, parent_group_id, depth) + + walk(tuple(range(len(rows))), 0, None, 0) + return tuple(segments) + + +def _empty_pack( + sequence_count: int = 0, + *, + device: torch.device | None = None, +) -> SharedPrefixPack: + flat = torch.empty(0, dtype=torch.long, device=device) + row = flat.unsqueeze(0) + return SharedPrefixPack( + tokens=row, + group_ids=row, + parent_ids=row, + position_ids=row, + positions_by_sequence=tuple(flat for _ in range(sequence_count)), + ) + + +def _sequence_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.ndim != 1: + raise ValueError( + f"pack_shared_prefixes expects 1-D tensors, got {tuple(tensor.shape)}" + ) + return tensor.detach().to(dtype=torch.long).contiguous() diff --git a/src/art/megatron/shared_prefix_state.py b/src/art/megatron/shared_prefix_state.py index 7bbda4624..f3c1565b1 100644 --- a/src/art/megatron/shared_prefix_state.py +++ b/src/art/megatron/shared_prefix_state.py @@ -2,6 +2,7 @@ from __future__ import annotations +from dataclasses import replace import gc from typing import Any @@ -10,7 +11,10 @@ from torch import Tensor from torch.nn.attention.flex_attention import BlockMask -from art.megatron.context_parallel.block_mask import build_block_mask +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) from art.megatron.context_parallel.builder import build_shared_prefix_attention_spec from art.megatron.context_parallel.layout_index import TokenLayoutIndex from art.megatron.context_parallel.types import ( @@ -75,14 +79,11 @@ def create_shared_prefix_state( attention_value_head_dim=attention_value_head_dim, ), ) - cp_rank, cp_size, cp_group = _gdn_cp_rank_size_group() - gdn_execution_spec = _build_gdn_execution_spec_once( - group_ids_cpu, - parent_ids_cpu, - build=build_gdn_execution_spec, - cp_rank=cp_rank, - cp_size=cp_size, - cp_group=cp_group, + cp_rank, cp_size = _gdn_cp_rank_size() + gdn_execution_spec = ( + parse_gdn_shared_prefix_segments(group_ids_cpu, parent_ids_cpu) + if build_gdn_execution_spec + else None ) return SharedPrefixAttentionState( block_mask=block_mask, @@ -94,7 +95,6 @@ def create_shared_prefix_state( device=device, cp_rank=cp_rank, cp_size=cp_size, - cp_group=cp_group, attention_token_layout_index=attention_token_layout_index, ), ) @@ -118,53 +118,105 @@ def _build_sparse_shared_prefix_block_mask( group_ids=group_ids_cpu, parent_ids=parent_ids_cpu, ) - row_spec = batch_spec.rows[0] seq_len = int(group_ids_cpu.shape[1]) - slices = _full_row_slices_with_padding( - row_slices=row_spec.slices, - valid_tokens=int(row_spec.valid_tokens), - seq_len=seq_len, - ) - if not slices: + row_masks = [] + token_indices = torch.arange(seq_len, dtype=torch.int64) + for row_spec in batch_spec.rows: + row_index = int(row_spec.row_index) + slices = tuple(replace(slice_, row_index=0) for slice_ in row_spec.slices) + if int(row_spec.valid_tokens) < seq_len: + padding_range = TokenRange(start=int(row_spec.valid_tokens), end=seq_len) + slices = ( + *slices, + AttnSlice( + q_range=padding_range, + k_range=padding_range, + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + family_index=None, + ), + ) + if not slices: + row_masks.append( + _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) + ) + continue + row_masks.append( + build_block_mask_from_context( + FlexMaskSpec( + q_len=seq_len, + k_len=seq_len, + block_size=block_size, + slices=slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key=f"identity:{seq_len}", + ), + ), + context=prepare_block_mask_context( + group_ids=group_ids_cpu[row_index], + parent_ids=parent_ids_cpu[row_index], + ), + device=device, + ) + ) + if not row_masks: return _empty_block_mask(seq_len=seq_len, block_size=block_size, device=device) - return build_block_mask( - FlexMaskSpec( - q_len=seq_len, - k_len=seq_len, - block_size=block_size, - slices=slices, - exact_mask=ExactMaskMetadata( - q_token_indices=torch.arange(seq_len, dtype=torch.int64), - k_token_indices=torch.arange(seq_len, dtype=torch.int64), - cache_key=f"identity:{seq_len}", - ), - ), - group_ids=group_ids_cpu[0], - parent_ids=parent_ids_cpu[0], - device=device, + return _stack_row_block_masks( + row_masks, + seq_len=seq_len, + block_size=block_size, ) -def _full_row_slices_with_padding( +def _stack_optional_block_tensors( + masks: list[BlockMask], + name: str, +) -> Tensor | None: + tensors = [getattr(mask, name) for mask in masks] + if any(tensor is None for tensor in tensors): + return None + return torch.cat(tensors, dim=0) + + +def _stack_row_block_masks( + masks: list[BlockMask], *, - row_slices: tuple[AttnSlice, ...], - valid_tokens: int, seq_len: int, -) -> tuple[AttnSlice, ...]: - if valid_tokens >= seq_len: - return row_slices - padding_range = TokenRange(start=int(valid_tokens), end=int(seq_len)) - if padding_range.is_empty(): - return row_slices - return ( - *row_slices, - AttnSlice( - q_range=padding_range, - k_range=padding_range, - mask_kind=AttnMaskKind.CAUSAL, - row_index=0, - family_index=None, - ), + block_size: tuple[int, int], +) -> BlockMask: + if len(masks) == 1: + return masks[0] + row_mask_mods = tuple(mask.mask_mod for mask in masks) + + def mask_mod( + batch_idx: Tensor, + head_idx: Tensor, + query_idx: Tensor, + kv_idx: Tensor, + ) -> Tensor: + result = torch.zeros_like(query_idx, dtype=torch.bool) + for row_index, row_mask_mod in enumerate(row_mask_mods): + result = torch.where( + batch_idx == row_index, + row_mask_mod(batch_idx, head_idx, query_idx, kv_idx), + result, + ) + return result + + return BlockMask( + seq_lengths=(int(seq_len), int(seq_len)), + kv_num_blocks=torch.cat([mask.kv_num_blocks for mask in masks], dim=0), + kv_indices=torch.cat([mask.kv_indices for mask in masks], dim=0), + full_kv_num_blocks=_stack_optional_block_tensors(masks, "full_kv_num_blocks"), + full_kv_indices=_stack_optional_block_tensors(masks, "full_kv_indices"), + q_num_blocks=_stack_optional_block_tensors(masks, "q_num_blocks"), + q_indices=_stack_optional_block_tensors(masks, "q_indices"), + full_q_num_blocks=_stack_optional_block_tensors(masks, "full_q_num_blocks"), + full_q_indices=_stack_optional_block_tensors(masks, "full_q_indices"), + BLOCK_SIZE=block_size, + mask_mod=mask_mod, ) @@ -223,45 +275,17 @@ def _shared_prefix_block_size( ) -def _build_gdn_execution_spec_once( - group_ids: Tensor, - parent_ids: Tensor, - *, - build: bool, - cp_rank: int, - cp_size: int, - cp_group: Any | None, -) -> GdnPackedExecutionSpec | None: - if not build: - return None - if cp_size == 1: - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - if ( - not torch.distributed.is_available() or not torch.distributed.is_initialized() # ty: ignore[possibly-missing-attribute] - ): - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - return parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) - - def _build_gdn_execution_plan_once( spec: GdnPackedExecutionSpec | None, *, device: torch.device, cp_rank: int, cp_size: int, - cp_group: Any | None, attention_token_layout_index: TokenLayoutIndex | None, ) -> GdnRankExecutionPlan | None: if spec is None: return None planner_device = torch.device("cpu") if device.type == "cuda" else device - del cp_group gc_was_enabled = gc.isenabled() if gc_was_enabled: gc.disable() @@ -279,7 +303,7 @@ def _build_gdn_execution_plan_once( return move_gdn_rank_execution_plan_to_device(plan, device) -def _gdn_cp_rank_size_group() -> tuple[int, int, Any | None]: +def _gdn_cp_rank_size() -> tuple[int, int]: try: from megatron.core import parallel_state as ps @@ -287,8 +311,7 @@ def _gdn_cp_rank_size_group() -> tuple[int, int, Any | None]: return ( int(ps.get_context_parallel_rank()), int(ps.get_context_parallel_world_size()), - ps.get_context_parallel_group(), ) except Exception: pass - return 0, 1, None + return 0, 1 diff --git a/src/art/megatron/shared_prefix_tree.py b/src/art/megatron/shared_prefix_tree.py new file mode 100644 index 000000000..63cdb0f07 --- /dev/null +++ b/src/art/megatron/shared_prefix_tree.py @@ -0,0 +1,234 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True, slots=True) +class SharedPrefixSegment: + group_id: int + parent_id: int + start: int + end: int + family_index: int + ancestors: tuple[int, ...] + + @property + def depth(self) -> int: + return len(self.ancestors) + + +@dataclass(frozen=True, slots=True) +class SharedPrefixRowTree: + row_index: int + valid_tokens: int + segments: tuple[SharedPrefixSegment, ...] + + +def parse_shared_prefix_tree( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + ignore_padding_group_id: int = -1, +) -> tuple[SharedPrefixRowTree, ...]: + if group_ids.shape != parent_ids.shape: + raise RuntimeError( + "group_ids and parent_ids must share shape, got " + f"{tuple(group_ids.shape)} vs {tuple(parent_ids.shape)}" + ) + if group_ids.ndim != 2: + raise RuntimeError( + "group_ids and parent_ids must be rank-2 packed tensors, got " + f"{group_ids.ndim}" + ) + return tuple( + parse_shared_prefix_row( + group_ids=group_ids[row_index], + parent_ids=parent_ids[row_index], + row_index=row_index, + ignore_padding_group_id=ignore_padding_group_id, + ) + for row_index in range(int(group_ids.shape[0])) + ) + + +def parse_shared_prefix_row( + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + row_index: int = 0, + ignore_padding_group_id: int = -1, +) -> SharedPrefixRowTree: + if group_ids.shape != parent_ids.shape: + raise RuntimeError( + "group_ids and parent_ids must share shape, got " + f"{tuple(group_ids.shape)} vs {tuple(parent_ids.shape)}" + ) + if group_ids.ndim != 1: + raise RuntimeError( + f"group_ids and parent_ids must be rank-1 row tensors, got {group_ids.ndim}" + ) + + valid_tokens = _valid_length( + group_ids, + parent_ids, + ignore_padding_group_id=ignore_padding_group_id, + ) + if valid_tokens == 0: + return SharedPrefixRowTree(row_index=row_index, valid_tokens=0, segments=()) + + runs = _scan_runs(group_ids[:valid_tokens], parent_ids[:valid_tokens]) + first_segment_by_group: dict[int, SharedPrefixSegment] = {} + family_by_group: dict[int, int] = {} + ancestors_by_group: dict[int, tuple[int, ...]] = {} + segments: list[SharedPrefixSegment] = [] + next_family_index = 0 + + seen_groups: set[int] = set() + repeated_groups: dict[int, int] = {} + for _start, _end, group_id, _parent_id in runs: + if group_id in seen_groups and group_id != ignore_padding_group_id: + repeated_groups[group_id] = repeated_groups.get(group_id, 1) + 1 + seen_groups.add(group_id) + if repeated_groups: + raise RuntimeError( + "Shared-prefix metadata requires contiguous group runs per row, " + f"found repeats in row {row_index}: {repeated_groups}" + ) + + for start, end, group_id, parent_id in runs: + is_root = group_id == parent_id or ( + start == 0 and parent_id == ignore_padding_group_id + ) + if is_root: + family_index = next_family_index + next_family_index += 1 + ancestors: tuple[int, ...] = () + else: + parent_segment = first_segment_by_group.get(parent_id) + if parent_segment is None: + raise RuntimeError( + "Shared-prefix run points to a missing parent run: " + f"row={row_index}, group_id={group_id}, parent_id={parent_id}" + ) + if int(parent_segment.end) > int(start): + raise RuntimeError( + "Shared-prefix parent run must end before its child starts: " + f"row={row_index}, group_id={group_id}, parent_id={parent_id}" + ) + family_index = family_by_group[parent_id] + ancestors = (*ancestors_by_group[parent_id], parent_id) + + segment = SharedPrefixSegment( + group_id=group_id, + parent_id=parent_id, + start=start, + end=end, + family_index=family_index, + ancestors=ancestors, + ) + first_segment_by_group[group_id] = segment + family_by_group[group_id] = family_index + ancestors_by_group[group_id] = ancestors + segments.append(segment) + + return SharedPrefixRowTree( + row_index=row_index, + valid_tokens=valid_tokens, + segments=tuple(segments), + ) + + +def _valid_length( + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + *, + ignore_padding_group_id: int, +) -> int: + valid_mask = group_ids != ignore_padding_group_id + valid_count = int(valid_mask.sum().item()) + if valid_count == 0: + return 0 + if not bool(valid_mask[:valid_count].all().item()): + raise RuntimeError("Padding tokens must be a contiguous tail") + return _infer_terminal_padding_length( + group_ids[:valid_count], + parent_ids[:valid_count], + ) + + +def _infer_terminal_padding_length( + group_row: torch.Tensor, + parent_row: torch.Tensor, +) -> int: + if group_row.numel() == 0: + return 0 + runs = _scan_runs(group_row, parent_row) + if len(runs) < 2: + return int(group_row.numel()) + last_start, _last_end, last_group_id, last_parent_id = runs[-1] + if last_parent_id >= 0: + return int(group_row.numel()) + terminal_pair = (last_group_id, last_parent_id) + if any( + (group_id, parent_id) == terminal_pair + for _start, _end, group_id, parent_id in runs[:-1] + ): + return last_start + return int(group_row.numel()) + + +def _scan_runs( + group_row: torch.Tensor, + parent_row: torch.Tensor, +) -> list[tuple[int, int, int, int]]: + length = int(group_row.numel()) + if length == 0: + return [] + + group_changes = group_row[1:] != group_row[:-1] + parent_changes = parent_row[1:] != parent_row[:-1] + inconsistent_parent = torch.nonzero( + torch.logical_not(group_changes) & parent_changes, + as_tuple=False, + ).flatten() + if int(inconsistent_parent.numel()) > 0: + mismatch_index = int(inconsistent_parent[0].item()) + 1 + prior_boundaries = torch.nonzero( + group_changes[: mismatch_index - 1], + as_tuple=False, + ).flatten() + start = ( + 0 + if int(prior_boundaries.numel()) == 0 + else int(prior_boundaries[-1].item()) + 1 + ) + group_id = int(group_row[start].item()) + raise RuntimeError( + "Found one group run with inconsistent parent ids: " + f"group_id={group_id}, start={start}, end={mismatch_index}" + ) + + run_starts = torch.cat( + ( + torch.zeros(1, dtype=torch.int64, device=group_row.device), + torch.nonzero(group_changes, as_tuple=False).flatten() + 1, + ) + ) + run_ends = torch.cat( + ( + run_starts[1:], + torch.tensor([length], dtype=torch.int64, device=group_row.device), + ) + ) + starts = run_starts.to(device="cpu").tolist() + ends = run_ends.to(device="cpu").tolist() + group_ids = group_row.index_select(0, run_starts).to(device="cpu").tolist() + parent_ids = parent_row.index_select(0, run_starts).to(device="cpu").tolist() + return [ + (int(start), int(end), int(group_id), int(parent_id)) + for start, end, group_id, parent_id in zip( + starts, ends, group_ids, parent_ids, strict=True + ) + ] diff --git a/src/art/megatron/trainer_rank.py b/src/art/megatron/trainer_rank.py new file mode 100644 index 000000000..ad1bd24a8 --- /dev/null +++ b/src/art/megatron/trainer_rank.py @@ -0,0 +1,2212 @@ +from __future__ import annotations + +from collections.abc import ( + Callable, + Iterable, + Iterator, + Mapping, + MutableMapping, + Sequence, +) +from dataclasses import dataclass +import os +from typing import TYPE_CHECKING, Generic, Literal, ParamSpec, TypeVar, cast, overload + +import torch +import torch.distributed as dist + +from art.megatron.shared_prefix_packing import ( + SharedPrefixPack, + estimate_shared_prefix_packed_tokens, + pack_shared_prefixes, +) + +if TYPE_CHECKING: + from megatron.core.models.gpt.gpt_model import GPTModel + from megatron.core.optimizer import MegatronOptimizer, OptimizerConfig + from megatron.core.packed_seq_params import PackedSeqParams + + from art.megatron.context_parallel.types import ( + ArtContextParallelState, + ParallelTopology, + ) + from art.megatron.lora import LoRASlotRef + from art.megatron.shared_prefix_state import SharedPrefixAttentionState + from art.megatron.train import TrainingRuntime + + +@dataclass(frozen=True) +class AdamParams: + learning_rate: float + beta1: float = 0.9 + beta2: float = 0.99 + weight_decay: float = 0.1 + grad_clip_norm: float = 0.1 + + +@dataclass(frozen=True) +class TopK: + logprobs: torch.Tensor + tokens: torch.Tensor + + +LogprobsT = TypeVar("LogprobsT", bound=torch.Tensor | None, covariant=True) +TopKT = TypeVar("TopKT", bound=TopK | None, covariant=True) +LogitsT = TypeVar("LogitsT", bound=torch.Tensor | None, covariant=True) +HiddenStatesT = TypeVar("HiddenStatesT", bound=torch.Tensor | None, covariant=True) +T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") + +_COMPILED_FUNCTIONS: dict[Callable[..., object], Callable[..., object]] = {} +_MEMORY_PROFILE_TRUST_GROWTH = 8 + + +class _Unset: + pass + + +Unset = _Unset() +type AdapterSelection = str | None | _Unset + + +@dataclass(frozen=True) +class ForwardOutput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): + target_logprobs: LogprobsT + top_k: TopKT + logits: LogitsT + hidden_states: HiddenStatesT + + +@dataclass(slots=True) +class ForwardInput(Generic[LogprobsT, TopKT, LogitsT, HiddenStatesT]): + input_tokens: torch.Tensor + target_tokens: torch.Tensor | None = None + top_k: int | None = None + logits: bool = False + hidden_states: bool = False + checkpoint: AdapterSelection = Unset + lora: AdapterSelection = Unset + + def __post_init__(self) -> None: + if self.top_k is not None and self.top_k < 1: + raise ValueError("top_k must be >= 1") + if self.checkpoint is not Unset and self.lora is not Unset: + raise ValueError("ForwardInput cannot set both checkpoint and lora") + + +type AnyForwardInput = ForwardInput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, +] +type AnyForwardOutput = ForwardOutput[ + torch.Tensor | None, + TopK | None, + torch.Tensor | None, + torch.Tensor | None, +] +type ForwardInputs = AnyForwardInput | Iterable["ForwardInputs"] +type ForwardOutputs = AnyForwardOutput | Sequence["ForwardOutputs"] +ForwardInputsT = TypeVar("ForwardInputsT", bound=ForwardInputs) + + +@dataclass(frozen=True) +class MicroBatch(Generic[ForwardInputsT]): + inputs: Sequence[ForwardInputsT] + outputs: Sequence[ForwardOutputs] + indices: Sequence[int] + stats: "MicroBatchStats" + + def select(self, xs: Sequence[T]) -> Sequence[T]: + return [xs[i] for i in self.indices] + + +@dataclass(frozen=True) +class MicroBatchStats: + global_start: int + global_stop: int + global_count: int + local_count: int + packed_tokens: int + logical_tokens: int + estimated_required_bytes: int + available_bytes: int + rejected_candidates: int + cold_start: bool + + +@dataclass(frozen=True) +class _MemoryCheck: + estimated_required_bytes: int + available_bytes: int + fits: bool + + +@dataclass(frozen=True) +class _MemoryProfile: + bytes_per_token: float + packed_tokens: int + + +@dataclass(frozen=True) +class _CandidateMicroBatch(Generic[ForwardInputsT]): + inputs: Sequence[ForwardInputsT] + indices: tuple[int, ...] + plan: "_FlatForwardPlan" + check: _MemoryCheck + stats_global_count: int + rejected_candidates: int + cold_start: bool + + +class TrainerRankMemoryError(RuntimeError): + pass + + +@dataclass(frozen=True) +class _PushedSlot: + trainer: "TrainerRank" + ref: "LoRASlotRef" + + def __enter__(self) -> "_PushedSlot": + return self + + def __exit__(self, *args: object) -> bool: + if not self.trainer._slot_stack or self.trainer._slot_stack[-1] != self.ref: + raise RuntimeError( + "Pushed LoRA/checkpoint stack changed before context exit" + ) + self.trainer.pop_pushed_lora_or_checkpoint() + return False + + +@dataclass(frozen=True) +class _ForwardItem: + request: AnyForwardInput + input_ids: torch.Tensor + labels: torch.Tensor | None + + +@dataclass(frozen=True) +class _PreparedPackedForward: + tokens: torch.Tensor + position_ids: torch.Tensor + attention_state: "SharedPrefixAttentionState | ArtContextParallelState" + packed_seq_params: "PackedSeqParams | None" + positions_by_item: tuple[torch.Tensor, ...] + source_positions_by_item: tuple[torch.Tensor, ...] + + +@dataclass(frozen=True) +class _RowMatch: + source_offsets: torch.Tensor + row_offsets: torch.Tensor + + +@dataclass(frozen=True) +class _MemorySignature: + topology: tuple[int, int, int, int] + shared_prefix_max_depth: int + slot_group_count: int + request_mix: tuple[str, ...] + + +@dataclass(frozen=True) +class _ForwardGroupPlan: + slot_ref: "LoRASlotRef | None" + request_indices: tuple[int, ...] + items: tuple[_ForwardItem, ...] + packed: SharedPrefixPack + + +@dataclass(frozen=True) +class _FlatForwardPlan: + request_count: int + groups: tuple[_ForwardGroupPlan, ...] + packed_tokens: int + logical_tokens: int + output_bytes: int + signature: _MemorySignature + + +type _AdaptivePlanCacheKey = tuple[tuple[int, ...], object, tuple[object, ...], int] + + +class TrainerRank: + def __init__( + self, + runtime: TrainingRuntime, + *, + head_chunk_tokens: int = 512, + shared_prefix_max_depth: int = 1, + memory_safety_factor: float = 1.10, + memory_reserve_fraction: float = 0.03, + ) -> None: + if head_chunk_tokens < 1: + raise ValueError("head_chunk_tokens must be >= 1") + if shared_prefix_max_depth < 0: + raise ValueError("shared_prefix_max_depth must be >= 0") + if memory_safety_factor < 1.0: + raise ValueError("memory_safety_factor must be >= 1.0") + if not (0.0 <= memory_reserve_fraction < 1.0): + raise ValueError("memory_reserve_fraction must be in [0, 1)") + self.runtime: TrainingRuntime = runtime + self.head_chunk_tokens = head_chunk_tokens + self.shared_prefix_max_depth = shared_prefix_max_depth + self.memory_safety_factor = memory_safety_factor + self.memory_reserve_fraction = memory_reserve_fraction + self.device = next(runtime.model[0].parameters()).device + self._param_dtype_size = _dtype_size(next(runtime.model[0].parameters()).dtype) + try: + metadata_model = _language_model(runtime.model[0]) + except RuntimeError: + metadata_model = None + self._hidden_size = _hidden_size(metadata_model, runtime.provider) + self._padded_vocab_size = ( + None if metadata_model is None else _padded_vocab_size(metadata_model) + ) + self._num_layers = int( + getattr(getattr(metadata_model, "config", None), "num_layers", 0) + or getattr(runtime.provider, "num_layers", 1) + or 1 + ) + self._default_slot_ref: LoRASlotRef | None = None + self._slot_stack: list[LoRASlotRef] = [] + self._dynamic_optimizers: dict[str, torch.optim.Optimizer] = {} + self._checkpoint_slot_params_by_name: dict[ + str, tuple[torch.nn.Parameter, ...] + ] = {} + self._memory_profiles: dict[_MemorySignature, _MemoryProfile] = {} + self._adaptive_plan_cache: dict[_AdaptivePlanCacheKey, _FlatForwardPlan] = {} + self._adaptive_plan_cache_top_level_ids: tuple[int, ...] = () + self._adaptive_estimate_cache: dict[ + _AdaptivePlanCacheKey, tuple[_MemoryCheck, bool] | None + ] = {} + self._last_global_micro_batch_size: int | None = None + self.zero_grad() + + def zero_grad(self) -> None: + for chunk in self.runtime.model: + zero_grad_buffer = getattr(chunk, "zero_grad_buffer", None) + if callable(zero_grad_buffer): + zero_grad_buffer() + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is not None: + optimizer.zero_grad() + for params in self._checkpoint_slot_params_by_name.values(): + for param in params: + param.grad = None + + def _optimizer(self) -> "MegatronOptimizer": + optimizer = cast("MegatronOptimizer | None", self.runtime.optimizer) + if optimizer is None: + raise RuntimeError("TrainerRank requires a runtime with an optimizer") + return optimizer + + def set_checkpoint(self, name: str | None) -> None: + self._set_default_slot(self._slot_ref("checkpoint", name)) + + def set_lora(self, name: str | None) -> None: + self._set_default_slot(self._slot_ref("lora", name)) + + def push_checkpoint(self, name: str | None) -> _PushedSlot: + ref = self._slot_ref("checkpoint", name) + self._slot_stack.append(ref) + return _PushedSlot(self, ref) + + def push_lora(self, name: str | None) -> _PushedSlot: + ref = self._slot_ref("lora", name) + self._slot_stack.append(ref) + return _PushedSlot(self, ref) + + def pop_pushed_lora_or_checkpoint(self) -> None: + if not self._slot_stack: + raise RuntimeError("No pushed LoRA or checkpoint to pop") + self._slot_stack.pop() + + def load_checkpoint_slot( + self, + name: str, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float | None = None, + ) -> int: + loaded = self._load_slot( + "checkpoint", name, adapter_model, trainable=True, alpha=alpha + ) + self._checkpoint_slot_params_by_name[name] = ( + self._validate_dynamic_slot_consistency("checkpoint", name, loaded) + ) + self._dynamic_optimizers.pop(name, None) + return loaded + + def load_lora_slot( + self, + name: str, + adapter_model: dict[str, torch.Tensor], + *, + alpha: float | None = None, + ) -> int: + loaded = self._load_slot( + "lora", name, adapter_model, trainable=False, alpha=alpha + ) + self._validate_dynamic_slot_consistency("lora", name, loaded) + return loaded + + def _load_slot( + self, + kind: Literal["checkpoint", "lora"], + name: str, + adapter_model: dict[str, torch.Tensor], + *, + trainable: bool, + alpha: float | None, + ) -> int: + from art.megatron.lora import LORA_ALPHA, load_lora_slot_into_model + + return load_lora_slot_into_model( + self.runtime.model, + self._slot_ref(kind, name), + adapter_model, + alpha=LORA_ALPHA if alpha is None else alpha, + requires_grad=trainable, + ) + + def _set_default_slot(self, ref: "LoRASlotRef") -> None: + if self._slot_stack: + raise RuntimeError("Cannot set a LoRA/checkpoint while a slot is pushed") + self._default_slot_ref = ref + + @staticmethod + def _slot_ref( + kind: Literal["checkpoint", "lora"], name: str | None + ) -> "LoRASlotRef": + from art.megatron.lora import LoRASlotRef + + return LoRASlotRef(kind=kind, name=name) + + def _validate_dynamic_slot_consistency( + self, + kind: Literal["checkpoint", "lora"], + name: str, + loaded_sites: int, + ) -> tuple[torch.nn.Parameter, ...]: + from art.megatron.lora import iter_lora_slot_parameters + + ref = self._slot_ref(kind, name) + params = tuple(iter_lora_slot_parameters(self.runtime.model, ref)) + if not (dist.is_available() and dist.is_initialized()): + return params + + local = { + "rank": dist.get_rank(), + "loaded_sites": int(loaded_sites), + "param_count": len(params), + "numel": sum(int(param.numel()) for param in params), + "signature": [ + ( + tuple(int(dim) for dim in param.shape), + str(param.dtype), + bool(getattr(param, "allreduce", True)), + str(getattr(param, "grad_sync_domain", "tp_default")), + str(getattr(param, "grad_sync_op", "none")), + ) + for param in params + ], + } + gathered: list[dict[str, object] | None] = [None] * dist.get_world_size() + dist.all_gather_object(gathered, local) + ranks = [rank for rank in gathered if rank is not None] + reference = ranks[0] + if all( + rank["loaded_sites"] == reference["loaded_sites"] + and rank["signature"] == reference["signature"] + for rank in ranks + ): + return params + + summary = [ + {key: rank[key] for key in ("rank", "loaded_sites", "param_count", "numel")} + for rank in ranks + ] + raise RuntimeError( + f"Dynamic LoRA slot {kind}:{name} is not loaded consistently across " + "distributed ranks. This usually means a sharded/exported LoRA state " + "dict was passed directly to TrainerRank; gather or materialize the " + "full adapter state before loading a dynamic slot. " + f"Rank summary: {summary}." + ) + + def _resolve_slot_ref(self, request: AnyForwardInput) -> "LoRASlotRef | None": + if request.checkpoint is not Unset: + return self._slot_ref("checkpoint", cast(str | None, request.checkpoint)) + if request.lora is not Unset: + return self._slot_ref("lora", cast(str | None, request.lora)) + if self._slot_stack: + return self._slot_stack[-1] + return self._default_slot_ref + + def forward_micro_batches( + self, + inputs: Iterable[ForwardInputsT], + ) -> Iterator[MicroBatch[ForwardInputsT]]: + items = list(inputs) + self._validate_replicated_top_level_count(len(items)) + start = 0 + while start < len(items): + candidate = self._select_next_micro_batch(items, start) + flat_outputs = iter( + self._run_flat_plan_with_memory_tracking( + candidate.plan, + context="forward_micro_batches", + ) + ) + outputs = [_unflatten(item, flat_outputs) for item in candidate.inputs] + stop = start + candidate.stats_global_count + if stop < len(items): + self._last_global_micro_batch_size = max( + self._last_global_micro_batch_size or 0, + candidate.stats_global_count, + ) + yield MicroBatch( + inputs=candidate.inputs, + outputs=outputs, + indices=candidate.indices, + stats=MicroBatchStats( + global_start=start, + global_stop=stop, + global_count=candidate.stats_global_count, + local_count=len(candidate.inputs), + packed_tokens=candidate.plan.packed_tokens, + logical_tokens=candidate.plan.logical_tokens, + estimated_required_bytes=candidate.check.estimated_required_bytes, + available_bytes=candidate.check.available_bytes, + rejected_candidates=candidate.rejected_candidates, + cold_start=candidate.cold_start, + ), + ) + start = stop + + @overload + def dp_rank_forward( + self, + inputs: Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]], + ) -> Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]]: ... + + @overload + def dp_rank_forward( + self, + inputs: Iterable[ + Iterable[ForwardInput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ], + ) -> Sequence[ + Sequence[ForwardOutput[LogprobsT, TopKT, LogitsT, HiddenStatesT]] + ]: ... + + def dp_rank_forward(self, inputs: ForwardInputs) -> ForwardOutputs: + materialized = _materialize(inputs) + plan = self._plan_flat_forward(list(_flatten(materialized))) + check = self._memory_check(plan) + if not check.fits: + self._raise_memory_error( + plan, + check, + context="dp_rank_forward", + message="forward is predicted to exceed available memory", + ) + outputs = iter( + self._run_flat_plan_with_memory_tracking( + plan, + context="dp_rank_forward", + ) + ) + return _unflatten(materialized, outputs) + + def dp_reduce( + self, + tensor: torch.Tensor, + *, + op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, + ) -> None: + from megatron.core import parallel_state as ps + + dist.all_reduce( + tensor, + op=op, + group=ps.get_data_parallel_group(with_context_parallel=True), + ) + + def optim_step( + self, + *, + params: AdamParams, + scale_grads: float = 1.0, + checkpoints: Sequence[str] | None = None, + ) -> dict[str, float]: + selected_checkpoints = self._selected_dynamic_checkpoints(checkpoints) + if selected_checkpoints: + return self._dynamic_optim_step( + selected_checkpoints, + params=params, + scale_grads=scale_grads, + ) + + from art.megatron.training.finalize_grads import ( + finalize_model_grads_extended, + flush_param_grads_to_main_grads, + ) + from art.megatron.training.model_chunks import as_megatron_api_chunks + + optimizer = self._optimizer() + flush_param_grads_to_main_grads(self.runtime.model) + finalize_model_grads_extended( + as_megatron_api_chunks(self.runtime.model), + num_tokens=None, + ) + self._scale_main_grads(scale_grads) + self._configure_optimizer(params) + update_successful, grad_norm, num_zeros = optimizer.step() + optimizer.zero_grad() + self.zero_grad() + return { + "learning_rate": float(params.learning_rate), + "grad_norm": float(grad_norm), + "update_successful": float(bool(update_successful)), + "num_zeros_in_grad": float(num_zeros or 0), + } + + def _selected_dynamic_checkpoints( + self, + checkpoints: Sequence[str] | None, + ) -> tuple[str, ...]: + if checkpoints is not None: + if ( + unknown := set(checkpoints) + - self._checkpoint_slot_params_by_name.keys() + ): + raise ValueError(f"Unknown checkpoint slots: {sorted(unknown)}") + return tuple(dict.fromkeys(checkpoints)) + slots = tuple(sorted(self._checkpoint_slot_params_by_name.items())) + if not slots: + return () + has_grad = torch.tensor( + [ + int(any(param.grad is not None for param in params)) + for _, params in slots + ], + device=self.device, + dtype=torch.int32, + ) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(has_grad, op=dist.ReduceOp.MAX) + return tuple(name for (name, _), flag in zip(slots, has_grad.tolist()) if flag) + + def _dynamic_optim_step( + self, + checkpoint_names: Sequence[str], + *, + params: AdamParams, + scale_grads: float, + ) -> dict[str, float]: + all_params: list[torch.nn.Parameter] = [] + for name in checkpoint_names: + slot_params = self._checkpoint_slot_params_by_name[name] + for param in slot_params: + if param.grad is None: + param.grad = torch.zeros_like(param) + elif scale_grads != 1.0: + param.grad.mul_(scale_grads) + self._reduce_dynamic_grads(slot_params) + all_params.extend(slot_params) + + grad_norm = torch.nn.utils.clip_grad_norm_( + all_params, + max_norm=params.grad_clip_norm, + ) + for name in checkpoint_names: + optimizer = self._dynamic_optimizer(name, params) + optimizer.step() + optimizer.zero_grad(set_to_none=True) + return { + "learning_rate": float(params.learning_rate), + "grad_norm": float(grad_norm), + "update_successful": 1.0, + "num_zeros_in_grad": 0.0, + } + + def _dynamic_optimizer( + self, + name: str, + params: AdamParams, + ) -> torch.optim.Optimizer: + optimizer = self._dynamic_optimizers.get(name) + if optimizer is None: + optimizer = torch.optim.AdamW( + self._checkpoint_slot_params_by_name[name], + lr=params.learning_rate, + betas=(params.beta1, params.beta2), + weight_decay=params.weight_decay, + ) + self._dynamic_optimizers[name] = optimizer + return optimizer + for group in optimizer.param_groups: + group["lr"] = params.learning_rate + group["betas"] = (params.beta1, params.beta2) + group["weight_decay"] = params.weight_decay + return optimizer + + def _reduce_dynamic_grads(self, params: Sequence[torch.nn.Parameter]) -> None: + from megatron.core import parallel_state as ps + + from art.megatron.training.finalize_grads import ( + coalesced_all_reduce, + tensor_parallel_grad_sync, + ) + + buckets: dict[ + tuple[int, str, torch.dtype, torch.device], + tuple[object, dist.ReduceOp.RedOpType, list[torch.Tensor]], + ] = {} + + def add(group: object, op: dist.ReduceOp.RedOpType, grad: torch.Tensor) -> None: + key = (id(group), str(op), grad.dtype, grad.device) + buckets.setdefault(key, (group, op, []))[2].append(grad) + + for param in params: + grad = param.grad + if grad is None: + continue + if bool(getattr(param, "allreduce", True)): + group = ps.get_data_parallel_group(with_context_parallel=True) + else: + group = ps.get_expert_data_parallel_group() + if group is not None and group.size() > 1: + add(group, dist.ReduceOp.SUM, grad) + + sync = tensor_parallel_grad_sync(param, name="dynamic LoRA") + if sync is not None: + group, reduce_op = sync + add(group, reduce_op, grad) + + for group, op, grads in buckets.values(): + coalesced_all_reduce(grads, group=group, op=op) + + def _select_next_micro_batch( + self, + items: Sequence[ForwardInputsT], + start: int, + ) -> _CandidateMicroBatch[ForwardInputsT]: + dp_rank, dp_size = self._dp_rank_and_size() + remaining = len(items) - start + min_width = min(dp_size, remaining) + if min_width <= 0: + raise RuntimeError("cannot select an empty microbatch window") + top_level_ids = tuple(id(item) for item in items) + if top_level_ids != self._adaptive_plan_cache_top_level_ids: + self._adaptive_plan_cache.clear() + self._adaptive_estimate_cache.clear() + self._adaptive_plan_cache_top_level_ids = top_level_ids + + def clamp_width(width: int) -> int: + return max(min_width, min(width, remaining)) + + base_granularity = 1 if remaining < 64 else 8 if remaining < 256 else 32 + granularity = max( + 1, + ((base_granularity + dp_size - 1) // dp_size) * dp_size, + ) + + def snap_width(width: int) -> int: + width = clamp_width(width) + if width in (min_width, remaining) or granularity <= 1: + return width + if width < granularity: + return width + return max(min_width, (width // granularity) * granularity) + + def local_slice(width: int) -> tuple[tuple[int, ...], list[ForwardInputsT]]: + stop = start + clamp_width(width) + indices = tuple(range(start + dp_rank, stop, dp_size)) + return indices, [items[index] for index in indices] + + def candidate( + width: int, + estimated_check: _MemoryCheck | None = None, + *, + rejected: int, + ) -> _CandidateMicroBatch[ForwardInputsT]: + width = clamp_width(width) + indices, local_inputs = local_slice(width) + plan = self._cached_adaptive_plan(indices, local_inputs) + return _CandidateMicroBatch( + inputs=local_inputs, + indices=indices, + plan=plan, + check=estimated_check or self._memory_check(plan), + stats_global_count=width, + rejected_candidates=rejected, + cold_start=not self._all_ranks_have_memory_profile( + packed_tokens=plan.packed_tokens, + signature=plan.signature, + ), + ) + + def estimate(width: int) -> tuple[_MemoryCheck, bool] | None: + indices, local_inputs = local_slice(width) + return self._cached_adaptive_estimate(indices, local_inputs) + + def probe(width: int) -> tuple[bool, _MemoryCheck | None, bool]: + estimated = estimate(width) + if estimated is not None: + check, trusted = estimated + return trusted and check.fits, check, trusted + item = candidate(width, rejected=0) + return item.check.fits, item.check, not item.cold_start + + rejected = 0 + best_width = min_width + best_check: _MemoryCheck | None = None + + def fit(width: int) -> bool: + nonlocal best_width, best_check, rejected + ok, check, _ = probe(width) + if ok: + best_width = snap_width(width) + best_check = check + else: + rejected += 1 + return ok + + def search_below(failed_width: int) -> None: + low = best_width + 1 + high = failed_width - 1 + while low <= high: + mid = (low + high) // 2 + if fit(mid): + low = mid + 1 + else: + high = mid - 1 + + first_fits, first_check, first_trusted = probe(min_width) + if not first_fits: + first = candidate(min_width, first_check, rejected=rejected) + if not first.check.fits: + self._raise_memory_error( + first.plan, + first.check, + context="forward_micro_batches", + message="smallest DP microbatch is predicted to exceed available memory", + ) + if first.cold_start: + return first + best_check = first.check + else: + best_check = first_check + + stable_width = self._last_global_micro_batch_size + if stable_width is not None and stable_width >= max(64, granularity * 2): + stable_capacity = stable_width + stable_width = clamp_width(stable_capacity) + if fit(stable_width): + grow_multiplier = 4 if stable_capacity < 256 else 2 + grow_capacity = min(remaining, stable_capacity * grow_multiplier) + if remaining > grow_capacity: + grow_width = clamp_width(grow_capacity) + if grow_width > stable_width and not fit(grow_width): + search_below(grow_width) + return candidate(best_width, best_check, rejected=rejected) + search_below(stable_width) + self._last_global_micro_batch_size = best_width + return candidate(best_width, best_check, rejected=rejected) + + high_fail: int | None = None + width = min( + remaining, + max(min_width, (self._last_global_micro_batch_size or min_width) * 2), + ) + while width <= remaining: + if fit(width): + if width == remaining: + break + width = min(remaining, max(width + 1, width * 2)) + continue + high_fail = width + break + + if high_fail is not None: + search_below(high_fail) + + if not first_trusted and best_width == min_width and best_check is None: + return candidate(min_width, first_check, rejected=rejected) + return candidate(best_width, best_check, rejected=rejected) + + def _cached_adaptive_plan( + self, + indices: tuple[int, ...], + local_inputs: Sequence[ForwardInputsT], + ) -> _FlatForwardPlan: + key = self._adaptive_cache_key(indices) + cached = self._adaptive_plan_cache.get(key) + if cached is not None: + return cached + plan = self._plan_flat_forward(list(_flatten(local_inputs))) + self._adaptive_plan_cache[key] = plan + return plan + + def _cached_adaptive_estimate( + self, + indices: tuple[int, ...], + local_inputs: Sequence[ForwardInputsT], + ) -> tuple[_MemoryCheck, bool] | None: + key = self._adaptive_cache_key(indices) + if key in self._adaptive_estimate_cache: + return self._adaptive_estimate_cache[key] + estimate = self._estimate_flat_forward(list(_flatten(local_inputs))) + if estimate is not None: + packed_tokens, output_bytes, signature = estimate + estimate = ( + self._memory_check_required( + self._estimate_required_memory_bytes_from_values( + packed_tokens=packed_tokens, + output_bytes=output_bytes, + signature=signature, + ) + ), + self._all_ranks_have_memory_profile( + packed_tokens=packed_tokens, + signature=signature, + ), + ) + self._adaptive_estimate_cache[key] = estimate + return estimate + + def _adaptive_cache_key( + self, + indices: tuple[int, ...], + ) -> _AdaptivePlanCacheKey: + return ( + indices, + self._default_slot_ref, + tuple(self._slot_stack), + self.shared_prefix_max_depth, + ) + + def _validate_replicated_top_level_count(self, count: int) -> None: + if not (dist.is_available() and dist.is_initialized()): + return + counts = [0 for _ in range(dist.get_world_size())] + dist.all_gather_object(counts, int(count)) + if len(set(counts)) == 1: + return + raise ValueError( + "forward_micro_batches requires the same top-level input count on every " + "distributed rank. Pass already-DP-local inputs to dp_rank_forward instead. " + f"Observed counts by rank: {counts}." + ) + + def _dp_rank_and_size(self) -> tuple[int, int]: + try: + from megatron.core import parallel_state as ps + + return int(ps.get_data_parallel_rank()), int( + ps.get_data_parallel_world_size() + ) + except (AssertionError, ImportError, RuntimeError, ValueError): + return 0, 1 + + def _plan_flat_forward( + self, requests: Sequence[AnyForwardInput] + ) -> _FlatForwardPlan: + plans: list[_ForwardGroupPlan] = [] + output_bytes = self._estimate_group_request_output_bytes(requests) + logical_tokens = sum(int(request.input_tokens.numel()) for request in requests) + groups = self._group_active_request_indices(requests) + for slot_ref, group_indices in groups: + items = tuple( + self._forward_item(requests[index]) for index in group_indices + ) + packed = pack_shared_prefixes( + (item.input_ids for item in items), + max_depth=self.shared_prefix_max_depth, + ) + plans.append( + _ForwardGroupPlan( + slot_ref=slot_ref, + request_indices=tuple(group_indices), + items=items, + packed=packed, + ) + ) + + return _FlatForwardPlan( + request_count=len(requests), + groups=tuple(plans), + packed_tokens=sum(int(plan.packed.tokens.numel()) for plan in plans), + logical_tokens=logical_tokens, + output_bytes=output_bytes, + signature=self._memory_signature_from_requests( + requests, + slot_group_count=len(plans), + ), + ) + + def _estimate_flat_forward( + self, requests: Sequence[AnyForwardInput] + ) -> tuple[int, int, _MemorySignature] | None: + groups = self._group_active_request_indices(requests) + packed_tokens = 0 + for _, group_indices in groups: + group_packed_tokens = estimate_shared_prefix_packed_tokens( + (requests[index].input_tokens for index in group_indices), + max_depth=self.shared_prefix_max_depth, + ) + if group_packed_tokens is None: + return None + packed_tokens += group_packed_tokens + + return ( + packed_tokens, + self._estimate_group_request_output_bytes(requests), + self._memory_signature_from_requests( + requests, + slot_group_count=len(groups), + ), + ) + + def _group_active_request_indices( + self, + requests: Sequence[AnyForwardInput], + ) -> tuple[tuple["LoRASlotRef | None", tuple[int, ...]], ...]: + groups: dict[LoRASlotRef | None, list[int]] = {} + for index, request in enumerate(requests): + if ( + request.target_tokens is not None + or request.logits + or request.top_k is not None + or request.hidden_states + ): + groups.setdefault(self._resolve_slot_ref(request), []).append(index) + return tuple((slot_ref, tuple(indices)) for slot_ref, indices in groups.items()) + + def _run_flat_plan_with_memory_tracking( + self, + plan: _FlatForwardPlan, + *, + context: str, + ) -> list[AnyForwardOutput]: + if torch.cuda.is_available() and self.device.type == "cuda": + torch.cuda.synchronize(self.device) + baseline = int(torch.cuda.memory_allocated(self.device)) + torch.cuda.reset_peak_memory_stats(self.device) + else: + baseline = 0 + try: + outputs = self._execute_flat_plan(plan) + except torch.cuda.OutOfMemoryError as exc: + check = self._memory_check(plan) + self._raise_memory_error( + plan, + check, + context=context, + message="CUDA OOM occurred despite the planner estimate", + ) + raise AssertionError("unreachable") from exc + if torch.cuda.is_available() and self.device.type == "cuda": + torch.cuda.synchronize(self.device) + peak = int(torch.cuda.max_memory_allocated(self.device)) + self._update_memory_profile(plan, max(0, peak - baseline)) + return outputs + + def _execute_flat_plan(self, plan: _FlatForwardPlan) -> list[AnyForwardOutput]: + outputs = [ + ForwardOutput( + target_logprobs=None, + top_k=None, + logits=None, + hidden_states=None, + ) + for _ in range(plan.request_count) + ] + for group in plan.groups: + from art.megatron.lora import use_lora_slot + + with use_lora_slot(group.slot_ref): + prepared = self._prepare_packed_forward(group.packed) + item_outputs = self._forward_packed(group.items, prepared) + for index, output in zip(group.request_indices, item_outputs, strict=True): + outputs[index] = output + return outputs + + def _estimate_group_request_output_bytes( + self, + requests: Sequence[AnyForwardInput], + ) -> int: + total = 0 + for request in requests: + seq_len = int(request.input_tokens.numel()) + if request.target_tokens is not None: + total += int(request.target_tokens.numel()) * _dtype_size(torch.float32) + if request.top_k is not None: + total += ( + seq_len + * int(request.top_k) + * (_dtype_size(torch.float32) + _dtype_size(torch.long)) + ) + if request.logits: + if self._padded_vocab_size is None: + raise RuntimeError("logits output memory requires a GPT model") + total += seq_len * self._padded_vocab_size * self._param_dtype_size + if request.hidden_states: + total += seq_len * self._hidden_size * self._param_dtype_size + return total + + def _memory_signature_from_requests( + self, + requests: Sequence[AnyForwardInput], + *, + slot_group_count: int, + ) -> _MemorySignature: + return _MemorySignature( + topology=self._topology_key(), + shared_prefix_max_depth=self.shared_prefix_max_depth, + slot_group_count=slot_group_count, + request_mix=tuple( + sorted({_request_mix_key(request) for request in requests}) + ), + ) + + def _topology_key(self) -> tuple[int, int, int, int]: + try: + topology = self._topology() + return cast( + tuple[int, int, int, int], + tuple( + int(getattr(topology, name)) for name in ("dp", "tp", "cp", "pp") + ), + ) + except (AssertionError, AttributeError, ImportError, RuntimeError, ValueError): + return (1, 1, 1, 1) + + def _memory_check( + self, + forward: _FlatForwardPlan, + ) -> _MemoryCheck: + return self._memory_check_required( + self._estimate_required_memory_bytes_from_values( + packed_tokens=forward.packed_tokens, + output_bytes=forward.output_bytes, + signature=forward.signature, + ) + ) + + def _memory_check_required(self, required: int) -> _MemoryCheck: + available = self._available_memory_bytes() + if dist.is_available() and dist.is_initialized(): + group = self._forward_memory_group() + values = torch.tensor( + [float(required), float(available)], + device=self.device if self.device.type == "cuda" else "cpu", + dtype=torch.float64, + ) + dist.all_reduce(values[0], op=dist.ReduceOp.MAX, group=group) + dist.all_reduce(values[1], op=dist.ReduceOp.MIN, group=group) + required = int(values[0].item()) + available = int(values[1].item()) + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=available, + fits=required <= available, + ) + + @staticmethod + def _forward_memory_group() -> object | None: + try: + from megatron.core import parallel_state as ps + + return ps.get_tensor_and_context_parallel_group(check_initialized=False) + except (AssertionError, ImportError, RuntimeError, ValueError): + return None + + def _raise_memory_error( + self, + plan: _FlatForwardPlan, + check: _MemoryCheck, + *, + context: str, + message: str, + ) -> None: + raise TrainerRankMemoryError( + f"{context}: {message}. " + f"packed_tokens={plan.packed_tokens} " + f"logical_tokens={plan.logical_tokens} " + f"output_gb={plan.output_bytes / 1024**3:.3f} " + f"estimated_required_gb={check.estimated_required_bytes / 1024**3:.3f} " + f"available_gb={check.available_bytes / 1024**3:.3f}. " + "Use smaller top-level items, reduce output requests, or call " + "dp_rank_forward with already-DP-local smaller inputs." + ) + + def _estimate_required_memory_bytes_from_values( + self, + *, + packed_tokens: int, + output_bytes: int, + signature: _MemorySignature, + ) -> int: + if packed_tokens <= 0: + return output_bytes + profiled = self._memory_profiles.get(signature) + activation_factor = max(4, min(16, self._num_layers // 4 + 4)) + static_compute = ( + packed_tokens + * self._hidden_size + * self._param_dtype_size + * activation_factor + ) + if ( + profiled is None + or profiled.packed_tokens * _MEMORY_PROFILE_TRUST_GROWTH < packed_tokens + ): + compute = static_compute + else: + compute = max(static_compute, int(profiled.bytes_per_token * packed_tokens)) + return int((output_bytes + compute) * self.memory_safety_factor) + + def _available_memory_bytes(self) -> int: + if not (torch.cuda.is_available() and self.device.type == "cuda"): + return 1 << 60 + free, total = torch.cuda.mem_get_info(self.device) + allocated = int(torch.cuda.memory_allocated(self.device)) + reserved = int(torch.cuda.memory_reserved(self.device)) + reusable_reserved = max(0, reserved - allocated) + reserve = int(total * self.memory_reserve_fraction) + return max(0, int(free) + reusable_reserved - reserve) + + def _all_ranks_have_memory_profile( + self, + *, + packed_tokens: int, + signature: _MemorySignature, + ) -> bool: + profile = self._memory_profiles.get(signature) + local = packed_tokens <= 0 or ( + profile is not None + and profile.packed_tokens * _MEMORY_PROFILE_TRUST_GROWTH >= packed_tokens + ) + if dist.is_available() and dist.is_initialized(): + value = torch.tensor( + int(local), + device=self.device if self.device.type == "cuda" else "cpu", + dtype=torch.int32, + ) + dist.all_reduce(value, op=dist.ReduceOp.MIN) + return bool(value.item()) + return local + + def _update_memory_profile( + self, plan: _FlatForwardPlan, peak_delta_bytes: int + ) -> None: + if plan.packed_tokens <= 0: + return + compute_delta = max(0, peak_delta_bytes - plan.output_bytes) + bytes_per_token = compute_delta / max(1, plan.packed_tokens) + previous = self._memory_profiles.get(plan.signature) + self._memory_profiles[plan.signature] = _MemoryProfile( + bytes_per_token=max( + bytes_per_token, + 0.0 if previous is None else previous.bytes_per_token, + ), + packed_tokens=max( + plan.packed_tokens, + 0 if previous is None else previous.packed_tokens, + ), + ) + + def _forward_item(self, request: AnyForwardInput) -> _ForwardItem: + if request.top_k is not None: + _validate_top_k(request.top_k, _language_model(self.runtime.model[0])) + input_ids = request.input_tokens.reshape(-1).to(dtype=torch.long) + if int(input_ids.numel()) == 0: + raise ValueError("input_tokens must not be empty") + labels = None + if request.target_tokens is not None: + labels = request.target_tokens.to(dtype=torch.long) + if int(labels.numel()) == 0: + raise ValueError("target_tokens must not be empty") + input_shape = tuple(request.input_tokens.shape) + if tuple(labels.shape) == input_shape: + labels = labels.reshape(-1) + elif ( + labels.ndim > request.input_tokens.ndim + and tuple(labels.shape[: request.input_tokens.ndim]) == input_shape + ): + labels = labels.reshape( + int(input_ids.numel()), *labels.shape[request.input_tokens.ndim :] + ) + elif labels.ndim < 1 or int(labels.shape[0]) != int(input_ids.numel()): + raise ValueError( + "target_tokens must match input_tokens or add trailing target " + f"dimensions: input_tokens={input_shape} " + f"target_tokens={tuple(labels.shape)}" + ) + return _ForwardItem(request=request, input_ids=input_ids, labels=labels) + + def _forward_packed( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + ) -> list[AnyForwardOutput]: + hidden_by_row = self._gather_sequence_parallel_hidden( + self._decoder_hidden(prepared) + ) + return self._project_head(items, prepared, hidden_by_row) + + def _decoder_hidden( + self, + prepared: _PreparedPackedForward, + ) -> torch.Tensor: + from art.megatron.train import _placeholder_attention_mask + + handler = self.runtime.model_support_handler + model = _language_model(self.runtime.model[0]) + attention_mask = _placeholder_attention_mask(self.device) + forward_kwargs = handler.get_forward_kwargs( + self.runtime.model[0], + attention_bias=prepared.attention_state, + ) + extra_block_kwargs = cast( + dict[str, object] | None, + forward_kwargs.pop("extra_block_kwargs", None), + ) + preprocessed = model._preprocess( + input_ids=prepared.tokens, + position_ids=prepared.position_ids, + packed_seq_params=cast("PackedSeqParams", prepared.packed_seq_params), + ) + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = preprocessed[:6] + rotary_pos_cos_sin = preprocessed[6] if len(preprocessed) == 7 else None + return cast( + torch.Tensor, + model.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + packed_seq_params=prepared.packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + **(extra_block_kwargs or {}), + ), + ) + + def _project_head( + self, + items: Sequence[_ForwardItem], + prepared: _PreparedPackedForward, + hidden_by_row: torch.Tensor, + ) -> list[AnyForwardOutput]: + model = _language_model(self.runtime.model[0]) + output_weight = ( + model.shared_embedding_or_output_weight() + if bool(model.share_embeddings_and_output_weights) + else None + ) + device = hidden_by_row.device + target_logprobs = [None for _ in items] + logits: list[torch.Tensor | None] = [None for _ in items] + top_k: list[TopK | None] = [None for _ in items] + label_rows: list[torch.Tensor | None] = [None for _ in items] + projected_rows: list[torch.Tensor] = [] + + for index, (item, positions_cpu) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ): + positions = positions_cpu.to(device=device) + if item.request.logits or item.request.top_k is not None: + projected_rows.append(positions) + if item.labels is not None: + source_positions = prepared.source_positions_by_item[index].to(device) + labels = item.labels.to(device=device).index_select(0, source_positions) + label_rows[index] = labels + target_logprobs[index] = torch.zeros( + tuple(labels.shape), + device=device, + dtype=torch.float32, + ) + if item.request.top_k is None and not item.request.logits: + valid = labels != -100 + if labels.ndim > 1: + valid = valid.reshape(int(labels.shape[0]), -1).any(dim=1) + valid_offsets = torch.nonzero(valid, as_tuple=False).reshape(-1) + if int(valid_offsets.numel()): + projected_rows.append(positions.index_select(0, valid_offsets)) + if item.request.logits: + logits[index] = torch.empty( + (int(positions.numel()), _padded_vocab_size(model)), + device=hidden_by_row.device, + dtype=hidden_by_row.dtype, + ) + + row_tensor = ( + torch.cat(projected_rows).unique(sorted=True) + if projected_rows + else torch.empty(0, dtype=torch.long, device=device) + ) + if int(row_tensor.numel()): + local_row_matches = tuple( + _row_match(positions.to(device=device), row_tensor) + for positions in prepared.positions_by_item + ) + self._project_vocab_parallel( + items, + hidden_by_row, + row_tensor, + row_matches=local_row_matches, + item_lengths=tuple( + int(positions.numel()) for positions in prepared.positions_by_item + ), + output_weight=output_weight, + target_logprobs=target_logprobs, + top_k=top_k, + logits=logits, + label_rows=label_rows, + ) + + target_logprobs, top_k = _anchor_disconnected_outputs( + target_logprobs, + top_k, + hidden_by_row, + ) + return [ + ForwardOutput( + target_logprobs=target_logprobs[index], + top_k=top_k[index], + logits=logits[index], + hidden_states=( + _select_positions(hidden_by_row, positions) + if item.request.hidden_states + else None + ), + ) + for index, (item, positions) in enumerate( + zip(items, prepared.positions_by_item, strict=True) + ) + ] + + def _project_vocab_parallel( + self, + items: Sequence[_ForwardItem], + hidden_by_row: torch.Tensor, + rows: torch.Tensor, + *, + row_matches: Sequence[_RowMatch], + item_lengths: Sequence[int], + output_weight: torch.Tensor | None, + target_logprobs: list[torch.Tensor | None], + top_k: list[TopK | None], + logits: list[torch.Tensor | None], + label_rows: list[torch.Tensor | None], + ) -> None: + model = _language_model(self.runtime.model[0]) + max_top_k = max((int(item.request.top_k or 0) for item in items), default=0) + need_log_z = any( + item.labels is not None or item.request.top_k is not None for item in items + ) + for start in range(0, int(rows.numel()), self.head_chunk_tokens): + chunk_rows = rows[start : start + self.head_chunk_tokens] + local_logits = self._local_logits_from_hidden_rows( + model, + _select_positions(hidden_by_row, chunk_rows), + output_weight=output_weight, + ) + log_z: torch.Tensor | None = None + local_topk: tuple[torch.Tensor, torch.Tensor] | None = None + if need_log_z: + topk_stats = _try_triton_local_topk_stats(local_logits, k=max_top_k) + logsumexp_stats = ( + _try_triton_local_logsumexp_stats(local_logits) + if topk_stats is None + else None + ) + stats = topk_stats if topk_stats is not None else logsumexp_stats + if stats is not None: + local_max, local_sum = stats[:2] + local_max = local_max.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + global_sum = _all_reduce_tensor_parallel_sum( + local_sum * torch.exp(local_max - global_max) + ) + log_z = global_max + torch.log(global_sum) + else: + log_z = _vocab_parallel_log_z(local_logits) + + if topk_stats is not None: + _, _, local_values, local_tokens = topk_stats + local_topk = (local_values, local_tokens) + elif logsumexp_stats is not None and max_top_k > 0: + local_k = min(max_top_k, int(local_logits.shape[1])) + local_values, local_tokens = torch.topk( + local_logits, k=local_k, dim=-1 + ) + local_topk = (local_values.float(), local_tokens) + + logit_chunks = [ + chunk_offsets + for item, match in zip(items, row_matches, strict=True) + if item.request.logits + for _, chunk_offsets in ( + _match_chunk_offsets( + match, + start=start, + end=start + int(chunk_rows.numel()), + ), + ) + if int(chunk_offsets.numel()) + ] + logit_chunk_offsets = ( + torch.cat(logit_chunks).unique(sorted=True) + if logit_chunks + else torch.empty(0, dtype=torch.long, device=rows.device) + ) + chunk_logits: torch.Tensor | None = None + if int(logit_chunk_offsets.numel()): + chunk_logits = _batch_seq_logits( + self._gather_tensor_parallel_logits( + local_logits.index_select(0, logit_chunk_offsets).unsqueeze(1) + ), + seq_len=int(logit_chunk_offsets.numel()), + ).squeeze(0) + + for index, item in enumerate(items): + offsets, chunk_offsets = _match_chunk_offsets( + row_matches[index], + start=start, + end=start + int(chunk_rows.numel()), + ) + if int(offsets.numel()) == 0: + continue + item_logits = logits[index] + if item_logits is not None: + if chunk_logits is None: + raise RuntimeError("logits output requires gathered logits") + source_offsets, gathered_offsets = _matching_offsets( + chunk_offsets, + logit_chunk_offsets, + ) + item_logits[offsets.index_select(0, source_offsets)] = ( + chunk_logits.index_select(0, gathered_offsets) + ) + labels = label_rows[index] + item_logprobs = target_logprobs[index] + if item_logprobs is not None and labels is not None: + if log_z is None: + raise RuntimeError("target logprobs require logsumexp") + selected_log_z = log_z.index_select(0, chunk_offsets) + item_logprobs[offsets] = _vocab_parallel_target_logprobs( + local_logits, + labels.index_select(0, offsets), + selected_log_z, + row_offsets=chunk_offsets, + ) + k = item.request.top_k + if k is not None: + if log_z is None: + raise RuntimeError("top_k requires logsumexp") + selected_log_z = log_z.index_select(0, chunk_offsets) + if local_topk is not None: + local_values, local_tokens = local_topk + selected_values = local_values.index_select(0, chunk_offsets) + selected_tokens = local_tokens.index_select(0, chunk_offsets) + else: + selected_logits = local_logits.index_select(0, chunk_offsets) + selected_values, selected_tokens = torch.topk( + selected_logits.float(), + k=min(k, int(selected_logits.shape[1])), + dim=-1, + ) + values = _vocab_parallel_topk_from_local( + selected_values, + selected_tokens, + k=k, + log_z=selected_log_z, + vocab_start=_vocab_range(local_logits)[0], + ) + current = top_k[index] + if current is None: + current = TopK( + logprobs=torch.empty( + (item_lengths[index], int(values.logprobs.shape[1])), + device=values.logprobs.device, + dtype=values.logprobs.dtype, + ), + tokens=torch.empty( + (item_lengths[index], int(values.tokens.shape[1])), + device=values.tokens.device, + dtype=values.tokens.dtype, + ), + ) + top_k[index] = current + current.logprobs[offsets] = values.logprobs + current.tokens[offsets] = values.tokens + + def _local_logits_from_hidden_rows( + self, + model: "GPTModel", + hidden: torch.Tensor, + *, + output_weight: torch.Tensor | None, + ) -> torch.Tensor: + output_layer = model.output_layer + sequence_parallel = bool(getattr(output_layer, "sequence_parallel", False)) + if sequence_parallel: + output_layer.sequence_parallel = False + try: + logits, _ = output_layer( + hidden.unsqueeze(1), + weight=output_weight, + runtime_gather_output=None, + ) + finally: + if sequence_parallel: + output_layer.sequence_parallel = True + return _batch_seq_logits( + model._scale_logits(logits), + seq_len=int(hidden.shape[0]), + ).squeeze(0) + + def _gather_sequence_parallel_hidden(self, hidden: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return hidden.squeeze(1) + from megatron.core import tensor_parallel + + gathered = tensor_parallel.gather_from_sequence_parallel_region( + hidden, + tensor_parallel_output_grad=True, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ) + return cast(torch.Tensor, gathered).squeeze(1) + + def _prepare_packed_forward( + self, + batch: SharedPrefixPack, + ) -> _PreparedPackedForward: + topology = self._topology() + batch = _pad_packed_batch(batch, multiple=int(topology.tp)) + if int(topology.cp) > 1: + return self._prepare_context_parallel_forward(batch, topology=topology) + from art.megatron.shared_prefix_state import create_shared_prefix_state + + handler = self.runtime.model_support_handler + provider = self.runtime.provider + return _PreparedPackedForward( + tokens=batch.tokens.to(self.device), + position_ids=batch.position_ids.to(self.device), + attention_state=create_shared_prefix_state( + group_ids=batch.group_ids, + parent_ids=batch.parent_ids, + target_device=self.device, + build_gdn_execution_spec=handler.build_gdn_execution_spec, + attention_head_dim=provider.kv_channels, + attention_value_head_dim=provider.kv_channels, + ), + packed_seq_params=None, + positions_by_item=batch.positions_by_sequence, + source_positions_by_item=tuple( + torch.arange( + int(positions.numel()), + dtype=torch.long, + device=positions.device, + ) + for positions in batch.positions_by_sequence + ), + ) + + def _prepare_context_parallel_forward( + self, + batch: SharedPrefixPack, + *, + topology: "ParallelTopology", + ) -> _PreparedPackedForward: + from megatron.core import parallel_state as ps + + from art.megatron.context_parallel.runtime import ( + _dispatch_tensor, + prepare_cp_micro, + ) + from art.megatron.training.microbatches import ( + _context_parallel_config_for_provider, + ) + from art.preprocessing.pack import PackedTensors + + assistant_mask = torch.ones_like(batch.tokens, dtype=torch.bool) + sparse_micro: PackedTensors = { + "tokens": batch.tokens, + "group_ids": batch.group_ids, + "parent_ids": batch.parent_ids, + "input_pos": batch.position_ids, + "assistant_mask": assistant_mask, + "logprobs": torch.full_like( + batch.tokens, float("nan"), dtype=torch.float32 + ), + "advantages": torch.zeros_like(batch.tokens, dtype=torch.float32), + "weights": assistant_mask.to(dtype=torch.float32), + "pixel_values": [None], + "image_grid_thw": [None], + "moe_routing_replay": None, + } + handler = self.runtime.model_support_handler + prepared = prepare_cp_micro( + micro=sparse_micro, + topology=topology, + config=_context_parallel_config_for_provider( + self.runtime.provider, self.device + ), + cp_group=ps.get_context_parallel_group(check_initialized=False), + cp_rank=ps.get_context_parallel_rank(), + build_gdn_execution_spec=handler.build_gdn_execution_spec, + target_device=self.device, + ) + if prepared.rank_plan is None: + raise RuntimeError("CP forward preparation did not return a rank plan") + local_positions = _dispatch_tensor( + torch.arange( + int(batch.tokens.shape[1]), + dtype=torch.long, + ).unsqueeze(0), + rank_plan=prepared.rank_plan, + pad_value=-1, + pad_multiple=prepared.pad_multiple, + ) + local_position_pairs = tuple( + _local_position_pairs(local_positions, positions) + for positions in batch.positions_by_sequence + ) + return _PreparedPackedForward( + tokens=prepared.tensors.tokens, + position_ids=prepared.tensors.input_pos, + attention_state=cast("ArtContextParallelState", prepared.attention_state), + packed_seq_params=prepared.packed_seq_params, + positions_by_item=tuple(pair[0] for pair in local_position_pairs), + source_positions_by_item=tuple(pair[1] for pair in local_position_pairs), + ) + + def _topology(self) -> "ParallelTopology": + from art.megatron.train import _infer_parallel_topology + + return _infer_parallel_topology(self.runtime.model) + + def _gather_tensor_parallel_logits(self, logits: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return logits + from megatron.core import tensor_parallel + + return cast( + torch.Tensor, + tensor_parallel.gather_from_tensor_model_parallel_region(logits), + ) + + def _configure_optimizer(self, params: AdamParams) -> None: + optimizer = self._optimizer() + config = cast("OptimizerConfig | None", optimizer.config) + if config is not None: + config.lr = params.learning_rate + config.adam_beta1 = params.beta1 + config.adam_beta2 = params.beta2 + config.weight_decay = params.weight_decay + config.clip_grad = params.grad_clip_norm + for group in optimizer.param_groups: + param_group = cast(MutableMapping[str, object], group) + param_group["lr"] = params.learning_rate + param_group["weight_decay"] = params.weight_decay + if "betas" in param_group: + param_group["betas"] = (params.beta1, params.beta2) + + def _scale_main_grads(self, scale: float) -> None: + if scale == 1.0: + return + for chunk in self.runtime.model: + for param in chunk.parameters(): + grad = getattr(param, "main_grad", None) + if isinstance(grad, torch.Tensor): + grad.mul_(scale) + elif param.grad is not None: + param.grad.mul_(scale) + + +def _validate_top_k(top_k: int, model: "GPTModel") -> None: + vocab_size = _padded_vocab_size(model) + if top_k > vocab_size: + raise ValueError(f"top_k={top_k} exceeds vocabulary size {vocab_size}") + + +def _request_mix_key(request: AnyForwardInput) -> str: + parts = [] + if request.target_tokens is not None: + target = request.target_tokens + tail_shape = tuple(target.shape[request.input_tokens.ndim :]) + parts.append(f"target:{tail_shape or 'single'}") + if request.top_k is not None: + parts.append(f"topk:{int(request.top_k)}") + if request.logits: + parts.append("logits") + if request.hidden_states: + parts.append("hidden") + return "+".join(parts) if parts else "inactive" + + +def _pad_packed_batch( + batch: SharedPrefixPack, + *, + multiple: int, +) -> SharedPrefixPack: + if multiple <= 1: + return batch + seq_len = int(batch.tokens.shape[1]) + pad = -seq_len % multiple + if pad == 0: + return batch + + device = batch.tokens.device + next_group = ( + int(batch.group_ids.max().item()) + 1 if int(batch.group_ids.numel()) else 1 + ) + pad_group_ids = torch.arange( + next_group, + next_group + pad, + dtype=batch.group_ids.dtype, + device=device, + ).unsqueeze(0) + return SharedPrefixPack( + tokens=torch.cat( + ( + batch.tokens, + torch.zeros((1, pad), dtype=batch.tokens.dtype, device=device), + ), + dim=1, + ), + group_ids=torch.cat((batch.group_ids, pad_group_ids), dim=1), + parent_ids=torch.cat((batch.parent_ids, pad_group_ids), dim=1), + position_ids=torch.cat( + ( + batch.position_ids, + torch.zeros((1, pad), dtype=batch.position_ids.dtype, device=device), + ), + dim=1, + ), + positions_by_sequence=batch.positions_by_sequence, + ) + + +def _language_model(model: torch.nn.Module) -> "GPTModel": + module: object = model + while hasattr(module, "module"): + module = getattr(module, "module") + if hasattr(module, "_preprocess") and hasattr(module, "decoder"): + return cast("GPTModel", module) + language_model = getattr(module, "language_model", None) + if language_model is not None: + return cast("GPTModel", language_model) + raise RuntimeError("expected a Megatron GPT model") + + +def _padded_vocab_size(model: "GPTModel") -> int: + vocab_size = getattr(getattr(model, "config", None), "padded_vocab_size", None) + if vocab_size is None: + vocab_size = getattr(model, "vocab_size", None) + if vocab_size is None: + raise RuntimeError("could not determine full padded vocabulary size") + return int(vocab_size) + + +def _hidden_size(model: "GPTModel | None", provider: object) -> int: + for source in (getattr(model, "config", None), model, provider): + if source is None: + continue + hidden_size = getattr(source, "hidden_size", None) + if hidden_size is not None: + return int(hidden_size) + raise RuntimeError("could not determine hidden size") + + +def _dtype_size(dtype: torch.dtype) -> int: + return torch.empty((), dtype=dtype).element_size() + + +def _vocab_parallel_target_logprobs( + local_logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, + *, + row_offsets: torch.Tensor, +) -> torch.Tensor: + start, _ = _vocab_range(local_logits) + target_logits = _call_compiled( + _owned_target_logits_for_rows, + local_logits, + labels, + start, + row_offsets, + ) + target_logits = _all_reduce_tensor_parallel_sum(target_logits) + return _call_compiled(_finish_target_logprobs, target_logits, labels, log_z) + + +def _owned_target_logits_for_rows( + local_logits: torch.Tensor, + labels: torch.Tensor, + vocab_start: int, + row_offsets: torch.Tensor, +) -> torch.Tensor: + flat_labels = labels.reshape(int(labels.shape[0]), -1) + local_labels = flat_labels - vocab_start + owns_label = ( + (flat_labels != -100) + & (local_labels >= 0) + & (local_labels < int(local_logits.shape[1])) + ) + rows = row_offsets.reshape(int(row_offsets.shape[0]), 1).expand_as(flat_labels) + selected = local_logits[ + rows, + local_labels.clamp(0, int(local_logits.shape[1]) - 1), + ].float() + return selected.masked_fill(~owns_label, 0.0).reshape(labels.shape) + + +def _finish_target_logprobs( + target_logits: torch.Tensor, + labels: torch.Tensor, + log_z: torch.Tensor, +) -> torch.Tensor: + log_z = log_z.reshape(int(log_z.shape[0]), *((1,) * (int(labels.ndim) - 1))) + return (target_logits.float() - log_z).masked_fill(labels == -100, 0.0) + + +def _anchor_disconnected_outputs( + target_logprobs: list[torch.Tensor | None], + top_k: list[TopK | None], + hidden_by_row: torch.Tensor, +) -> tuple[list[torch.Tensor | None], list[TopK | None]]: + if not hidden_by_row.requires_grad: + return target_logprobs, top_k + anchor: torch.Tensor | None = None + + def anchor_tensor(tensor: torch.Tensor) -> torch.Tensor: + nonlocal anchor + if tensor.requires_grad: + return tensor + if anchor is None: + anchor = hidden_by_row.reshape(-1)[:1].float().sum() * 0.0 + return tensor + anchor + + return ( + [ + None if item_logprobs is None else anchor_tensor(item_logprobs) + for item_logprobs in target_logprobs + ], + [ + None + if item_top_k is None + else TopK( + logprobs=anchor_tensor(item_top_k.logprobs), + tokens=item_top_k.tokens, + ) + for item_top_k in top_k + ], + ) + + +def _try_triton_local_topk_stats( + local_logits: torch.Tensor, + *, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None: + if k <= 0 or k > int( + os.environ.get("ART_TRAINER_RANK_TRITON_FUSED_TOPK_MAX", "10") + ): + return None + return cast( + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None, + _try_triton_stats( + "local_topk_stats", + local_logits, + k=min(k, int(local_logits.shape[1])), + ), + ) + + +def _try_triton_local_logsumexp_stats( + local_logits: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor] | None: + return cast( + tuple[torch.Tensor, torch.Tensor] | None, + _try_triton_stats("local_logsumexp_stats", local_logits), + ) + + +def _try_triton_stats( + name: str, + local_logits: torch.Tensor, + **kwargs: object, +) -> object | None: + if not local_logits.is_cuda: + return None + if os.environ.get("ART_TRAINER_RANK_TRITON_TOPK", "1").lower() in { + "0", + "false", + } or int(local_logits.shape[0]) < int( + os.environ.get("ART_TRAINER_RANK_TRITON_MIN_ROWS", "64") + ): + return None + try: + from art.megatron import trainer_rank_topk + + return getattr(trainer_rank_topk, name)(local_logits, **kwargs) + except Exception: + if os.environ.get("ART_TRAINER_RANK_TRITON_TOPK", "1").lower() == "strict": + raise + return None + + +def _vocab_parallel_topk_from_local( + local_values: torch.Tensor, + local_tokens: torch.Tensor, + *, + k: int, + log_z: torch.Tensor, + vocab_start: int, +) -> TopK: + local_k = min(k, int(local_values.shape[1])) + local_values = local_values[:, :local_k] - log_z.unsqueeze(1) + local_tokens = local_tokens[:, :local_k] + vocab_start + + from megatron.core import parallel_state as ps + + tp_size = int(ps.get_tensor_model_parallel_world_size()) + if tp_size <= 1: + return TopK(logprobs=local_values, tokens=local_tokens) + + from torch.distributed.nn.functional import all_gather + + group = ps.get_tensor_model_parallel_group(check_initialized=False) + gathered_values = cast(tuple[torch.Tensor, ...], all_gather(local_values, group)) + gathered_tokens = [torch.empty_like(local_tokens) for _ in range(tp_size)] + dist.all_gather(gathered_tokens, local_tokens, group=group) + values = torch.cat(gathered_values, dim=1) + tokens = torch.cat(gathered_tokens, dim=1) + top_values, top_offsets = torch.topk(values, k=k, dim=-1) + return TopK(logprobs=top_values, tokens=tokens.gather(1, top_offsets)) + + +def _vocab_parallel_log_z(local_logits: torch.Tensor) -> torch.Tensor: + local_logits = local_logits.float() + local_max = local_logits.max(dim=-1).values.detach() + global_max = _all_reduce_tensor_parallel_max(local_max) + local_sum = _call_compiled(_local_vocab_exp_sum, local_logits, global_max) + global_sum = _all_reduce_tensor_parallel_sum(local_sum) + return global_max + torch.log(global_sum) + + +def _local_vocab_exp_sum( + local_logits: torch.Tensor, + global_max: torch.Tensor, +) -> torch.Tensor: + return torch.exp(local_logits.float() - global_max.unsqueeze(1)).sum(dim=-1) + + +def _vocab_range(local_logits: torch.Tensor) -> tuple[int, int]: + from megatron.core import parallel_state as ps + + local_size = int(local_logits.shape[1]) + rank = int(ps.get_tensor_model_parallel_rank()) + start = rank * local_size + return start, start + local_size + + +def _all_reduce_tensor_parallel_sum(tensor: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return tensor + from torch.distributed.nn.functional import all_reduce + + return cast( + torch.Tensor, + all_reduce( + tensor, + op=dist.ReduceOp.SUM, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ), + ) + + +def _all_reduce_tensor_parallel_max(tensor: torch.Tensor) -> torch.Tensor: + from megatron.core import parallel_state as ps + + if int(ps.get_tensor_model_parallel_world_size()) <= 1: + return tensor + output = tensor.clone() + dist.all_reduce( + output, + op=dist.ReduceOp.MAX, + group=ps.get_tensor_model_parallel_group(check_initialized=False), + ) + return output + + +def _call_compiled(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + if os.environ.get("ART_TRAINER_RANK_COMPILE", "0").lower() in {"0", "false"}: + return fn(*args, **kwargs) + compiled = _COMPILED_FUNCTIONS.get(fn) + if compiled is None: + compiled = cast(Callable[..., object], torch.compile(fn, dynamic=True)) + _COMPILED_FUNCTIONS[fn] = compiled + try: + return cast(Callable[P, R], compiled)(*args, **kwargs) + except Exception: + return fn(*args, **kwargs) + + +def _matching_offsets( + positions: torch.Tensor, + chunk_rows: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + if int(positions.numel()) == 0 or int(chunk_rows.numel()) == 0: + empty = torch.empty(0, dtype=torch.long, device=positions.device) + return empty, empty + sorted_rows, order = chunk_rows.sort() + indices = torch.searchsorted(sorted_rows, positions) + in_bounds = indices < int(sorted_rows.numel()) + source_offsets = torch.arange( + int(positions.numel()), + device=positions.device, + dtype=torch.long, + )[in_bounds] + found = indices[in_bounds] + keep = sorted_rows.index_select(0, found) == positions.index_select( + 0, + source_offsets, + ) + return source_offsets[keep], order.index_select(0, found[keep]) + + +def _row_match(positions: torch.Tensor, rows: torch.Tensor) -> _RowMatch: + source_offsets, row_offsets = _matching_offsets(positions, rows) + if int(row_offsets.numel()) > 1: + order = row_offsets.argsort() + source_offsets = source_offsets.index_select(0, order) + row_offsets = row_offsets.index_select(0, order) + return _RowMatch(source_offsets=source_offsets, row_offsets=row_offsets) + + +def _match_chunk_offsets( + match: _RowMatch, + *, + start: int, + end: int, +) -> tuple[torch.Tensor, torch.Tensor]: + keep = (match.row_offsets >= start) & (match.row_offsets < end) + source_offsets = match.source_offsets[keep] + return source_offsets, match.row_offsets[keep] - start + + +def _local_position_pairs( + local_global_positions: torch.Tensor, + item_positions: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + flat = local_global_positions.reshape(-1).to(device=item_positions.device) + local_positions = torch.nonzero(flat >= 0, as_tuple=False).reshape(-1) + global_positions = flat.index_select(0, local_positions) + source_offsets, local_offsets = _matching_offsets(item_positions, global_positions) + return ( + local_positions.index_select(0, local_offsets).to("cpu"), + source_offsets.to("cpu"), + ) + + +def _select_positions(values: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + if int(positions.numel()) == 0: + return values[:0] + return values.index_select(0, positions.to(device=values.device)) + + +def _batch_seq_logits(logits: torch.Tensor, *, seq_len: int) -> torch.Tensor: + if int(logits.ndim) != 3: + raise RuntimeError( + f"expected logits with shape [B, S, V] or [S, B, V], got {tuple(logits.shape)}" + ) + if int(logits.shape[0]) == 1 and int(logits.shape[1]) == seq_len: + return logits + if int(logits.shape[0]) == seq_len and int(logits.shape[1]) == 1: + return logits.transpose(0, 1).contiguous() + raise RuntimeError( + f"logits do not match sequence length {seq_len}: {tuple(logits.shape)}" + ) + + +def _materialize(inputs: ForwardInputs) -> ForwardInputs: + if isinstance(inputs, ForwardInput): + return inputs + return [_materialize(item) for item in _nested_forward_children(inputs)] + + +def _flatten(inputs: ForwardInputs) -> Iterator[AnyForwardInput]: + if isinstance(inputs, ForwardInput): + yield inputs + return + for item in _nested_forward_children(inputs): + yield from _flatten(item) + + +def _unflatten( + template: ForwardInputs, outputs: Iterator[AnyForwardOutput] +) -> ForwardOutputs: + if isinstance(template, ForwardInput): + return next(outputs) + return [_unflatten(item, outputs) for item in _nested_forward_children(template)] + + +def _nested_forward_children(inputs: ForwardInputs) -> Iterator[ForwardInputs]: + if isinstance(inputs, Mapping): + raise TypeError( + "dict was passed directly to TrainerRank; gather or materialize the " + "values into a list/tuple so nested forward output ordering is explicit" + ) + if isinstance(inputs, str | bytes): + raise TypeError( + "TrainerRank forward inputs must be ForwardInput objects or nested " + "iterables of ForwardInput objects, not strings" + ) + try: + return iter(cast(Iterable[ForwardInputs], inputs)) + except TypeError as exc: + raise TypeError( + "TrainerRank forward inputs must be ForwardInput objects or nested " + "iterables of ForwardInput objects" + ) from exc + + +__all__ = [ + "AdamParams", + "ForwardInput", + "ForwardOutput", + "MicroBatch", + "MicroBatchStats", + "TopK", + "TrainerRank", + "TrainerRankMemoryError", +] diff --git a/src/art/megatron/trainer_rank_topk.py b/src/art/megatron/trainer_rank_topk.py new file mode 100644 index 000000000..e0a84722f --- /dev/null +++ b/src/art/megatron/trainer_rank_topk.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from typing import Any + +import torch +import triton +import triton.language as tl + +type LocalTopKStats = tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] +type LocalLogSumExpStats = tuple[torch.Tensor, torch.Tensor] + + +@triton.jit +def _stats_stage1_kernel( + logits_ptr, + partial_max_ptr, + partial_sum_ptr, + partial_values_ptr, + partial_tokens_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + n_blocks: tl.constexpr, + k: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + values = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + + block_max = tl.max(values, axis=0) + block_sum = tl.sum(tl.exp(values - block_max), axis=0) + partial_offset = row * n_blocks + block + tl.store(partial_max_ptr + partial_offset, block_max) + tl.store(partial_sum_ptr + partial_offset, block_sum) + + work = values + arange = tl.arange(0, block_v) + for slot in tl.static_range(0, k): + top_value, top_index = tl.max( + work, + axis=0, + return_indices=True, + return_indices_tie_break_left=True, + ) + output_offset = (partial_offset * k) + slot + tl.store(partial_values_ptr + output_offset, top_value) + tl.store( + partial_tokens_ptr + output_offset, + (block * block_v + top_index).to(tl.int64), + ) + work = tl.where(arange == top_index, -float("inf"), work) + + +@triton.jit +def _stats_stage2_kernel( + partial_max_ptr, + partial_sum_ptr, + partial_values_ptr, + partial_tokens_ptr, + local_max_ptr, + local_sum_ptr, + values_ptr, + tokens_ptr, + n_blocks: tl.constexpr, + k: tl.constexpr, + block_b: tl.constexpr, + block_candidates: tl.constexpr, +): + row = tl.program_id(0) + + block_offsets = tl.arange(0, block_b) + block_mask = block_offsets < n_blocks + partial_base = row * n_blocks + block_max = tl.load( + partial_max_ptr + partial_base + block_offsets, + mask=block_mask, + other=-float("inf"), + ) + row_max = tl.max(block_max, axis=0) + block_sum = tl.load( + partial_sum_ptr + partial_base + block_offsets, + mask=block_mask, + other=0.0, + ) + row_sum = tl.sum(block_sum * tl.exp(block_max - row_max), axis=0) + tl.store(local_max_ptr + row, row_max) + tl.store(local_sum_ptr + row, row_sum) + + if k > 0: + candidate_offsets = tl.arange(0, block_candidates) + candidate_mask = candidate_offsets < n_blocks * k + candidate_base = row * n_blocks * k + candidates = tl.load( + partial_values_ptr + candidate_base + candidate_offsets, + mask=candidate_mask, + other=-float("inf"), + ) + work = candidates + for slot in tl.static_range(0, k): + top_value, top_index = tl.max( + work, + axis=0, + return_indices=True, + return_indices_tie_break_left=True, + ) + output_offset = row * k + slot + tl.store(values_ptr + output_offset, top_value) + tl.store( + tokens_ptr + output_offset, + tl.load(partial_tokens_ptr + candidate_base + top_index), + ) + work = tl.where(candidate_offsets == top_index, -float("inf"), work) + + +@triton.jit +def _stats_backward_kernel( + logits_ptr, + local_max_ptr, + tokens_ptr, + grad_sum_ptr, + grad_values_ptr, + grad_logits_ptr, + stride_row: tl.constexpr, + vocab_size: tl.constexpr, + k: tl.constexpr, + block_v: tl.constexpr, +): + row = tl.program_id(0) + block = tl.program_id(1) + offsets = block * block_v + tl.arange(0, block_v) + mask = offsets < vocab_size + + logits = tl.load( + logits_ptr + row * stride_row + offsets, + mask=mask, + other=-float("inf"), + ).to(tl.float32) + local_max = tl.load(local_max_ptr + row) + grad = tl.load(grad_sum_ptr + row).to(tl.float32) * tl.exp(logits - local_max) + + for slot in tl.static_range(0, k): + token = tl.load(tokens_ptr + row * k + slot) + value_grad = tl.load(grad_values_ptr + row * k + slot).to(tl.float32) + grad += tl.where(offsets == token, value_grad, 0.0) + + tl.store(grad_logits_ptr + row * stride_row + offsets, grad, mask=mask) + + +class _LocalStatsFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, local_logits: torch.Tensor, k: int): + local_max, local_sum, values, tokens = _local_stats_forward(local_logits, k=k) + ctx.save_for_backward(local_logits, local_max, tokens) + ctx.k = k + return local_max, local_sum, values, tokens + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_local_max, grad_local_sum, grad_values, grad_tokens = grad_outputs + del grad_local_max, grad_tokens + logits, local_max, tokens = ctx.saved_tensors + k = int(ctx.k) + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = int(triton.cdiv(vocab_size, block_v)) + + if grad_local_sum is None: + grad_local_sum = torch.zeros_like(local_max) + if grad_values is None: + grad_values = torch.zeros( + (rows, k), + device=logits.device, + dtype=torch.float32, + ) + + grad_logits = torch.empty_like(logits) + _stats_backward_kernel[(rows, n_blocks)]( + logits, + local_max, + tokens, + grad_local_sum.contiguous(), + grad_values.contiguous(), + grad_logits, + logits.stride(0), + vocab_size=vocab_size, # ty: ignore[invalid-argument-type] + k=k, # ty: ignore[invalid-argument-type] + block_v=block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + return grad_logits, None + + +def _check_local_logits(local_logits: torch.Tensor) -> torch.Tensor: + if local_logits.ndim != 2: + raise ValueError( + f"expected [rows, vocab] logits, got {tuple(local_logits.shape)}" + ) + if not local_logits.is_cuda: + raise ValueError("local top-k helpers require CUDA logits") + return local_logits.contiguous() + + +def _local_stats_forward(local_logits: torch.Tensor, *, k: int) -> LocalTopKStats: + logits = _check_local_logits(local_logits) + if k < 0 or k > int(local_logits.shape[1]): + raise ValueError( + f"k={k} is outside local vocab size {int(local_logits.shape[1])}" + ) + + rows = int(logits.shape[0]) + vocab_size = int(logits.shape[1]) + block_v = 4096 + n_blocks = int(triton.cdiv(vocab_size, block_v)) + block_b = int(triton.next_power_of_2(n_blocks)) + block_candidates = int(triton.next_power_of_2(n_blocks * k)) if k else 1 + + partial_shape = (rows, n_blocks) + partial_max = torch.empty(partial_shape, device=logits.device, dtype=torch.float32) + partial_sum = torch.empty_like(partial_max) + partial_topk_shape = (rows, n_blocks, k) if k else (1,) + partial_values = torch.empty( + partial_topk_shape, device=logits.device, dtype=torch.float32 + ) + partial_tokens = torch.empty( + partial_topk_shape, device=logits.device, dtype=torch.long + ) + local_max = torch.empty((rows,), device=logits.device, dtype=torch.float32) + local_sum = torch.empty_like(local_max) + values = torch.empty((rows, k), device=logits.device, dtype=torch.float32) + tokens = torch.empty((rows, k), device=logits.device, dtype=torch.long) + + _stats_stage1_kernel[(rows, n_blocks)]( + logits, + partial_max, + partial_sum, + partial_values, + partial_tokens, + stride_row=logits.stride(0), # ty: ignore[invalid-argument-type] + vocab_size=vocab_size, # ty: ignore[invalid-argument-type] + n_blocks=n_blocks, # ty: ignore[invalid-argument-type] + k=k, # ty: ignore[invalid-argument-type] + block_v=block_v, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + _stats_stage2_kernel[(rows,)]( + partial_max, + partial_sum, + partial_values, + partial_tokens, + local_max, + local_sum, + values, + tokens, + n_blocks=n_blocks, # ty: ignore[invalid-argument-type] + k=k, # ty: ignore[invalid-argument-type] + block_b=block_b, # ty: ignore[invalid-argument-type] + block_candidates=block_candidates, # ty: ignore[invalid-argument-type] + num_warps=8, # ty: ignore[unknown-argument] + ) + return local_max, local_sum, values, tokens + + +def local_topk_stats(local_logits: torch.Tensor, *, k: int) -> LocalTopKStats: + logits = local_logits.contiguous() + if not logits.requires_grad: + return _local_stats_forward(logits, k=k) + return _LocalStatsFunction.apply(logits, k) + + +def local_logsumexp_stats(local_logits: torch.Tensor) -> LocalLogSumExpStats: + logits = local_logits.contiguous() + if not logits.requires_grad: + local_max, local_sum, _, _ = _local_stats_forward(logits, k=0) + return local_max, local_sum + local_max, local_sum, _, _ = _LocalStatsFunction.apply(logits, 0) + return local_max, local_sum diff --git a/src/art/megatron/training/finalize_grads.py b/src/art/megatron/training/finalize_grads.py index cde0e7b06..e00cd8218 100644 --- a/src/art/megatron/training/finalize_grads.py +++ b/src/art/megatron/training/finalize_grads.py @@ -1,4 +1,3 @@ -from collections import defaultdict from collections.abc import Iterable from typing import Any, Literal, cast @@ -16,7 +15,6 @@ GRAD_SYNC_OP_NONE: GradSyncOp = "none" GRAD_SYNC_OP_SUM: GradSyncOp = "sum" GRAD_SYNC_OP_AVG: GradSyncOp = "avg" -VALID_DOMAINS = (TP_DEFAULT_GRAD_SYNC_DOMAIN, EXPERT_TP_GRAD_SYNC_DOMAIN) VALID_SYNC_OPS = (GRAD_SYNC_OP_NONE, GRAD_SYNC_OP_SUM, GRAD_SYNC_OP_AVG) @@ -28,6 +26,8 @@ def _iter_named_trainable_parameters( for name, param in model_chunk.named_parameters(): if not param.requires_grad: continue + if getattr(param, "_art_dynamic_lora_slot", False): + continue param_id = id(param) if param_id in seen: continue @@ -60,6 +60,48 @@ def _resolve_reduce_op(op: GradSyncOp) -> Any: raise RuntimeError(f"Unknown grad sync op: {op}") +def tensor_parallel_grad_sync( + param: torch.nn.Parameter, + *, + name: str, +) -> tuple[Any, Any] | None: + domain: GradSyncDomain = getattr( + param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN + ) + group = _resolve_domain_group(domain) + if group is None: + return None + op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) + if op not in VALID_SYNC_OPS: + raise RuntimeError(f"{name}: unsupported grad_sync_op={op}") + if op == GRAD_SYNC_OP_NONE: + return None + return group, _resolve_reduce_op(op) + + +def coalesced_all_reduce( + grads: list[torch.Tensor], + *, + group: Any, + op: Any, +) -> None: + coalesced = _flatten_dense_tensors(grads) + reduced = ( + coalesced.float() + if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 + else coalesced + ) + torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] + reduced, + op=op, + group=group, + ) + if reduced is not coalesced: + reduced = reduced.to(dtype=coalesced.dtype) + for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): + grad.copy_(synced) + + def flush_param_grads_to_main_grads(model_chunks: Iterable[torch.nn.Module]) -> None: """Fallback for direct jobs when DDP post-hooks leave grads in param.grad. @@ -100,57 +142,30 @@ def finalize_model_grads_extended( ) buckets: dict[ - tuple[GradSyncDomain, GradSyncOp, torch.dtype, torch.device], - list[tuple[str, torch.Tensor]], - ] = defaultdict(list) + tuple[int, str, torch.dtype, torch.device], + tuple[Any, Any, list[torch.Tensor]], + ] = {} for name, param in _iter_named_trainable_parameters(model): - domain: GradSyncDomain = getattr( - param, "grad_sync_domain", TP_DEFAULT_GRAD_SYNC_DOMAIN - ) - if _resolve_domain_group(domain) is None: - continue - - op: GradSyncOp = getattr(param, "grad_sync_op", GRAD_SYNC_OP_NONE) - if op not in VALID_SYNC_OPS: - raise RuntimeError(f"{name}: unsupported grad_sync_op={op}") - if op == GRAD_SYNC_OP_NONE: + sync = tensor_parallel_grad_sync(param, name=name) + if sync is None: continue if not hasattr(param, "main_grad"): raise RuntimeError( - f"{name}: expected main_grad for domain={domain} reduce_op={op}, but attribute is missing" + f"{name}: expected main_grad for tensor-parallel grad sync, but attribute is missing" ) grad = param.main_grad if grad is None: raise RuntimeError( - f"{name}: expected non-None main_grad for domain={domain} reduce_op={op}" + f"{name}: expected non-None main_grad for tensor-parallel grad sync" ) local_grad = cast( # local part of dtensor torch.Tensor, grad._local_tensor if hasattr(grad, "_local_tensor") else grad ) - buckets[(domain, op, local_grad.dtype, local_grad.device)].append( - (name, local_grad) - ) + group, reduce_op = sync + key = (id(group), str(reduce_op), local_grad.dtype, local_grad.device) + buckets.setdefault(key, (group, reduce_op, []))[2].append(local_grad) - for (domain, op, _dtype, _device), entries in buckets.items(): - group = _resolve_domain_group( - domain - ) # already checked if the domain is one we are handling - - grads = [grad for _name, grad in entries] - coalesced = _flatten_dense_tensors(grads) - reduced = ( - coalesced.float() - if torch.is_floating_point(coalesced) and coalesced.dtype != torch.float32 - else coalesced - ) - torch.distributed.all_reduce( # ty: ignore[possibly-missing-attribute] - reduced, - op=_resolve_reduce_op(op), - group=group, - ) - if reduced is not coalesced: - reduced = reduced.to(dtype=coalesced.dtype) - for grad, synced in zip(grads, _unflatten_dense_tensors(reduced, grads)): - grad.copy_(synced) + for group, op, grads in buckets.values(): + coalesced_all_reduce(grads, group=group, op=op) diff --git a/src/art/megatron/weights/adapter_export.py b/src/art/megatron/weights/adapter_export.py index cce081188..76c545bda 100644 --- a/src/art/megatron/weights/adapter_export.py +++ b/src/art/megatron/weights/adapter_export.py @@ -1,3 +1,4 @@ +from collections.abc import Callable, Sequence import math from typing import Any @@ -9,31 +10,15 @@ from art.megatron.lora import ( GatedDeltaNetInProjLoRA, LoRA, - MLPExpertsLinearFC1FusedLoRA, MLPExpertsLinearFC1LoRA, MLPExpertsLinearFC2LoRA, SelfAttentionLinearProjLoRA, SelfAttentionLinearQKVLoRA, SharedExpertsLinearFC1LoRA, - SharedExpertsLinearFC2LoRA, ) from art.megatron.weights.param_name_canonicalization import canonical_art_param_name -def layer_base_prefix( - module: TransformerLayer, - *, - module_name: str | None = None, -) -> str: - if module_name is not None: - canonical_name = canonical_art_param_name(module_name) - if canonical_name.startswith( - ("decoder.layers.", "language_model.decoder.layers.") - ): - return canonical_name - return f"language_model.decoder.layers.{module.layer_number - 1}" - - def _adapter_alpha_dim(lora: LoRA) -> tuple[int, int]: dim = int(lora.A_T.shape[-1]) alpha = float(lora.scale) * dim @@ -51,12 +36,6 @@ def _adapter_tensors( return a_t.transpose(-1, -2).contiguous(), b_t.transpose(-1, -2).contiguous() -def _adapter_param_prefix(base_prefix: str, adapter_key: str | None) -> str: - if adapter_key is None: - return f"{base_prefix}.adapter" - return f"{base_prefix}.adapter.{adapter_key}" - - def _adapter_weight( *, base_prefix: str, @@ -66,7 +45,8 @@ def _adapter_weight( linear_in: torch.Tensor, linear_out: torch.Tensor, ) -> AdapterWeight: - param_prefix = _adapter_param_prefix(base_prefix, adapter_key) + adapter_suffix = "" if adapter_key is None else f".{adapter_key}" + param_prefix = f"{base_prefix}.adapter{adapter_suffix}" return AdapterWeight( global_base_prefix=base_prefix, adapter_key=adapter_key, @@ -162,83 +142,174 @@ def _fused_pair_adapter_weight( ) -def add_standard_self_attention_adapter_weights( +def _set_adapter_weights( + out: dict[str, list[Any]], + base_prefix: str, + *weights: AdapterWeight, + weight_suffix: str = ".weight", +) -> None: + out[f"{base_prefix}{weight_suffix}"] = list(weights) + + +def _set_expert_adapter_weights( + out: dict[str, list[Any]], + base_prefix: str, + lora: LoRA, + build_weight: Callable[[int], AdapterWeight], +) -> None: + for local_expert_idx in range(lora.num_local_experts): + global_expert_idx = local_expert_idx + lora._expert_offset + _set_adapter_weights( + out, + base_prefix, + build_weight(local_expert_idx), + weight_suffix=f".weight{global_expert_idx}", + ) + + +def _set_lora_weights( + out: dict[str, list[Any]], + base_prefix: str, + *items: tuple[LoRA, str | None], +) -> None: + _set_adapter_weights( + out, + base_prefix, + *( + _simple_adapter_weight(base_prefix, lora, adapter_key=adapter_key) + for lora, adapter_key in items + ), + ) + + +def build_transformer_layer_adapter_weights( + model_chunks: Sequence[Any], + grouped_moe: bool = False, + language_layers_only: bool = False, +) -> dict[str, list[Any]]: + layer_filter = None + if language_layers_only: + from art.megatron.lora import ( + _is_language_transformer_layer_name as layer_filter, + ) + + add_mlp_adapter_weights = ( + _add_moe_mlp_adapter_weights_for_layer + if grouped_moe + else _add_dense_mlp_adapter_weights_for_layer + ) + adapter_weights_by_base: dict[str, list[Any]] = {} + for chunk in model_chunks: + for module_name, module in chunk.named_modules(): + if not isinstance(module, TransformerLayer): + continue + if layer_filter is not None and not layer_filter(module_name): + continue + canonical_name = canonical_art_param_name(module_name) + layer_prefix = ( + canonical_name + if canonical_name.startswith( + ("decoder.layers.", "language_model.decoder.layers.") + ) + else f"language_model.decoder.layers.{module.layer_number - 1}" + ) + add_self_attention_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + self_attention=module.self_attention, + ) + add_mlp_adapter_weights(adapter_weights_by_base, layer_prefix, module) + return adapter_weights_by_base + + +def add_self_attention_adapter_weights( adapter_weights_by_base: dict[str, list[Any]], *, layer_prefix: str, self_attention: Any, ) -> None: - linear_proj = getattr(self_attention, "linear_proj", None) - if isinstance(linear_proj, SelfAttentionLinearProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.linear_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_proj.lora) - ] + for attr in ("linear_proj", "out_proj"): + linear_proj = getattr(self_attention, attr, None) + if isinstance(linear_proj, SelfAttentionLinearProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.{attr}" + _set_lora_weights( + adapter_weights_by_base, + base_prefix, + (linear_proj.lora, None), + ) linear_qkv = getattr(self_attention, "linear_qkv", None) if isinstance(linear_qkv, SelfAttentionLinearQKVLoRA): base_prefix = f"{layer_prefix}.self_attention.linear_qkv" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ + _set_lora_weights( + adapter_weights_by_base, + base_prefix, + (linear_qkv.q_proj_lora, "adapter_q"), + (linear_qkv.k_proj_lora, "adapter_k"), + (linear_qkv.v_proj_lora, "adapter_v"), + ) + + in_proj = getattr(self_attention, "in_proj", None) + if isinstance(in_proj, GatedDeltaNetInProjLoRA): + base_prefix = f"{layer_prefix}.self_attention.in_proj" + input_dim = int(in_proj.qkv_lora.A_T.shape[-2]) + output_dim = int(in_proj.num_value_heads_per_partition) + _set_adapter_weights( + adapter_weights_by_base, + base_prefix, _simple_adapter_weight( - base_prefix, - linear_qkv.q_proj_lora, - adapter_key="adapter_q", + base_prefix, in_proj.qkv_lora, adapter_key="adapter_qkv" ), _simple_adapter_weight( - base_prefix, - linear_qkv.k_proj_lora, - adapter_key="adapter_k", + base_prefix, in_proj.z_lora, adapter_key="adapter_z" ), - _simple_adapter_weight( - base_prefix, - linear_qkv.v_proj_lora, - adapter_key="adapter_v", + *( + _zero_adapter_weight( + base_prefix=base_prefix, + adapter_key=adapter_key, + input_dim=input_dim, + output_dim=output_dim, + like=in_proj.qkv_lora.B_T, + ) + for adapter_key in ("adapter_b", "adapter_a") ), - ] + ) -def add_gated_delta_net_adapter_weights( +def _add_dense_mlp_adapter_weights_for_layer( adapter_weights_by_base: dict[str, list[Any]], - *, layer_prefix: str, - self_attention: Any, + module: Any, ) -> None: - out_proj = getattr(self_attention, "out_proj", None) - if isinstance(out_proj, SelfAttentionLinearProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.out_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, out_proj.lora) - ] + from art.megatron.model_support.handlers.default_dense import _require_dense_mlp + + _require_dense_mlp(module) + add_split_mlp_adapter_weights( + adapter_weights_by_base, + f"{layer_prefix}.mlp", + module.mlp, + ) - in_proj = getattr(self_attention, "in_proj", None) - if isinstance(in_proj, GatedDeltaNetInProjLoRA): - base_prefix = f"{layer_prefix}.self_attention.in_proj" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - in_proj.qkv_lora, - adapter_key="adapter_qkv", - ), - _simple_adapter_weight( - base_prefix, - in_proj.z_lora, - adapter_key="adapter_z", - ), - _zero_adapter_weight( - base_prefix=base_prefix, - adapter_key="adapter_b", - input_dim=int(in_proj.qkv_lora.A_T.shape[-2]), - output_dim=int(in_proj.num_value_heads_per_partition), - like=in_proj.qkv_lora.B_T, - ), - _zero_adapter_weight( - base_prefix=base_prefix, - adapter_key="adapter_a", - input_dim=int(in_proj.qkv_lora.A_T.shape[-2]), - output_dim=int(in_proj.num_value_heads_per_partition), - like=in_proj.qkv_lora.B_T, - ), - ] + +def _add_moe_mlp_adapter_weights_for_layer( + adapter_weights_by_base: dict[str, list[Any]], + layer_prefix: str, + module: Any, +) -> None: + from art.megatron.model_support.handlers.default_dense import _require_moe_experts + + add_grouped_moe_adapter_weights( + adapter_weights_by_base, + layer_prefix=layer_prefix, + experts=_require_moe_experts(module), + ) + shared_experts = getattr(module.mlp, "shared_experts", None) + if shared_experts is not None: + add_split_mlp_adapter_weights( + adapter_weights_by_base, + f"{layer_prefix}.mlp.shared_experts", + shared_experts, + ) def add_grouped_moe_adapter_weights( @@ -248,100 +319,66 @@ def add_grouped_moe_adapter_weights( experts: Any, ) -> None: linear_fc1 = getattr(experts, "linear_fc1", None) - if isinstance(linear_fc1, MLPExpertsLinearFC1FusedLoRA): - base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" - for local_expert_idx in range(linear_fc1.lora.num_local_experts): - global_expert_idx = local_expert_idx + linear_fc1.lora._expert_offset - adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.lora, - expert_idx=local_expert_idx, - ) - ] - elif isinstance(linear_fc1, MLPExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" - for local_expert_idx in range(linear_fc1.gate_lora.num_local_experts): - global_expert_idx = local_expert_idx + linear_fc1.gate_lora._expert_offset - adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ - _fused_pair_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - linear_fc1.up_lora, - first_expert_idx=local_expert_idx, - second_expert_idx=local_expert_idx, - ) - ] + base_prefix = f"{layer_prefix}.mlp.experts.linear_fc1" + if isinstance(linear_fc1, MLPExpertsLinearFC1LoRA): + if linear_fc1.fused_gate_up: + lora = linear_fc1.lora + build_weight = lambda local_expert_idx: _simple_adapter_weight( + base_prefix, + linear_fc1.lora, + expert_idx=local_expert_idx, + ) + else: + lora = linear_fc1.gate_lora + build_weight = lambda local_expert_idx: _fused_pair_adapter_weight( + base_prefix, + linear_fc1.gate_lora, + linear_fc1.up_lora, + first_expert_idx=local_expert_idx, + second_expert_idx=local_expert_idx, + ) + _set_expert_adapter_weights( + adapter_weights_by_base, + base_prefix, + lora, + build_weight, + ) linear_fc2 = getattr(experts, "linear_fc2", None) if isinstance(linear_fc2, MLPExpertsLinearFC2LoRA): base_prefix = f"{layer_prefix}.mlp.experts.linear_fc2" - for local_expert_idx in range(linear_fc2.lora.num_local_experts): - global_expert_idx = local_expert_idx + linear_fc2.lora._expert_offset - adapter_weights_by_base[f"{base_prefix}.weight{global_expert_idx}"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc2.lora, - expert_idx=local_expert_idx, - ) - ] + _set_expert_adapter_weights( + adapter_weights_by_base, + base_prefix, + linear_fc2.lora, + lambda local_expert_idx: _simple_adapter_weight( + base_prefix, + linear_fc2.lora, + expert_idx=local_expert_idx, + ), + ) -def add_dense_mlp_adapter_weights( +def add_split_mlp_adapter_weights( adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, + base_prefix: str, mlp: Any, ) -> None: linear_fc1 = getattr(mlp, "linear_fc1", None) if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.linear_fc1" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - adapter_key="adapter_gate", - ), - _simple_adapter_weight( - base_prefix, - linear_fc1.up_lora, - adapter_key="adapter_up", - ), - ] + fc1_prefix = f"{base_prefix}.linear_fc1" + _set_lora_weights( + adapter_weights_by_base, + fc1_prefix, + (linear_fc1.gate_lora, "adapter_gate"), + (linear_fc1.up_lora, "adapter_up"), + ) linear_fc2 = getattr(mlp, "linear_fc2", None) - if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.linear_fc2" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) - ] - - -def add_shared_experts_adapter_weights( - adapter_weights_by_base: dict[str, list[Any]], - *, - layer_prefix: str, - shared_experts: Any, -) -> None: - linear_fc1 = getattr(shared_experts, "linear_fc1", None) - if isinstance(linear_fc1, SharedExpertsLinearFC1LoRA): - base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc1" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight( - base_prefix, - linear_fc1.gate_lora, - adapter_key="adapter_gate", - ), - _simple_adapter_weight( - base_prefix, - linear_fc1.up_lora, - adapter_key="adapter_up", - ), - ] - - linear_fc2 = getattr(shared_experts, "linear_fc2", None) - if isinstance(linear_fc2, SharedExpertsLinearFC2LoRA): - base_prefix = f"{layer_prefix}.mlp.shared_experts.linear_fc2" - adapter_weights_by_base[f"{base_prefix}.weight"] = [ - _simple_adapter_weight(base_prefix, linear_fc2.row_parallel_lora.lora) - ] + if isinstance(linear_fc2, SelfAttentionLinearProjLoRA): + fc2_prefix = f"{base_prefix}.linear_fc2" + _set_adapter_weights( + adapter_weights_by_base, + fc2_prefix, + _simple_adapter_weight(fc2_prefix, linear_fc2.lora), + ) diff --git a/src/art/megatron/weights/lora_publish.py b/src/art/megatron/weights/lora_publish.py index f4fd02a0a..930b45c58 100644 --- a/src/art/megatron/weights/lora_publish.py +++ b/src/art/megatron/weights/lora_publish.py @@ -1,16 +1,22 @@ from collections.abc import Iterable, Sequence -import re from typing import Any, NamedTuple import torch -from art.megatron.lora import LoRAPublishPlanner, LoraShardMeta +from art.megatron.lora import ( + LoRA, + LoRAPublishPlanner, + LoraShardMeta, + _block_for_key, + _dtype_name, +) +from art.megatron.lora import ( + _distributed_initialized as _distributed_ready, +) from art.megatron.model_support.lora_disk import save_vllm_lora_tensors from art.megatron.model_support.spec import ExpertPackedLoraGroup, ExpertPackedLoraSlot from art.megatron.training.model_chunks import ModelChunks -_LAYER_BLOCK_RE = re.compile(r"^(?P.*\.layers\.\d+)\.") - class PackedExpertShardMeta(NamedTuple): key: str @@ -58,35 +64,17 @@ def finish(self) -> None: self._events.clear() -def iter_lora_modules(model_chunks: ModelChunks) -> Iterable[Any]: +def iter_lora_modules(model_chunks: ModelChunks) -> Iterable[LoRA]: for chunk in model_chunks: for module in chunk.modules(): - yield module - - -def _dtype_name(dtype: torch.dtype) -> str: - return str(dtype).removeprefix("torch.") + if isinstance(module, LoRA): + yield module def _dtype_from_name(name: str) -> torch.dtype: - dtype = getattr(torch, name, None) - if not isinstance(dtype, torch.dtype): - raise RuntimeError(f"Unsupported LoRA tensor dtype={name!r}") - return dtype - - -def _block_for_key(key: str) -> str: - match = _LAYER_BLOCK_RE.match(key) - if match is not None: - return match.group("block") - return "__global__" - - -def _expert_prefix_projection(adapter_model_prefix: str) -> tuple[str, str] | None: - group_prefix, separator, projection = adapter_model_prefix.partition(".{expert}.") - if not separator: - return None - return group_prefix, projection + if isinstance(dtype := getattr(torch, name, None), torch.dtype): + return dtype + raise RuntimeError(f"Unsupported LoRA tensor dtype={name!r}") def _packed_expert_slot( @@ -94,10 +82,9 @@ def _packed_expert_slot( suffix: str, groups: Sequence[ExpertPackedLoraGroup], ) -> tuple[str, ExpertPackedLoraSlot] | None: - parts = _expert_prefix_projection(adapter_model_prefix) - if parts is None: + group_prefix, separator, projection = adapter_model_prefix.partition(".{expert}.") + if not separator: return None - group_prefix, projection = parts lora_name = suffix.removesuffix(".weight") for group in groups: if not group_prefix.endswith(group.art_group_suffix): @@ -109,23 +96,15 @@ def _packed_expert_slot( def _uses_packed_expert_publish( - module: Any, + module: LoRA, groups: Sequence[ExpertPackedLoraGroup], ) -> bool: - if int(getattr(module, "num_local_experts", 1)) <= 1: - return False - if not hasattr(module, "_lora_params"): + if module.num_local_experts <= 1: return False - adapter_model_prefix = getattr(module, "adapter_model_prefix", "") - if not isinstance(adapter_model_prefix, str): - return False - lora_suffixes = [ - suffix - for suffix, _param in module._lora_params() # type: ignore[attr-defined] - ] - return bool(lora_suffixes) and all( - _packed_expert_slot(adapter_model_prefix, suffix, groups) is not None - for suffix in lora_suffixes + params = tuple(module._lora_params()) + return bool(params) and all( + _packed_expert_slot(module.adapter_model_prefix, suffix, groups) is not None + for suffix, _param in params ) @@ -141,15 +120,12 @@ def collect_local_lora_entries( for module in iter_lora_modules(model_chunks): if _uses_packed_expert_publish(module, packed_expert_groups): continue - if hasattr(module, "sharded_lora_state_dict"): - module_state: dict[str, torch.Tensor] = module.sharded_lora_state_dict() # type: ignore[attr-defined] - for key, value in module_state.items(): - target_dtype = ( - adapter_model[key].dtype if key in adapter_model else value.dtype - ) - local_tensors[key] = value.to(target_dtype).contiguous() - if hasattr(module, "sharded_lora_manifest"): - local_manifest.update(module.sharded_lora_manifest()) # type: ignore[attr-defined] + for key, value in module.sharded_lora_state_dict().items(): + target_dtype = ( + adapter_model[key].dtype if key in adapter_model else value.dtype + ) + local_tensors[key] = value.to(target_dtype).contiguous() + local_manifest.update(module.sharded_lora_manifest()) if set(local_tensors) != set(local_manifest): raise RuntimeError( @@ -171,18 +147,6 @@ def collect_local_lora_entries( return local_tensors, metadata -def _target_dtype_for_lora_param( - module: Any, - adapter_model: dict[str, torch.Tensor], - suffix: str, - fallback: torch.dtype, -) -> torch.dtype: - keys = module._expected_weight_keys(suffix.removesuffix(".weight")) # type: ignore[attr-defined] - return ( - adapter_model[keys[0]].dtype if keys and keys[0] in adapter_model else fallback - ) - - def collect_local_packed_expert_entries( model_chunks: ModelChunks, adapter_model: dict[str, torch.Tensor], @@ -195,25 +159,24 @@ def collect_local_packed_expert_entries( for module in iter_lora_modules(model_chunks): if not _uses_packed_expert_publish(module, packed_expert_groups): continue - adapter_model_prefix = module.adapter_model_prefix # type: ignore[attr-defined] - expert_start = int(module._expert_offset) # type: ignore[attr-defined] - expert_count = int(module.num_local_experts) # type: ignore[attr-defined] - for suffix, param in module._lora_params(): # type: ignore[attr-defined] + expert_start = int(module._expert_offset) + expert_count = int(module.num_local_experts) + for suffix, param in module._lora_params(): slot_match = _packed_expert_slot( - adapter_model_prefix, + module.adapter_model_prefix, suffix, packed_expert_groups, ) - if slot_match is None or not module._should_export_parameter(param): # type: ignore[attr-defined] + if slot_match is None or not module._should_export_parameter(param): continue group_prefix, slot = slot_match key = f"{group_prefix}.{slot.output_suffix}" tensor = param.data.transpose(1, 2).contiguous() - target_dtype = _target_dtype_for_lora_param( - module, - adapter_model, - suffix, - tensor.dtype, + source_keys = module._expected_weight_keys(suffix.removesuffix(".weight")) + target_dtype = ( + adapter_model[source_keys[0]].dtype + if source_keys and source_keys[0] in adapter_model + else tensor.dtype ) tensor = tensor.to(target_dtype).contiguous() if key in local_tensors: @@ -225,7 +188,7 @@ def collect_local_packed_expert_entries( owner_rank=owner_rank, shape=tuple(int(dim) for dim in tensor.shape), dtype_name=_dtype_name(tensor.dtype), - manifest=module._manifest_for_param(param), # type: ignore[attr-defined] + manifest=module._manifest_for_param(param), expert_start=expert_start, expert_count=expert_count, pack_layout=slot.pack_layout, @@ -252,7 +215,12 @@ def _global_packed_expert_metadata( continue group_prefix, slot = slot_match shard_ranks = range(template.shard_world_size) if template.sharded else (0,) - for ep_rank in range(planner._expert_model_world_size()): + ep_world_size = 1 + if _distributed_ready(): + from megatron.core import parallel_state as ps + + ep_world_size = ps.get_expert_model_parallel_world_size() + for ep_rank in range(ep_world_size): expert_start = ep_rank * template.num_local_experts expert_key = ( f"{template.adapter_model_prefix.format(expert=expert_start)}." @@ -350,70 +318,66 @@ def _merge_sharded_tensor( return torch.cat(tuple(ordered_shards), dim=axis).contiguous() -def merge_sharded_adapter_entries( - entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]], -) -> dict[str, torch.Tensor]: - adapter_model: dict[str, torch.Tensor] = {} - for key, key_entries in entries_by_key.items(): - first_manifest = key_entries[0][0] - sharded = bool(first_manifest["sharded"]) - shard_world_size = int(first_manifest["shard_world_size"]) - for manifest_entry, _tensor in key_entries: - if bool(manifest_entry["sharded"]) != sharded: - raise RuntimeError(f"Inconsistent sharded flag for key={key}") - if int(manifest_entry["shard_world_size"]) != shard_world_size: - raise RuntimeError(f"Inconsistent shard world size for key={key}") - - if not sharded: - if len(key_entries) != 1: - raise RuntimeError( - f"Replicated key={key} expected 1 shard, got {len(key_entries)}" - ) - adapter_model[key] = key_entries[0][1] - continue - - shard_rank_to_tensor: dict[int, torch.Tensor] = {} - for manifest_entry, shard_tensor in key_entries: - shard_rank = int(manifest_entry["shard_rank"]) - if shard_rank in shard_rank_to_tensor: - raise RuntimeError(f"Duplicate shard_rank={shard_rank} for key={key}") - shard_rank_to_tensor[shard_rank] = shard_tensor +def _merge_manifest_entries( + key: str, + key_entries: Sequence[tuple[dict[str, Any], torch.Tensor]], + *, + manifest: dict[str, Any] | None = None, +) -> torch.Tensor: + first_manifest = key_entries[0][0] + sharded = bool(first_manifest["sharded"]) + shard_world_size = int(first_manifest["shard_world_size"]) + for entry_manifest, _tensor in key_entries: + if bool(entry_manifest["sharded"]) != sharded: + raise RuntimeError(f"Inconsistent sharded flag for key={key}") + if int(entry_manifest["shard_world_size"]) != shard_world_size: + raise RuntimeError(f"Inconsistent shard world size for key={key}") - expected_shard_ranks = set(range(shard_world_size)) - if set(shard_rank_to_tensor) != expected_shard_ranks: + if not sharded: + if len(key_entries) != 1: raise RuntimeError( - f"Shard rank coverage mismatch for key={key}: " - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" + f"Replicated key={key} expected 1 shard, got {len(key_entries)}" ) + return key_entries[0][1] - ordered_shards = [ - shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) - ] - adapter_model[key] = _merge_sharded_tensor( - key, - ordered_shards=ordered_shards, - manifest=first_manifest, + shard_rank_to_tensor: dict[int, torch.Tensor] = {} + for entry_manifest, shard_tensor in key_entries: + shard_rank = int(entry_manifest["shard_rank"]) + if shard_rank in shard_rank_to_tensor: + raise RuntimeError(f"Duplicate shard_rank={shard_rank} for key={key}") + shard_rank_to_tensor[shard_rank] = shard_tensor + + expected_shard_ranks = set(range(shard_world_size)) + if set(shard_rank_to_tensor) != expected_shard_ranks: + raise RuntimeError( + f"Shard rank coverage mismatch for key={key}: " + f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" ) - return adapter_model + return _merge_sharded_tensor( + key, + ordered_shards=[ + shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) + ], + manifest=first_manifest if manifest is None else manifest, + ) -def _distributed_ready() -> bool: - is_initialized = getattr(torch.distributed, "is_initialized", None) - return ( - torch.distributed.is_available() - and callable(is_initialized) - and bool(is_initialized()) - ) +def merge_sharded_adapter_entries( + entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]], +) -> dict[str, torch.Tensor]: + return { + key: _merge_manifest_entries(key, key_entries) + for key, key_entries in entries_by_key.items() + } def _rank_and_device() -> tuple[int, torch.device]: - if _distributed_ready(): - rank = torch.distributed.get_rank() # type: ignore[possibly-missing-attribute] - else: - rank = 0 - if torch.cuda.is_available(): - return rank, torch.device("cuda", torch.cuda.current_device()) - return rank, torch.device("cpu") + return ( + torch.distributed.get_rank() if _distributed_ready() else 0, # type: ignore[possibly-missing-attribute] + torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available() + else torch.device("cpu"), + ) def _metadata_by_owner_dtype( @@ -514,45 +478,10 @@ def _merge_packed_expert_block( key: str, key_entries: list[tuple[dict[str, Any], torch.Tensor]], ) -> torch.Tensor: - first_manifest = key_entries[0][0] - sharded = bool(first_manifest["sharded"]) - shard_world_size = int(first_manifest["shard_world_size"]) - if not sharded: - if len(key_entries) != 1: - raise RuntimeError( - f"Replicated packed key={key} expected 1 shard, got {len(key_entries)}" - ) - return key_entries[0][1] - - shard_rank_to_tensor: dict[int, torch.Tensor] = {} - for manifest_entry, shard_tensor in key_entries: - if bool(manifest_entry["sharded"]) != sharded: - raise RuntimeError(f"Inconsistent sharded flag for packed key={key}") - if int(manifest_entry["shard_world_size"]) != shard_world_size: - raise RuntimeError(f"Inconsistent shard world size for packed key={key}") - shard_rank = int(manifest_entry["shard_rank"]) - if shard_rank in shard_rank_to_tensor: - raise RuntimeError( - f"Duplicate shard_rank={shard_rank} for packed key={key}" - ) - shard_rank_to_tensor[shard_rank] = shard_tensor - - expected_shard_ranks = set(range(shard_world_size)) - if set(shard_rank_to_tensor) != expected_shard_ranks: - raise RuntimeError( - f"Shard rank coverage mismatch for packed key={key}: " - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" - ) - - manifest = dict(first_manifest) - manifest["export_shard_dim"] = int(manifest["export_shard_dim"]) + 1 - return _merge_sharded_tensor( - key, - ordered_shards=[ - shard_rank_to_tensor[shard_rank] for shard_rank in range(shard_world_size) - ], - manifest=manifest, - ) + manifest = dict(key_entries[0][0]) + if bool(manifest["sharded"]): + manifest["export_shard_dim"] = int(manifest["export_shard_dim"]) + 1 + return _merge_manifest_entries(key, key_entries, manifest=manifest) def _pack_merged_expert_blocks( diff --git a/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py b/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py index 3d3d51d4c..5b6e39390 100644 --- a/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py +++ b/tests/integration/megatron/cp_attn/test_attention_packed_vs_flattened.py @@ -60,9 +60,7 @@ def test_shared_prefix_attention_matches_flattened_grad_accumulation() -> None: tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) q, k, v = _attention_inputs(group_ids.shape, seed=20260425) q_ref = q.detach().clone().requires_grad_(True) k_ref = k.detach().clone().requires_grad_(True) @@ -82,40 +80,19 @@ def test_shared_prefix_attention_matches_flattened_grad_accumulation() -> None: ref_out = torch.zeros_like(packed_out) ref_loss = q_ref.new_zeros(()) - for family in spec.families: - prefix = family.prefix - prefix_grad_used = False - for completion in family.completions: - indices = torch.tensor( - [ - *range(prefix.start, prefix.end), - *range(completion.start, completion.end), - ], - device=q.device, - dtype=torch.long, - ) - row = family.row_index - q_slice = q_ref[row : row + 1].index_select(2, indices) - k_slice = k_ref[row : row + 1].index_select(2, indices) - v_slice = v_ref[row : row + 1].index_select(2, indices) - flat_out = _dense_causal_attention(q_slice, k_slice, v_slice) - - ref_out[row, :, completion.start : completion.end] = flat_out[ - 0, :, prefix.length : - ] - flat_grad = torch.zeros_like(flat_out) - flat_grad[0, :, prefix.length :] = output_grad[ - row, :, completion.start : completion.end - ] - if not prefix_grad_used: - ref_out[row, :, prefix.start : prefix.end] = flat_out[ - 0, :, : prefix.length - ] - flat_grad[0, :, : prefix.length] = output_grad[ - row, :, prefix.start : prefix.end - ] - prefix_grad_used = True - ref_loss = ref_loss + (flat_out * flat_grad).sum() + for segment_index, segment in enumerate(spec.tree_segments): + indices, output_slice = _segment_context_positions(spec, segment_index) + index_tensor = torch.tensor(indices, device=q.device, dtype=torch.long) + row = segment.row_index + q_slice = q_ref[row : row + 1].index_select(2, index_tensor) + k_slice = k_ref[row : row + 1].index_select(2, index_tensor) + v_slice = v_ref[row : row + 1].index_select(2, index_tensor) + flat_out = _dense_causal_attention(q_slice, k_slice, v_slice) + + ref_out[row, :, segment.start : segment.end] = flat_out[0, :, output_slice] + flat_grad = torch.zeros_like(flat_out) + flat_grad[0, :, output_slice] = output_grad[row, :, segment.start : segment.end] + ref_loss = ref_loss + (flat_out * flat_grad).sum() ref_loss.backward() real_mask = _real_token_mask(spec, q.shape, device=q.device) @@ -142,9 +119,7 @@ def test_physical_causal_attention_leaks_across_siblings() -> None: tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) q, k, v = _attention_inputs(group_ids.shape, seed=20260427) attention_state = create_shared_prefix_state(group_ids, parent_ids) packed_out = FlexAttentionWrapper()( @@ -225,11 +200,27 @@ def _completion_token_mask( spec: Any, shape: torch.Size, *, device: torch.device ) -> torch.Tensor: mask = torch.zeros(shape, device=device, dtype=torch.bool) - for family in spec.families: - for completion in family.completions: - mask[ - family.row_index, - :, - completion.start : completion.end, - ] = True + for index, segment in enumerate(spec.tree_segments): + if spec.tree_parent_indices[index] >= 0: + mask[segment.row_index, :, segment.start : segment.end] = True return mask + + +def _segment_context_positions( + spec: Any, segment_index: int +) -> tuple[list[int], slice]: + path = [] + cursor = segment_index + while cursor >= 0: + path.append(cursor) + cursor = spec.tree_parent_indices[cursor] + path.reverse() + positions = [ + position + for index in path + for position in range( + spec.tree_segments[index].start, spec.tree_segments[index].end + ) + ] + segment_length = spec.tree_segments[segment_index].length + return positions, slice(len(positions) - segment_length, len(positions)) diff --git a/tests/integration/megatron/gdn_shared_prefix/distributed_init.py b/tests/integration/megatron/gdn_shared_prefix/distributed_init.py new file mode 100644 index 000000000..b9b4075c8 --- /dev/null +++ b/tests/integration/megatron/gdn_shared_prefix/distributed_init.py @@ -0,0 +1,7 @@ +from pathlib import Path + + +def file_init_method(tmp_path: Path, name: str) -> str: + path = tmp_path / f"{name}.dist" + path.unlink(missing_ok=True) + return f"file://{path}" diff --git a/tests/integration/megatron/gdn_shared_prefix/layout_reference.py b/tests/integration/megatron/gdn_shared_prefix/layout_reference.py index 7369eaef7..8cff82405 100644 --- a/tests/integration/megatron/gdn_shared_prefix/layout_reference.py +++ b/tests/integration/megatron/gdn_shared_prefix/layout_reference.py @@ -19,7 +19,7 @@ class TestGdnCpLayoutPlan(BaseModel): - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) batch_size: int = Field(ge=1) sequence_length: int = Field(ge=1) @@ -39,9 +39,7 @@ def build_test_gdn_cp_layout_plan( gdn_token_ranges_by_rank: Sequence[Sequence[tuple[int, int, int]]] | None = None, device: torch.device | str | None = None, ) -> TestGdnCpLayoutPlan: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) gdn_ranges = ( _normalize_rank_ranges(gdn_token_ranges_by_rank, cp_size=cp_size) if gdn_token_ranges_by_rank is not None @@ -89,7 +87,7 @@ def _build_full_exchange_plan( ) for transfer in local_plan.transfers: transfers.setdefault((transfer.source_rank, transfer.dest_rank), transfer) - return GdnCpExchangePlan.model_construct( + return GdnCpExchangePlan( cp_size=len(source_layout.token_counts_by_rank), source_token_counts_by_rank=source_layout.token_counts_by_rank, dest_token_counts_by_rank=tuple( @@ -133,7 +131,7 @@ def _split_gdn_token_ranges_by_rank( _segment_token_start(segment, spec.sequence_length), _segment_token_start(segment, spec.sequence_length) + segment.length, ) - for segment in spec.segments() + for segment in spec.tree_segments ), cp_size=cp_size, ) diff --git a/tests/integration/megatron/gdn_shared_prefix/oracles.py b/tests/integration/megatron/gdn_shared_prefix/oracles.py index 3d3f9ae12..019ec74e7 100644 --- a/tests/integration/megatron/gdn_shared_prefix/oracles.py +++ b/tests/integration/megatron/gdn_shared_prefix/oracles.py @@ -7,6 +7,8 @@ from torch import Tensor import torch.nn.functional as F +from art.megatron.gdn.gdn_shared_prefix import GdnPackedExecutionSpec, GdnSegmentSpec + from .metrics import ( mean_abs_pct, parameter_grad_mean_abs_pct_with_name, @@ -107,27 +109,27 @@ def run_toy_packed( group_ids: Tensor, parent_ids: Tensor, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden[row, family.prefix.start : family.prefix.end] - prefix_out, prefix_conv, prefix_rec = module.forward_segment( - prefix_hidden, - conv_initial=module.zero_conv_state(hidden), - recurrent_initial=module.zero_recurrent_state(hidden), + conv_states: list[Tensor] = [] + rec_states: list[Tensor] = [] + for segment_index, segment in enumerate(spec.tree_segments): + row = segment.row_index + parent_index = spec.tree_parent_indices[segment_index] + if parent_index < 0: + conv_initial = module.zero_conv_state(hidden) + rec_initial = module.zero_recurrent_state(hidden) + else: + conv_initial = conv_states[parent_index] + rec_initial = rec_states[parent_index] + segment_out, conv_final, rec_final = module.forward_segment( + hidden[row, segment.start : segment.end], + conv_initial=conv_initial, + recurrent_initial=rec_initial, ) - output[row, family.prefix.start : family.prefix.end] = prefix_out - for completion in family.completions: - suffix_hidden = hidden[row, completion.start : completion.end] - suffix_out, _, _ = module.forward_segment( - suffix_hidden, - conv_initial=prefix_conv, - recurrent_initial=prefix_rec, - ) - output[row, completion.start : completion.end] = suffix_out + output[row, segment.start : segment.end] = segment_out + conv_states.append(conv_final) + rec_states.append(rec_final) return output @@ -138,30 +140,36 @@ def run_toy_flattened_reference( group_ids: Tensor, parent_ids: Tensor, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden[row, family.prefix.start : family.prefix.end] - prefix_len = family.prefix.length - for child_index, completion in enumerate(family.completions): - suffix_hidden = hidden[row, completion.start : completion.end] - flattened = torch.cat([prefix_hidden, suffix_hidden], dim=0) - flat_out, _, _ = module.forward_segment( - flattened, - conv_initial=module.zero_conv_state(hidden), - recurrent_initial=module.zero_recurrent_state(hidden), - ) - if child_index == 0: - output[row, family.prefix.start : family.prefix.end] = flat_out[ - :prefix_len - ] - output[row, completion.start : completion.end] = flat_out[prefix_len:] + for segment_index, segment in enumerate(spec.tree_segments): + path = _segment_path(spec, segment_index) + flattened = torch.cat( + [hidden[node.row_index, node.start : node.end] for node in path], + dim=0, + ) + flat_out, _, _ = module.forward_segment( + flattened, + conv_initial=module.zero_conv_state(hidden), + recurrent_initial=module.zero_recurrent_state(hidden), + ) + segment_len = segment.length + output[segment.row_index, segment.start : segment.end] = flat_out[-segment_len:] return output +def _segment_path( + spec: GdnPackedExecutionSpec, + segment_index: int, +) -> tuple[GdnSegmentSpec, ...]: + indices = [] + cursor = segment_index + while cursor >= 0: + indices.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(indices)) + + def run_toy_physical_stream( module: ToyStatefulGdn, hidden: Tensor, diff --git a/tests/integration/megatron/gdn_shared_prefix/packed_layout.py b/tests/integration/megatron/gdn_shared_prefix/packed_layout.py index 45a41ff58..fa1b00d05 100644 --- a/tests/integration/megatron/gdn_shared_prefix/packed_layout.py +++ b/tests/integration/megatron/gdn_shared_prefix/packed_layout.py @@ -137,19 +137,23 @@ def summarize_case( conv_width: int, cp_sizes: tuple[int, ...] = (2, 4, 8), ) -> GdnCaseSummary: - spec = parse_gdn_shared_prefix_segments( - tensors["group_ids"], tensors["parent_ids"], min_completions_per_family=1 - ) + spec = parse_gdn_shared_prefix_segments(tensors["group_ids"], tensors["parent_ids"]) suffix_lengths = [ - segment.length for family in spec.families for segment in family.completions + segment.length + for index, segment in enumerate(spec.tree_segments) + if spec.tree_parent_indices[index] >= 0 ] boundary = _boundary_flags(spec, cp_sizes) return GdnCaseSummary( name=case.name, total_tokens=spec.real_token_count, family_count=spec.family_count, - completion_count=spec.completion_count, - max_segment_length=spec.max_segment_length, + completion_count=sum( + 1 for parent_index in spec.tree_parent_indices if parent_index >= 0 + ), + max_segment_length=max( + (segment.length for segment in spec.tree_segments), default=0 + ), suffix_shorter_than_conv=any(length < conv_width for length in suffix_lengths), suffix_equal_to_conv=any(length == conv_width for length in suffix_lengths), suffix_longer_than_conv=any(length > conv_width for length in suffix_lengths), @@ -227,19 +231,49 @@ def _boundary_flags( boundaries = {shard * rank for rank in range(1, cp_size)} if shard * (cp_size - 1) >= spec.real_token_count: flags["empty_trailing_rank"] = True - for family in spec.families: - family_start = _segment_real_start(family.prefix, spec, real_index) - family_end = _segment_real_end(family.completions[-1], spec, real_index) + for root in _root_segments(spec): + descendants = _descendant_segments(spec, root.family_index) + family_segments = (root, *descendants) + family_start = min( + _segment_real_start(segment, spec, real_index) + for segment in family_segments + ) + family_end = max( + _segment_real_end(segment, spec, real_index) + for segment in family_segments + ) if family_start in boundaries or family_end in boundaries: flags["family_boundary_at_partition"] = True - if _crosses_boundary(family.prefix, spec, real_index, boundaries): + if _crosses_boundary(root, spec, real_index, boundaries): flags["cp_boundary_prefix"] = True - for completion in family.completions: + for completion in descendants: if _crosses_boundary(completion, spec, real_index, boundaries): flags["cp_boundary_suffix"] = True return flags +def _root_segments(spec: GdnPackedExecutionSpec) -> tuple[Any, ...]: + return tuple( + segment + for index, segment in enumerate(spec.tree_segments) + if spec.tree_parent_indices[index] < 0 + ) + + +def _descendant_segments( + spec: GdnPackedExecutionSpec, root_index: int +) -> tuple[Any, ...]: + descendants = [] + for index, segment in enumerate(spec.tree_segments): + parent = spec.tree_parent_indices[index] + while parent >= 0: + if parent == root_index: + descendants.append(segment) + break + parent = spec.tree_parent_indices[parent] + return tuple(descendants) + + def _segment_real_start( segment: Any, spec: GdnPackedExecutionSpec, real_index: dict[int, int] ) -> int: diff --git a/tests/integration/megatron/gdn_shared_prefix/parser_import.py b/tests/integration/megatron/gdn_shared_prefix/parser_import.py index ce184d96e..3a473ebf3 100644 --- a/tests/integration/megatron/gdn_shared_prefix/parser_import.py +++ b/tests/integration/megatron/gdn_shared_prefix/parser_import.py @@ -24,6 +24,5 @@ def _load_parser_module() -> ModuleType: _MODULE = _load_parser_module() GdnPackedExecutionSpec: Any = _MODULE.GdnPackedExecutionSpec -build_gdn_cp_segment_schedule: Any = _MODULE.build_gdn_cp_segment_schedule build_gdn_rank_execution_plan: Any = _MODULE.build_gdn_rank_execution_plan parse_gdn_shared_prefix_segments: Any = _MODULE.parse_gdn_shared_prefix_segments diff --git a/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py b/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py index e69fef22b..38fb01889 100644 --- a/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py +++ b/tests/integration/megatron/gdn_shared_prefix/real_gdn_oracle.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Any, Literal, NamedTuple from pydantic import BaseModel, ConfigDict import torch @@ -61,6 +61,57 @@ class GdnChainBoundaryDebug(BaseModel): ] +class _TreeFamily(NamedTuple): + row_index: int + family_index: int + prefix: Any + completions: tuple[Any, ...] + segment_indices: tuple[int, ...] + parent_indices: tuple[int, ...] + + @property + def token_count(self) -> int: + return self.prefix.length + sum(segment.length for segment in self.completions) + + +def _segment_path(spec: Any, segment_index: int) -> tuple[Any, ...]: + path = [] + cursor = segment_index + while cursor >= 0: + path.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(path)) + + +def _tree_families(spec: Any) -> tuple[_TreeFamily, ...]: + families = [] + for root_index, root in enumerate(spec.tree_segments): + if spec.tree_parent_indices[root_index] >= 0: + continue + segment_indices = [root_index] + for index in range(root_index + 1, len(spec.tree_segments)): + parent = spec.tree_parent_indices[index] + while parent >= 0: + if parent == root_index: + segment_indices.append(index) + break + parent = spec.tree_parent_indices[parent] + segments = tuple(spec.tree_segments[index] for index in segment_indices) + families.append( + _TreeFamily( + row_index=root.row_index, + family_index=root_index, + prefix=root, + completions=segments[1:], + segment_indices=tuple(segment_indices), + parent_indices=tuple( + spec.tree_parent_indices[index] for index in segment_indices + ), + ) + ) + return tuple(families) + + def compare_real_gdn_cp1_to_flattened( *, packed_gdn: Any, @@ -296,35 +347,34 @@ def run_real_gdn_flattened_reference( parent_ids: Tensor, execution_spec: Any | None = None, ) -> Tensor: - spec = execution_spec or parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=1 - ) + spec = execution_spec or parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) - for family in spec.families: - row = family.row_index - prefix_hidden = hidden_states[ - family.prefix.start : family.prefix.end, row : row + 1, : - ] - prefix_len = family.prefix.length - for child_index, completion in enumerate(family.completions): - suffix_hidden = hidden_states[ - completion.start : completion.end, row : row + 1, : - ] - flat_hidden = torch.cat([prefix_hidden, suffix_hidden], dim=0) - flat_out, _, _, _ = _run_gdn_segment( - gdn, - flat_hidden, - conv_initial=_zero_conv_state(gdn, hidden_states, row), - recurrent_initial=_zero_recurrent_state(gdn, hidden_states, row), - output_final_state=False, - ) - if child_index == 0: - output[family.prefix.start : family.prefix.end, row : row + 1, :] = ( - flat_out[:prefix_len] - ) - output[completion.start : completion.end, row : row + 1, :] = flat_out[ - prefix_len: - ] + for segment_index, segment in enumerate(spec.tree_segments): + flat_hidden = torch.cat( + [ + hidden_states[ + node.start : node.end, + node.row_index : node.row_index + 1, + :, + ] + for node in _segment_path(spec, segment_index) + ], + dim=0, + ) + flat_out, _, _, _ = _run_gdn_segment( + gdn, + flat_hidden, + conv_initial=_zero_conv_state(gdn, hidden_states, segment.row_index), + recurrent_initial=_zero_recurrent_state( + gdn, hidden_states, segment.row_index + ), + output_final_state=False, + ) + output[ + segment.start : segment.end, + segment.row_index : segment.row_index + 1, + :, + ] = flat_out[-segment.length :] return output @@ -359,9 +409,7 @@ def run_real_gdn_local_fork_reference( cp_size: int, attention_token_layout_index: TokenLayoutIndex | None = None, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) gdn_token_indices_by_rank = _split_gdn_families_by_rank(spec, cp_size=cp_size) gdn_token_ranges_by_rank = _rank_ranges_from_tokens_by_rank( gdn_token_indices_by_rank @@ -414,12 +462,12 @@ def _split_gdn_families_by_rank( raise ValueError(f"cp_size must be >= 1, got {cp_size}") ranks: list[list[int]] = [[] for _ in range(cp_size)] loads = [0] * cp_size - for family in spec.families: + for family in _tree_families(spec): rank = min(range(cp_size), key=lambda index: (loads[index], index)) family_tokens = tuple( token for segment in (family.prefix, *family.completions) - for token in segment.linear_indices(spec.sequence_length) + for token in _segment_linear_indices(segment, spec.sequence_length) ) ranks[rank].extend(family_tokens) loads[rank] += len(family_tokens) @@ -471,6 +519,11 @@ def _simulate_all_to_all_single( return tuple(outputs) +def _segment_linear_indices(segment: Any, sequence_length: int) -> range: + base = int(segment.row_index) * int(sequence_length) + return range(base + int(segment.start), base + int(segment.end)) + + def _transfer_positions(tensor: Tensor | None, *, count: int) -> tuple[int, ...]: if tensor is None: return tuple(range(count)) @@ -523,11 +576,9 @@ def run_real_gdn_suffix_only_chain_reference( mutation: GdnChainMutation | None = None, boundary_debug: list[GdnChainBoundaryDebug] | None = None, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) - for family in spec.families: + for family in _tree_families(spec): row = family.row_index zero_conv = _zero_conv_state(gdn, hidden_states, batch_size=1) zero_rec = _zero_recurrent_state(gdn, hidden_states, batch_size=1) @@ -575,11 +626,9 @@ def run_real_gdn_chunk_native_reference( group_ids: Tensor, parent_ids: Tensor, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) - for family in spec.families: + for family in _tree_families(spec): _scatter_family_output( output, family, @@ -597,13 +646,11 @@ def run_real_gdn_mixed_cp_reference( cp_size: int, local_fork_max_tokens: int, ) -> Tensor: - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) output = torch.zeros_like(hidden_states) local_count = 0 chain_count = 0 - for family in spec.families: + for family in _tree_families(spec): if family.token_count <= local_fork_max_tokens: local_count += 1 _scatter_family_output( @@ -753,14 +800,21 @@ def _family_group_tensors( ) -> tuple[Tensor, Tensor]: group_ids = [] parent_ids = [] - prefix_group_id = 0 - group_ids.extend([prefix_group_id] * family.prefix.length) - parent_ids.extend([prefix_group_id] * family.prefix.length) - next_group_id = 1 - for completion in family.completions: - group_ids.extend([next_group_id] * completion.length) - parent_ids.extend([prefix_group_id] * completion.length) - next_group_id += 1 + local_group_by_global: dict[int, int] = {} + for local_group_id, (segment, global_index, parent_index) in enumerate( + zip( + (family.prefix, *family.completions), + family.segment_indices, + family.parent_indices, + strict=True, + ) + ): + local_group_by_global[global_index] = local_group_id + local_parent_id = ( + local_group_id if parent_index < 0 else local_group_by_global[parent_index] + ) + group_ids.extend([local_group_id] * segment.length) + parent_ids.extend([local_parent_id] * segment.length) return ( torch.tensor([group_ids], device=device, dtype=torch.long), torch.tensor([parent_ids], device=device, dtype=torch.long), @@ -883,12 +937,12 @@ def _local_fork_group_tensors( ) parent_ids = torch.full_like(group_ids, -1) next_group_id = 0 - for family in spec.families: + for family in _tree_families(spec): family_segments = (family.prefix, *family.completions) family_tokens = tuple( token_index for segment in family_segments - for token_index in segment.linear_indices(spec.sequence_length) + for token_index in _segment_linear_indices(segment, spec.sequence_length) ) token_is_local = tuple( token_index in local_position for token_index in family_tokens @@ -898,19 +952,23 @@ def _local_fork_group_tensors( if not all(token_is_local): raise ValueError("local-fork execution requires whole prompt families") - prefix_group_id = next_group_id - next_group_id += 1 - for token_index in family.prefix.linear_indices(spec.sequence_length): - position = local_position[token_index] - group_ids[position] = prefix_group_id - parent_ids[position] = prefix_group_id - for completion in family.completions: - child_group_id = next_group_id + group_by_segment_index: dict[int, int] = {} + for segment, global_index, parent_index in zip( + family_segments, + family.segment_indices, + family.parent_indices, + strict=True, + ): + group_id = next_group_id next_group_id += 1 - for token_index in completion.linear_indices(spec.sequence_length): + group_by_segment_index[global_index] = group_id + parent_group_id = ( + group_id if parent_index < 0 else group_by_segment_index[parent_index] + ) + for token_index in _segment_linear_indices(segment, spec.sequence_length): position = local_position[token_index] - group_ids[position] = child_group_id - parent_ids[position] = prefix_group_id + group_ids[position] = group_id + parent_ids[position] = parent_group_id if torch.any(group_ids == -1): raise RuntimeError("local-fork metadata left unassigned token rows") return group_ids.unsqueeze(0), parent_ids.unsqueeze(0) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py b/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py index bcf3a0cfb..6f5eefc17 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_fla_cp_native_recurrent.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket from typing import Any, cast import pytest @@ -20,6 +19,7 @@ chunk_gated_delta_rule_native_cp, ) +from .distributed_init import file_init_method # noqa: E402 from .metrics import GDN_CORRECTNESS_DTYPE, assert_mean_abs_pct # noqa: E402 _CP_SIZES = ( @@ -43,10 +43,10 @@ def test_native_fla_cp_recurrent_matches_single_rank( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_fla_recurrent_cp{cp_size}") mp.spawn( _native_fla_cp_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -62,10 +62,10 @@ def test_native_fla_cp_recurrent_matches_single_rank( def test_native_fla_cp_recurrent_varlen_multichain_matches_single_rank( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_fla_varlen_cp{cp_size}") mp.spawn( _native_fla_cp_varlen_multichain_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -119,13 +119,13 @@ def test_native_fla_summary_affine_debug_matches_final_state() -> None: def _native_fla_cp_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -201,13 +201,13 @@ def _native_fla_cp_worker( def _native_fla_cp_varlen_multichain_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -519,9 +519,3 @@ def _cat_varlen_slices( def _assert_grad_close(left: torch.Tensor, right_grad: torch.Tensor, name: str) -> None: assert left.grad is not None, name assert_mean_abs_pct(right_grad, left.grad, name) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py index 2151b41e1..ac14e8df8 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_packed_correctness.py @@ -2,7 +2,6 @@ from collections.abc import Callable from pathlib import Path -import socket from typing import Any import pytest @@ -21,6 +20,7 @@ parse_gdn_shared_prefix_segments, ) from art.megatron.gdn.operator import run_gdn_layer # noqa: E402 +from art.megatron.shared_prefix_packing import pack_shared_prefixes # noqa: E402 from .cases import ( # noqa: E402 GdnFamilyShape, @@ -29,6 +29,7 @@ default_phase0_cases, ) from .distributed_grad import all_reduce_parameter_grads_coalesced # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .metrics import ( # noqa: E402 GDN_CORRECTNESS_DTYPE, REAL_GDN_GRAD_MEAN_ABS_PCT_THRESHOLD, @@ -50,10 +51,10 @@ def test_gdn_cp_packed_matches_cp1_oracle_all_edge_cases( cp_size: int, tmp_path: Path ) -> None: _skip_without_gpus(cp_size) - port = _find_free_port() + init_method = file_init_method(tmp_path, f"cp1_oracle_cp{cp_size}") mp.spawn( _cp1_oracle_worker, - args=(cp_size, port, str(tmp_path), False), + args=(cp_size, init_method, str(tmp_path), False), nprocs=cp_size, join=True, ) @@ -66,10 +67,10 @@ def test_gdn_cp_packed_sibling_order_matches_cp1_oracle( cp_size: int, tmp_path: Path ) -> None: _skip_without_gpus(cp_size) - port = _find_free_port() + init_method = file_init_method(tmp_path, f"cp1_oracle_sibling_cp{cp_size}") mp.spawn( _cp1_oracle_worker, - args=(cp_size, port, str(tmp_path), True), + args=(cp_size, init_method, str(tmp_path), True), nprocs=cp_size, join=True, ) @@ -77,17 +78,45 @@ def test_gdn_cp_packed_sibling_order_matches_cp1_oracle( assert (tmp_path / f"cp1_oracle_sibling_rank_{rank}.ok").read_text() == "ok\n" +@pytest.mark.parametrize("cp_size", (2, 4)) +def test_gdn_cp_tree_chain_matches_cp1_oracle(cp_size: int, tmp_path: Path) -> None: + _skip_without_gpus(cp_size) + init_method = file_init_method(tmp_path, f"tree_chain_cp{cp_size}") + mp.spawn( + _tree_chain_oracle_worker, + args=(cp_size, init_method, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_chain_rank_{rank}.ok").read_text() == "ok\n" + + +def test_gdn_cp_tree_fuzz_matches_cp1_oracle(tmp_path: Path) -> None: + cp_size = 4 + _skip_without_gpus(cp_size) + init_method = file_init_method(tmp_path, "tree_fuzz_cp4") + mp.spawn( + _tree_fuzz_oracle_worker, + args=(cp_size, init_method, str(tmp_path)), + nprocs=cp_size, + join=True, + ) + for rank in range(cp_size): + assert (tmp_path / f"tree_fuzz_rank_{rank}.ok").read_text() == "ok\n" + + def _cp1_oracle_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, sibling_only: bool, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -126,6 +155,86 @@ def _cp1_oracle_worker( destroy_process_group() +def _tree_chain_oracle_worker( + rank: int, + cp_size: int, + init_method: str, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + ref_gdn, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size) + _assert_tree_pack_matches_cp1( + "tree_chain", + ref_gdn, + cp_gdn, + _tree_chain_pack(), + rank=rank, + cp_size=cp_size, + seed=9090, + planner_config=_tree_chain_planner_config(), + require_chain=True, + ) + Path(output_dir, f"tree_chain_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + +def _tree_fuzz_oracle_worker( + rank: int, + cp_size: int, + init_method: str, + output_dir: str, +) -> None: + torch.cuda.set_device(rank) + init_process_group( + backend="nccl", + init_method=init_method, + rank=rank, + world_size=cp_size, + ) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=cp_size, + expert_model_parallel_size=1, + ) + ref_gdn, cp_gdn = _make_matching_gdn_pair(cp_size=cp_size) + for case_index, (name, pack) in enumerate(_tree_fuzz_packs()): + _assert_tree_pack_matches_cp1( + name, + ref_gdn, + cp_gdn, + pack, + rank=rank, + cp_size=cp_size, + seed=9190 + case_index, + planner_config=_tree_fuzz_planner_config(), + require_chain=False, + ) + torch.distributed.barrier() + Path(output_dir, f"tree_fuzz_rank_{rank}.ok").write_text("ok\n") + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + def _assert_case_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -141,9 +250,7 @@ def _assert_case_matches_cp1( tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) plan = build_gdn_rank_execution_plan( spec, device=group_ids.device, @@ -212,6 +319,81 @@ def _assert_case_matches_cp1( ) +def _assert_tree_pack_matches_cp1( + name: str, + ref_gdn: torch.nn.Module, + cp_gdn: torch.nn.Module, + pack: Any, + *, + rank: int, + cp_size: int, + seed: int, + planner_config: GdnPlannerConfig, + require_chain: bool, +) -> None: + zero_parameter_grads(ref_gdn) + zero_parameter_grads(cp_gdn) + group_ids = pack.group_ids.cuda() + parent_ids = pack.parent_ids.cuda() + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) + plan = build_gdn_rank_execution_plan( + spec, + device=group_ids.device, + cp_rank=rank, + cp_size=cp_size, + planner_config=planner_config, + ) + if require_chain: + assert any(plan.tree_chain_buckets_by_depth) + hidden, output_grad = _tree_hidden_and_grad(spec.real_token_count, seed=seed) + ref_hidden = hidden.clone().detach().requires_grad_(True) + ref_out, _ = run_gdn_layer( + ref_gdn, + ref_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + ) + ref_loss = (ref_out * output_grad).sum() + ref_loss.backward() + + flat_hidden = hidden.transpose(0, 1).reshape(-1, hidden.shape[-1]) + flat_grad = output_grad.transpose(0, 1).reshape(-1, output_grad.shape[-1]) + local_index = torch.tensor( + plan.attention_token_indices, device=hidden.device, dtype=torch.long + ) + local_hidden = ( + flat_hidden.index_select(0, local_index) + .unsqueeze(1) + .contiguous() + .detach() + .requires_grad_(True) + ) + local_output_grad = flat_grad.index_select(0, local_index).unsqueeze(1).contiguous() + cp_out, _ = run_gdn_layer( + cp_gdn, + local_hidden, + group_ids=group_ids, + parent_ids=parent_ids, + execution_spec=spec, + execution_plan=plan, + cp_group=torch.distributed.group.WORLD, + ) + cp_loss = (cp_out * local_output_grad).sum() + cp_loss.backward() + _assert_cp_matches_reference( + name, + ref_gdn, + cp_gdn, + ref_hidden, + ref_out, + ref_loss.detach(), + local_hidden, + cp_out, + cp_loss.detach(), + local_index, + ) + + def _assert_sibling_order_matches_cp1( ref_gdn: torch.nn.Module, cp_gdn: torch.nn.Module, @@ -233,9 +415,7 @@ def _assert_sibling_order_matches_cp1( swapped_parent_ids[0, 5:9] = 0 swapped_group_ids[0, 9:12] = 2 swapped_parent_ids[0, 9:12] = 0 - spec = parse_gdn_shared_prefix_segments( - swapped_group_ids, swapped_parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(swapped_group_ids, swapped_parent_ids) plan = build_gdn_rank_execution_plan( spec, device=group_ids.device, @@ -377,6 +557,126 @@ def _hidden_and_grad( return hidden, grad +def _tree_hidden_and_grad( + sequence_length: int, *, seed: int +) -> tuple[torch.Tensor, torch.Tensor]: + generator = torch.Generator(device="cuda").manual_seed(seed) + hidden = torch.randn( + sequence_length, + 1, + 64, + device="cuda", + dtype=GDN_CORRECTNESS_DTYPE, + generator=generator, + ) + grad = torch.randn( + hidden.shape, + device="cuda", + dtype=GDN_CORRECTNESS_DTYPE, + generator=generator, + ) + torch.distributed.broadcast(hidden, src=0) + torch.distributed.broadcast(grad, src=0) + return hidden, grad + + +def _tree_chain_pack(): + long_root = torch.arange(11, 267) + short_root = torch.arange(1001, 1097) + long_mid = torch.arange(2001, 2641) + other_mid = torch.arange(3001, 3065) + return pack_shared_prefixes( + ( + torch.cat((long_root, torch.tensor([301]))), + torch.cat((long_root, torch.tensor([302]))), + torch.cat((short_root, long_mid, torch.tensor([401]))), + torch.cat((short_root, long_mid, torch.tensor([402]))), + torch.cat((short_root, other_mid, torch.tensor([403]))), + ), + max_depth=2, + ) + + +def _tree_chain_planner_config() -> GdnPlannerConfig: + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=16, + cp_chain_min_total_tokens=128, + cp_chain_min_prefix_only_tokens=128, + max_padding_ratio=4.0, + ) + + +def _tree_fuzz_planner_config() -> GdnPlannerConfig: + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=64, + cp_chain_min_prefix_only_tokens=64, + cp_tree_chain_min_total_tokens=64, + cp_tree_chain_min_prefix_only_tokens=64, + max_padding_ratio=4.0, + ) + + +def _tree_fuzz_packs() -> tuple[tuple[str, Any], ...]: + return ( + ( + "tree_fuzz_duplicates", + pack_shared_prefixes(_duplicate_tree_sequences(), max_depth=4), + ), + ( + "tree_fuzz_ragged_depth4", + pack_shared_prefixes(_random_tree_sequences(13, max_depth=4), max_depth=4), + ), + ( + "tree_fuzz_mixed_tiny_long", + pack_shared_prefixes(_random_tree_sequences(29, max_depth=5), max_depth=5), + ), + ) + + +def _duplicate_tree_sequences() -> tuple[torch.Tensor, ...]: + root = torch.arange(11, 331) + mid_a = torch.arange(1001, 1261) + mid_b = torch.arange(2001, 2065) + leaf_a = torch.arange(3001, 3013) + leaf_b = torch.arange(4001, 4017) + first = torch.cat((root, mid_a, leaf_a)) + second = torch.cat((root, mid_a, leaf_b)) + third = torch.cat((root, mid_b, torch.tensor([91, 92, 93]))) + return (first, first, second, third, third) + + +def _random_tree_sequences(seed: int, *, max_depth: int) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + next_token = 1 + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def tokens(length: int) -> torch.Tensor: + nonlocal next_token + out = torch.arange(next_token, next_token + length) + next_token += length + 997 + return out + + def segment_length(depth: int) -> int: + choices = (1, 3, 17, 64, 129, 257, 384 if depth == 0 else 96) + return choices[randint(0, len(choices) - 1)] + + def walk(prefix: torch.Tensor, depth: int) -> list[torch.Tensor]: + here = torch.cat((prefix, tokens(segment_length(depth)))) + if depth + 1 >= max_depth: + return [ + torch.cat((here, tokens(randint(1, 17)))) for _ in range(randint(2, 4)) + ] + leaves: list[torch.Tensor] = [] + for _ in range(randint(2, 3)): + leaves.extend(walk(here, depth + 1)) + return leaves + + return tuple(walk(torch.empty(0, dtype=torch.long), 0)) + + def _packed_correctness_cases() -> tuple[GdnPhase0Case, ...]: return ( *default_phase0_cases(conv_width=2), @@ -450,9 +750,3 @@ def _swap_siblings(tensor: torch.Tensor) -> torch.Tensor: def _skip_without_gpus(cp_size: int) -> None: if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: pytest.skip(f"Need {cp_size} CUDA devices for CP{cp_size} packed GDN.") - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py index e0d2e831f..6ef5a8890 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_gdn_cp_train_prepare.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket from typing import Any, cast import pytest @@ -24,6 +23,7 @@ from art.preprocessing.pack import PackedTensors # noqa: E402 from .cases import default_phase0_cases # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .packed_layout import build_phase0_packed_tensors # noqa: E402 @@ -31,10 +31,10 @@ def test_gdn_cp_training_batch_carries_prebuilt_rank_plan(tmp_path: Path) -> Non cp_size = 2 if not torch.cuda.is_available() or torch.cuda.device_count() < cp_size: pytest.skip(f"requires {cp_size} CUDA devices") - port = _find_free_port() + init_method = file_init_method(tmp_path, "gdn_cp_train_prepare") mp.spawn( _worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -42,11 +42,11 @@ def test_gdn_cp_training_batch_carries_prebuilt_rank_plan(tmp_path: Path) -> Non assert (tmp_path / f"rank_{rank}.ok").read_text() == "ok\n" -def _worker(rank: int, cp_size: int, port: int, output_dir: str) -> None: +def _worker(rank: int, cp_size: int, init_method: str, output_dir: str) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -101,12 +101,6 @@ def _worker(rank: int, cp_size: int, port: int, output_dir: str) -> None: destroy_process_group() -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) - - def test_main_loss_matches_shifted_dispatched_loss_inputs() -> None: packed = cast( Any, diff --git a/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py b/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py index 19f33970c..b8e61537d 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_qwen35_full_model_cp1_packed_vs_flattened.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from contextlib import ExitStack, contextmanager -import socket +from pathlib import Path +import tempfile from typing import Any import pytest @@ -31,6 +32,7 @@ _apply_test_flex_inner_fp32_patch, ) from .cases import default_phase0_cases +from .distributed_init import file_init_method from .metrics import ( GDN_CORRECTNESS_DTYPE, MEAN_ABS_PCT_THRESHOLD, @@ -96,66 +98,66 @@ def test_qwen35_full_model_cp1_matches_flattened_grad_accumulation() -> None: flat_loss_sum: torch.Tensor | None = None logits_mean_abs_pct = 0.0 - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) - for family in spec.families: - row = family.row_index - prefix = family.prefix - for completion in family.completions: - ref_tokens = torch.cat( - [ - tokens[row : row + 1, prefix.start : prefix.end], - tokens[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_pos = torch.cat( - [ - input_pos[row : row + 1, prefix.start : prefix.end], - input_pos[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_assistant_mask = torch.cat( - [ - torch.zeros( - (1, prefix.length), dtype=torch.bool, device=device - ), - assistant_mask[ - row : row + 1, completion.start : completion.end - ], - ], - dim=1, - ) - ref_group_ids = torch.zeros_like(ref_tokens) - ref_parent_ids = torch.zeros_like(ref_tokens) - ref_logits, ref_loss = _run_model_loss( - flat_model, - tokens=ref_tokens, - input_pos=ref_pos, - group_ids=ref_group_ids, - parent_ids=ref_parent_ids, - assistant_mask=ref_assistant_mask, - ) - ref_loss.backward() - flat_loss_sum = ( - ref_loss.detach() - if flat_loss_sum is None - else flat_loss_sum + ref_loss.detach() - ) + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) + for segment_index, completion in enumerate(spec.tree_segments): + if spec.tree_parent_indices[segment_index] < 0: + continue + row = completion.row_index + path = _segment_path(spec, segment_index) + completion_offset = sum(segment.length for segment in path[:-1]) + ref_tokens = torch.cat( + [ + tokens[row : row + 1, segment.start : segment.end] + for segment in path + ], + dim=1, + ) + ref_pos = torch.cat( + [ + input_pos[row : row + 1, segment.start : segment.end] + for segment in path + ], + dim=1, + ) + ref_assistant_mask = torch.cat( + [ + torch.zeros( + (1, completion_offset), + dtype=torch.bool, + device=device, + ), + assistant_mask[row : row + 1, completion.start : completion.end], + ], + dim=1, + ) + ref_group_ids = torch.zeros_like(ref_tokens) + ref_parent_ids = torch.zeros_like(ref_tokens) + ref_logits, ref_loss = _run_model_loss( + flat_model, + tokens=ref_tokens, + input_pos=ref_pos, + group_ids=ref_group_ids, + parent_ids=ref_parent_ids, + assistant_mask=ref_assistant_mask, + ) + ref_loss.backward() + flat_loss_sum = ( + ref_loss.detach() + if flat_loss_sum is None + else flat_loss_sum + ref_loss.detach() + ) - if completion.length > 1: - packed_slice = packed_logits[ - row : row + 1, completion.start : completion.end - 1 - ] - ref_slice = ref_logits[ - :, prefix.length : prefix.length + completion.length - 1 - ] - logits_mean_abs_pct = max( - logits_mean_abs_pct, - mean_abs_pct(ref_slice, packed_slice), - ) + if completion.length > 1: + packed_slice = packed_logits[ + row : row + 1, completion.start : completion.end - 1 + ] + ref_slice = ref_logits[ + :, completion_offset : completion_offset + completion.length - 1 + ] + logits_mean_abs_pct = max( + logits_mean_abs_pct, + mean_abs_pct(ref_slice, packed_slice), + ) assert flat_loss_sum is not None grad_name, grad_pct = parameter_grad_mean_abs_pct_with_name( @@ -214,70 +216,64 @@ def _assert_logits_vjp_equivalence( flat_loss_sum: torch.Tensor | None = None logits_mean_abs_pct = 0.0 - spec = parse_gdn_shared_prefix_segments( - group_ids.cpu(), parent_ids.cpu(), min_completions_per_family=1 - ) - for family in spec.families: - row = family.row_index - prefix = family.prefix - for completion in family.completions: - ref_tokens = torch.cat( - [ - tokens[row : row + 1, prefix.start : prefix.end], - tokens[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_pos = torch.cat( - [ - input_pos[row : row + 1, prefix.start : prefix.end], - input_pos[row : row + 1, completion.start : completion.end], - ], - dim=1, - ) - ref_logits = _run_model_logits( - flat_model, - tokens=ref_tokens, - input_pos=ref_pos, - group_ids=torch.zeros_like(ref_tokens), - parent_ids=torch.zeros_like(ref_tokens), - ) - ref_output_grad = torch.zeros_like(ref_logits) - ref_output_mask = torch.zeros( - ref_logits.shape[:2], - device=ref_logits.device, - dtype=torch.bool, - ) - if completion.length > 1: - ref_output_grad[ - :, prefix.length : prefix.length + completion.length - 1 - ] = output_grad[row : row + 1, completion.start : completion.end - 1] - ref_output_mask[ - :, prefix.length : prefix.length + completion.length - 1 - ] = True - ref_loss = stable_output_mse_loss( - ref_logits, - ref_output_grad, - mask=ref_output_mask.unsqueeze(-1), - denominator=loss_denominator, - ) - ref_loss.backward() - flat_loss_sum = ( - ref_loss.detach() - if flat_loss_sum is None - else flat_loss_sum + ref_loss.detach() + spec = parse_gdn_shared_prefix_segments(group_ids.cpu(), parent_ids.cpu()) + for segment_index, completion in enumerate(spec.tree_segments): + if spec.tree_parent_indices[segment_index] < 0: + continue + row = completion.row_index + path = _segment_path(spec, segment_index) + completion_offset = sum(segment.length for segment in path[:-1]) + ref_tokens = torch.cat( + [tokens[row : row + 1, segment.start : segment.end] for segment in path], + dim=1, + ) + ref_pos = torch.cat( + [input_pos[row : row + 1, segment.start : segment.end] for segment in path], + dim=1, + ) + ref_logits = _run_model_logits( + flat_model, + tokens=ref_tokens, + input_pos=ref_pos, + group_ids=torch.zeros_like(ref_tokens), + parent_ids=torch.zeros_like(ref_tokens), + ) + ref_output_grad = torch.zeros_like(ref_logits) + ref_output_mask = torch.zeros( + ref_logits.shape[:2], + device=ref_logits.device, + dtype=torch.bool, + ) + if completion.length > 1: + ref_output_grad[ + :, completion_offset : completion_offset + completion.length - 1 + ] = output_grad[row : row + 1, completion.start : completion.end - 1] + ref_output_mask[ + :, completion_offset : completion_offset + completion.length - 1 + ] = True + ref_loss = stable_output_mse_loss( + ref_logits, + ref_output_grad, + mask=ref_output_mask.unsqueeze(-1), + denominator=loss_denominator, + ) + ref_loss.backward() + flat_loss_sum = ( + ref_loss.detach() + if flat_loss_sum is None + else flat_loss_sum + ref_loss.detach() + ) + if completion.length > 1: + packed_slice = packed_logits[ + row : row + 1, completion.start : completion.end - 1 + ] + ref_slice = ref_logits[ + :, completion_offset : completion_offset + completion.length - 1 + ] + logits_mean_abs_pct = max( + logits_mean_abs_pct, + mean_abs_pct(ref_slice, packed_slice), ) - if completion.length > 1: - packed_slice = packed_logits[ - row : row + 1, completion.start : completion.end - 1 - ] - ref_slice = ref_logits[ - :, prefix.length : prefix.length + completion.length - 1 - ] - logits_mean_abs_pct = max( - logits_mean_abs_pct, - mean_abs_pct(ref_slice, packed_slice), - ) assert flat_loss_sum is not None grad_name, grad_pct = parameter_grad_mean_abs_pct_with_name( @@ -359,6 +355,15 @@ def _run_model_logits( return logits +def _segment_path(spec: Any, segment_index: int) -> tuple[Any, ...]: + indices = [] + cursor = segment_index + while cursor >= 0: + indices.append(cursor) + cursor = spec.tree_parent_indices[cursor] + return tuple(spec.tree_segments[index] for index in reversed(indices)) + + def _make_matching_models() -> tuple[torch.nn.Module, torch.nn.Module]: model_parallel_cuda_manual_seed(1234) packed = _make_model() @@ -424,28 +429,23 @@ def _single_rank_model_parallel() -> Iterator[None]: if is_initialized(): pytest.skip("torch.distributed is already initialized in this process.") torch.cuda.set_device(0) - init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{_find_free_port()}", - rank=0, - world_size=1, - ) - try: - ps.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=1, + with tempfile.TemporaryDirectory(prefix="art_dist_") as tmp: + init_process_group( + backend="nccl", + init_method=file_init_method(Path(tmp), "qwen35_full_model_cp1"), + rank=0, + world_size=1, ) - yield - finally: - if getattr(ps, "model_parallel_is_initialized", lambda: False)(): - ps.destroy_model_parallel() - if is_initialized(): - destroy_process_group() - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + if is_initialized(): + destroy_process_group() diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py index de6933582..f026b90c9 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_cp1_packed_vs_flattened.py @@ -2,7 +2,8 @@ from collections.abc import Iterator from contextlib import contextmanager -import socket +from pathlib import Path +import tempfile import pytest @@ -18,13 +19,13 @@ from megatron.core.ssm.gated_delta_net import GatedDeltaNet from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from torch.distributed import ( - DistNetworkError, destroy_process_group, init_process_group, is_initialized, ) from .cases import default_phase0_cases +from .distributed_init import file_init_method from .metrics import ( GDN_CORRECTNESS_DTYPE, MEAN_ABS_PCT_MISMATCH_THRESHOLD, @@ -232,44 +233,23 @@ def _single_rank_model_parallel() -> Iterator[None]: if is_initialized(): pytest.skip("torch.distributed is already initialized in this process.") torch.cuda.set_device(0) - _init_single_rank_process_group() - try: - ps.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=1, + with tempfile.TemporaryDirectory(prefix="art_dist_") as tmp: + init_process_group( + backend="nccl", + init_method=file_init_method(Path(tmp), "single_rank"), + rank=0, + world_size=1, ) - yield - finally: - if getattr(ps, "model_parallel_is_initialized", lambda: False)(): - ps.destroy_model_parallel() - if is_initialized(): - destroy_process_group() - - -def _init_single_rank_process_group() -> None: - last_error: DistNetworkError | None = None - for _ in range(16): try: - init_process_group( - backend="nccl", - init_method=f"tcp://127.0.0.1:{_find_free_port()}", - rank=0, - world_size=1, + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, ) - return - except DistNetworkError as error: - if "EADDRINUSE" not in str(error): - raise - last_error = error + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() if is_initialized(): destroy_process_group() - if last_error is not None: - raise last_error - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py index e0d164c56..7a173ae8f 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_native_fla_cp.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket from typing import cast import pytest @@ -37,6 +36,7 @@ ) from .cases import GdnFamilyShape, GdnPackedRowShape, GdnPhase0Case # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .metrics import ( # noqa: E402 GDN_CORRECTNESS_DTYPE, MEAN_ABS_PCT_THRESHOLD, @@ -70,10 +70,10 @@ def test_real_qwen35_gdn_native_fla_cp_prepared_varlen_batch_matches_single_rank( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_gdn_prepared_cp{cp_size}") mp.spawn( _native_gdn_cp_prepared_varlen_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -89,10 +89,10 @@ def test_real_qwen35_gdn_native_fla_cp_prepared_varlen_batch_matches_single_rank def test_real_qwen35_gdn_native_cp_packed_layer_matches_cp1( cp_size: int, tmp_path: Path ) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, f"native_gdn_packed_cp{cp_size}") mp.spawn( _native_gdn_cp_packed_layer_worker, - args=(cp_size, port, str(tmp_path)), + args=(cp_size, init_method, str(tmp_path)), nprocs=cp_size, join=True, ) @@ -103,13 +103,13 @@ def test_real_qwen35_gdn_native_cp_packed_layer_matches_cp1( def _native_gdn_cp_packed_layer_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -127,9 +127,7 @@ def _native_gdn_cp_packed_layer_worker( tensors = build_phase0_packed_tensors(case) group_ids = tensors["group_ids"].cuda() parent_ids = tensors["parent_ids"].cuda() - spec = parse_gdn_shared_prefix_segments( - group_ids, parent_ids, min_completions_per_family=0 - ) + spec = parse_gdn_shared_prefix_segments(group_ids, parent_ids) plan = build_gdn_rank_execution_plan( spec, device=group_ids.device, @@ -139,17 +137,9 @@ def _native_gdn_cp_packed_layer_worker( cp_chain_min_tokens_per_rank=16, cp_chain_min_total_tokens=128, cp_chain_min_prefix_only_tokens=128, - # This test is the native chain correctness guard, so force the - # planner onto chain prefix and completion buckets. - planner_chain_bucket_ms=0.0, - planner_chain_token_ms=0.0, - planner_local_bucket_ms=1.0, - planner_local_token_ms=1.0, - cp_chain_min_score_delta_ms=0.0, ), ) - assert plan.chain_prefix_buckets - assert plan.chain_completion_buckets + assert any(plan.tree_chain_buckets_by_depth) hidden, output_grad = _packed_hidden_and_grad(case, cp_size) ref_hidden = hidden.clone().detach().requires_grad_(True) ref_out, _ = run_gdn_layer( @@ -215,13 +205,13 @@ def _native_gdn_cp_packed_layer_worker( def _native_gdn_cp_prepared_varlen_worker( rank: int, cp_size: int, - port: int, + init_method: str, output_dir: str, ) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=cp_size, ) @@ -599,9 +589,3 @@ def _all_reduce_parameter_grads(module: torch.nn.Module) -> None: main_grad = getattr(parameter, "main_grad", None) if main_grad is not None: torch.distributed.all_reduce(main_grad) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py index c4bd99abc..62e217daf 100644 --- a/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py +++ b/tests/integration/megatron/gdn_shared_prefix/test_real_gdn_tp_lora.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -import socket import pytest @@ -26,6 +25,7 @@ from art.megatron.model_support.handlers import QWEN3_5_MOE_HANDLER # noqa: E402 from .cases import GdnPhase0Case, default_phase0_cases # noqa: E402 +from .distributed_init import file_init_method # noqa: E402 from .metrics import GDN_CORRECTNESS_DTYPE, assert_real_gdn_metrics # noqa: E402 from .packed_layout import build_phase0_packed_tensors # noqa: E402 from .real_gdn_oracle import ( # noqa: E402 @@ -68,10 +68,10 @@ def test_real_qwen35_gdn_lora_gradients_match_flattened() -> None: reason="At least two CUDA devices are required for TP2 GDN coverage.", ) def test_real_qwen35_gdn_tp2_gradients_match_flattened(tmp_path: Path) -> None: - port = _find_free_port() + init_method = file_init_method(tmp_path, "real_gdn_tp2_lora") mp.spawn( _tp2_worker, - args=(port, str(tmp_path)), + args=(init_method, str(tmp_path)), nprocs=2, join=True, ) @@ -79,11 +79,11 @@ def test_real_qwen35_gdn_tp2_gradients_match_flattened(tmp_path: Path) -> None: assert (tmp_path / f"rank_{rank}.ok").read_text() == "ok\n" -def _tp2_worker(rank: int, port: int, output_dir: str) -> None: +def _tp2_worker(rank: int, init_method: str, output_dir: str) -> None: torch.cuda.set_device(rank) init_process_group( backend="nccl", - init_method=f"tcp://127.0.0.1:{port}", + init_method=init_method, rank=rank, world_size=2, ) @@ -229,9 +229,3 @@ def _gdn_lora_grad_names(gdn: torch.nn.Module) -> tuple[str, ...]: and parameter.grad is not None and bool(parameter.grad.abs().max().item() > 0) ) - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/lora/test_dynamic_lora_slots.py b/tests/integration/megatron/lora/test_dynamic_lora_slots.py new file mode 100644 index 000000000..ce3b3b4d1 --- /dev/null +++ b/tests/integration/megatron/lora/test_dynamic_lora_slots.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from contextlib import contextmanager +import os +import socket +from types import SimpleNamespace + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("megatron.core") + +from megatron.core import parallel_state as ps # noqa: E402 +from torch.distributed import destroy_process_group, init_process_group # noqa: E402 + +from art.megatron.lora import LoRA, LoRASlotRef, use_lora_slot # noqa: E402 +from art.megatron.trainer_rank import AdamParams, TrainerRank # noqa: E402 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required.") +def test_dynamic_lora_slots_capture_recompute_context_and_step_independently() -> None: + with _single_rank_model_parallel(): + device = torch.device("cuda") + lora = LoRA( + "dense", + in_features=4, + out_features=5, + rank=2, + alpha=32, + dtype=torch.float32, + device=device, + ) + ref_a = LoRASlotRef("checkpoint", "A") + ref_b = LoRASlotRef("checkpoint", "B") + lora.load_lora_slot( + ref_a, _adapter("dense", rank=1, seed=1), requires_grad=True + ) + lora.load_lora_slot( + ref_b, _adapter("dense", rank=4, seed=2), requires_grad=True + ) + + x = torch.randn(7, 4, device=device) + with use_lora_slot(LoRASlotRef("checkpoint", None)): + assert torch.equal(lora(x), torch.zeros(7, 5, device=device)) + with use_lora_slot(LoRASlotRef("lora", "missing")): + assert torch.equal(lora(x), torch.zeros(7, 5, device=device)) + + slot_a = lora._slot(ref_a) + assert slot_a is not None + with use_lora_slot(ref_a): + actual = lora(x) + expected = (x @ slot_a.A_T) @ slot_a.B_T * slot_a.scale + assert torch.allclose(actual, expected, atol=0, rtol=0) + assert slot_a.rank == 1 + assert slot_a.scale == 32.0 + assert lora._slot(ref_b).scale == 8.0 # type: ignore[union-attr] + + trainer = _trainer_for(lora, device) + with trainer.push_checkpoint("A"): + assert trainer._slot_stack[-1] == ref_a + with trainer.push_lora(None): + assert trainer._slot_stack[-1].name is None + assert trainer._slot_stack[-1] == ref_a + assert trainer._slot_stack == [] + + from megatron.core.tensor_parallel.random import ( + checkpoint as megatron_checkpoint, + ) + from torch.utils.checkpoint import checkpoint as torch_checkpoint + + _assert_checkpoint_recomputes_with(ref_a, ref_b, lora, torch_checkpoint) + _assert_checkpoint_recomputes_with( + ref_a, ref_b, lora, megatron_checkpoint, False + ) + _assert_step_updates_only(ref_a, ref_b, lora, trainer) + _assert_reload_replaces_slot_optimizer(ref_a, lora, trainer) + + +def _adapter(prefix: str, *, rank: int, seed: int) -> dict[str, torch.Tensor]: + device = torch.device("cuda") + generator = torch.Generator(device=device).manual_seed(seed) + return { + f"{prefix}.lora_A.weight": torch.randn( + rank, 4, generator=generator, device=device + ), + f"{prefix}.lora_B.weight": torch.randn( + 5, rank, generator=generator, device=device + ), + } + + +def _assert_checkpoint_recomputes_with( + expected_ref: LoRASlotRef, + ambient_ref: LoRASlotRef, + lora: LoRA, + checkpoint, + *checkpoint_args, +) -> None: + for param in lora.parameters(): + param.grad = None + x = torch.randn(3, 4, device="cuda", requires_grad=True) + with use_lora_slot(expected_ref): + y = checkpoint(lambda t: lora(t), *checkpoint_args, x) + with use_lora_slot(ambient_ref): + y.sum().backward() + assert lora._slot(expected_ref).A_T.grad is not None # type: ignore[union-attr] + assert lora._slot(ambient_ref).A_T.grad is None # type: ignore[union-attr] + + +def _assert_step_updates_only( + stepped_ref: LoRASlotRef, + frozen_ref: LoRASlotRef, + lora: LoRA, + trainer: TrainerRank, +) -> None: + for param in lora.parameters(): + param.grad = None + with use_lora_slot(stepped_ref): + lora(torch.randn(5, 4, device="cuda")).sum().backward() + before_stepped = [p.detach().clone() for p in lora.lora_slot_params(stepped_ref)] + before_frozen = [p.detach().clone() for p in lora.lora_slot_params(frozen_ref)] + trainer.optim_step( + params=AdamParams(learning_rate=1e-3, weight_decay=0.0, grad_clip_norm=1.0), + checkpoints=[stepped_ref.name or ""], + ) + assert any( + not torch.equal(before, after) + for before, after in zip( + before_stepped, lora.lora_slot_params(stepped_ref), strict=True + ) + ) + assert all( + torch.equal(before, after) + for before, after in zip( + before_frozen, lora.lora_slot_params(frozen_ref), strict=True + ) + ) + + +def _assert_reload_replaces_slot_optimizer( + ref: LoRASlotRef, + lora: LoRA, + trainer: TrainerRank, +) -> None: + assert ref.name is not None + old_params = trainer._checkpoint_slot_params_by_name[ref.name] + assert ref.name in trainer._dynamic_optimizers + + trainer.load_checkpoint_slot(ref.name, _adapter("dense", rank=3, seed=9)) + + new_params = trainer._checkpoint_slot_params_by_name[ref.name] + assert ref.name not in trainer._dynamic_optimizers + assert [tuple(param.shape) for param in new_params] == [(4, 3), (3, 5)] + assert all(old is not new for old, new in zip(old_params, new_params, strict=True)) + assert lora._slot(ref).rank == 3 # type: ignore[union-attr] + + +def _trainer_for(lora: LoRA, device: torch.device) -> TrainerRank: + trainer = TrainerRank.__new__(TrainerRank) + trainer.runtime = SimpleNamespace(model=[lora], optimizer=None) + trainer.device = device + trainer._slot_stack = [] + trainer._default_slot_ref = None + trainer._dynamic_optimizers = {} + trainer._checkpoint_slot_params_by_name = { + "A": tuple(lora.lora_slot_params(LoRASlotRef("checkpoint", "A"))), + "B": tuple(lora.lora_slot_params(LoRASlotRef("checkpoint", "B"))), + } + return trainer + + +@contextmanager +def _single_rank_model_parallel(): + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = str(_free_port()) + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + os.environ.setdefault("LOCAL_RANK", "0") + torch.cuda.set_device(0) + init_process_group("nccl", rank=0, world_size=1) + try: + ps.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1, + ) + yield + finally: + if getattr(ps, "model_parallel_is_initialized", lambda: False)(): + ps.destroy_model_parallel() + destroy_process_group() + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py index b14cd2a4c..7bb3e1b94 100644 --- a/tests/integration/megatron/lora/test_lora_disk_codecs.py +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -1,12 +1,17 @@ import json +import os from pathlib import Path +import shutil import subprocess import sys from typing import Any, cast +import pytest from safetensors.torch import load_file, save_file import torch +pytest.importorskip("megatron.bridge.models.gpt_provider") + from art.megatron import lora as lora_module from art.megatron.lora import LoRA, LoRAParallelSpec, LoRAPublishPlanner from art.megatron.model_support.handlers import ( @@ -29,6 +34,66 @@ REPO_ROOT = Path(__file__).parents[4] VLLM_PYTHON = REPO_ROOT / "vllm_runtime/.venv/bin/python" +_VLLM_RUNTIME_UNAVAILABLE_REASON: str | None | object = object() + + +def _vllm_python_cmd() -> list[str]: + override = os.environ.get("ART_TEST_VLLM_PYTHON") + if override: + return [override] + if VLLM_PYTHON.exists(): + return [str(VLLM_PYTHON)] + uv = shutil.which("uv") + if uv is None: + raise RuntimeError( + f"{VLLM_PYTHON} does not exist and uv is not available to run " + "the locked vLLM runtime project" + ) + return [ + uv, + "run", + "--project", + str(REPO_ROOT / "vllm_runtime"), + "--frozen", + "--no-dev", + "python", + ] + + +def _vllm_runtime_unavailable_reason() -> str | None: + global _VLLM_RUNTIME_UNAVAILABLE_REASON + if isinstance(_VLLM_RUNTIME_UNAVAILABLE_REASON, str): + return _VLLM_RUNTIME_UNAVAILABLE_REASON + if _VLLM_RUNTIME_UNAVAILABLE_REASON is None: + return None + try: + subprocess.run( + [ + *_vllm_python_cmd(), + "-c", + "import vllm; from vllm.lora.lora_model import LoRAModel", + ], + check=True, + text=True, + capture_output=True, + timeout=120, + ) + except Exception as exc: + _VLLM_RUNTIME_UNAVAILABLE_REASON = ( + "Stock vLLM loader runtime is unavailable. Run " + "`uv sync --project vllm_runtime --frozen --no-dev`, or set " + "`ART_TEST_VLLM_PYTHON` to a Python environment with vLLM installed. " + f"Original error: {exc}" + ) + return _VLLM_RUNTIME_UNAVAILABLE_REASON + _VLLM_RUNTIME_UNAVAILABLE_REASON = None + return None + + +def test_stock_vllm_loader_runtime_is_available() -> None: + reason = _vllm_runtime_unavailable_reason() + if reason is not None: + pytest.fail(reason) def _config(base_model: str, rank: int = 2, alpha: int = 4) -> dict: @@ -116,6 +181,8 @@ def _assert_stock_vllm_loads( expected_modules: set[str], mapper: str = "none", ) -> list[str]: + if reason := _vllm_runtime_unavailable_reason(): + pytest.skip(reason) script = r""" import json import sys @@ -142,7 +209,7 @@ def _assert_stock_vllm_loads( """ result = subprocess.run( [ - str(VLLM_PYTHON), + *_vllm_python_cmd(), "-c", script, str(path), diff --git a/tests/integration/megatron/model_support/forward_trace.py b/tests/integration/megatron/model_support/forward_trace.py index 289b8b7a6..30731cbdd 100644 --- a/tests/integration/megatron/model_support/forward_trace.py +++ b/tests/integration/megatron/model_support/forward_trace.py @@ -1118,6 +1118,7 @@ def _canonicalize_row_aligned_value( def _canonicalize_call_row_token_order(cls, call: dict[str, Any]) -> None: """Canonicalizes all row-aligned call tensors to global token order.""" cls._align_exact_zero_padding_row_token_uids(call) + cls._drop_exact_zero_padding_rows(call) row_token_uids = call.get("row_token_uids") if not isinstance(row_token_uids, torch.Tensor) or row_token_uids.ndim != 1: return @@ -1138,6 +1139,38 @@ def _canonicalize_call_row_token_order(cls, call: dict[str, Any]) -> None: ) call["row_token_uids"] = row_token_uids.index_select(0, order).contiguous() + @classmethod + def _drop_exact_zero_padding_rows(cls, call: dict[str, Any]) -> None: + """Removes traced sequence-padding rows before comparing compact CP traces.""" + row_token_uids = call.get("row_token_uids") + tensor = call.get("primary_output") + if ( + not isinstance(row_token_uids, torch.Tensor) + or row_token_uids.ndim != 1 + or not isinstance(tensor, torch.Tensor) + or tensor.ndim == 0 + or int(tensor.shape[0]) != int(row_token_uids.numel()) + ): + return + row_count = int(row_token_uids.numel()) + padding_rows = row_token_uids < 0 + if row_count == 0 or not bool(padding_rows.any().item()): + return + flat = tensor.detach().reshape(row_count, -1) + if not bool((flat[padding_rows] == 0).all().item()): + return + valid_rows = torch.nonzero(~padding_rows, as_tuple=False).reshape(-1) + original_call = dict(call) + for key, value in original_call.items(): + if key == "row_token_uids": + continue + call[key] = cls._slice_row_aligned_value( + value, + row_indices=valid_rows, + total_rows=row_count, + ) + call["row_token_uids"] = row_token_uids.index_select(0, valid_rows).contiguous() + @staticmethod def _align_exact_zero_padding_row_token_uids(call: dict[str, Any]) -> None: """Moves padding UID markers onto exact-zero sequence-parallel pad rows.""" diff --git a/tests/integration/megatron/model_support/oracle_worker.py b/tests/integration/megatron/model_support/oracle_worker.py index 86ad2fb0e..50141edde 100644 --- a/tests/integration/megatron/model_support/oracle_worker.py +++ b/tests/integration/megatron/model_support/oracle_worker.py @@ -1117,12 +1117,16 @@ def _reference_forward( def _reference_fc1_forward(self: Any, x: torch.Tensor, tokens_per_expert: Any): base_out, bias_out = self.linear_fc1(x, tokens_per_expert) - adapter_out = torch.cat( - ( - self.gate_lora(x, tokens_per_expert), - self.up_lora(x, tokens_per_expert), - ), - dim=1, + adapter_out = ( + self.lora(x, tokens_per_expert) + if self.fused_gate_up + else torch.cat( + ( + self.gate_lora(x, tokens_per_expert), + self.up_lora(x, tokens_per_expert), + ), + dim=1, + ) ) return base_out + adapter_out, bias_out diff --git a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py index 5a45bc03a..043736553 100644 --- a/tests/integration/megatron/model_support/test_oracle_harness_invariants.py +++ b/tests/integration/megatron/model_support/test_oracle_harness_invariants.py @@ -314,6 +314,37 @@ def test_forward_trace_canonicalizes_row_outputs_by_token_uid() -> None: ) +def test_forward_trace_drops_exact_zero_padding_rows() -> None: + trace: dict[str, list[dict[str, Any]]] = { + "chunk0.module.decoder.layers.0.self_attention.out_proj": [ + { + "primary_output": torch.tensor( + [[0.0, 0.0], [30.0, 31.0], [10.0, 11.0], [20.0, 21.0]] + ), + "output": { + "hidden": torch.tensor( + [[0.0, 0.0], [3.0, 3.1], [1.0, 1.1], [2.0, 2.1]] + ) + }, + "row_token_uids": torch.tensor([-1, 3, 1, 2]), + } + ] + } + + ForwardTraceCapture.canonicalize_trace(trace) + + call = trace["chunk0.module.decoder.layers.0.self_attention.out_proj"][0] + assert torch.equal(call["row_token_uids"], torch.tensor([1, 2, 3])) + assert torch.equal( + call["primary_output"], + torch.tensor([[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]), + ) + assert torch.equal( + call["output"]["hidden"], + torch.tensor([[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]]), + ) + + def test_forward_trace_expands_attention_output_uids_for_out_norm_heads() -> None: trace: dict[str, list[dict[str, Any]]] = { "chunk0.module.decoder.layers.0.self_attention": [ diff --git a/tests/unit/test_shared_prefix_attention_builder.py b/tests/unit/test_shared_prefix_attention_builder.py new file mode 100644 index 000000000..34992bac9 --- /dev/null +++ b/tests/unit/test_shared_prefix_attention_builder.py @@ -0,0 +1,586 @@ +from __future__ import annotations + +import pytest +import torch +from torch.nn.attention.flex_attention import BlockMask +from torch.nn.attention.flex_attention import create_block_mask as torch_block_mask + +pytest.importorskip("megatron.core.packed_seq_params") + +from art.megatron.context_parallel.block_mask import ( + build_block_mask_from_context, + prepare_block_mask_context, +) +from art.megatron.context_parallel.builder import ( + build_dense_reference_mask, + build_shared_prefix_attention_spec, +) +from art.megatron.context_parallel.runtime import get_or_build_runtime_plan +from art.megatron.context_parallel.types import ( + AttnMaskKind, + AttnSlice, + ContextParallelConfig, + ExactMaskMetadata, + FlexMaskSpec, + ParallelTopology, + TokenRange, +) +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes +from art.megatron.shared_prefix_state import create_shared_prefix_state + + +def build_block_mask( + spec: FlexMaskSpec, + *, + group_ids: torch.Tensor, + parent_ids: torch.Tensor, + device: torch.device, +) -> BlockMask | None: + return build_block_mask_from_context( + spec, + context=prepare_block_mask_context( + group_ids=group_ids, + parent_ids=parent_ids, + ), + device=device, + ) + + +def test_shared_prefix_attention_spec_supports_branching_completions() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + dense = build_dense_reference_mask(row_spec=spec.rows[0]) + + assert dense.int().tolist() == [ + [1, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 1, 0, 0], + [1, 1, 1, 0, 0, 1, 0], + [1, 1, 1, 0, 0, 1, 1], + ] + + +def test_shared_prefix_attention_spec_matches_tree_reference() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + dense = build_dense_reference_mask(row_spec=spec.rows[0]) + + assert dense.equal(_reference_tree_mask(group_ids[0], parent_ids[0])) + + +def test_shared_prefix_can_build_context_parallel_layout() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + + plan = get_or_build_runtime_plan( + spec, + topology=ParallelTopology(cp=2), + config=ContextParallelConfig(planner_chunk_size=2, planner_max_search_steps=1), + original_seq_len=int(group_ids.numel()), + ) + + assert sum(plan[rank].local_valid_lengths[0] for rank in range(2)) == int( + group_ids.numel() + ) + + +def test_sparse_block_mask_exact_predicate_matches_dense_reference() -> None: + group_ids, parent_ids = _branching_prefix_inputs() + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + row = spec.rows[0] + token_indices = torch.arange(row.valid_tokens, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=row.valid_tokens, + k_len=row.valid_tokens, + block_size=(2, 2), + slices=row.slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key="depth-two", + ), + ), + group_ids=group_ids[0], + parent_ids=parent_ids[0], + device=torch.device("cpu"), + ) + + assert block_mask is not None + q_indices = torch.arange(row.valid_tokens)[:, None] + k_indices = torch.arange(row.valid_tokens)[None, :] + actual = block_mask.mask_mod( + torch.zeros_like(q_indices), + torch.zeros_like(q_indices), + q_indices, + k_indices, + ) + + assert actual.equal(build_dense_reference_mask(row_spec=row)) + + +@pytest.mark.parametrize( + ("name", "pack"), + ( + ( + "no-sharing", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3]), + torch.tensor([4, 5]), + torch.tensor([6, 7, 8, 9]), + ), + max_depth=0, + ), + ), + ( + "depth-one", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 2, 6]), + ), + max_depth=1, + ), + ), + ( + "depth-three", + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ), + ), + ), +) +def test_sparse_block_mask_matches_torch_block_metadata( + name: str, + pack: SharedPrefixPack, +) -> None: + del name + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + row = spec.rows[0] + token_indices = torch.arange(row.valid_tokens, dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=row.valid_tokens, + k_len=row.valid_tokens, + block_size=(2, 2), + slices=row.slices, + exact_mask=ExactMaskMetadata( + q_token_indices=token_indices, + k_token_indices=token_indices, + cache_key="torch-parity", + ), + ), + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + device=torch.device("cpu"), + ) + + assert block_mask is not None + _assert_matches_torch_block_mask(block_mask) + + +def test_sparse_block_mask_prunes_exact_blocks_rejected_by_group_tree() -> None: + group_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.long) + parent_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=4, + k_len=4, + block_size=(2, 2), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=4), + k_range=TokenRange(start=0, end=4), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=torch.tensor([4, 5, 6, 7], dtype=torch.long), + k_token_indices=torch.tensor([0, 1, 2, 3], dtype=torch.long), + cache_key="all-false-cross-family", + ), + ), + group_ids=group_ids, + parent_ids=parent_ids, + device=torch.device("cpu"), + ) + + assert block_mask is not None + assert int(block_mask.kv_num_blocks.sum().item()) == 0 + assert int(block_mask.full_kv_num_blocks.sum().item()) == 0 + _assert_matches_torch_block_mask(block_mask) + + +def test_shared_prefix_state_builds_batched_block_mask() -> None: + group_ids = torch.tensor( + [ + [1, 1, 2, 2, -1], + [10, 11, 11, -1, -1], + ], + dtype=torch.long, + ) + parent_ids = torch.tensor( + [ + [1, 1, 1, 1, -1], + [10, 10, 10, -1, -1], + ], + dtype=torch.long, + ) + + state = create_shared_prefix_state( + group_ids=group_ids, + parent_ids=parent_ids, + target_device=torch.device("cpu"), + ) + seq_len = int(group_ids.shape[1]) + batch_idx = torch.arange(2)[:, None, None].expand(2, seq_len, seq_len) + query_idx = torch.arange(seq_len)[None, :, None].expand(2, seq_len, seq_len) + kv_idx = torch.arange(seq_len)[None, None, :].expand(2, seq_len, seq_len) + actual = state.block_mask.mask_mod( + batch_idx, + torch.zeros_like(batch_idx), + query_idx, + kv_idx, + ) + spec = build_shared_prefix_attention_spec( + group_ids=group_ids, + parent_ids=parent_ids, + ) + assert int(state.block_mask.kv_num_blocks.shape[0]) == 2 + for row_index, row_spec in enumerate(spec.rows): + valid_tokens = int(row_spec.valid_tokens) + assert actual[ + row_index, + :valid_tokens, + :valid_tokens, + ].equal(build_dense_reference_mask(row_spec=row_spec)) + _assert_matches_torch_block_mask(state.block_mask, batch_size=2) + + +def test_context_parallel_stage_masks_match_dense_nested_tree() -> None: + _assert_context_parallel_stage_masks_match_dense( + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ), + require_remote_stage=True, + ) + _assert_context_parallel_stage_masks_match_dense( + pack_shared_prefixes( + ( + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6]), + torch.tensor([7, 8]), + torch.tensor([9, 10, 11, 12]), + ), + max_depth=3, + ), + require_remote_stage=False, + ) + + +def _assert_context_parallel_stage_masks_match_dense( + pack: SharedPrefixPack, + *, + require_remote_stage: bool, +) -> None: + spec = build_shared_prefix_attention_spec( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + row = spec.rows[0] + dense = build_dense_reference_mask(row_spec=row) + topology = ParallelTopology(cp=2) + config = ContextParallelConfig( + block_size=2, + planner_chunk_size=2, + planner_max_search_steps=1, + planner_remote_stage_token_floor=1, + planner_remote_stage_pair_floor=1, + ) + plan = get_or_build_runtime_plan( + spec, + topology=topology, + config=config, + original_seq_len=int(pack.tokens.numel()), + ) + + checked_stages = 0 + checked_remote_stages = 0 + for rank_plan in plan: + for stage in rank_plan.stage_plans: + if stage.mask_metadata is None: + continue + block_mask = build_block_mask( + FlexMaskSpec( + q_len=stage.q_len, + k_len=stage.k_len, + block_size=(2, 2), + slices=stage.slices, + exact_mask=stage.mask_metadata, + ), + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + device=torch.device("cpu"), + ) + assert block_mask is not None + q_offsets = torch.arange(stage.q_len)[:, None] + k_offsets = torch.arange(stage.k_len)[None, :] + actual = block_mask.mask_mod( + torch.zeros_like(q_offsets), + torch.zeros_like(q_offsets), + q_offsets, + k_offsets, + ) + q_tokens = stage.mask_metadata.q_token_indices + k_tokens = stage.mask_metadata.k_token_indices + expected = ( + dense[q_tokens.clamp_min(0)[:, None], k_tokens.clamp_min(0)[None, :]] + & (q_tokens[:, None] >= 0) + & (k_tokens[None, :] >= 0) + ) + + assert actual.equal(expected) + assert _effective_block_mask(block_mask).equal(expected) + _assert_matches_torch_block_mask(block_mask) + checked_stages += 1 + checked_remote_stages += int(not stage.is_local_stage) + + assert checked_stages + if require_remote_stage: + assert checked_remote_stages + + +def _effective_block_mask(block_mask: BlockMask) -> torch.Tensor: + q_len, k_len = block_mask.seq_lengths + q_block, k_block = block_mask.BLOCK_SIZE + effective = torch.zeros((q_len, k_len), dtype=torch.bool) + _fill_full_blocks(effective, block_mask, q_block=q_block, k_block=k_block) + _fill_partial_blocks(effective, block_mask, q_block=q_block, k_block=k_block) + return effective + + +def _fill_full_blocks( + effective: torch.Tensor, + block_mask: BlockMask, + *, + q_block: int, + k_block: int, +) -> None: + if block_mask.full_kv_num_blocks is None or block_mask.full_kv_indices is None: + return + for q_block_index in range(int(block_mask.full_kv_num_blocks.shape[-1])): + q_slice = slice(q_block_index * q_block, (q_block_index + 1) * q_block) + block_count = int(block_mask.full_kv_num_blocks[0, 0, q_block_index]) + for k_block_index in block_mask.full_kv_indices[ + 0, 0, q_block_index, :block_count + ].tolist(): + k_slice = slice( + int(k_block_index) * k_block, + (int(k_block_index) + 1) * k_block, + ) + effective[q_slice, k_slice] = True + + +def _fill_partial_blocks( + effective: torch.Tensor, + block_mask: BlockMask, + *, + q_block: int, + k_block: int, +) -> None: + for q_block_index in range(int(block_mask.kv_num_blocks.shape[-1])): + q_offsets = torch.arange( + q_block_index * q_block, + min((q_block_index + 1) * q_block, effective.shape[0]), + )[:, None] + block_count = int(block_mask.kv_num_blocks[0, 0, q_block_index]) + for k_block_index in block_mask.kv_indices[ + 0, 0, q_block_index, :block_count + ].tolist(): + k_offsets = torch.arange( + int(k_block_index) * k_block, + min((int(k_block_index) + 1) * k_block, effective.shape[1]), + )[None, :] + effective[q_offsets, k_offsets] |= block_mask.mask_mod( + torch.zeros_like(q_offsets), + torch.zeros_like(q_offsets), + q_offsets, + k_offsets, + ) + + +def test_sparse_block_mask_supports_non_monotonic_remote_k_indices() -> None: + q_token_indices = torch.tensor([4, 5, 6, 7], dtype=torch.long) + k_token_indices = torch.tensor([0, 1, 6, 2, 3, 4], dtype=torch.long) + block_mask = build_block_mask( + FlexMaskSpec( + q_len=int(q_token_indices.numel()), + k_len=int(k_token_indices.numel()), + block_size=(2, 2), + slices=( + AttnSlice( + q_range=TokenRange(start=0, end=int(q_token_indices.numel())), + k_range=TokenRange(start=0, end=int(k_token_indices.numel())), + mask_kind=AttnMaskKind.CAUSAL, + row_index=0, + ), + ), + exact_mask=ExactMaskMetadata( + q_token_indices=q_token_indices, + k_token_indices=k_token_indices, + cache_key="non-monotonic-k", + ), + ), + group_ids=torch.ones(8, dtype=torch.long), + parent_ids=torch.ones(8, dtype=torch.long), + device=torch.device("cpu"), + ) + + assert block_mask is not None + q_indices = torch.arange(q_token_indices.numel())[:, None] + k_indices = torch.arange(k_token_indices.numel())[None, :] + + actual = block_mask.mask_mod( + torch.zeros_like(q_indices), + torch.zeros_like(q_indices), + q_indices, + k_indices, + ) + + assert actual.equal(q_token_indices[:, None] >= k_token_indices[None, :]) + _assert_matches_torch_block_mask(block_mask) + + +def _assert_matches_torch_block_mask( + block_mask: BlockMask, + *, + batch_size: int = 1, +) -> None: + q_len, k_len = block_mask.seq_lengths + reference = torch_block_mask( + block_mask.mask_mod, + B=batch_size, + H=1, + Q_LEN=q_len, + KV_LEN=k_len, + device="cpu", + BLOCK_SIZE=block_mask.BLOCK_SIZE, + ) + assert _effective_block_mask(block_mask).equal(_effective_block_mask(reference)) + for counts_name, indices_name in ( + ("kv_num_blocks", "kv_indices"), + ("full_kv_num_blocks", "full_kv_indices"), + ("q_num_blocks", "q_indices"), + ("full_q_num_blocks", "full_q_indices"), + ): + assert _block_entries(block_mask, counts_name, indices_name) == _block_entries( + reference, + counts_name, + indices_name, + ) + + +def _block_entries( + block_mask: BlockMask, + counts_name: str, + indices_name: str, +) -> set[tuple[int, int, int, int]]: + counts = getattr(block_mask, counts_name) + indices = getattr(block_mask, indices_name) + if counts is None or indices is None: + return set() + entries = set() + for batch_index in range(int(counts.shape[0])): + for head_index in range(int(counts.shape[1])): + for block_index in range(int(counts.shape[2])): + block_count = int(counts[batch_index, head_index, block_index]) + for other_block in indices[ + batch_index, + head_index, + block_index, + :block_count, + ].tolist(): + entries.add( + ( + batch_index, + head_index, + block_index, + int(other_block), + ) + ) + return entries + + +def _branching_prefix_inputs() -> tuple[torch.Tensor, torch.Tensor]: + return ( + torch.tensor([[1, 1, 1, 2, 3, 4, 4]], dtype=torch.long), + torch.tensor([[1, 1, 1, 1, 1, 1, 1]], dtype=torch.long), + ) + + +def _reference_tree_mask( + group_ids: torch.Tensor, parent_ids: torch.Tensor +) -> torch.Tensor: + group_list = [int(value) for value in group_ids.tolist()] + parent_by_group: dict[int, int | None] = {} + for group_id, parent_id in zip(group_list, parent_ids.tolist(), strict=True): + group_id = int(group_id) + parent_id = int(parent_id) + if group_id not in parent_by_group: + parent_by_group[group_id] = None if parent_id == group_id else parent_id + + ancestors_by_group = { + group_id: _ancestors(group_id, parent_by_group) for group_id in parent_by_group + } + dense = torch.zeros((len(group_list), len(group_list)), dtype=torch.bool) + for q_pos, q_group in enumerate(group_list): + allowed_groups = ancestors_by_group[q_group] | {q_group} + for k_pos, k_group in enumerate(group_list): + dense[q_pos, k_pos] = k_pos <= q_pos and k_group in allowed_groups + return dense + + +def _ancestors( + group_id: int, + parent_by_group: dict[int, int | None], +) -> set[int]: + ancestors: set[int] = set() + cursor = parent_by_group[group_id] + while cursor is not None and cursor not in ancestors: + ancestors.add(cursor) + cursor = parent_by_group.get(cursor) + return ancestors diff --git a/tests/unit/test_shared_prefix_grad_parity.py b/tests/unit/test_shared_prefix_grad_parity.py new file mode 100644 index 000000000..5b812782b --- /dev/null +++ b/tests/unit/test_shared_prefix_grad_parity.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +from copy import deepcopy + +import pytest +import torch +from torch import nn +import torch.nn.functional as F + +from art.megatron.shared_prefix_packing import SharedPrefixPack, pack_shared_prefixes + + +class _ToyCausalLM(nn.Module): + def __init__(self) -> None: + super().__init__() + self.token_embedding = nn.Embedding(32, 8, dtype=torch.float64) + self.position_embedding = nn.Embedding(8, 8, dtype=torch.float64) + self.mix = nn.Linear(8, 8, bias=False, dtype=torch.float64) + self.output = nn.Linear(8, 32, bias=False, dtype=torch.float64) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + causal_mask: torch.Tensor, + ) -> torch.Tensor: + states = self.token_embedding(input_ids) + self.position_embedding(position_ids) + context = causal_mask.to(states.dtype) @ states + return self.output(torch.tanh(self.mix(context))) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +@pytest.mark.parametrize("multi_target", (False, True)) +def test_shared_prefix_ce_parameter_grads_match_independent_sequences( + *, + max_depth: int, + multi_target: bool, +) -> None: + input_ids = _input_ids() + target_ids = tuple( + _targets(tokens, multi_target=multi_target) for tokens in input_ids + ) + pack = pack_shared_prefixes(input_ids, max_depth=max_depth) + + assert int(pack.tokens.numel()) < sum(len(row) for row in input_ids) + + torch.manual_seed(20260518) + naive_model = _ToyCausalLM() + packed_model = deepcopy(naive_model) + + naive_loss = torch.stack( + [ + _sequence_ce_loss(naive_model, tokens, labels) + for tokens, labels in zip(input_ids, target_ids, strict=True) + ] + ).sum() + packed_loss = _packed_ce_loss(packed_model, pack, target_ids) + + torch.testing.assert_close(packed_loss, naive_loss, rtol=1e-12, atol=1e-12) + naive_loss.backward() + packed_loss.backward() + + for (name, naive_param), packed_param in zip( + naive_model.named_parameters(), + packed_model.parameters(), + strict=True, + ): + assert naive_param.grad is not None, name + assert packed_param.grad is not None, name + torch.testing.assert_close( + packed_param.grad, + naive_param.grad, + rtol=1e-10, + atol=1e-10, + msg=lambda msg, name=name: f"{name} grad mismatch:\n{msg}", + ) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +def test_same_layout_mutation_preserves_forward_outputs(max_depth: int) -> None: + pack = pack_shared_prefixes(_input_ids(), max_depth=max_depth) + torch.manual_seed(20260518) + model = _ToyCausalLM() + logits = _packed_logits(model, pack) + + for positions in pack.positions_by_sequence: + mutated_logits = _packed_logits(model, _mutated_pack(pack, keep=positions)) + torch.testing.assert_close( + mutated_logits.index_select(0, positions), + logits.index_select(0, positions), + rtol=0.0, + atol=0.0, + ) + + +@pytest.mark.parametrize("max_depth", (1, 2, 3)) +@pytest.mark.parametrize("sequence_index", (0, 2, 4)) +def test_same_layout_mutation_preserves_target_loss_grads( + max_depth: int, + sequence_index: int, +) -> None: + input_ids = _input_ids() + target_ids = tuple(_targets(tokens, multi_target=True) for tokens in input_ids) + pack = pack_shared_prefixes(input_ids, max_depth=max_depth) + mutated = _mutated_pack(pack, keep=pack.positions_by_sequence[sequence_index]) + + torch.manual_seed(20260518) + base_model = _ToyCausalLM() + mutated_model = deepcopy(base_model) + + base_loss = _packed_sequence_ce_loss(base_model, pack, target_ids, sequence_index) + mutated_loss = _packed_sequence_ce_loss( + mutated_model, + mutated, + target_ids, + sequence_index, + ) + + torch.testing.assert_close(mutated_loss, base_loss, rtol=0.0, atol=0.0) + base_loss.backward() + mutated_loss.backward() + _assert_matching_grads(mutated_model, base_model) + + +def _input_ids() -> tuple[torch.Tensor, ...]: + return ( + torch.tensor([1, 2, 3, 4, 5]), + torch.tensor([1, 2, 3, 4, 6]), + torch.tensor([1, 2, 3, 7]), + torch.tensor([1, 2, 8]), + torch.tensor([9, 10, 11]), + ) + + +def _targets(tokens: torch.Tensor, *, multi_target: bool) -> torch.Tensor: + labels = (tokens * 3 + 5) % 31 + if not multi_target: + return labels + alternate = (tokens * 5 + 7) % 31 + stacked = torch.stack((labels, alternate), dim=1) + if int(stacked.numel()) > 2: + stacked[1, 1] = -100 + return stacked + + +def _sequence_ce_loss( + model: _ToyCausalLM, + input_ids: torch.Tensor, + target_ids: torch.Tensor, +) -> torch.Tensor: + seq_len = int(input_ids.numel()) + logits = model( + input_ids, + torch.arange(seq_len), + torch.ones((seq_len, seq_len), dtype=torch.bool).tril(), + ) + return _target_ce_loss(logits, target_ids) + + +def _packed_ce_loss( + model: _ToyCausalLM, + pack: SharedPrefixPack, + target_ids: tuple[torch.Tensor, ...], +) -> torch.Tensor: + logits = _packed_logits(model, pack) + losses = [ + _target_ce_loss(logits.index_select(0, positions), labels) + for positions, labels in zip( + pack.positions_by_sequence, + target_ids, + strict=True, + ) + ] + return torch.stack(losses).sum() + + +def _packed_sequence_ce_loss( + model: _ToyCausalLM, + pack: SharedPrefixPack, + target_ids: tuple[torch.Tensor, ...], + sequence_index: int, +) -> torch.Tensor: + return _target_ce_loss( + _packed_logits(model, pack).index_select( + 0, + pack.positions_by_sequence[sequence_index], + ), + target_ids[sequence_index], + ) + + +def _packed_logits(model: _ToyCausalLM, pack: SharedPrefixPack) -> torch.Tensor: + return model( + pack.tokens.reshape(-1), + pack.position_ids.reshape(-1), + _shared_prefix_causal_mask(pack), + ) + + +def _target_ce_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if labels.ndim == 1: + return F.cross_entropy(logits, labels, ignore_index=-100, reduction="sum") + expanded = logits.unsqueeze(1).expand(-1, int(labels.shape[1]), -1) + return F.cross_entropy( + expanded.reshape(-1, int(logits.shape[-1])), + labels.reshape(-1), + ignore_index=-100, + reduction="sum", + ) + + +def _mutated_pack(pack: SharedPrefixPack, *, keep: torch.Tensor) -> SharedPrefixPack: + tokens = pack.tokens.clone() + mutate = torch.ones(int(tokens.shape[1]), dtype=torch.bool) + mutate[keep] = False + replacement = torch.arange(int(tokens.shape[1]), dtype=tokens.dtype) + 17 + tokens[0, mutate] = replacement[mutate] % 31 + return SharedPrefixPack( + tokens=tokens, + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + position_ids=pack.position_ids, + positions_by_sequence=pack.positions_by_sequence, + ) + + +def _assert_matching_grads(actual_model: nn.Module, expected_model: nn.Module) -> None: + for (name, expected_param), actual_param in zip( + expected_model.named_parameters(), + actual_model.parameters(), + strict=True, + ): + assert expected_param.grad is not None, name + assert actual_param.grad is not None, name + torch.testing.assert_close( + actual_param.grad, + expected_param.grad, + rtol=1e-10, + atol=1e-10, + msg=lambda msg, name=name: f"{name} grad mismatch:\n{msg}", + ) + + +def _shared_prefix_causal_mask(pack: SharedPrefixPack) -> torch.Tensor: + group_ids = pack.group_ids.reshape(-1).tolist() + parent_ids = pack.parent_ids.reshape(-1).tolist() + position_ids = pack.position_ids.reshape(-1).tolist() + parent_by_group: dict[int, int] = {} + for group_id, parent_id in zip(group_ids, parent_ids, strict=True): + previous = parent_by_group.setdefault(group_id, parent_id) + assert previous == parent_id + + ancestors = { + group_id: _ancestor_groups(group_id, parent_by_group) + for group_id in parent_by_group + } + mask = torch.zeros((len(group_ids), len(group_ids)), dtype=torch.bool) + for query_index, query_group in enumerate(group_ids): + query_ancestors = ancestors[query_group] + query_position = position_ids[query_index] + for key_index, key_group in enumerate(group_ids): + if ( + key_group in query_ancestors + and position_ids[key_index] <= query_position + ): + mask[query_index, key_index] = True + return mask + + +def _ancestor_groups(group_id: int, parent_by_group: dict[int, int]) -> set[int]: + ancestors = {group_id} + parent_id = parent_by_group[group_id] + while parent_id != group_id: + if parent_id in ancestors: + raise AssertionError("shared-prefix group parents contain a cycle") + ancestors.add(parent_id) + group_id = parent_id + parent_id = parent_by_group[group_id] + return ancestors diff --git a/tests/unit/test_shared_prefix_packing.py b/tests/unit/test_shared_prefix_packing.py new file mode 100644 index 000000000..6243dd3da --- /dev/null +++ b/tests/unit/test_shared_prefix_packing.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import pytest +import torch + +from art.megatron.shared_prefix_packing import ( + estimate_shared_prefix_packed_tokens, + pack_shared_prefixes, +) +from art.megatron.trainer_rank import _local_position_pairs + + +def test_pack_shared_prefixes_support_depth_one() -> None: + inputs = ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 5]), + torch.tensor([9]), + ) + + pack = pack_shared_prefixes(inputs, max_depth=1) + + assert pack.tokens.tolist() == [[1, 2, 3, 4, 5, 9]] + assert pack.group_ids.tolist() == [[1, 1, 2, 2, 3, 4]] + assert pack.parent_ids.tolist() == [[1, 1, 1, 1, 1, 4]] + assert pack.position_ids.tolist() == [[0, 1, 2, 3, 2, 0]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1, 2, 3], + [0, 1, 4], + [5], + ] + + +def test_pack_shared_prefixes_support_zero_depth_without_sharing() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2]), + torch.tensor([1, 3]), + torch.tensor([4]), + ), + max_depth=0, + ) + + assert pack.tokens.tolist() == [[1, 2, 1, 3, 4]] + assert pack.group_ids.tolist() == [[1, 1, 2, 2, 3]] + assert pack.parent_ids.tolist() == [[1, 1, 2, 2, 3]] + assert pack.position_ids.tolist() == [[0, 1, 0, 1, 0]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1], + [2, 3], + [4], + ] + + +def test_pack_shared_prefixes_support_deeper_trees() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6, 7]), + ), + max_depth=2, + ) + + assert pack.tokens.tolist() == [[1, 2, 3, 4, 5, 6, 7]] + assert pack.group_ids.tolist() == [[1, 2, 2, 3, 4, 5, 5]] + assert pack.parent_ids.tolist() == [[1, 1, 1, 2, 2, 1, 1]] + assert pack.position_ids.tolist() == [[0, 1, 2, 3, 3, 1, 2]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0, 1, 2, 3], + [0, 1, 2, 4], + [0, 5, 6], + ] + + +def test_packing_preserves_first_seen_branch_order() -> None: + pack = pack_shared_prefixes( + (torch.tensor([9]), torch.tensor([1])), + max_depth=1, + ) + + assert pack.tokens.tolist() == [[9, 1]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [ + [0], + [1], + ] + + +def test_packing_handles_empty_sequences() -> None: + pack = pack_shared_prefixes( + (torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)), + max_depth=1, + ) + + assert pack.tokens.tolist() == [[]] + assert pack.group_ids.tolist() == [[]] + assert pack.parent_ids.tolist() == [[]] + assert [positions.tolist() for positions in pack.positions_by_sequence] == [[], []] + + +def test_packed_token_estimator_matches_real_packing() -> None: + cases = [ + (torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4]), torch.tensor([5])), + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 2, 6, 7]), + torch.tensor([1, 8]), + ), + ( + torch.tensor([9, 1, 2]), + torch.tensor([9, 1, 3]), + torch.tensor([9, 4, 5]), + torch.tensor([6, 7]), + torch.tensor([], dtype=torch.long), + ), + ] + + for inputs in cases: + for depth in range(5): + pack = pack_shared_prefixes(inputs, max_depth=depth) + + assert estimate_shared_prefix_packed_tokens(inputs, max_depth=depth) == int( + pack.tokens.numel() + ) + + +def test_packed_token_estimator_matches_randomized_packing() -> None: + generator = torch.Generator().manual_seed(123) + inputs = [] + for family in range(5): + prefix = torch.randint(1, 100, (4,), generator=generator) + for branch in range(4): + middle = torch.tensor([family, branch]) + suffix = torch.randint(1, 100, (3,), generator=generator) + inputs.append(torch.cat((prefix, middle, suffix))) + + for depth in range(5): + pack = pack_shared_prefixes(inputs, max_depth=depth) + + assert estimate_shared_prefix_packed_tokens(inputs, max_depth=depth) == int( + pack.tokens.numel() + ) + + +def test_packing_rejects_non_1d_sequences() -> None: + with pytest.raises(ValueError, match="expects 1-D tensors"): + pack_shared_prefixes((torch.tensor([[1, 2], [3, 4]]),), max_depth=1) + + +def test_local_position_pairs_preserve_requested_order_without_dense_match() -> None: + local_global_positions = torch.tensor([[2, -1, 0, 4, 1]]) + item_positions = torch.tensor([0, 1, 2, 3, 4]) + + local_positions, source_positions = _local_position_pairs( + local_global_positions, + item_positions, + ) + + assert local_positions.tolist() == [2, 4, 0, 3] + assert source_positions.tolist() == [0, 1, 2, 4] diff --git a/tests/unit/test_shared_prefix_tree.py b/tests/unit/test_shared_prefix_tree.py new file mode 100644 index 000000000..ce95c4fe1 --- /dev/null +++ b/tests/unit/test_shared_prefix_tree.py @@ -0,0 +1,502 @@ +from __future__ import annotations + +import pytest +import torch + +from art.megatron.shared_prefix_packing import pack_shared_prefixes +from art.megatron.shared_prefix_tree import parse_shared_prefix_row + + +def test_parse_shared_prefix_row_tracks_ancestors_and_depth() -> None: + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 3, 4, 9]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=3, + ) + + tree = parse_shared_prefix_row( + group_ids=pack.group_ids[0], + parent_ids=pack.parent_ids[0], + ) + + assert tree.valid_tokens == int(pack.tokens.numel()) + assert max(segment.depth for segment in tree.segments) == 3 + assert [(segment.group_id, segment.ancestors) for segment in tree.segments] == [ + (1, ()), + (2, (1,)), + (3, (1, 2)), + (4, (1, 2, 3)), + (5, (1, 2, 3)), + (6, (1, 2)), + (7, (1,)), + ] + + +def test_parse_shared_prefix_row_rejects_missing_parent() -> None: + with pytest.raises(RuntimeError, match="missing parent"): + parse_shared_prefix_row( + group_ids=torch.tensor([1, 2]), + parent_ids=torch.tensor([1, 3]), + ) + + +def test_parse_shared_prefix_row_rejects_non_contiguous_group() -> None: + with pytest.raises(RuntimeError, match="contiguous group runs"): + parse_shared_prefix_row( + group_ids=torch.tensor([1, 2, 1]), + parent_ids=torch.tensor([1, 1, 1]), + ) + + +def test_gdn_tree_parser_accepts_nested_tree() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4]), + torch.tensor([1, 2, 3, 5]), + torch.tensor([1, 6]), + ), + max_depth=2, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan(spec, device="cpu") + + assert spec.tree_parent_indices == (-1, 0, 1, 1, 0) + assert spec.tree_depths == (0, 1, 2, 2, 1) + assert [ + sum(bucket.segment_count for bucket in buckets) + for buckets in plan.tree_segment_buckets_by_depth + ] == [1, 2, 2] + + +def test_gdn_tree_parser_accepts_zero_depth_roots() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2]), + torch.tensor([1, 3]), + torch.tensor([4]), + ), + max_depth=0, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan(spec, device="cpu") + + assert spec.tree_parent_indices == (-1, -1, -1) + assert spec.tree_depths == (0, 0, 0) + assert [bucket.segment_count for bucket in plan.tree_segment_buckets_by_depth[0]] + assert not hasattr(plan, "local_prefix_buckets") + assert not hasattr(plan, "chain_completion_buckets") + assert not hasattr(plan, "prefix_boundary_buckets") + assert all( + not bucket.needs_final_state for bucket in plan.tree_segment_buckets_by_depth[0] + ) + + +def test_gdn_tree_planner_splits_leaf_and_internal_final_state_buckets() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + pack = pack_shared_prefixes( + ( + torch.tensor([1, 2, 3, 4, 7]), + torch.tensor([1, 2, 3, 4, 8]), + torch.tensor([1, 2, 5, 6]), + ), + max_depth=2, + ) + + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plan = build_gdn_rank_execution_plan( + spec, + device="cpu", + planner_config=GdnPlannerConfig(max_padding_ratio=4.0), + ) + tree_has_children = _tree_has_children(spec) + + depth_one_buckets = plan.tree_segment_buckets_by_depth[1] + assert any(bucket.needs_final_state for bucket in depth_one_buckets) + assert any(not bucket.needs_final_state for bucket in depth_one_buckets) + for bucket in depth_one_buckets: + expected = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + assert expected == {bucket.needs_final_state} + + +def test_gdn_tree_cp_plan_chains_long_nodes() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + root = torch.arange(1, 321) + mid = torch.arange(1001, 1321) + other = torch.arange(2001, 2321) + pack = pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([11]))), + torch.cat((root, mid, torch.tensor([12]))), + torch.cat((root, other, torch.tensor([13]))), + ), + max_depth=3, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + config = _chain_every_legal_segment_config() + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert any(plans[0].tree_chain_buckets_by_depth[0]) + assert not any( + bucket + for plan in plans + for depth_buckets in plan.tree_chain_buckets_by_depth[1:] + for bucket in depth_buckets + ) + _assert_remote_parent_state_transfers_cover(spec, plans) + for plan in plans: + assert sum(plan.gdn_token_count for plan in plans) == spec.real_token_count + for depth_buckets in plan.tree_chain_buckets_by_depth: + for bucket in depth_buckets: + assert bucket.lengths_by_rank_cpu is not None + assert tuple(bucket.lengths_by_rank_cpu.shape)[0] == 4 + assert bucket.parent_indices is not None + + +def test_gdn_tree_cp_plan_exchanges_remote_parent_states() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + root = torch.arange(1, 17) + mid = torch.arange(1001, 1321) + pack = pack_shared_prefixes( + ( + torch.cat((root, mid, torch.tensor([11]))), + torch.cat((root, mid, torch.tensor([12]))), + torch.cat((root, torch.tensor([99]))), + ), + max_depth=2, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=_chain_every_legal_segment_config(), + ) + for rank in range(4) + ) + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert not any( + bucket + for plan in plans + for depth_buckets in plan.tree_chain_buckets_by_depth[1:] + for bucket in depth_buckets + ) + assert _remote_parent_state_transfer_count(plans) > 0 + _assert_remote_parent_state_transfers_cover(spec, plans) + + +def test_gdn_tree_cp_randomized_plans_cover_each_token_once() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + config = _chain_every_legal_segment_config() + for seed in range(8): + pack = pack_shared_prefixes( + _random_tree_sequences(seed), + max_depth=4, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + assert _covered_token_indices(plans) == set(range(spec.real_token_count)) + assert sum(plan.gdn_token_count for plan in plans) == spec.real_token_count + for plan in plans: + for depth_buckets in ( + *plan.tree_segment_buckets_by_depth, + *plan.tree_chain_buckets_by_depth, + ): + for bucket in depth_buckets: + assert bucket.parent_indices is not None + assert int(bucket.real_token_count) > 0 + + +def test_gdn_tree_cp_randomized_plans_pass_health_checks() -> None: + pytest.importorskip("megatron.core.packed_seq_params") + from art.megatron.gdn.gdn_shared_prefix import ( + GdnPlannerConfig, + build_gdn_rank_execution_plan, + parse_gdn_shared_prefix_segments, + ) + + config = GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=64, + cp_chain_min_prefix_only_tokens=64, + cp_tree_chain_min_total_tokens=64, + cp_tree_chain_min_prefix_only_tokens=64, + max_padding_ratio=4.0, + ) + for seed in range(16): + pack = pack_shared_prefixes( + _random_tree_sequences(seed + 100, max_depth=5), + max_depth=5, + ) + spec = parse_gdn_shared_prefix_segments( + group_ids=pack.group_ids, + parent_ids=pack.parent_ids, + ) + plans = tuple( + build_gdn_rank_execution_plan( + spec, + device="cpu", + cp_rank=rank, + cp_size=4, + planner_config=config, + ) + for rank in range(4) + ) + + _assert_tree_plan_health( + spec, plans, max_padding_ratio=config.max_padding_ratio + ) + + +def _chain_every_legal_segment_config(): + from art.megatron.gdn.gdn_shared_prefix import GdnPlannerConfig + + return GdnPlannerConfig( + cp_chain_min_tokens_per_rank=1, + cp_chain_min_total_tokens=1, + cp_chain_min_prefix_only_tokens=1, + max_padding_ratio=4.0, + ) + + +def _covered_token_indices(plans) -> set[int]: + return { + token + for plan in plans + for start, end, _position in plan.gdn_token_ranges + for token in range(start, end) + } + + +def _local_owner_by_family(plans) -> dict[int, int]: + owner_by_family = {} + for rank, plan in enumerate(plans): + for depth_buckets in plan.tree_segment_buckets_by_depth: + for bucket in depth_buckets: + for family_index in bucket.family_indices.tolist(): + previous = owner_by_family.setdefault(int(family_index), rank) + assert previous == rank + return owner_by_family + + +def _assert_remote_parent_state_transfers_cover(spec, plans) -> None: + owner_by_family = _local_owner_by_family(plans) + for family_index, parent_index in enumerate(spec.tree_parent_indices): + if parent_index < 0 or parent_index not in owner_by_family: + continue + source_rank = owner_by_family[parent_index] + dest_rank = owner_by_family[family_index] + if source_rank == dest_rank: + continue + depth = spec.tree_depths[family_index] + source_exchange = plans[source_rank].tree_state_exchanges_by_depth[depth] + dest_exchange = plans[dest_rank].tree_state_exchanges_by_depth[depth] + assert source_exchange is not None + assert dest_exchange is not None + assert parent_index in source_exchange.source_family_indices + assert parent_index in dest_exchange.dest_family_indices + matching = [ + transfer + for transfer in dest_exchange.exchange.transfers + if transfer.source_rank == source_rank and transfer.dest_rank == dest_rank + ] + assert matching + + +def _remote_parent_state_transfer_count(plans) -> int: + return sum( + exchange.exchange.cross_rank_token_count + for plan in plans + for exchange in plan.tree_state_exchanges_by_depth + if exchange is not None + ) // len(plans) + + +def _tree_has_children(spec) -> list[bool]: + has_children = [False] * spec.family_count + for parent_index in spec.tree_parent_indices: + if parent_index >= 0: + has_children[parent_index] = True + return has_children + + +def _assert_tree_plan_health(spec, plans, *, max_padding_ratio: float) -> None: + tree_has_children = _tree_has_children(spec) + token_counts = [0] * int(spec.real_token_count) + for plan in plans: + range_tokens = sum( + end - start for start, end, _position in plan.gdn_token_ranges + ) + assert range_tokens == int(plan.gdn_token_count) + assert len(plan.attention_token_indices) == int(plan.attention_token_count) + + bucket_tokens = 0 + for depth_buckets in plan.tree_segment_buckets_by_depth: + for bucket in depth_buckets: + bucket_tokens += int(bucket.real_token_count) + assert bucket.parent_indices is not None + assert int(bucket.parent_indices.numel()) == int(bucket.segment_count) + assert int(bucket.real_token_count) > 0 + padding_ratio = ( + bucket.length * bucket.segment_count / bucket.real_token_count + ) + assert padding_ratio <= max_padding_ratio + bucket_state_flags = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + assert bucket_state_flags == {bucket.needs_final_state} + for family_index, parent_index in zip( + bucket.family_indices.tolist(), + bucket.parent_indices.tolist(), + strict=True, + ): + assert spec.tree_parent_indices[family_index] == parent_index + + for depth_buckets in plan.tree_chain_buckets_by_depth: + for bucket in depth_buckets: + bucket_tokens += int(bucket.real_token_count) + assert bucket.parent_indices is not None + assert int(bucket.parent_indices.numel()) == int(bucket.segment_count) + assert int(bucket.real_token_count) > 0 + padding_ratio = ( + bucket.length * bucket.segment_count / bucket.real_token_count + ) + assert padding_ratio <= max_padding_ratio + bucket_state_flags = { + tree_has_children[family_index] + for family_index in bucket.family_indices.tolist() + } + if bucket.needs_final_state: + assert any(bucket_state_flags) + else: + assert bucket_state_flags == {False} + for family_index, parent_index in zip( + bucket.family_indices.tolist(), + bucket.parent_indices.tolist(), + strict=True, + ): + assert spec.tree_parent_indices[family_index] == parent_index + assert bucket_tokens == int(plan.gdn_token_count) + + for start, end, _position in plan.gdn_token_ranges: + for token_index in range(start, end): + token_counts[token_index] += 1 + + _assert_remote_parent_state_transfers_cover(spec, plans) + assert token_counts == [1] * int(spec.real_token_count) + rank_tokens = [int(plan.gdn_token_count) for plan in plans] + assert max(rank_tokens) - min(rank_tokens) <= max(256, spec.real_token_count // 3) + + +def _random_tree_sequences( + seed: int, *, max_depth: int = 4 +) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + next_token = 1 + + def tokens(length: int) -> torch.Tensor: + nonlocal next_token + out = torch.arange(next_token, next_token + length) + next_token += length + return out + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def walk(prefix: torch.Tensor, depth: int) -> list[torch.Tensor]: + segment_length = [1, 3, 17, 64, 129, 257][randint(0, 5)] + here = torch.cat((prefix, tokens(segment_length))) + if depth + 1 >= max_depth: + return [ + torch.cat((here, tokens(randint(1, 9)))) for _ in range(randint(2, 4)) + ] + leaves: list[torch.Tensor] = [] + for _ in range(randint(2, 3)): + leaves.extend(walk(here, depth + 1)) + return leaves + + return tuple(walk(torch.empty(0, dtype=torch.long), 0)) diff --git a/tests/unit/test_trainer_rank_validation.py b/tests/unit/test_trainer_rank_validation.py new file mode 100644 index 000000000..deaf25886 --- /dev/null +++ b/tests/unit/test_trainer_rank_validation.py @@ -0,0 +1,448 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from art.megatron.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + TrainerRankMemoryError, + Unset, + _anchor_disconnected_outputs, + _MemoryCheck, + _MemoryProfile, + _validate_top_k, +) + + +class _Model: + vocab_size = 8 + + +def _runtime(model: torch.nn.Module | None = None) -> SimpleNamespace: + return SimpleNamespace( + model=[model or torch.nn.Linear(1, 1)], + optimizer=None, + provider=SimpleNamespace(hidden_size=4, num_layers=1), + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + + +def _target_request(token: int) -> ForwardInput[torch.Tensor, None, None, None]: + tokens = torch.tensor([token, token + 1], dtype=torch.long) + return ForwardInput(input_tokens=tokens, target_tokens=tokens) + + +def test_forward_input_rejects_non_positive_top_k() -> None: + with pytest.raises(ValueError, match="top_k must be >= 1"): + ForwardInput(input_tokens=torch.tensor([1]), top_k=0) + + +def test_forward_input_adapter_selection_defaults_to_unset() -> None: + request = ForwardInput(input_tokens=torch.tensor([1])) + + assert request.checkpoint is Unset + assert request.lora is Unset + + +def test_forward_input_accepts_explicit_base_checkpoint() -> None: + request = ForwardInput(input_tokens=torch.tensor([1]), checkpoint=None) + + assert request.checkpoint is None + assert request.lora is Unset + + +def test_forward_input_rejects_checkpoint_and_lora_together() -> None: + with pytest.raises(ValueError, match="cannot set both checkpoint and lora"): + ForwardInput(input_tokens=torch.tensor([1]), checkpoint="a", lora="b") + + +def test_validate_top_k_rejects_values_above_vocab_size() -> None: + with pytest.raises(ValueError, match="top_k=9 exceeds vocabulary size 8"): + _validate_top_k(9, _Model()) # type: ignore[arg-type] + + +def test_trainer_rank_accepts_nested_shared_prefix_for_gdn_runtime() -> None: + trainer = TrainerRank(_runtime(), shared_prefix_max_depth=2) # type: ignore[arg-type] + + assert trainer.shared_prefix_max_depth == 2 + + +def test_trainer_rank_accepts_zero_depth_shared_prefix_for_gdn_runtime() -> None: + trainer = TrainerRank(_runtime(), shared_prefix_max_depth=0) # type: ignore[arg-type] + + assert trainer.shared_prefix_max_depth == 0 + + +def test_trainer_rank_pop_rejects_empty_adapter_stack() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + + with pytest.raises(RuntimeError, match="No pushed LoRA or checkpoint"): + trainer.pop_pushed_lora_or_checkpoint() + + +def test_dp_rank_forward_preserves_nested_shape_for_inactive_requests() -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + request_a = ForwardInput(input_tokens=torch.tensor([1])) + request_b = ForwardInput(input_tokens=torch.tensor([2])) + + outputs = trainer.dp_rank_forward([[request_a], [request_b]]) + + assert len(outputs) == 2 + assert len(outputs[0]) == 1 + assert outputs[0][0].target_logprobs is None + assert outputs[1][0].target_logprobs is None + assert not hasattr(trainer, "forward") + assert not hasattr(trainer, "micro_batches") + + +def test_forward_micro_batches_uses_deterministic_dp_windows( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (1, 2)) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batches = list( + trainer.forward_micro_batches([_target_request(i) for i in range(5)]) + ) + + assert [batch.indices for batch in batches] == [(1,), (3,), ()] + assert [len(batch.outputs) for batch in batches] == [1, 1, 0] + + +def test_forward_micro_batches_outputs_match_top_level_nested_inputs( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + nested = [[_target_request(1), _target_request(3)]] + batch = next(iter(trainer.forward_micro_batches(nested))) + + assert batch.inputs == nested + assert len(batch.outputs) == 1 + assert len(batch.outputs[0]) == 2 + + +def test_forward_micro_batches_ramps_after_first_success( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + + def run(plan, **_kwargs): + trainer._memory_profiles[plan.signature] = _MemoryProfile( + bytes_per_token=0.0, + packed_tokens=plan.packed_tokens, + ) + return [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ] + + monkeypatch.setattr(trainer, "_run_flat_plan_with_memory_tracking", run) + + batches = list( + trainer.forward_micro_batches([_target_request(i) for i in range(8)]) + ) + + assert batches[0].stats.global_count == 1 + assert batches[0].stats.cold_start + assert batches[1].stats.global_count > 1 + assert not batches[1].stats.cold_start + + +def test_forward_micro_batches_does_not_overtrust_tiny_memory_profile( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + inputs = [_target_request(i) for i in range(64)] + tiny_plan = trainer._plan_flat_forward([inputs[0]]) + trainer._memory_profiles[tiny_plan.signature] = _MemoryProfile( + bytes_per_token=0.0, + packed_tokens=tiny_plan.packed_tokens, + ) + + candidate = trainer._select_next_micro_batch(inputs, 0) + + assert candidate.stats_global_count == 8 + assert candidate.plan.packed_tokens == 16 + + +def test_forward_micro_batches_shrinks_to_largest_fitting_window( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._last_global_micro_batch_size = 4 + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + + def required_memory(**kwargs): + return kwargs["packed_tokens"] + + def memory_check(required): + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=6, + fits=required <= 6, + ) + + monkeypatch.setattr( + trainer, "_estimate_required_memory_bytes_from_values", required_memory + ) + monkeypatch.setattr(trainer, "_memory_check_required", memory_check) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batch = next( + iter(trainer.forward_micro_batches([_target_request(i) for i in range(8)])) + ) + + assert batch.stats.global_count == 3 + assert batch.stats.rejected_candidates >= 1 + + +def test_forward_micro_batches_tail_does_not_reset_stable_window( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._last_global_micro_batch_size = 64 + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + trainer, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=128, + fits=required <= 128, + ), + ) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batches = list( + trainer.forward_micro_batches([_target_request(i) for i in range(130)]) + ) + + assert [batch.stats.global_count for batch in batches] == [64, 64, 2] + assert trainer._last_global_micro_batch_size == 64 + + +def test_forward_micro_batches_grows_small_stable_window_when_work_remains( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + trainer._last_global_micro_batch_size = 64 + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + trainer, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=512, + fits=required <= 512, + ), + ) + + candidate = trainer._select_next_micro_batch( + [_target_request(i) for i in range(512)], + 0, + ) + + assert candidate.stats_global_count == 256 + + +def test_forward_micro_batches_reuses_cached_candidate_plans( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, "_all_ranks_have_memory_profile", lambda **_kwargs: True + ) + monkeypatch.setattr( + trainer, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + original_plan = trainer._plan_flat_forward + plan_calls = 0 + memory_checks = 0 + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + def memory_check(plan): + nonlocal memory_checks + memory_checks += 1 + return _MemoryCheck( + estimated_required_bytes=plan.packed_tokens, + available_bytes=10, + fits=True, + ) + + monkeypatch.setattr(trainer, "_plan_flat_forward", plan) + monkeypatch.setattr(trainer, "_memory_check", memory_check) + inputs = [_target_request(i) for i in range(8)] + + list(trainer.forward_micro_batches(inputs)) + first_plan_calls = plan_calls + first_memory_checks = memory_checks + list(trainer.forward_micro_batches(inputs)) + + assert first_plan_calls > 0 + assert first_plan_calls == 1 + assert plan_calls == first_plan_calls + assert memory_checks == first_memory_checks == 0 + + +def test_forward_micro_batches_raises_when_smallest_batch_will_not_fit( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(trainer, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + trainer, + "_estimate_required_memory_bytes_from_values", + lambda **_kwargs: 4, + ) + monkeypatch.setattr( + trainer, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=3, + fits=False, + ), + ) + with pytest.raises(TrainerRankMemoryError, match="smallest DP microbatch"): + next(iter(trainer.forward_micro_batches([_target_request(1)]))) + + +def test_forward_micro_batches_rejects_mismatched_replicated_counts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + trainer = TrainerRank(_runtime()) # type: ignore[arg-type] + import art.megatron.trainer_rank as trainer_rank + + monkeypatch.setattr(trainer_rank.dist, "is_available", lambda: True) + monkeypatch.setattr(trainer_rank.dist, "is_initialized", lambda: True) + monkeypatch.setattr(trainer_rank.dist, "get_world_size", lambda: 2) + + def gather(output, value): + output[:] = [value, value + 1] + + monkeypatch.setattr(trainer_rank.dist, "all_gather_object", gather) + + with pytest.raises(ValueError, match="same top-level input count"): + list(trainer.forward_micro_batches([_target_request(1)])) + + +def test_forward_plan_estimates_output_memory_for_request_combo() -> None: + class FakeGPT(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros(())) + self.config = SimpleNamespace( + hidden_size=4, + num_layers=1, + padded_vocab_size=10, + ) + self.decoder = object() + + def _preprocess(self, *args: object, **kwargs: object) -> None: + return None + + trainer = TrainerRank(_runtime(FakeGPT())) # type: ignore[arg-type] + tokens = torch.tensor([1, 2, 3], dtype=torch.long) + labels = torch.stack((tokens, tokens + 1), dim=1) + + plan = trainer._plan_flat_forward( + [ + ForwardInput( + input_tokens=tokens, + target_tokens=labels, + top_k=5, + logits=True, + hidden_states=True, + ) + ] + ) + + target_bytes = 3 * 2 * 4 + topk_bytes = 3 * 5 * (4 + 8) + logits_bytes = 3 * 10 * 4 + hidden_bytes = 3 * 4 * 4 + assert plan.output_bytes == target_bytes + topk_bytes + logits_bytes + hidden_bytes + + +def test_disconnected_outputs_keep_zero_graph_anchor() -> None: + hidden = torch.randn(2, 3, requires_grad=True) + disconnected = torch.zeros(4) + top_k = TopK(logprobs=torch.zeros(4, 2), tokens=torch.ones(4, 2, dtype=torch.long)) + + (anchored,), (anchored_top_k,) = _anchor_disconnected_outputs( + [disconnected], + [top_k], + hidden, + ) + + assert anchored is not None + assert anchored.requires_grad + assert anchored_top_k is not None + assert anchored_top_k.logprobs.requires_grad + torch.testing.assert_close(anchored, disconnected) + torch.testing.assert_close(anchored_top_k.logprobs, top_k.logprobs) + (anchored.sum() + anchored_top_k.logprobs.sum()).backward() + assert hidden.grad is not None + torch.testing.assert_close(hidden.grad, torch.zeros_like(hidden)) diff --git a/tests/unit/test_trainer_rank_weird_shapes.py b/tests/unit/test_trainer_rank_weird_shapes.py new file mode 100644 index 000000000..4dfbae5d7 --- /dev/null +++ b/tests/unit/test_trainer_rank_weird_shapes.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from collections.abc import Iterable +from types import SimpleNamespace + +import pytest +import torch + +from art.megatron.shared_prefix_packing import ( + estimate_shared_prefix_packed_tokens, + pack_shared_prefixes, +) +from art.megatron.trainer_rank import ( + ForwardInput, + ForwardOutput, + TopK, + TrainerRank, + TrainerRankMemoryError, + Unset, + _flatten, + _MemoryCheck, +) + + +class _FakeGPT(torch.nn.Module): + def __init__(self, *, hidden_size: int = 8, vocab_size: int = 32) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros((), dtype=torch.float16)) + self.config = SimpleNamespace( + hidden_size=hidden_size, + num_layers=4, + padded_vocab_size=vocab_size, + ) + self.decoder = object() + + def _preprocess(self, *args: object, **kwargs: object) -> None: + return None + + +def _runtime() -> SimpleNamespace: + return SimpleNamespace( + model=[_FakeGPT()], + optimizer=None, + provider=SimpleNamespace(hidden_size=8, num_layers=4), + model_support_handler=SimpleNamespace(build_gdn_execution_spec=True), + ) + + +def _tokens(*values: int) -> torch.Tensor: + return torch.tensor(values, dtype=torch.long) + + +def _target_request( + tokens: torch.Tensor, + *, + target_count: int = 1, + top_k: int | None = None, + logits: bool = False, + hidden_states: bool = False, + checkpoint: object = Unset, + lora: object = Unset, +) -> ForwardInput: + labels = ( + tokens + if target_count == 1 + else torch.stack( + tuple(tokens + offset for offset in range(target_count)), + dim=-1, + ) + ) + return ForwardInput( + input_tokens=tokens, + target_tokens=labels, + top_k=top_k, + logits=logits, + hidden_states=hidden_states, + checkpoint=checkpoint, # type: ignore[arg-type] + lora=lora, # type: ignore[arg-type] + ) + + +def _ternary_tree_sequences() -> tuple[torch.Tensor, ...]: + # Shape: shared root, two continuation branches, and terminal nodes at + # several depths. This mirrors prompt -> continuation A/B -> terminal data. + root = [10, 11, 12] + left = root + [20, 21] + right = root + [30, 31, 32] + return ( + _tokens(*(root + [1])), + _tokens(*(left + [2])), + _tokens(*(left + [3, 4])), + _tokens(*(right + [5])), + _tokens(*(right + [6, 7])), + _tokens(80, 81), + ) + + +def _vineppo_like_inputs() -> list[list[ForwardInput]]: + groups: list[list[ForwardInput]] = [] + for prompt_index in range(4): + prompt = [100 + prompt_index, 200 + prompt_index, 201 + prompt_index] + trajectories = [] + for branch_index, completion_len in enumerate((1, 2, 4)): + completion = [300 + branch_index] * completion_len + tokens = _tokens(*(prompt + completion)) + trajectories.append( + _target_request( + tokens, + target_count=2 if branch_index == 2 else 1, + top_k=5 if branch_index == 1 else None, + hidden_states=branch_index == 0, + ) + ) + groups.append(trajectories) + return groups + + +def _random_tree_sequences(seed: int, *, max_depth: int) -> tuple[torch.Tensor, ...]: + generator = torch.Generator().manual_seed(seed) + out: list[torch.Tensor] = [] + + def randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=generator).item()) + + def segment(depth: int) -> list[int]: + return [depth * 100 + randint(1, 40) for _ in range(randint(1, 4))] + + def walk(prefix: list[int], depth: int) -> None: + if depth >= max_depth or randint(0, 2) == 0: + out.append(_tokens(*(prefix + segment(depth)))) + return + shared = prefix + segment(depth) + out.append(_tokens(*shared)) + walk(shared + [10 + depth], depth + 1) + walk(shared + [20 + depth], depth + 1) + + walk([], 0) + return tuple(out) + + +@pytest.mark.parametrize("max_depth", (0, 1, 2, 4)) +def test_pack_estimator_matches_ternary_and_random_trees(max_depth: int) -> None: + cases = [ + _ternary_tree_sequences(), + _random_tree_sequences(3, max_depth=4), + _random_tree_sequences(99, max_depth=5), + ] + + for sequences in cases: + pack = pack_shared_prefixes(sequences, max_depth=max_depth) + + assert estimate_shared_prefix_packed_tokens( + sequences, max_depth=max_depth + ) == int(pack.tokens.numel()) + for sequence, positions in zip( + sequences, pack.positions_by_sequence, strict=True + ): + torch.testing.assert_close(pack.tokens.reshape(-1)[positions], sequence) + + +def test_planner_handles_vineppo_nested_shape_and_request_mix() -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=3) # type: ignore[arg-type] + inputs = _vineppo_like_inputs() + flat = list(_flatten(inputs)) + + plan = rank._plan_flat_forward(flat) + estimate = rank._estimate_flat_forward(flat) + + assert estimate is not None + packed_tokens, output_bytes, signature = estimate + assert packed_tokens == plan.packed_tokens + assert output_bytes == plan.output_bytes + assert signature == plan.signature + assert plan.request_count == 12 + assert plan.signature.request_mix == ( + "target:(2,)", + "target:single+hidden", + "target:single+topk:5", + ) + + +def test_forward_micro_batches_preserves_nested_vineppo_groups( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=2) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + monkeypatch.setattr( + rank, + "_memory_check", + lambda plan: _MemoryCheck(plan.packed_tokens, 10_000, True), + ) + monkeypatch.setattr( + rank, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + groups = _vineppo_like_inputs() + + micro_batches = list(rank.forward_micro_batches(groups)) + + assert [batch.indices for batch in micro_batches] == [(0, 1, 2, 3)] + assert micro_batches[0].select(groups) == groups + assert len(micro_batches[0].outputs) == 4 + assert all( + isinstance(group_outputs, list) and len(group_outputs) == 3 + for group_outputs in micro_batches[0].outputs + ) + + +def test_adaptive_planner_materializes_only_final_large_candidate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=3) # type: ignore[arg-type] + rank._last_global_micro_batch_size = 32 + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + plan_calls = 0 + estimate_calls = 0 + original_plan = rank._plan_flat_forward + original_estimate = rank._estimate_flat_forward + inputs = [ + _target_request( + _tokens(1, 2, 3, index % 7, index), + target_count=2 if index % 5 == 0 else 1, + top_k=3 if index % 4 == 0 else None, + hidden_states=index % 9 == 0, + ) + for index in range(96) + ] + limit = rank._estimate_flat_forward(inputs[:40]) + assert limit is not None + limit_packed_tokens = limit[0] + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + def estimate(requests): + nonlocal estimate_calls + estimate_calls += 1 + return original_estimate(requests) + + def required_memory(**kwargs): + return kwargs["packed_tokens"] + + def check(required): + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=limit_packed_tokens, + fits=required <= limit_packed_tokens, + ) + + monkeypatch.setattr(rank, "_plan_flat_forward", plan) + monkeypatch.setattr(rank, "_estimate_flat_forward", estimate) + monkeypatch.setattr( + rank, "_estimate_required_memory_bytes_from_values", required_memory + ) + monkeypatch.setattr(rank, "_memory_check_required", check) + + candidate = rank._select_next_micro_batch(inputs, 0) + + assert candidate.stats_global_count == 40 + assert plan_calls == 1 + assert estimate_calls <= 10 + assert candidate.rejected_candidates <= 8 + + +def test_adaptive_planner_reuses_large_stable_window( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=1) # type: ignore[arg-type] + rank._last_global_micro_batch_size = 512 + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + monkeypatch.setattr( + rank, + "_estimate_required_memory_bytes_from_values", + lambda **kwargs: kwargs["packed_tokens"], + ) + monkeypatch.setattr( + rank, + "_memory_check_required", + lambda required: _MemoryCheck( + estimated_required_bytes=required, + available_bytes=700, + fits=required <= 700, + ), + ) + + candidate = rank._select_next_micro_batch( + [_target_request(_tokens(index)) for index in range(900)], + 0, + ) + + assert candidate.stats_global_count == 512 + assert candidate.rejected_candidates == 0 + + +def test_forward_micro_batches_shrinks_when_memory_budget_drops( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=2) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr(rank, "_all_ranks_have_memory_profile", lambda **_kwargs: True) + inputs = [_target_request(_tokens(1, 2, 3, index)) for index in range(14)] + first_limit = rank._estimate_flat_forward(inputs[:8]) + tail_limit = rank._estimate_flat_forward(inputs[8:11]) + assert first_limit is not None + assert tail_limit is not None + first_limit_packed_tokens = first_limit[0] + tail_limit_packed_tokens = tail_limit[0] + available = {"packed_tokens": first_limit_packed_tokens} + plan_calls = 0 + original_plan = rank._plan_flat_forward + + def plan(requests): + nonlocal plan_calls + plan_calls += 1 + return original_plan(requests) + + def required_memory(**kwargs): + return kwargs["packed_tokens"] + + def check(required): + limit = available["packed_tokens"] + return _MemoryCheck( + estimated_required_bytes=required, + available_bytes=limit, + fits=required <= limit, + ) + + def run(plan, **_kwargs): + if available["packed_tokens"] == first_limit_packed_tokens: + available["packed_tokens"] = tail_limit_packed_tokens + return [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ] + + monkeypatch.setattr(rank, "_plan_flat_forward", plan) + monkeypatch.setattr( + rank, "_estimate_required_memory_bytes_from_values", required_memory + ) + monkeypatch.setattr(rank, "_memory_check_required", check) + monkeypatch.setattr(rank, "_run_flat_plan_with_memory_tracking", run) + + batches = list(rank.forward_micro_batches(inputs)) + + assert [batch.stats.global_count for batch in batches] == [8, 3, 3] + assert [batch.stats.available_bytes for batch in batches] == [ + first_limit_packed_tokens, + tail_limit_packed_tokens, + tail_limit_packed_tokens, + ] + assert [batch.indices for batch in batches] == [ + tuple(range(8)), + (8, 9, 10), + (11, 12, 13), + ] + assert plan_calls == len(batches) + + +def test_heterogeneous_slots_split_packing_without_losing_output_estimates( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime(), shared_prefix_max_depth=4) # type: ignore[arg-type] + monkeypatch.setattr( + TrainerRank, + "_slot_ref", + staticmethod(lambda kind, name: (kind, name)), + ) + rank.set_checkpoint("student") + requests = [ + _target_request(_tokens(1, 2, 3), top_k=3), + _target_request(_tokens(1, 2, 4), checkpoint=None, logits=True), + _target_request(_tokens(1, 2, 5), lora="teacher", hidden_states=True), + _target_request(_tokens(1, 2, 6), checkpoint="critic", target_count=4), + ] + + plan = rank._plan_flat_forward(requests) + estimate = rank._estimate_flat_forward(requests) + + assert estimate is not None + packed_tokens, output_bytes, signature = estimate + assert packed_tokens == plan.packed_tokens + assert output_bytes == plan.output_bytes + assert signature == plan.signature + assert plan.signature.slot_group_count == 4 + assert {group.slot_ref for group in plan.groups} == { + ("checkpoint", "student"), + ("checkpoint", None), + ("lora", "teacher"), + ("checkpoint", "critic"), + } + + +def test_dp_uneven_tail_yields_empty_rank_batch( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (3, 4)) + monkeypatch.setattr( + rank, + "_run_flat_plan_with_memory_tracking", + lambda plan, **_kwargs: [ + ForwardOutput(None, None, None, None) for _ in range(plan.request_count) + ], + ) + + batches = list( + rank.forward_micro_batches( + [_target_request(_tokens(i, i + 1)) for i in range(5)] + ) + ) + + assert [batch.indices for batch in batches] == [(3,), ()] + assert [batch.stats.local_count for batch in batches] == [1, 0] + + +def test_dp_rank_forward_raises_before_expected_oom( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr( + rank, + "_memory_check", + lambda plan: _MemoryCheck( + estimated_required_bytes=plan.output_bytes + 1, + available_bytes=plan.output_bytes, + fits=False, + ), + ) + + with pytest.raises(TrainerRankMemoryError, match="dp_rank_forward"): + rank.dp_rank_forward( + [_target_request(_tokens(1, 2, 3), logits=True, hidden_states=True)] + ) + + +def test_memory_error_includes_actionable_shape_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + monkeypatch.setattr(rank, "_dp_rank_and_size", lambda: (0, 1)) + monkeypatch.setattr( + rank, + "_estimate_required_memory_bytes_from_values", + lambda **_kwargs: 99, + ) + monkeypatch.setattr( + rank, + "_memory_check_required", + lambda required: _MemoryCheck(required, 1, False), + ) + + with pytest.raises(TrainerRankMemoryError) as exc_info: + next( + iter( + rank.forward_micro_batches( + [_target_request(_tokens(1, 2, 3), logits=True)] + ) + ) + ) + + message = str(exc_info.value) + assert "packed_tokens=" in message + assert "logical_tokens=" in message + assert "output_gb=" in message + assert "Use smaller top-level items" in message + + +def test_topk_output_memory_scales_with_requested_k() -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + tokens = _tokens(1, 2, 3, 4) + + small = rank._plan_flat_forward([_target_request(tokens, top_k=1)]) + large = rank._plan_flat_forward([_target_request(tokens, top_k=7)]) + + assert large.output_bytes - small.output_bytes == 4 * 6 * (4 + 8) + + +def test_flatten_rejects_dicts_to_avoid_silent_top_level_shape_changes() -> None: + with pytest.raises(TypeError, match="dict was passed directly"): + list(_flatten({"bad": _target_request(_tokens(1, 2))})) # type: ignore[arg-type] + + +def test_no_output_requests_do_not_pack_or_consume_compute_memory() -> None: + rank = TrainerRank(_runtime()) # type: ignore[arg-type] + requests: Iterable[ForwardInput] = [ + ForwardInput(input_tokens=_tokens(1, 2, 3)), + ForwardInput(input_tokens=_tokens(1, 2, 4)), + ] + + plan = rank._plan_flat_forward(list(requests)) + + assert plan.groups == () + assert plan.packed_tokens == 0 + assert rank._memory_check(plan).estimated_required_bytes == 0 diff --git a/typings/wandb/__init__.pyi b/typings/wandb/__init__.pyi new file mode 100644 index 000000000..09d1c8d16 --- /dev/null +++ b/typings/wandb/__init__.pyi @@ -0,0 +1,38 @@ +from typing import Any + +class Settings: + def __init__(self, **kwargs: Any) -> None: ... + +class Artifact: + aliases: list[str] + metadata: dict[str, Any] + def __init__(self, name: str, type: str, **kwargs: Any) -> None: ... + def add_dir(self, local_path: str, **kwargs: Any) -> None: ... + def add_file(self, local_path: str, **kwargs: Any) -> None: ... + def download(self, **kwargs: Any) -> str: ... + def save(self) -> None: ... + def wait(self) -> Artifact: ... + +class Run: + entity: str + project: str + name: str + config: Any + _is_finished: bool + def finish(self, *args: Any, **kwargs: Any) -> None: ... + def define_metric(self, *args: Any, **kwargs: Any) -> None: ... + def log(self, *args: Any, **kwargs: Any) -> None: ... + def log_artifact(self, *args: Any, **kwargs: Any) -> Artifact: ... + +class Api: + default_entity: str + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + def artifact(self, *args: Any, **kwargs: Any) -> Artifact: ... + def artifacts(self, *args: Any, **kwargs: Any) -> list[Artifact]: ... + def run(self, *args: Any, **kwargs: Any) -> Run: ... + +def init(*args: Any, **kwargs: Any) -> Run: ... +def login(*args: Any, **kwargs: Any) -> Any: ... + +class errors: + class CommError(Exception): ... diff --git a/typings/wandb/sdk/__init__.pyi b/typings/wandb/sdk/__init__.pyi new file mode 100644 index 000000000..1ce9ecf99 --- /dev/null +++ b/typings/wandb/sdk/__init__.pyi @@ -0,0 +1,5 @@ +from typing import Any + +__all__: list[str] + +def __getattr__(name: str) -> Any: ... diff --git a/typings/wandb/sdk/wandb_run.pyi b/typings/wandb/sdk/wandb_run.pyi new file mode 100644 index 000000000..416ae101e --- /dev/null +++ b/typings/wandb/sdk/wandb_run.pyi @@ -0,0 +1,3 @@ +from wandb import Run + +__all__ = ["Run"]