-
Notifications
You must be signed in to change notification settings - Fork 724
Plumb FP8+THD #2994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Plumb FP8+THD #2994
Changes from all commits
1b9b13e
977bfb6
dc4a172
10aa194
148fe40
ad50a2e
bfbafe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -281,12 +281,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( | |||||||||||||
| attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && | ||||||||||||||
| // pre-9.21: {bshd, sbhd}, {vanilla} | ||||||||||||||
| // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} | ||||||||||||||
| // 9.23+ sm100+: {thd} (ragged/packed variable-length) | ||||||||||||||
| ((cudnn_runtime_version < 92100 && | ||||||||||||||
| (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && | ||||||||||||||
| softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || | ||||||||||||||
| (cudnn_runtime_version >= 92100 && | ||||||||||||||
| (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || | ||||||||||||||
| qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && | ||||||||||||||
| qkv_format == NVTE_QKV_Format::NVTE_BHSD)) || | ||||||||||||||
| (cudnn_runtime_version >= 92300 && sm_arch_ >= 100 && | ||||||||||||||
| qkv_format == NVTE_QKV_Format::NVTE_THD && supported_ragged_offset_size)) && | ||||||||||||||
|
Comment on lines
+291
to
+292
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| !requires_64bit_ragged_offset && | ||||||||||||||
| // 9.10.0: known bugs with SDPA FP8 | ||||||||||||||
| (cudnn_runtime_version != 91000) && !return_max_logit) { | ||||||||||||||
|
|
@@ -623,12 +626,13 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |||||||||||||
| input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, | ||||||||||||||
| input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); | ||||||||||||||
| } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { | ||||||||||||||
| fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, is_training, | ||||||||||||||
| fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, | ||||||||||||||
| attn_scale, dropout, qkv_layout, o_format, qkv_scale_inv_format, bias_type, | ||||||||||||||
| attn_mask_type, softmax_type, window_size_left, window_size_right, | ||||||||||||||
| bottom_right_diagonal, input_Q, input_K, input_V, input_SoftmaxOffset, | ||||||||||||||
| input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, | ||||||||||||||
| input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); | ||||||||||||||
| input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, | ||||||||||||||
| input_rng_state, wkspace, stream, handle); | ||||||||||||||
| } else { | ||||||||||||||
| NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); | ||||||||||||||
| } | ||||||||||||||
|
|
@@ -725,14 +729,15 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso | |||||||||||||
| if (input_dO->scaling_mode == NVTE_MXFP8_1D_SCALING) { | ||||||||||||||
| input_dO_f16 = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||||||||||||||
| } | ||||||||||||||
| fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, attn_scale, dropout, | ||||||||||||||
| qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, | ||||||||||||||
| fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, | ||||||||||||||
| dropout, qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, | ||||||||||||||
| do_scale_inv_format, bias_type, attn_mask_type, softmax_type, | ||||||||||||||
| window_size_left, window_size_right, bottom_right_diagonal, deterministic, | ||||||||||||||
| input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S, | ||||||||||||||
| input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV, | ||||||||||||||
| output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, | ||||||||||||||
| input_rng_state, wkspace, stream, handle); | ||||||||||||||
| input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, | ||||||||||||||
| wkspace, stream, handle); | ||||||||||||||
| } else { | ||||||||||||||
| NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); | ||||||||||||||
| } | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two
printstatements were left in_run_dpa_fp8_vs_f16, which will produce noise on every test run and in CI logs. These look like leftover debug instrumentation from development.