Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@
from packaging import version

from jax import core
from jaxlib.mlir import ir
import jax
import jax.numpy as jnp

from ..version_utils import (
TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION,
TRITON_EXTENSION_MIN_JAX_VERSION,
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION,
is_triton_autotuned_alias_safe,
is_triton_extension_supported,
jax_version_meet_requirement,
)


Expand Down Expand Up @@ -581,18 +584,26 @@ def lowering(ctx, x, *, block_size):

serialized_metadata = b""
call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata)
compressed_call_proto = zlib.compress(call_proto)

if input_output_aliases:
ffi_operand_output_aliases = input_output_aliases
else:
ffi_operand_output_aliases = None

# Use JAX FFI lowering with compressed protobuf
rule = jax.ffi.ffi_lowering(
"triton_kernel_call", # Custom call target registered in gpu_triton.py
api_version=2,
backend_config=zlib.compress(call_proto),
operand_output_aliases=ffi_operand_output_aliases,
)
if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION):
rule = jax.ffi.ffi_lowering(
"triton_kernel_call_ffi", # Custom call target registered in gpu_triton.py
backend_config={"kernel_call_proto": ir.StringAttr.get(compressed_call_proto)},
operand_output_aliases=ffi_operand_output_aliases,
)
else:
rule = jax.ffi.ffi_lowering(
"triton_kernel_call", # Custom call target registered in gpu_triton.py
api_version=2,
backend_config=compressed_call_proto,
operand_output_aliases=ffi_operand_output_aliases,
)

return rule(ctx, *array_args)
3 changes: 3 additions & 0 deletions transformer_engine/jax/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def jax_version_meet_requirement(version: str):
# Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults).
TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0"

# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture)
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"
Comment on lines +28 to +29
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Using "0.10.1.dev0" as the floor accepts every 0.10.1.devYYYYMMDD nightly build (because any real date integer > 0), including early builds that may predate the CUDA-graph support. The existing pattern in this file pins to the exact nightly that first shipped the fix (e.g. "0.9.2.dev20260317"). If the feature was not present from the very first 0.10.1 dev build, this floor will silently enable the new code path on builds that don't have the matching C++ registration, likely producing a KeyError or undefined-symbol crash at lowering time.

Suggested change
# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture)
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"
# Minimum JAX version for non-legacy Triton kernel FFI (supporting CUDA graph capture).
# Pin to the exact nightly that first shipped the triton_kernel_call_ffi registration
# (analogous to _TRITON_AUTOTUNED_ALIAS_NIGHTLY_FLOOR above).
# TODO: replace dev0 with the first nightly date that contains the feature, e.g.
# "0.10.1.dev20YYMMDD", and add a separate stable floor if 0.10.1 stable ships the fix.
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION = "0.10.1.dev0"


# Nightly and stable floors for safe input_output_aliases in TritonAutotunedKernelCall.
# jaxlib/gpu/triton_kernels.cc had a bug in the autotuning save/restore loop:
# it iterated over all declared aliases unconditionally, but input_copies only
Expand Down
Loading