Skip to content

[JAX] Support for cuDNN-backed flex attention#2985

Open
vcherepanov-nv wants to merge 6 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax
Open

[JAX] Support for cuDNN-backed flex attention#2985
vcherepanov-nv wants to merge 6 commits into
NVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-jax

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • A new code path for FusedAttention backend, when score_mod (and the related parameters) is specified
  • Tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR introduces a new code path for the JAX FusedAttention backend using the cuDNN frontend Python API, allowing score_mod and score_mod_bprop callbacks to be routed into cuDNN FE's sdpa and sdpa_backward graph-build calls. The change adds Python-level graph building, C++ FFI handlers, a JAX custom_vjp layer, and tests.

  • New fused_attn code path: When score_mod is provided, validation rejects incompatible features, builds a cuDNN Python graph at JAX trace-time, registers it in a C++ registry keyed by an integer ID, and dispatches via new te_fused_attn_score_mod_forward_ffi / te_fused_attn_score_mod_backward_ffi handlers.
  • Cache-key design: Bound methods without an explicit score_mod_graph_cache_key() are left uncacheable (each trace builds a new graph); module-level functions and objects with stable keys are cached by (direction, config, input-avals).
  • Tests: Cover validation, cache-key stability, forward+backward correctness for causal, post-scale-bias, and softcap score mods, plus a distributed DP+TP sharding test.

Confidence Score: 4/5

Safe to merge for its experimental use case; the score_mod path is additive and does not touch the existing fused attention flow.

The core forward/backward logic, custom_vjp wiring, and C++ FFI handlers are well-structured and tests cover correctness for single-device and DP+TP distributed cases. Open issues from earlier review rounds (ref-counting leak, GIL during FFI dispatch, unbounded registry growth, callback-ordering assumption) are the primary risks. The missing SPMD sharding-rule registration and non-atomic sys.path mutation are quality gaps, not correctness regressions, but the sharding gap means distributed performance is untested at scale.

transformer_engine/jax/cpp_extensions/attention.py: fused_attn_score_mod_fwd/bwd bypass the BasePrimitive/SdyShardingRule infrastructure used by all existing attention primitives, and the sys.path modification is not thread-safe.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/attention.py Core of the new score_mod path: Python-level cuDNN graph building, cache logic, and FFI call dispatch. Missing SPMD partitioning rules; sys.path mutation not thread-safe.
transformer_engine/jax/csrc/extensions/attention.cpp C++ FFI handlers and ScoreModGraphRegistry. Pre-existing issues flagged in prior threads (Py_INCREF leak, GIL acquisition during FFI call) remain open.
transformer_engine/jax/attention.py Adds score_mod entry point via custom_vjp, validation, and early-return path. Logic, nondiff_argnums usage, and gradient structure are consistent.
tests/jax/test_fused_attn.py Tests cover forward+backward correctness for causal, post-scale-bias, and softcap score mods.
tests/jax/test_distributed_fused_attn.py Validates sharding with DP+TP meshes; asserts output and gradient shardings match the input sharding.

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +706 to +713
struct ScoreModGraphEntry {
PyObject *py_graph = nullptr;
std::vector<int64_t> user_uids;
std::vector<int64_t> input_uids;
std::vector<int64_t> output_uids;
std::vector<int64_t> scalar_uids;
std::vector<ScoreModScalarStorage> scalar_values;
};
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Python reference leak: Py_INCREF without a matching Py_DECREF

ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'd call it a leak, but yes, currently the cache is process-lifetime. If we ever encounter an issue with it's growth, then we'll need to implement some kind of eviction policy. But it is out of scope of this PR.

Comment on lines +684 to +692
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)

q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape)
k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape)
v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape)
o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape)
do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 id()-based cache keys can produce false cache hits after GC

_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.

Comment on lines +765 to +807
Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id,
const std::vector<void *> &input_ptrs,
const std::vector<void *> &output_ptrs, void *workspace) {
auto entry = GetScoreModGraphEntry(graph_id);
NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ",
entry->input_uids.size(), " inputs but got ", input_ptrs.size());
NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(),
"cuDNN score_mod graph expected at least ", entry->output_uids.size(),
" outputs but got ", output_ptrs.size());

std::unordered_map<int64_t, void *> variant_pack;
for (size_t i = 0; i < entry->input_uids.size(); ++i) {
variant_pack.emplace(entry->input_uids[i], input_ptrs[i]);
}
for (size_t i = 0; i < entry->output_uids.size(); ++i) {
variant_pack.emplace(entry->output_uids[i], output_ptrs[i]);
}
for (size_t i = 0; i < entry->scalar_uids.size(); ++i) {
variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data());
}

std::vector<std::intptr_t> user_ptrs;
user_ptrs.reserve(entry->user_uids.size());
for (const auto uid : entry->user_uids) {
auto it = variant_pack.find(uid);
NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid);
user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second));
}

auto handle = GetScoreModCudnnHandle();
NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream));
{
pybind11::gil_scoped_acquire gil;
try {
auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph);
graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace),
reinterpret_cast<std::intptr_t>(handle));
} catch (const pybind11::error_already_set &exc) {
NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what());
}
}
return ffi_with_cuda_error_check();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 GIL held across a CUDA FFI call boundary

ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.

_SCORE_MOD_UID_DQ = 7
_SCORE_MOD_UID_DK = 8
_SCORE_MOD_UID_DV = 9
_SCORE_MOD_FWD_TENSOR_UID_BASE = 1000
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _score_mod_graph_cache and C++ registry grow without bound

_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment on lines +143 to +195

def forward(self, graph, score, tensors):
import cudnn # pylint: disable=import-outside-toplevel

self.before_tanh_activation = graph.div(
a=score,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)
self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT)
tanh_out = graph.tanh(input=self.before_tanh_activation)
tanh_out.set_data_type(cudnn.data_type.FLOAT)
return graph.mul(
a=tanh_out,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)

def backward(self, graph, dscore, tensors):
import cudnn # pylint: disable=import-outside-toplevel

d_tanh_out = graph.mul(
a=dscore,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)
d_tanh_out.set_data_type(cudnn.data_type.FLOAT)
d_before_tanh_activation = graph.tanh_backward(
loss=d_tanh_out,
input=self.before_tanh_activation,
compute_data_type=cudnn.data_type.FLOAT,
)
d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT)
return graph.div(
a=d_before_tanh_activation,
b=tensors["softcap"],
compute_data_type=cudnn.data_type.FLOAT,
)


def _reference_attention(
query, key, value, scale, *, causal=False, relative_position=False, softcap=None
):
scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale
if causal:
q_pos = jnp.arange(query.shape[1])[:, None]
kv_pos = jnp.arange(key.shape[1])[None, :]
scores = jnp.where(q_pos >= kv_pos, scores, -1e9)
if relative_position:
q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None]
kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :]
scores = scores + q_pos - kv_pos
if softcap is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 _ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering

backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.

vcherepanov-nv and others added 2 commits May 15, 2026 03:35
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
_SCORE_MOD_UID_K = 2
_SCORE_MOD_UID_V = 3
_SCORE_MOD_UID_O = 4
_SCORE_MOD_UID_STATS = 5
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do these _SCORE_MOD_UID_XXXXX come from? Is it a C/C++ enum? If so, we should make this a Python Enum that derives its values from the C/C++ enum exposed via pybind

See this enum for reference:

NO_SCALING = JAXX_Scaling_Mode.NO_SCALING

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are just arbitrary numbers, really. In fact, assigning UIDs is completely optional, cuDNN can auto-assign. UIDs are added here just for determinism / to make future troubleshooting easier, e.g. so that we know that 4 is the output tensor.


def _import_cudnn_for_score_mod():
cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH)
cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work for all installation types? Editable install, system installation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Source/editable installs are covered by the repo-relative 3rdparty/cudnn-frontend/python path. Normal/system installs work if the cudnn Python package is installed on sys.path; otherwise the code raises an explicit import error.

}

std::atomic<int64_t> &NextScoreModGraphId() {
static std::atomic<int64_t> next_id{1};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need these static values? Can we simplify this to track things on the Python side instead of C++ or as explicit FFI args?

  static std::unordered_map<int64_t, std::shared_ptr<ScoreModGraphEntry>> registry;
...
static std::mutex mutex;
...
  static std::atomic<int64_t> next_id{1};

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot carry cuDNN graph directly via FFI, right? The Python trace-time code registers the graph and passes a graph_id; the C++ FFI handler later uses that ID to find the graph and UID metadata when XLA executes. Moving the registry to Python would still require a global lookup and more GIL / module-lifetime coupling.

Comment on lines +706 to +713
struct ScoreModGraphEntry {
PyObject *py_graph = nullptr;
std::vector<int64_t> user_uids;
std::vector<int64_t> input_uids;
std::vector<int64_t> output_uids;
std::vector<int64_t> scalar_uids;
std::vector<ScoreModScalarStorage> scalar_values;
};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here

):
"""Validate arguments for the cuDNN frontend score_mod path."""
header = "score_mod fused_attn"
if qkv_layout is not QKVLayout.BSHD_BSHD_BSHD:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed error checks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants