From 1b9b13e0cc2df9aab5840b51e1c4b59832ee3938 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 9 Dec 2025 15:37:17 -0800 Subject: [PATCH 1/5] update_filter_fp8_thd_attention Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention.py | 28 ++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4aedcff1b8..dfd797cd66 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1757,23 +1757,23 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - "fp8_9": ModelConfig(2, 2048, 16, 128), - "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), - "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + # "fp8_9": ModelConfig(2, 2048, 16, 128), + # "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] -qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"] -qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] +qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd", "thd_thd_thd"] +qkv_format_fp8_vs_f16 = ["bshd", "sbhd", "thd"] @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @@ -2300,6 +2300,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") + print(f"qkv_format: {qkv_format}") + print(f"inp shape: {inp[0].shape}, {inp[1].shape}, {inp[2].shape}") with autocast(enabled=fp8_dpa, recipe=fp8_recipe): out = dpa( inp[0], From dc4a172fb41a2b61993d185359ae150fdbdfa8db Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 2 Jan 2026 15:22:51 -0800 Subject: [PATCH 2/5] add for fp8+thd debug Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention.py | 1 + .../common/fused_attn/fused_attn.cpp | 3 +-- .../common/fused_attn/fused_attn_fp8.cu | 27 +++++-------------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 553d7b30db..539eaa9ea4 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2316,6 +2316,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, fp8_output=fp8_dpa, + fast_zero_fill=False, ) if is_training: out.backward(out_grad) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index fde0d38921..1a9ccc26ee 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -252,7 +252,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -539,7 +538,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, h = input_QKV->data.shape[ndim - 3]; } else { NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); - } + } size_t d = input_QKV->data.shape[ndim - 1]; size_t t = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3630041ccf..9cdbb49492 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2474,7 +2474,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) + || (qkv_format == NVTE_QKV_Format::NVTE_THD)) { fused_attn::fused_attn_fp8_fwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, @@ -2482,16 +2483,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + NVTE_ERROR("FP8 fused attention only supports qkv_format=bshd/sbhd/thd. \n"); } if (workspace_size > 0) { @@ -2569,7 +2562,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { + if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) + || (qkv_format == NVTE_QKV_Format::NVTE_THD)) { fused_attn::fused_attn_fp8_bwd_impl_v1( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, @@ -2580,17 +2574,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + NVTE_ERROR("FP8 fused attention only supports qkv_format=bshd/sbhd/thd. \n"); } if (workspace_size > 0) { From 10aa1940a53bb3b16b0ffa043903596f9b1e3eb7 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 2 Jan 2026 15:49:16 -0800 Subject: [PATCH 3/5] uncomment the configs Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 539eaa9ea4..9b940caae5 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1757,18 +1757,18 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) - # "fp8_9": ModelConfig(2, 2048, 16, 128), - # "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), - # "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), - # "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), - # "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), - # "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - # "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), - # "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), - # "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), - # "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_9": ModelConfig(2, 2048, 16, 128), + "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), - # "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] From ad50a2e7b8daf3ff7268ea8b62bd4b827360450d Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 12 May 2026 16:48:34 -0700 Subject: [PATCH 4/5] Re-plumb FP8+THD ragged-offset support on top of merged main Mirrors the F16 arbitrary_seqlen ragged-offset pattern in the FP8 path: - Backend selector: enable FP8+THD for cuDNN >= 9.23 on sm >= 100 - fwd/bwd _impl: ragged detection, batch/seqlen bucketing, set_ragged_offset() on Q/K/V/O/dO/dQ/dK/dV/Stats, workspace allocation for ragged offsets, cu_seqlens_padded_to_offsets kernel - fwd/bwd dispatchers: accept num_tokens_q/kv, cu_seqlens_padded, compute max_batch/max_tokens, THD Stats shape Signed-off-by: Sudhakar Singh --- .../common/fused_attn/fused_attn.cpp | 34 +- .../common/fused_attn/fused_attn_fp8.cu | 429 ++++++++++++++++-- .../common/fused_attn/fused_attn_fp8.h | 40 +- 3 files changed, 422 insertions(+), 81 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d2eb1a831c..df70bcce66 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -281,12 +281,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && // pre-9.21: {bshd, sbhd}, {vanilla} // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} + // 9.23+ sm100+: {thd} (ragged/packed variable-length) ((cudnn_runtime_version < 92100 && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || (cudnn_runtime_version >= 92100 && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && + qkv_format == NVTE_QKV_Format::NVTE_BHSD)) || + (cudnn_runtime_version >= 92300 && sm_arch_ >= 100 && + qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) && !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { @@ -623,12 +626,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, - attn_scale, dropout, qkv_layout, o_format, qkv_scale_inv_format, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, - bottom_right_diagonal, input_Q, input_K, input_V, input_SoftmaxOffset, - input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, + is_training, attn_scale, dropout, qkv_layout, o_format, + qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, input_Q, + input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } @@ -725,13 +730,14 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, - qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, - do_scale_inv_format, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S, - input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV, - output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, + attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, + qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, + input_M, input_S, input_SoftmaxOffset, input_output_dP, output_dQ, + output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index eab1ae02e6..978a6c0035 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -6,6 +6,7 @@ #include "../common.h" #include "../cudnn_utils.h" +#include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_fp8.h" #include "utils.h" @@ -18,6 +19,7 @@ using namespace transformer_engine; // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, @@ -25,7 +27,8 @@ void fused_attn_fp8_fwd_impl( void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, - void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + void* devPtrcuSeqlensKV, void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -60,6 +63,27 @@ void fused_attn_fp8_fwd_impl( NVTE_CHECK(!is_mxfp8 || cudnn_runtime_version >= 92100, "MXFP8 fused attention requires cuDNN 9.21.0 or later!"); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; + + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + + int64_t actual_b = b; + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + if (sm_arch_ != 120) { + b = max_b; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; + } + } + try { FADescriptor_v1 descriptor{b, h, @@ -121,6 +145,11 @@ void fused_attn_fp8_fwd_impl( std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_o + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -146,6 +175,8 @@ void fused_attn_fp8_fwd_impl( std::shared_ptr descale_q, descale_k, descale_v; std::shared_ptr descale_s, scale_s, scale_o; std::shared_ptr bias, softmax_offset, seq_q, seq_kv; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, attn_scale @@ -157,6 +188,14 @@ void fused_attn_fp8_fwd_impl( .set_dim({b, h, s_q, d_qk}) .set_stride(q_strides) .set_data_type(qkv_tensor_type)); + if (is_ragged_q) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Q->set_ragged_offset(offset_q); + } K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d_qk}) @@ -167,6 +206,20 @@ void fused_attn_fp8_fwd_impl( .set_dim({b, hg, s_kv, d_v}) .set_stride(v_strides) .set_data_type(qkv_tensor_type)); + if (is_ragged_kv) { + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + K->set_ragged_offset(offset_k); + V->set_ragged_offset(offset_v); + } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -323,15 +376,33 @@ void fused_attn_fp8_fwd_impl( .set_dim({b, h, s_q, d_v}) .set_stride(o_strides) .set_data_type(o_tensor_type); + if (is_ragged_q) { + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + O->set_ragged_offset(offset_o); + } amax_o->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + if (use_ragged_stats) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + } + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (use_ragged_stats) { + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } std::tuple, // Q std::shared_ptr, // K @@ -357,6 +428,12 @@ void fused_attn_fp8_fwd_impl( is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto offset_qo_tuple = + is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = + is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); + auto offset_s_tuple = + use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -367,22 +444,37 @@ void fused_attn_fp8_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, - softmax_offset_tuple, padding_tuple, dropout_tuple); + softmax_offset_tuple, padding_tuple, offset_qo_tuple, offset_kv_tuple, + offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, - attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, dropout_seed, - dropout_offset] = get_graph(sdpa_fp8_fprop_cache, descriptor); - - auto plan_workspace_size = mha_graph->get_workspace_size(); + attn_scale, O, amax_s, amax_o, Stats, bias, softmax_offset, seq_q, seq_kv, offset_q, + offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = + get_graph(sdpa_fp8_fprop_cache, descriptor); + + auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); + const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; + const size_t num_bytes_per_ragged_offset = + alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8); + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged_q || is_ragged_kv) { + size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); + if (use_ragged_stats) { + seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; + } + } // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); if (workspace == nullptr) { - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + *workspace_size = + plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; return; } @@ -420,9 +512,9 @@ void fused_attn_fp8_fwd_impl( constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + actual_b, b, static_cast(devPtrcuSeqlensQ), static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -430,6 +522,49 @@ void fused_attn_fp8_fwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } + if (is_ragged_q || is_ragged_kv) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block) / nthreads_per_block; + void* devOffsets = + static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; + void* devOffsetsQ = nullptr; + void* devOffsetsO = nullptr; + if (is_ragged_q) { + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + } + void* devOffsetsK = nullptr; + void* devOffsetsV = nullptr; + if (is_ragged_kv) { + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + } + void* devOffsetsS = nullptr; + if (use_ragged_stats) { + devOffsetsS = static_cast(devOffsets) + + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; + } + cu_seqlens_padded_to_offsets<<>>( + layout_group, actual_b, b, h, hg, d_qk, d_v, + static_cast(devPtrSeqOffsetsQ), + static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, + devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (use_ragged_stats) { + variant_pack[offset_stats] = devOffsetsS; + } + } + if (is_dropout) { variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; @@ -448,6 +583,7 @@ void fused_attn_fp8_fwd_impl( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, @@ -460,7 +596,8 @@ void fused_attn_fp8_bwd_impl( void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, - void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + void* devPtrcuSeqlensKV, void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsKV, + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, @@ -497,6 +634,27 @@ void fused_attn_fp8_bwd_impl( NVTE_CHECK(!is_mxfp8 || cudnn_runtime_version >= 92100, "MXFP8 fused attention requires cuDNN 9.21.0 or later!"); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + bool use_ragged_stats = is_ragged_q && cudnn_runtime_version >= 90600 && sm_arch_ != 120; + + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + + int64_t actual_b = b; + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { + NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); + if (sm_arch_ != 120) { + b = max_b; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; + } + } + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); @@ -582,6 +740,11 @@ void fused_attn_fp8_bwd_impl( std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv + std::shared_ptr, // offset_q + std::shared_ptr, // offset_o + std::shared_ptr, // offset_k + std::shared_ptr, // offset_v + std::shared_ptr, // offset_stats std::shared_ptr, // dropout_seed std::shared_ptr>; // dropout_offset @@ -614,6 +777,8 @@ void fused_attn_fp8_bwd_impl( std::shared_ptr scale_dQ, scale_dK, scale_dV; std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset; std::shared_ptr seq_q, seq_kv; + std::shared_ptr offset_q, offset_k, offset_v, offset_o, + offset_stats; std::shared_ptr dropout_seed, dropout_offset; // Q, K, V, O, dO, stats, attn_scale @@ -627,6 +792,19 @@ void fused_attn_fp8_bwd_impl( .set_dim({b, h, s_q, d_qk}) .set_stride(q_strides) .set_data_type(qkv_tensor_type)); + if (is_ragged_q) { + offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Q->set_ragged_offset(offset_q); + } K = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("K") .set_dim({b, hg, s_kv, d_qk}) @@ -637,21 +815,53 @@ void fused_attn_fp8_bwd_impl( .set_dim({b, hg, s_kv, d_v}) .set_stride(v_strides) .set_data_type(qkv_tensor_type)); + if (is_ragged_kv) { + offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_v") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + K->set_ragged_offset(offset_k); + V->set_ragged_offset(offset_v); + } O = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") .set_dim({b, h, s_q, d_v}) .set_stride(o_strides) .set_data_type(o_tensor_type)); + if (is_ragged_q) { + O->set_ragged_offset(offset_o); + } dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d_v}) .set_stride(dO_strides) .set_data_type(do_tensor_type)); + if (is_ragged_q) { + dO->set_ragged_offset(offset_o); + } + if (use_ragged_stats) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + } Stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Stats") .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); + if (use_ragged_stats) { + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("attn_scale") .set_dim({1, 1, 1, 1}) @@ -901,6 +1111,9 @@ void fused_attn_fp8_bwd_impl( .set_dim({b, h, s_q, d_qk}) .set_stride(dq_strides) .set_data_type(dqkv_tensor_type); + if (is_ragged_q) { + dQ->set_ragged_offset(offset_q); + } dK->set_output(true) .set_dim({b, hg, s_kv, d_qk}) .set_stride(dk_strides) @@ -909,6 +1122,10 @@ void fused_attn_fp8_bwd_impl( .set_dim({b, hg, s_kv, d_v}) .set_stride(dv_strides) .set_data_type(dqkv_tensor_type); + if (is_ragged_kv) { + dK->set_ragged_offset(offset_k); + dV->set_ragged_offset(offset_v); + } amax_dQ->set_output(!is_mxfp8) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -967,6 +1184,12 @@ void fused_attn_fp8_bwd_impl( : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); + auto offset_qo_tuple = + is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = + is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); + auto offset_s_tuple = + use_ragged_stats ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) : std::make_tuple(nullptr, nullptr); @@ -978,7 +1201,8 @@ void fused_attn_fp8_bwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, mxfp8_tensors_tuple, - bias_tuple, softmax_offset_tuple, padding_tuple, dropout_tuple); + bias_tuple, softmax_offset_tuple, padding_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; @@ -987,14 +1211,27 @@ void fused_attn_fp8_bwd_impl( descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP, Q_t, K_t, dO_f16, dO_t, descale_q_t, descale_k_t, descale_dO_t, bias, dBias, softmax_offset, d_softmax_offset, seq_q, seq_kv, - dropout_seed, dropout_offset] = get_graph(sdpa_fp8_bprop_cache, descriptor); - - auto plan_workspace_size = mha_graph->get_workspace_size(); + offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = + get_graph(sdpa_fp8_bprop_cache, descriptor); + + auto plan_workspace_size = alignTo<16>(mha_graph->get_workspace_size()); + const size_t num_bytes_per_seqlen = alignTo<16>(b * sizeof(int32_t)); + const size_t actual_seqlen_workspace_size = is_padding ? 2 * num_bytes_per_seqlen : 0; + const size_t num_bytes_per_ragged_offset = + alignTo<16>(((b + 1) * typeToNumBits(ragged_offset_type)) / 8); + size_t seqlen_offsets_workspace_size = 0; + if (is_ragged_q || is_ragged_kv) { + size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); + if (use_ragged_stats) { + seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; + } else { + seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; + } + } - // Exit to request upper level API to allocate memory if needed - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); if (workspace == nullptr) { - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; + *workspace_size = + plan_workspace_size + actual_seqlen_workspace_size + seqlen_offsets_workspace_size; return; } @@ -1060,9 +1297,9 @@ void fused_attn_fp8_bwd_impl( constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + num_bytes_per_seqlen; cu_seqlens_to_actual_seqlens<<>>( - b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + actual_b, b, static_cast(devPtrcuSeqlensQ), static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1070,6 +1307,49 @@ void fused_attn_fp8_bwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } + if (is_ragged_q || is_ragged_kv) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block) / nthreads_per_block; + void* devOffsets = + static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; + void* devOffsetsQ = nullptr; + void* devOffsetsO = nullptr; + if (is_ragged_q) { + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + } + void* devOffsetsK = nullptr; + void* devOffsetsV = nullptr; + if (is_ragged_kv) { + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + } + void* devOffsetsS = nullptr; + if (use_ragged_stats) { + devOffsetsS = static_cast(devOffsets) + + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; + } + cu_seqlens_padded_to_offsets<<>>( + layout_group, actual_b, b, h, hg, d_qk, d_v, + static_cast(devPtrSeqOffsetsQ), + static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, + devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (use_ragged_stats) { + variant_pack[offset_stats] = devOffsetsS; + } + } + if (is_dropout) { variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; @@ -1091,14 +1371,16 @@ void fused_attn_fp8_bwd_impl( // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, - bool bottom_right_diagonal, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, + const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_SoftmaxOffset, Tensor* input_output_S, Tensor* output_O, NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { + const Tensor* cu_seqlens_q_padded, const Tensor* cu_seqlens_kv_padded, const Tensor* rng_state, + Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void *devPtrQ = nullptr, *devPtrK = nullptr, *devPtrV = nullptr; void *devPtrDescaleQ = nullptr, *devPtrDescaleK = nullptr, *devPtrDescaleV = nullptr; @@ -1125,12 +1407,36 @@ void fused_attn_fp8_fwd( if (softmax_type != NVTE_VANILLA_SOFTMAX) { devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + const auto cudnn_runtime_version = cudnnGetVersion(); + + void* devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; + void* devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = fused_attn::get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { + max_tokens_q = fused_attn::get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + max_tokens_kv = fused_attn::get_max_tokens(num_tokens_kv); + } + void* devPtrM = nullptr; if (Aux_CTX_Tensors->size == 0) { int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_M->data.shape = {num_tokens_q, num_attn_heads, 1}; + } else { + output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } output_M->data.dtype = DType::kFloat32; Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; @@ -1172,18 +1478,20 @@ void fused_attn_fp8_fwd( NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || - (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { + (qkv_format == NVTE_QKV_Format::NVTE_BHSD) || (qkv_format == NVTE_QKV_Format::NVTE_THD)) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + max_batch_size, max_tokens_q, max_tokens_kv, is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_format=BSHD, SBHD, or BHSD.\n"); + NVTE_ERROR("FP8 fused attention only supports qkv_format=BSHD, SBHD, BHSD, or THD.\n"); } if (workspace_size > 0) { @@ -1201,18 +1509,20 @@ void fused_attn_fp8_fwd( // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, - NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, - bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, - const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_S, - const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, const Tensor* output_dQ, - const Tensor* output_dK, const Tensor* output_dV, Tensor* output_dSoftmaxOffset, - const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, + const Tensor* input_O, const Tensor* input_dO, const Tensor* input_dO_f16, + const Tensor* input_M, const Tensor* input_S, const Tensor* input_SoftmaxOffset, + Tensor* input_output_dP, const Tensor* output_dQ, const Tensor* output_dK, + const Tensor* output_dV, Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, + const Tensor* cu_seqlens_kv, const Tensor* cu_seqlens_q_padded, + const Tensor* cu_seqlens_kv_padded, const Tensor* rng_state, Tensor* workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -1277,6 +1587,25 @@ void fused_attn_fp8_bwd( devPtrScaledV = output_dV->scale.dptr; } + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + + void* devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; + void* devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + + size_t max_batch_size = 0; + size_t max_tokens_q = 0; + size_t max_tokens_kv = 0; + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { + max_batch_size = fused_attn::get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { + max_tokens_q = fused_attn::get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + max_tokens_kv = fused_attn::get_max_tokens(num_tokens_kv); + } + void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); void* devPtrcuSeqlensKV = @@ -1293,9 +1622,10 @@ void fused_attn_fp8_bwd( NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || - (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { + (dqkv_format == NVTE_QKV_Format::NVTE_BHSD) || (dqkv_format == NVTE_QKV_Format::NVTE_THD)) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, + max_batch_size, max_tokens_q, max_tokens_kv, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, @@ -1304,12 +1634,13 @@ void fused_attn_fp8_bwd( devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, qkv_scale_inv_format, do_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports dqkv_format=BSHD, SBHD, or BHSD.\n"); + NVTE_ERROR("FP8 fused attention only supports dqkv_format=BSHD, SBHD, BHSD, or THD.\n"); } if (workspace_size > 0) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index b9660128ca..5b3b2fff14 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -15,28 +15,32 @@ namespace transformer_engine { // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, - bool bottom_right_diagonal, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_SoftmaxOffset, Tensor *input_output_S, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); // fused attention BWD FP8 with separate Q, K, V void fused_attn_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, - NVTE_QKV_Layout dqkv_layout, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, - bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_S, - const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, const Tensor *output_dQ, - const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, + size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, + bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_dO_f16, + const Tensor *input_M, const Tensor *input_S, const Tensor *input_SoftmaxOffset, + Tensor *input_output_dP, const Tensor *output_dQ, const Tensor *output_dK, + const Tensor *output_dV, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); } // namespace transformer_engine From bfbafe9dbcb6ace0e1a378dafe93b6ccfdd1a7a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 19:10:21 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 33 ++++--- .../common/fused_attn/fused_attn_fp8.cu | 86 +++++++++---------- 2 files changed, 56 insertions(+), 63 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index df70bcce66..4742d1bdd7 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -626,14 +626,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { - fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, - is_training, attn_scale, dropout, qkv_layout, o_format, - qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, input_Q, - input_K, input_V, input_SoftmaxOffset, input_output_S, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, + attn_scale, dropout, qkv_layout, o_format, qkv_scale_inv_format, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, input_Q, input_K, input_V, input_SoftmaxOffset, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_rng_state, wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } @@ -730,15 +729,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, - attn_scale, dropout, qkv_layout, o_format, do_format, dqkv_layout, - qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, - input_M, input_S, input_SoftmaxOffset, input_output_dP, output_dQ, - output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_rng_state, wkspace, stream, handle); + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, + dropout, qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, + do_scale_inv_format, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, deterministic, + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S, + input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 978a6c0035..f29f0d92be 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -19,16 +19,15 @@ using namespace transformer_engine; // fused attention FWD FP8 with FE 1.0+ void fused_attn_fp8_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, - bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrO, void* devPtrDescaleQ, - void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, - void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, - void* devPtrcuSeqlensKV, void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, bool is_training, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrO, + void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, + void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, + void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrSeqOffsetsQ, + void* devPtrSeqOffsetsKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { @@ -547,8 +546,7 @@ void fused_attn_fp8_fwd_impl( num_bytes_per_ragged_offset; } cu_seqlens_padded_to_offsets<<>>( - layout_group, actual_b, b, h, hg, d_qk, d_v, - static_cast(devPtrSeqOffsetsQ), + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -583,13 +581,13 @@ void fused_attn_fp8_fwd_impl( // fused attention BWD FP8 with FE 1.0+ void fused_attn_fp8_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, - NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrO, - void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, void* devPtrdK, void* devPtrdV, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, float scaling_factor, + float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, void* devPtrQ, + void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrO, void* devPtrdO, + void* devPtrSoftmaxOffset, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, @@ -597,12 +595,11 @@ void fused_attn_fp8_bwd_impl( void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrSeqOffsetsQ, void* devPtrSeqOffsetsKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, - cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, - NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, - NVTE_QKV_Format do_scale_inv_format, void* workspace, size_t* workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, + cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, void* workspace, + size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1332,8 +1329,7 @@ void fused_attn_fp8_bwd_impl( num_bytes_per_ragged_offset; } cu_seqlens_padded_to_offsets<<>>( - layout_group, actual_b, b, h, hg, d_qk, d_v, - static_cast(devPtrSeqOffsetsQ), + layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -1481,15 +1477,14 @@ void fused_attn_fp8_fwd( (qkv_format == NVTE_QKV_Format::NVTE_BHSD) || (qkv_format == NVTE_QKV_Format::NVTE_THD)) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, - is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrO, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, - devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, - qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, is_training, attn_scale, p_dropout, qkv_layout, + o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, devPtrQ, devPtrK, devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrO, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, + devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, qkv_scale_inv_format, + workspace->data.dptr, &workspace_size, stream, handle); } else { NVTE_ERROR("FP8 fused attention only supports qkv_format=BSHD, SBHD, BHSD, or THD.\n"); } @@ -1625,16 +1620,15 @@ void fused_attn_fp8_bwd( (dqkv_format == NVTE_QKV_Format::NVTE_BHSD) || (dqkv_format == NVTE_QKV_Format::NVTE_THD)) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, - attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, - devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, - devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, - devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, - devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, - devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + max_batch_size, max_tokens_q, max_tokens_kv, attn_scale, p_dropout, qkv_layout, o_format, + do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrM, + devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, + devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, + devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, + devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, + devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, + devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, qkv_scale_inv_format, do_scale_inv_format, workspace->data.dptr,