TritonKernelCall: CUDA graph compatibility#3000
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te_ci jax |
for more information, see https://pre-commit.ci
|
/te-ci jax |
Greptile SummaryThis PR enables CUDA graph capture for JAX Triton kernel calls by switching to a newer FFI target (
Confidence Score: 3/5The lowering branch logic is sound, but the version floor could enable the new code path on early 0.10.1 nightly builds that lack the triton_kernel_call_ffi registration, producing a crash at compilation time. The '0.10.1.dev0' floor unconditionally enables the new FFI target on any 0.10.1.devYYYYMMDD build. If the C++ registration was only added partway through the 0.10.1 nightly cycle — as was the case for the aliasing fix where a specific date was required — early nightlies would reach the new code path and crash at lowering time. transformer_engine/jax/version_utils.py — the version floor definition needs a documented specific nightly date or a justification for why .dev0 is safe. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[triton_call_lowering called] --> B[Serialize kernel_call to proto]
B --> C[zlib.compress call_proto]
C --> D{JAX >= 0.10.1.dev0?}
D -- Yes --> E[ffi_lowering triton_kernel_call_ffi\nbackend_config as ir.StringAttr dict]
D -- No --> F[ffi_lowering triton_kernel_call\napi_version=2, backend_config as bytes]
E --> G[CUDA-graph-compatible lowering rule]
F --> H[Legacy lowering rule]
G --> I[rule applied to ctx and array_args]
H --> I
|
| # Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture) | ||
| TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0" |
There was a problem hiding this comment.
Using
"0.10.1.dev0" as the floor accepts every 0.10.1.devYYYYMMDD nightly build (because any real date integer > 0), including early builds that may predate the CUDA-graph support. The existing pattern in this file pins to the exact nightly that first shipped the fix (e.g. "0.9.2.dev20260317"). If the feature was not present from the very first 0.10.1 dev build, this floor will silently enable the new code path on builds that don't have the matching C++ registration, likely producing a KeyError or undefined-symbol crash at lowering time.
| # Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture) | |
| TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0" | |
| # Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture). | |
| # Pin to the exact nightly that first shipped the triton_kernel_call_ffi registration | |
| # (analogous to _TRITON_AUTOTUNED_ALIAS_NIGHTLY_FLOOR above). | |
| # TODO: replace dev0 with the first nightly date that contains the feature, e.g. | |
| # "0.10.1.dev20YYMMDD", and add a separate stable floor if 0.10.1 stable ships the fix. | |
| TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0" |
Description
Previously triton kernel calls are not CUDA graphable and will show up as individual ops instead of being able to be coalesced into a command buffer. JAX has recently added their end of the support for this. So this PR will connect that support for JAX triton kernels' CUDA graphability in TE.
Fixes # (issue)
Type of change
Changes
Change API version to newer version for the triton kernel call ffi if JAX version supports CUDA graphability for triton kernel calls.
Checklist: