From c1bb105f0b88725a13fd4daa631281941ec2e55f Mon Sep 17 00:00:00 2001 From: tdophung Date: Fri, 15 May 2026 15:07:53 -0700 Subject: [PATCH 1/2] TritonKernelCall: CUDA graph compatibility Signed-off-by: tdophung --- .../jax/triton_extensions/utils.py | 25 ++++++++++++++----- transformer_engine/jax/version_utils.py | 3 +++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2a86321c34..b4917dec06 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -44,14 +44,17 @@ from packaging import version from jax import core +from jaxlib.mlir import ir import jax import jax.numpy as jnp from ..version_utils import ( TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION, TRITON_EXTENSION_MIN_JAX_VERSION, + TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION, is_triton_autotuned_alias_safe, is_triton_extension_supported, + jax_version_meet_requirement, ) @@ -581,6 +584,7 @@ def lowering(ctx, x, *, block_size): serialized_metadata = b"" call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata) + compressed_call_proto = zlib.compress(call_proto) if input_output_aliases: ffi_operand_output_aliases = input_output_aliases @@ -588,11 +592,20 @@ def lowering(ctx, x, *, block_size): ffi_operand_output_aliases = None # Use JAX FFI lowering with compressed protobuf - rule = jax.ffi.ffi_lowering( - "triton_kernel_call", # Custom call target registered in gpu_triton.py - api_version=2, - backend_config=zlib.compress(call_proto), - operand_output_aliases=ffi_operand_output_aliases, - ) + if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION): + rule = jax.ffi.ffi_lowering( + "triton_kernel_call_ffi", # Custom call target registered in gpu_triton.py + backend_config={ + "kernel_call_proto": ir.StringAttr.get(compressed_call_proto) + }, + operand_output_aliases=ffi_operand_output_aliases, + ) + else: + rule = jax.ffi.ffi_lowering( + "triton_kernel_call", # Custom call target registered in gpu_triton.py + api_version=2, + backend_config=compressed_call_proto, + operand_output_aliases=ffi_operand_output_aliases, + ) return rule(ctx, *array_args) diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py index e6ed9a8ea6..62368e43eb 100644 --- a/transformer_engine/jax/version_utils.py +++ b/transformer_engine/jax/version_utils.py @@ -25,6 +25,9 @@ def jax_version_meet_requirement(version: str): # Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults). TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0" +# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture) +TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0" + # Nightly and stable floors for safe input_output_aliases in TritonAutotunedKernelCall. # jaxlib/gpu/triton_kernels.cc had a bug in the autotuning save/restore loop: # it iterated over all declared aliases unconditionally, but input_copies only From 931ff207659d23ef26085be064e97c9e0c485be0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 22:15:32 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/triton_extensions/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index b4917dec06..b0cbcbb5ab 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -595,9 +595,7 @@ def lowering(ctx, x, *, block_size): if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION): rule = jax.ffi.ffi_lowering( "triton_kernel_call_ffi", # Custom call target registered in gpu_triton.py - backend_config={ - "kernel_call_proto": ir.StringAttr.get(compressed_call_proto) - }, + backend_config={"kernel_call_proto": ir.StringAttr.get(compressed_call_proto)}, operand_output_aliases=ffi_operand_output_aliases, ) else: