Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,8 +1949,8 @@ def get_model(dtype, config):
}

param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd", "thd_thd_thd"]
qkv_format_fp8_vs_f16 = ["bshd", "sbhd", "thd"]


@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
Expand Down Expand Up @@ -2512,6 +2512,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")

print(f"qkv_format: {qkv_format}")
print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}")
Comment on lines +2515 to +2516
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Debug prints committed to test code

Two print statements were left in _run_dpa_fp8_vs_f16, which will produce noise on every test run and in CI logs. These look like leftover debug instrumentation from development.

Suggested change
print(f"qkv_format: {qkv_format}")
print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}")
with autocast(enabled=fp8_dpa, recipe=fp8_recipe):

with autocast(enabled=fp8_dpa, recipe=fp8_recipe):
out = dpa(
inp[0],
Expand All @@ -2526,6 +2528,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
fp8_output=fp8_dpa,
fast_zero_fill=False,
)
if is_training:
out.backward(out_grad)
Expand Down
17 changes: 11 additions & 6 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) &&
// pre-9.21: {bshd, sbhd}, {vanilla}
// 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable}
// 9.23+ sm100+: {thd} (ragged/packed variable-length)
((cudnn_runtime_version < 92100 &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) ||
(cudnn_runtime_version >= 92100 &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD ||
qkv_format == NVTE_QKV_Format::NVTE_BHSD))) &&
qkv_format == NVTE_QKV_Format::NVTE_BHSD)) ||
(cudnn_runtime_version >= 92300 && sm_arch_ >= 100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) &&
Comment on lines +291 to +292
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backend selection for THD permits non-padding mask types that the implementation rejects

nvte_get_fused_attn_backend can return NVTE_FP8 for qkv_format == NVTE_THD with attn_mask_type == NVTE_NO_MASK or NVTE_CAUSAL_MASK, because the only mask-type filter is the outer condition shared with the BSHD/SBHD/BHSD paths. However, fused_attn_fp8_fwd_impl (and bwd_impl) enforce NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"), which means those combinations would pass backend selection only to fail at kernel dispatch with a confusing internal error. The THD sub-condition should additionally require attn_mask_type == NVTE_PADDING_MASK || attn_mask_type == NVTE_PADDING_CAUSAL_MASK.

Suggested change
(cudnn_runtime_version >= 92300 && sm_arch_ >= 100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) &&
(cudnn_runtime_version >= 92300 && sm_arch_ >= 100 &&
qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size &&
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&

!requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000) && !return_max_logit) {
Expand Down Expand Up @@ -623,12 +626,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training,
fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training,
attn_scale, dropout, qkv_layout, o_format, qkv_scale_inv_format, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, input_Q, input_K, input_V, input_SoftmaxOffset,
input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_rng_state, wkspace, stream, handle);
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
Expand Down Expand Up @@ -725,14 +729,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) {
input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout,
qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format,
fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale,
dropout, qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format,
do_scale_inv_format, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, bottom_right_diagonal, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S,
input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
Expand Down
Loading
Loading