Skip to content

Plumb FP8+THD #2994

Open
sudhakarsingh27 wants to merge 7 commits into
NVIDIA:mainfrom
sudhakarsingh27:fp8_thd_attention_try2
Open

Plumb FP8+THD #2994
sudhakarsingh27 wants to merge 7 commits into
NVIDIA:mainfrom
sudhakarsingh27:fp8_thd_attention_try2

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Take upstream versions for fused_attn.cpp and fused_attn_fp8.cu APIs.
Keep branch's test_attention.py THD parametrization.
FP8+THD ragged-offset plumbing is re-applied in the following commit.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Mirrors the F16 arbitrary_seqlen ragged-offset pattern in the FP8 path:
- Backend selector: enable FP8+THD for cuDNN >= 9.23 on sm >= 100
- fwd/bwd _impl: ragged detection, batch/seqlen bucketing,
  set_ragged_offset() on Q/K/V/O/dO/dQ/dK/dV/Stats, workspace
  allocation for ragged offsets, cu_seqlens_padded_to_offsets kernel
- fwd/bwd dispatchers: accept num_tokens_q/kv, cu_seqlens_padded,
  compute max_batch/max_tokens, THD Stats shape

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 requested a review from cyanguwa as a code owner May 14, 2026 19:09
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 14, 2026

Greptile Summary

This PR extends FP8 fused attention to support the THD (packed/ragged variable-length) format, enabling FP8+THD combination that was previously only available for the F16 arbitrary-seqlen backend. The changes thread token counts (num_tokens_q, num_tokens_kv) and padded cu_seqlens pointers down through fused_attn_fp8_fwd/bwd into the cuDNN frontend graph builder, adds ragged offset tensors for Q/K/V/O/Stats, and updates workspace sizing to accommodate the new offset arrays.

  • fused_attn_fp8.cu: Both FWD and BWD _impl functions gain ragged offset graph tensors, adjusted workspace sizing with 16-byte alignment, and a cu_seqlens_padded_to_offsets kernel dispatch; the stats tensor stride is switched to the ragged layout {h*s_q, 1, h, 1} when use_ragged_stats is true.
  • fused_attn.cpp: The FP8 backend condition gains a new THD sub-clause (cuDNN 9.23+ / sm100+); the FP8 fwd/bwd call sites are updated to forward the new parameters.
  • test_attention.py: thd_thd_thd layout and thd format added to the FP8 vs F16 test matrix; two debug print statements were left in the test helper and need to be removed.

Confidence Score: 3/5

The core CUDA/cuDNN graph logic appears consistent between FWD and BWD, but two issues need addressing before merge: debug prints committed to the test helper, and the FP8 backend selector accepting THD with non-padding mask types that the implementation rejects at runtime.

The ragged offset workspace arithmetic is careful and the FWD/BWD symmetry looks correct. However, nvte_get_fused_attn_backend will return NVTE_FP8 for THD + NVTE_NO_MASK or NVTE_CAUSAL_MASK, causing callers to dispatch to FP8 only to receive a NVTE_CHECK failure inside fused_attn_fp8_fwd_impl; this is a present inconsistency on a newly-added code path. The debug prints in the test helper are an obvious oversight that produces noise in every CI run exercising the new THD test cases.

transformer_engine/common/fused_attn/fused_attn.cpp (backend selector mask-type restriction for THD) and tests/pytorch/attention/test_attention.py (debug prints).

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Adds thd_thd_thd layout and thd format to FP8 vs F16 test matrix; adds fp8_output and fast_zero_fill arguments to the DPA call; leaves two debug print statements that must be removed.
transformer_engine/common/fused_attn/fused_attn.cpp Adds THD format support to the FP8 backend selection condition and plumbs t_q/t_kv token counts and padded cu_seqlens pointers through to the FP8 fwd/bwd calls; the new THD sub-condition in the backend selector does not restrict to padding mask types, diverging from the implementation's own enforcement.
transformer_engine/common/fused_attn/fused_attn_fp8.cu Plumbs THD/ragged (packed variable-length) support into the FP8 forward and backward attention kernels: adds ragged offset tensors for Q/K/V/O/Stats, adjusts workspace sizing with aligned offsets, and adds cu_seqlens_padded_to_offsets kernel dispatch; logic appears consistent between FWD and BWD paths.
transformer_engine/common/fused_attn/fused_attn_fp8.h Header updated to match the new num_tokens_q, num_tokens_kv, cu_seqlens_q_padded, and cu_seqlens_kv_padded parameters added to both fused_attn_fp8_fwd and fused_attn_fp8_bwd; no issues found.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant fused_attn.cpp as fused_attn.cpp
    participant BackendSelector as nvte_get_fused_attn_backend
    participant FP8FWD as fused_attn_fp8_fwd
    participant FP8Impl as fused_attn_fp8_fwd_impl
    participant cuDNN as cuDNN Frontend Graph

    Caller->>fused_attn.cpp: Q,K,V + cu_seqlens + cu_seqlens_padded + num_tokens_q/kv
    fused_attn.cpp->>BackendSelector: qkv_format, mask_type, ...
    BackendSelector-->>fused_attn.cpp: "NVTE_FP8 (if THD + cuDNN>=9.23 + sm100+)"
    fused_attn.cpp->>FP8FWD: batch, t_q, t_kv, cu_seqlens_q/kv_padded, ...
    FP8FWD->>FP8FWD: get_max_batch_size / get_max_tokens
    FP8FWD->>FP8Impl: max_b, max_t_q, max_t_kv, devPtrSeqOffsetsQ/KV, ...
    FP8Impl->>FP8Impl: detect is_ragged_q/kv, use_ragged_stats
    FP8Impl->>FP8Impl: "b=max_b, s_q=max_t_q, s_kv=max_t_kv (sm!=120)"
    FP8Impl->>cuDNN: build graph with ragged offset tensors (offset_q/k/v/o/stats)
    FP8Impl->>FP8Impl: launch cu_seqlens_padded_to_offsets kernel
    FP8Impl->>cuDNN: execute plan with variant_pack
    cuDNN-->>Caller: O (FP8 output) + Stats
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +2515 to +2516
print(f"qkv_format: {qkv_format}")
print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Debug prints committed to test code

Two print statements were left in _run_dpa_fp8_vs_f16, which will produce noise on every test run and in CI logs. These look like leftover debug instrumentation from development.

Suggested change
print(f"qkv_format: {qkv_format}")
print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}")
with autocast(enabled=fp8_dpa, recipe=fp8_recipe):

Comment on lines +291 to +292
(cudnn_runtime_version >= 92300 && sm_arch_ >= 100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) &&
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backend selection for THD permits non-padding mask types that the implementation rejects

nvte_get_fused_attn_backend can return NVTE_FP8 for qkv_format == NVTE_THD with attn_mask_type == NVTE_NO_MASK or NVTE_CAUSAL_MASK, because the only mask-type filter is the outer condition shared with the BSHD/SBHD/BHSD paths. However, fused_attn_fp8_fwd_impl (and bwd_impl) enforce NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"), which means those combinations would pass backend selection only to fail at kernel dispatch with a confusing internal error. The THD sub-condition should additionally require attn_mask_type == NVTE_PADDING_MASK || attn_mask_type == NVTE_PADDING_CAUSAL_MASK.

Suggested change
(cudnn_runtime_version >= 92300 && sm_arch_ >= 100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) &&
(cudnn_runtime_version >= 92300 && sm_arch_ >= 100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant