[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a new code path for the JAX FusedAttention backend using the cuDNN frontend Python API, allowing
Confidence Score: 4/5Safe to merge for its experimental use case; the score_mod path is additive and does not touch the existing fused attention flow. The core forward/backward logic, custom_vjp wiring, and C++ FFI handlers are well-structured and tests cover correctness for single-device and DP+TP distributed cases. Open issues from earlier review rounds (ref-counting leak, GIL during FFI dispatch, unbounded registry growth, callback-ordering assumption) are the primary risks. The missing SPMD sharding-rule registration and non-atomic sys.path mutation are quality gaps, not correctness regressions, but the sharding gap means distributed performance is untested at scale. transformer_engine/jax/cpp_extensions/attention.py: fused_attn_score_mod_fwd/bwd bypass the BasePrimitive/SdyShardingRule infrastructure used by all existing attention primitives, and the sys.path modification is not thread-safe. Important Files Changed
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
Python reference leak:
Py_INCREF without a matching Py_DECREF
ScoreModGraphEntry stores a raw PyObject* and its refcount is bumped at registration (Py_INCREF(entry->py_graph) at line 833), but the struct has no destructor to call Py_DECREF. Because ScoreModGraphRegistry never removes entries either, every cuDNN Python graph object registered here is permanently immortalised — it will never be collected by Python's GC regardless of what the call site does. Over many different attention shapes or graph configurations this accumulates silently. The fix is to add a destructor that acquires the GIL and calls Py_DECREF, or to store a pybind11::object (which manages the refcount automatically) and ensure destruction always happens under the GIL.
There was a problem hiding this comment.
@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here
There was a problem hiding this comment.
Not sure if I'd call it a leak, but yes, currently the cache is process-lifetime. If we ever encounter an issue with it's growth, then we'll need to implement some kind of eviction policy. But it is out of scope of this PR.
| intermediate_data_type=cudnn.data_type.FLOAT, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| q_dim, q_stride = _bshd_as_bhsd_dim_stride(q_aval.shape) | ||
| k_dim, k_stride = _bshd_as_bhsd_dim_stride(k_aval.shape) | ||
| v_dim, v_stride = _bshd_as_bhsd_dim_stride(v_aval.shape) | ||
| o_dim, o_stride = _bshd_as_bhsd_dim_stride(output_aval.shape) | ||
| do_dim, do_stride = _bshd_as_bhsd_dim_stride(doutput_aval.shape) |
There was a problem hiding this comment.
id()-based cache keys can produce false cache hits after GC
_score_mod_callback_cache_key builds its key from id(self_obj) and id(func). Python recycles object addresses after GC, so if a callback instance is collected and a new object (of a different class or with different graph logic) is allocated at the same address, the new config will compare equal to the old one under __eq__. JAX's nondiff-argnum caching then reuses the traced function and graph built for the original callback, silently executing the wrong cuDNN graph. The risk is low for long-lived module-level functions but real for short-lived class instances. Anchoring the key to a non-id stable identifier (e.g., a weakref plus explicit id, or requiring callers to supply an explicit stable key) would eliminate the ambiguity.
| Error_Type ExecuteScoreModGraph(cudaStream_t stream, int64_t graph_id, | ||
| const std::vector<void *> &input_ptrs, | ||
| const std::vector<void *> &output_ptrs, void *workspace) { | ||
| auto entry = GetScoreModGraphEntry(graph_id); | ||
| NVTE_CHECK(input_ptrs.size() == entry->input_uids.size(), "cuDNN score_mod graph expected ", | ||
| entry->input_uids.size(), " inputs but got ", input_ptrs.size()); | ||
| NVTE_CHECK(output_ptrs.size() >= entry->output_uids.size(), | ||
| "cuDNN score_mod graph expected at least ", entry->output_uids.size(), | ||
| " outputs but got ", output_ptrs.size()); | ||
|
|
||
| std::unordered_map<int64_t, void *> variant_pack; | ||
| for (size_t i = 0; i < entry->input_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->input_uids[i], input_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->output_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->output_uids[i], output_ptrs[i]); | ||
| } | ||
| for (size_t i = 0; i < entry->scalar_uids.size(); ++i) { | ||
| variant_pack.emplace(entry->scalar_uids[i], entry->scalar_values[i].data.data()); | ||
| } | ||
|
|
||
| std::vector<std::intptr_t> user_ptrs; | ||
| user_ptrs.reserve(entry->user_uids.size()); | ||
| for (const auto uid : entry->user_uids) { | ||
| auto it = variant_pack.find(uid); | ||
| NVTE_CHECK(it != variant_pack.end(), "cuDNN score_mod graph variant pack is missing UID ", uid); | ||
| user_ptrs.push_back(reinterpret_cast<std::intptr_t>(it->second)); | ||
| } | ||
|
|
||
| auto handle = GetScoreModCudnnHandle(); | ||
| NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); | ||
| { | ||
| pybind11::gil_scoped_acquire gil; | ||
| try { | ||
| auto graph = pybind11::reinterpret_borrow<pybind11::object>(entry->py_graph); | ||
| graph.attr("_execute_with_ptrs")(user_ptrs, reinterpret_cast<std::intptr_t>(workspace), | ||
| reinterpret_cast<std::intptr_t>(handle)); | ||
| } catch (const pybind11::error_already_set &exc) { | ||
| NVTE_ERROR("cuDNN score_mod SDPA graph execution failed: ", exc.what()); | ||
| } | ||
| } | ||
| return ffi_with_cuda_error_check(); | ||
| } |
There was a problem hiding this comment.
GIL held across a CUDA FFI call boundary
ExecuteScoreModGraph acquires pybind11::gil_scoped_acquire while the CUDA stream is live and calls a Python method (_execute_with_ptrs) synchronously. Any other Python thread that holds the GIL and is waiting on CUDA work will deadlock. More broadly, acquiring the GIL inside an XLA/JAX FFI handler — which JAX may dispatch from a non-Python thread — creates a locking inversion risk. This is by-design if cuDNN's Python frontend has no C-level execution path, but the limitation should be documented and the possibility of multi-threaded JAX dispatch should be explicitly considered.
| _SCORE_MOD_UID_DQ = 7 | ||
| _SCORE_MOD_UID_DK = 8 | ||
| _SCORE_MOD_UID_DV = 9 | ||
| _SCORE_MOD_FWD_TENSOR_UID_BASE = 1000 |
There was a problem hiding this comment.
_score_mod_graph_cache and C++ registry grow without bound
_score_mod_graph_cache is a module-level dict that accumulates (graph_id, workspace_size) entries for every unique (direction, config, aval-tuple) seen during tracing, and the C++ ScoreModGraphRegistry holds the corresponding cuDNN graph objects forever. Each entry keeps a Python cuDNN graph alive (and, due to the missing Py_DECREF noted separately, prevents GC). In long-running services or evaluation loops that sweep over many shapes/dtypes, this leads to unbounded cuDNN graph memory accumulation. An LRU eviction strategy or an explicit graph-release API paired with cache invalidation would contain the growth.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
|
|
||
| def forward(self, graph, score, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| self.before_tanh_activation = graph.div( | ||
| a=score, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| self.before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| tanh_out = graph.tanh(input=self.before_tanh_activation) | ||
| tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.mul( | ||
| a=tanh_out, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
| def backward(self, graph, dscore, tensors): | ||
| import cudnn # pylint: disable=import-outside-toplevel | ||
|
|
||
| d_tanh_out = graph.mul( | ||
| a=dscore, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_tanh_out.set_data_type(cudnn.data_type.FLOAT) | ||
| d_before_tanh_activation = graph.tanh_backward( | ||
| loss=d_tanh_out, | ||
| input=self.before_tanh_activation, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
| d_before_tanh_activation.set_data_type(cudnn.data_type.FLOAT) | ||
| return graph.div( | ||
| a=d_before_tanh_activation, | ||
| b=tensors["softcap"], | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| ) | ||
|
|
||
|
|
||
| def _reference_attention( | ||
| query, key, value, scale, *, causal=False, relative_position=False, softcap=None | ||
| ): | ||
| scores = jnp.einsum("bqhd,bkhd->bhqk", query, key).astype(jnp.float32) * scale | ||
| if causal: | ||
| q_pos = jnp.arange(query.shape[1])[:, None] | ||
| kv_pos = jnp.arange(key.shape[1])[None, :] | ||
| scores = jnp.where(q_pos >= kv_pos, scores, -1e9) | ||
| if relative_position: | ||
| q_pos = jnp.arange(query.shape[1], dtype=jnp.float32)[:, None] | ||
| kv_pos = jnp.arange(key.shape[1], dtype=jnp.float32)[None, :] | ||
| scores = scores + q_pos - kv_pos | ||
| if softcap is not None: |
There was a problem hiding this comment.
_ScoreModSoftcap.backward relies on undocumented cuDNN callback ordering
backward reads self.before_tanh_activation, which is written by forward during sdpa_backward graph construction. This is only safe if cuDNN's sdpa_backward guarantees it calls score_mod (the forward callback) before score_mod_bprop (the backward callback) within the same graph-build invocation. If that order is ever reversed, self.before_tanh_activation is None at the time backward runs, and graph.tanh_backward(input=None, ...) will fail silently or crash at execution time rather than at graph-build time.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
| _SCORE_MOD_UID_K = 2 | ||
| _SCORE_MOD_UID_V = 3 | ||
| _SCORE_MOD_UID_O = 4 | ||
| _SCORE_MOD_UID_STATS = 5 |
There was a problem hiding this comment.
Where do these _SCORE_MOD_UID_XXXXX come from? Is it a C/C++ enum? If so, we should make this a Python Enum that derives its values from the C/C++ enum exposed via pybind
See this enum for reference:
There was a problem hiding this comment.
These are just arbitrary numbers, really. In fact, assigning UIDs is completely optional, cuDNN can auto-assign. UIDs are added here just for determinism / to make future troubleshooting easier, e.g. so that we know that 4 is the output tensor.
|
|
||
| def _import_cudnn_for_score_mod(): | ||
| cudnn_frontend_path = str(_CUDNN_FRONTEND_PYTHON_PATH) | ||
| cudnn_frontend_package = _CUDNN_FRONTEND_PYTHON_PATH / "cudnn" |
There was a problem hiding this comment.
Does this work for all installation types? Editable install, system installation?
There was a problem hiding this comment.
Source/editable installs are covered by the repo-relative 3rdparty/cudnn-frontend/python path. Normal/system installs work if the cudnn Python package is installed on sys.path; otherwise the code raises an explicit import error.
| } | ||
|
|
||
| std::atomic<int64_t> &NextScoreModGraphId() { | ||
| static std::atomic<int64_t> next_id{1}; |
There was a problem hiding this comment.
Why do we need these static values? Can we simplify this to track things on the Python side instead of C++ or as explicit FFI args?
static std::unordered_map<int64_t, std::shared_ptr<ScoreModGraphEntry>> registry;
...
static std::mutex mutex;
...
static std::atomic<int64_t> next_id{1};
There was a problem hiding this comment.
We cannot carry cuDNN graph directly via FFI, right? The Python trace-time code registers the graph and passes a graph_id; the C++ FFI handler later uses that ID to find the graph and UID metadata when XLA executes. Moving the registry to Python would still require a global lookup and more GIL / module-lifetime coupling.
| struct ScoreModGraphEntry { | ||
| PyObject *py_graph = nullptr; | ||
| std::vector<int64_t> user_uids; | ||
| std::vector<int64_t> input_uids; | ||
| std::vector<int64_t> output_uids; | ||
| std::vector<int64_t> scalar_uids; | ||
| std::vector<ScoreModScalarStorage> scalar_values; | ||
| }; |
There was a problem hiding this comment.
@vcherepanov-nv This seems like a valid comment from greptile about leaking pygraphs. But I'm also not sure if that is the intended design to prevent GC freeing up a graph too early by mistake before we use it in the XLA C++ FFI. I'm not sure what the best option is here
| ): | ||
| """Validate arguments for the cuDNN frontend score_mod path.""" | ||
| header = "score_mod fused_attn" | ||
| if qkv_layout is not QKVLayout.BSHD_BSHD_BSHD: |
There was a problem hiding this comment.
Thanks for the detailed error checks!
Description
This PR introduces an alternative code path for the FusedAttention backend for JAX.
The user can specify score_mod and score_mod_bprop functions, which get routed to the corresponding parameters of the sdpa and sdpa_backward calls to cuDNN FE.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: