diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 2a86321c34..b0cbcbb5ab 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -44,14 +44,17 @@ from packaging import version from jax import core +from jaxlib.mlir import ir import jax import jax.numpy as jnp from ..version_utils import ( TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION, TRITON_EXTENSION_MIN_JAX_VERSION, + TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION, is_triton_autotuned_alias_safe, is_triton_extension_supported, + jax_version_meet_requirement, ) @@ -581,6 +584,7 @@ def lowering(ctx, x, *, block_size): serialized_metadata = b"" call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata) + compressed_call_proto = zlib.compress(call_proto) if input_output_aliases: ffi_operand_output_aliases = input_output_aliases @@ -588,11 +592,18 @@ def lowering(ctx, x, *, block_size): ffi_operand_output_aliases = None # Use JAX FFI lowering with compressed protobuf - rule = jax.ffi.ffi_lowering( - "triton_kernel_call", # Custom call target registered in gpu_triton.py - api_version=2, - backend_config=zlib.compress(call_proto), - operand_output_aliases=ffi_operand_output_aliases, - ) + if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION): + rule = jax.ffi.ffi_lowering( + "triton_kernel_call_ffi", # Custom call target registered in gpu_triton.py + backend_config={"kernel_call_proto": ir.StringAttr.get(compressed_call_proto)}, + operand_output_aliases=ffi_operand_output_aliases, + ) + else: + rule = jax.ffi.ffi_lowering( + "triton_kernel_call", # Custom call target registered in gpu_triton.py + api_version=2, + backend_config=compressed_call_proto, + operand_output_aliases=ffi_operand_output_aliases, + ) return rule(ctx, *array_args) diff --git a/transformer_engine/jax/version_utils.py b/transformer_engine/jax/version_utils.py index e6ed9a8ea6..62368e43eb 100644 --- a/transformer_engine/jax/version_utils.py +++ b/transformer_engine/jax/version_utils.py @@ -25,6 +25,9 @@ def jax_version_meet_requirement(version: str): # Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults). TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0" +# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture) +TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0" + # Nightly and stable floors for safe input_output_aliases in TritonAutotunedKernelCall. # jaxlib/gpu/triton_kernels.cc had a bug in the autotuning save/restore loop: # it iterated over all declared aliases unconditionally, but input_copies only