Plumb FP8+THD #2994
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Take upstream versions for fused_attn.cpp and fused_attn_fp8.cu APIs. Keep branch's test_attention.py THD parametrization. FP8+THD ragged-offset plumbing is re-applied in the following commit. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirrors the F16 arbitrary_seqlen ragged-offset pattern in the FP8 path: - Backend selector: enable FP8+THD for cuDNN >= 9.23 on sm >= 100 - fwd/bwd _impl: ragged detection, batch/seqlen bucketing, set_ragged_offset() on Q/K/V/O/dO/dQ/dK/dV/Stats, workspace allocation for ragged offsets, cu_seqlens_padded_to_offsets kernel - fwd/bwd dispatchers: accept num_tokens_q/kv, cu_seqlens_padded, compute max_batch/max_tokens, THD Stats shape Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR extends FP8 fused attention to support the THD (packed/ragged variable-length) format, enabling FP8+THD combination that was previously only available for the F16 arbitrary-seqlen backend. The changes thread token counts (
Confidence Score: 3/5The core CUDA/cuDNN graph logic appears consistent between FWD and BWD, but two issues need addressing before merge: debug prints committed to the test helper, and the FP8 backend selector accepting THD with non-padding mask types that the implementation rejects at runtime. The ragged offset workspace arithmetic is careful and the FWD/BWD symmetry looks correct. However, transformer_engine/common/fused_attn/fused_attn.cpp (backend selector mask-type restriction for THD) and tests/pytorch/attention/test_attention.py (debug prints). Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant fused_attn.cpp as fused_attn.cpp
participant BackendSelector as nvte_get_fused_attn_backend
participant FP8FWD as fused_attn_fp8_fwd
participant FP8Impl as fused_attn_fp8_fwd_impl
participant cuDNN as cuDNN Frontend Graph
Caller->>fused_attn.cpp: Q,K,V + cu_seqlens + cu_seqlens_padded + num_tokens_q/kv
fused_attn.cpp->>BackendSelector: qkv_format, mask_type, ...
BackendSelector-->>fused_attn.cpp: "NVTE_FP8 (if THD + cuDNN>=9.23 + sm100+)"
fused_attn.cpp->>FP8FWD: batch, t_q, t_kv, cu_seqlens_q/kv_padded, ...
FP8FWD->>FP8FWD: get_max_batch_size / get_max_tokens
FP8FWD->>FP8Impl: max_b, max_t_q, max_t_kv, devPtrSeqOffsetsQ/KV, ...
FP8Impl->>FP8Impl: detect is_ragged_q/kv, use_ragged_stats
FP8Impl->>FP8Impl: "b=max_b, s_q=max_t_q, s_kv=max_t_kv (sm!=120)"
FP8Impl->>cuDNN: build graph with ragged offset tensors (offset_q/k/v/o/stats)
FP8Impl->>FP8Impl: launch cu_seqlens_padded_to_offsets kernel
FP8Impl->>cuDNN: execute plan with variant_pack
cuDNN-->>Caller: O (FP8 output) + Stats
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| print(f"qkv_format: {qkv_format}") | ||
| print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}") |
There was a problem hiding this comment.
Debug prints committed to test code
Two print statements were left in _run_dpa_fp8_vs_f16, which will produce noise on every test run and in CI logs. These look like leftover debug instrumentation from development.
| print(f"qkv_format: {qkv_format}") | |
| print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}") | |
| with autocast(enabled=fp8_dpa, recipe=fp8_recipe): |
| (cudnn_runtime_version >= 92300 && sm_arch_ >= 100 && | ||
| qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) && |
There was a problem hiding this comment.
Backend selection for THD permits non-padding mask types that the implementation rejects
nvte_get_fused_attn_backend can return NVTE_FP8 for qkv_format == NVTE_THD with attn_mask_type == NVTE_NO_MASK or NVTE_CAUSAL_MASK, because the only mask-type filter is the outer condition shared with the BSHD/SBHD/BHSD paths. However, fused_attn_fp8_fwd_impl (and bwd_impl) enforce NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"), which means those combinations would pass backend selection only to fail at kernel dispatch with a confusing internal error. The THD sub-condition should additionally require attn_mask_type == NVTE_PADDING_MASK || attn_mask_type == NVTE_PADDING_CAUSAL_MASK.
| (cudnn_runtime_version >= 92300 && sm_arch_ >= 100 && | |
| qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) && | |
| (cudnn_runtime_version >= 92300 && sm_arch_ >= 100 && | |
| qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size && | |
| (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || | |
| attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: