Skip to content

TritonKernelCall: CUDA graph compatibility#3000

Open
tdophung wants to merge 2 commits into
NVIDIA:mainfrom
tdophung:jax-triton-cuda-graph-ffi
Open

TritonKernelCall: CUDA graph compatibility#3000
tdophung wants to merge 2 commits into
NVIDIA:mainfrom
tdophung:jax-triton-cuda-graph-ffi

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented May 15, 2026

Description

Previously triton kernel calls are not CUDA graphable and will show up as individual ops instead of being able to be coalesced into a command buffer. JAX has recently added their end of the support for this. So this PR will connect that support for JAX triton kernels' CUDA graphability in TE.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Change API version to newer version for the triton kernel call ffi if JAX version supports CUDA graphability for triton kernel calls.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as ready for review May 15, 2026 22:15
@tdophung
Copy link
Copy Markdown
Collaborator Author

/te_ci jax

@tdophung
Copy link
Copy Markdown
Collaborator Author

/te-ci jax

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 15, 2026

Greptile Summary

This PR enables CUDA graph capture for JAX Triton kernel calls by switching to a newer FFI target (triton_kernel_call_ffi) and a structured ir.StringAttr backend config when the installed JAX version supports it, falling back to the existing api_version=2 path on older versions.

  • version_utils.py: Adds TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0" as the gating version floor; unlike the existing aliasing floor (which pins a specific nightly date), .dev0 accepts every 0.10.1 dev build unconditionally.
  • utils.py: Adds a version branch in triton_call_lowering — new JAX uses triton_kernel_call_ffi with a dict backend config; older JAX keeps the previous triton_kernel_call path with api_version=2.

Confidence Score: 3/5

The lowering branch logic is sound, but the version floor could enable the new code path on early 0.10.1 nightly builds that lack the triton_kernel_call_ffi registration, producing a crash at compilation time.

The '0.10.1.dev0' floor unconditionally enables the new FFI target on any 0.10.1.devYYYYMMDD build. If the C++ registration was only added partway through the 0.10.1 nightly cycle — as was the case for the aliasing fix where a specific date was required — early nightlies would reach the new code path and crash at lowering time.

transformer_engine/jax/version_utils.py — the version floor definition needs a documented specific nightly date or a justification for why .dev0 is safe.

Important Files Changed

Filename Overview
transformer_engine/jax/version_utils.py Adds TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"; the .dev0 floor accepts all 0.10.1 nightlies regardless of when the feature shipped, and the constant is missing from all.
transformer_engine/jax/triton_extensions/utils.py Branches triton_call_lowering to use the new triton_kernel_call_ffi target with an ir.StringAttr-wrapped backend config when JAX >= 0.10.1.dev0; logic and import additions look correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[triton_call_lowering called] --> B[Serialize kernel_call to proto]
    B --> C[zlib.compress call_proto]
    C --> D{JAX >= 0.10.1.dev0?}
    D -- Yes --> E[ffi_lowering triton_kernel_call_ffi\nbackend_config as ir.StringAttr dict]
    D -- No --> F[ffi_lowering triton_kernel_call\napi_version=2, backend_config as bytes]
    E --> G[CUDA-graph-compatible lowering rule]
    F --> H[Legacy lowering rule]
    G --> I[rule applied to ctx and array_args]
    H --> I
Loading

Comments Outside Diff (1)

  1. transformer_engine/jax/version_utils.py, line 77-83 (link)

    P2 The new constant TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION is not listed in __all__. Any caller that relies on from transformer_engine.jax.version_utils import * will not get it, and the exported surface is inconsistent with the other constants and helpers already listed.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +28 to +29
# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture)
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Using "0.10.1.dev0" as the floor accepts every 0.10.1.devYYYYMMDD nightly build (because any real date integer > 0), including early builds that may predate the CUDA-graph support. The existing pattern in this file pins to the exact nightly that first shipped the fix (e.g. "0.9.2.dev20260317"). If the feature was not present from the very first 0.10.1 dev build, this floor will silently enable the new code path on builds that don't have the matching C++ registration, likely producing a KeyError or undefined-symbol crash at lowering time.

Suggested change
# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture)
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"
# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture).
# Pin to the exact nightly that first shipped the triton_kernel_call_ffi registration
# (analogous to _TRITON_AUTOTUNED_ALIAS_NIGHTLY_FLOOR above).
# TODO: replace dev0 with the first nightly date that contains the feature, e.g.
# "0.10.1.dev20YYMMDD", and add a separate stable floor if 0.10.1 stable ships the fix.
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant