Skip to content

[PyTorch] Support for cuDNN-backed flex attention#2984

Open
vcherepanov-nv wants to merge 12 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3
Open

[PyTorch] Support for cuDNN-backed flex attention#2984
vcherepanov-nv wants to merge 12 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-3

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

This PR introduces an alternative, Python-only code path for the FusedAttention backend for PyTorch.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • A new code path for FusedAttention backend, when score_mod (and the related parameters) is specified
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds a Python-only cuDNN flex-attention code path to the FusedAttention backend, routing user-supplied score_mod and score_mod_bprop callbacks to cuDNN Frontend's sdpa / sdpa_backward calls. The implementation includes a graph cache keyed on callback identity, tensor metadata, and user-defined structural keys, plus a dedicated FusedAttentionWithScoreModFunc autograd function.

  • New FusedAttentionWithScoreModFunc (backends.py): cuDNN Frontend Python graph build/cache/execute pipeline for forward and backward passes, with save_for_backward version-counter checks that catch in-place mutations of score_mod_tensors before backward.
  • Callback caching (backends.py): distinguishes module-level named functions (stable by qualname), lambdas (keyed by code object), and stateful bound methods (requires explicit score_mod_graph_cache_key() or bypasses cache) to avoid stale graph reuse.
  • Integration into DotProductAttention (dot_product_attention.py, utils.py): comprehensive precondition assertions guard against incompatible feature combinations (FP8, THD format, context parallelism, KV caching, dropout, etc.) and the get_attention_backend filter correctly restricts the path to the F16/BF16 arbitrary-seqlen cuDNN sub-backend.

Confidence Score: 4/5

Safe to merge once the test tensors are moved to CUDA; the production code path has no device-related issues.

The production-side implementation looks correct — device validation is applied before graph construction, and the autograd function properly guards the backward path. However, the score_mod_tensors tensors in the causal and softcap integration test cases are created on CPU while the cuDNN graph runs on CUDA. When cuDNN attempts to bind these CPU tensors during graph execution, it will raise a device-mismatch error, causing both affected parameterized test cases to fail unconditionally on any CUDA-capable machine.

tests/pytorch/attention/test_attention.py — the causal and softcap branches of test_dot_product_attention_score_mod create score_mod_tensors without device="cuda".

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Adds score_mod cache-key unit tests and an integration test for causal/softcap/post_scale_bias score mods; the causal and softcap test cases create score_mod_tensors on CPU but pass them to a CUDA cuDNN graph, which will fail at execution time.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds ~800 lines implementing the cuDNN-backed flex-attention code path: graph cache, FusedAttentionWithScoreModFunc autograd function, and score_mod callback key logic. Logic looks correct; no critical issues found in production path.
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Wires score_mod parameters through DotProductAttention.forward with extensive precondition assertions covering dtype, format, and feature exclusivity; change is straightforward and safe.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Extends AttentionParams dataclass with score_mod/score_mod_bprop flags and adds backend-selection filtering logic to disable flash/unfused backends when score_mod is active.
tests/pytorch/utils.py Trivial: adds score_mod/score_mod_bprop bool parameters to get_available_attention_backends helper.

Sequence Diagram

sequenceDiagram
    participant User
    participant DPA as DotProductAttention
    participant FA as FusedAttention
    participant Func as FusedAttentionWithScoreModFunc
    participant Cache as _cudnn_score_mod_graph_cache
    participant cuDNN as cuDNN Frontend

    User->>DPA: "forward(q,k,v, score_mod=..., score_mod_tensors=...)"
    DPA->>DPA: validate preconditions
    DPA->>FA: forward(..., score_mod, score_mod_bprop, ...)
    FA->>Func: apply(is_training, q, k, v, ..., score_mod, ...)
    Func->>Cache: _get_cudnn_score_mod_fwd_graph(key)
    alt cache miss
        Cache->>cuDNN: _build_cudnn_score_mod_fwd_graph()
        cuDNN-->>Cache: _CudnnScoreModFwdGraphEntry
        Cache->>Cache: store entry
    end
    Cache-->>Func: graph entry
    Func->>cuDNN: _execute_cudnn_graph(variant_pack)
    cuDNN-->>Func: output tensor
    Func->>Func: ctx.save_for_backward(q,k,v,out,stats,tensors...)
    Func-->>User: output

    User->>Func: backward(d_out)
    Func->>Cache: _get_cudnn_score_mod_bwd_graph(key)
    alt cache miss
        Cache->>cuDNN: _build_cudnn_score_mod_bwd_graph()
        cuDNN-->>Cache: _CudnnScoreModBwdGraphEntry
    end
    Cache-->>Func: bwd graph entry
    Func->>cuDNN: _execute_cudnn_graph(variant_pack)
    cuDNN-->>Func: dq, dk, dv
    Func-->>User: (None, dq, dk, dv, ...)
Loading

Reviews (3): Last reviewed commit: "Fix score_mod lambda cache keys" | Re-trigger Greptile

Comment on lines +1273 to +1281
def _score_mod_callback_cache_key(callback: Optional[Callable]) -> Optional[Tuple[Any, ...]]:
"""Create a stable cache key for a score_mod callable."""
if callback is None:
return None
self_obj = getattr(callback, "__self__", None)
func_obj = getattr(callback, "__func__", None)
if self_obj is not None and func_obj is not None:
return ("bound_method", id(self_obj), id(func_obj))
return ("callable", id(callback))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 id()-based cache key is unsafe for parameterized bound-method score_mods

id(self_obj) identifies a Python object by its memory address. When a bound-method instance is garbage-collected, Python may immediately reuse that memory for a new instance. If the new instance belongs to the same class (same id(func_obj)), the cache key is identical, so _get_cudnn_score_mod_fwd_graph returns the old compiled graph even though the new instance might construct a structurally different computation — e.g., a score_mod class whose forward loops self.n_layers times. The wrong graph is executed without any error, silently producing incorrect attention outputs.

For stateless module-level functions this is fine (they're never GC'd), but any stateful class-based score_mod where different instances produce different graph topologies can hit this bug in long-running programs. Consider using type(self_obj) and a per-class sequence counter, or requiring callers to provide an explicit cache key.

Comment on lines 91 to 92
_flash_attn_varlen_fwd = None
_flash_attn_varlen_bwd = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unbounded module-level graph cache will grow indefinitely

_cudnn_score_mod_graph_cache is a plain dict with no eviction policy. Cache keys encode tensor shapes, strides, dtype, and device, so every new (batch, seq, heads, dim) combination — extremely common in training with variable-length sequences or multi-task workloads — inserts a permanent entry. Each cached cuDNN graph holds compiled CUDA kernels and associated state, which can be several tens of MB. Over a long training run this will silently consume increasing GPU/CPU memory. Consider a bounded LRU cache (e.g., functools.lru_cache or a collections.OrderedDict with a size cap).

Comment on lines +1556 to +1563
fused_attention_backend = tex.get_fused_attn_backend(
self.training,
q_type,
q_type,
dpa_utils.QKVLayout["bshd_bshd_bshd"],
dpa_utils.AttnBiasType["no_bias"],
dpa_utils.AttnMaskType["no_mask"],
dpa_utils.SoftmaxType["vanilla"],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 get_fused_attn_backend availability check always uses bshd_bshd_bshd regardless of actual format

The score_mod path hard-codes dpa_utils.QKVLayout["bshd_bshd_bshd"] for the backend probe, even when the user passes qkv_format="sbhd". The result is only used to gate on NVTE_No_Backend, so in practice it likely works today because backend availability for a given dtype is layout-independent. However, if a future cuDNN version makes SBHD/BSHD support diverge, this probe would give a false-positive (accepts sbhd even though no backend supports it) or false-negative (rejects sbhd when it is actually supported). Using the real layout for the probe would make the check self-documenting and future-proof.

)

if context_parallel:
if score_mod is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be in the else branch, because it doesn't support context parallelism. Something like this:
if context_parallel: elif score_mod is not None: else:

score_mod: Optional[Callable] = None,
score_mod_bprop: Optional[Callable] = None,
score_mod_tensors: Optional[Dict[str, torch.Tensor]] = None,
score_mod_bprop_tensors: Optional[Dict[str, torch.Tensor]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it'd be clearer if we add "fprop" to the names? i.e. score_mod_fprop, score_mod_bprop, score_mod_fprop_tensors, score_mod_bprop_tensors?

isinstance(k, str) and isinstance(v, torch.Tensor)
for k, v in score_mod_bprop_tensors.items()
), "score_mod_bprop_tensors must map string names to torch.Tensor instances!"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all these checks can go into dpa_utils.get_attention_backend(), and with score_mod_xxx args passed in (to AttentionParams), that utility function can return use_fused_attention=False if one of the checks if violated. dpa_utils.get_attention_backend() is used in the tests as well (by get_available_attention_backends()).

raise ValueError(
"score_mod requires a cuDNN FusedAttention backend, but no fused "
"attention backend supports the provided inputs."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the score_mod path, I don't think we need to call tex.get_fused_attn_backend() and check if it's supported or not. If anything, we should add graph.validate() -> .... graph.build_plans() to dpa_utils.get_attention_backend(attention_params), but if that's too heavy-handed, we can only do the checks you had above (the asserts). Once those checks were added to dpa_utils.get_attention_backend, whether FusedAttention backend is run or not will be controlled by the following logic (just like with non-score_mod cases):

(
                        use_flash_attention,
                        flash_attention_backend,
                        use_fused_attention,
                        fused_attention_backend,
                        use_unfused_attention,
                        _,
                    ) = dpa_utils.get_attention_backend(attention_params)

else:
pad_between_seqs = False

if score_mod is None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please label this "experimental".


def _build_cudnn_pygraph(dtype: torch.dtype, device: torch.device):
"""Create a cuDNN frontend Python graph for F16/BF16 SDPA."""
import cudnn # pylint: disable=import-outside-toplevel
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you import the cudnn from 3rdparty/cudnn-frontend, instead of from the environment/system-wide installation? We have control over the version in 3rdparty/cudnn-frontend, but not the system one.

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("scalar_loss", [False, True])
def test_dot_product_attention_score_mod(dtype, qkv_format, scalar_loss):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would @pytest.mark.parameterize("score_mod", ["causal", "softcap", "post_scale_bias"]) simplify the tests a bit, so that we don't have 3 separate tests with a lot of repeated code?

score_mod: Callable,
score_mod_tensors: Optional[Dict[str, torch.Tensor]],
output_layer: torch.Tensor,
stats_bhs1: Optional[torch.Tensor],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just call this stats, even though it might only support bhs1 shape right now. On the C++ side, cuDNN does support th1 (for THD format) as well. Could we leave the name generic for now in case we want to add more support to it in the future?

return output.contiguous()


def _bhsd_dim_stride(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a lot of small utility functions here - is there a way to pack them up a bit or group them in some way, so the code is easier to read? I know this is Python and we probably do need more than 2 functions (fwd+bwd) but could you please have a look into this? Thanks.

)


def _score_mod_relative_position(score_mod_graph, score_tensor, _tensors):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just call this "post_scale_bias" to be consistent with our nomenclature elsewhere.

vcherepanov-nv and others added 3 commits May 15, 2026 00:48
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +1368 to +1373
if (
inspect.isfunction(callback)
and callback.__closure__ is None
and "<locals>" not in callback.__qualname__
):
return ("function", callback.__module__, callback.__qualname__)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Module-level lambdas all share the same __qualname__ = "<lambda>", so two different lambdas defined at module scope in the same file (e.g., sm1 = lambda g, s, t: s and sm2 = lambda g, s, t: g.neg(input=s)) would produce the identical cache key ("function", module, "<lambda>"). The second lambda would silently reuse the compiled graph from the first, computing wrong attention scores with no error. Named module-level functions are safe because their qualnames are unique, but lambdas are not. Excluding <lambda> from the cacheable path makes them _SCORE_MOD_UNCACHEABLE, which builds a fresh graph every call — the same safe fallback already used for closures and nested functions.

Suggested change
if (
inspect.isfunction(callback)
and callback.__closure__ is None
and "<locals>" not in callback.__qualname__
):
return ("function", callback.__module__, callback.__qualname__)
if (
inspect.isfunction(callback)
and callback.__closure__ is None
and "<locals>" not in callback.__qualname__
and "<lambda>" not in callback.__qualname__
):
return ("function", callback.__module__, callback.__qualname__)

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +1968 to +1973
score_mod_kwargs = {
"score_mod": _score_mod_causal,
"score_mod_bprop": _score_mod_causal_bprop,
"score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)},
"score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)},
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The neg_inf and zero tensors are created on CPU (torch.full defaults to CPU), but the attention computation runs on CUDA. When cuDNN executes the graph it calls into CUDA kernels and expects all variant-pack tensors to reside on the compute device. Passing CPU tensors here will produce a device-mismatch error at graph execution time, causing both the "causal" test cases to fail.

Suggested change
score_mod_kwargs = {
"score_mod": _score_mod_causal,
"score_mod_bprop": _score_mod_causal_bprop,
"score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9)},
"score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0)},
}
score_mod_kwargs = {
"score_mod": _score_mod_causal,
"score_mod_bprop": _score_mod_causal_bprop,
"score_mod_tensors": {"neg_inf": torch.full((1, 1, 1, 1), -1e9, device="cuda")},
"score_mod_bprop_tensors": {"zero": torch.full((1, 1, 1, 1), 0.0, device="cuda")},
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants