From 2a8104d9db88d3dea619d8da4b65c14b933c1797 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Sun, 28 Jun 2026 21:40:49 +0000 Subject: [PATCH 1/2] Fix Gemma4 MoE Triton flex stage config --- src/art/megatron/flex_attn/compiled.py | 16 +-- .../megatron/model_support/handlers/gemma4.py | 32 +++++- .../model_support/test_compile_flags.py | 102 ++++++++++++++++++ 3 files changed, 140 insertions(+), 10 deletions(-) diff --git a/src/art/megatron/flex_attn/compiled.py b/src/art/megatron/flex_attn/compiled.py index 9b8c93a88..b6bd8dfe2 100644 --- a/src/art/megatron/flex_attn/compiled.py +++ b/src/art/megatron/flex_attn/compiled.py @@ -192,10 +192,6 @@ def get_dense_compiled_flex_attention( head_dim_v: int, triton_num_stages_2_head_dims: tuple[int, ...] = (), ) -> Any: - if backend == _FORCED_FLEX_BACKEND: - return dense_compiled_flex_attention - if backend == "FLASH": - return flash_dense_compiled_flex_attention if _needs_triton_num_stages_2( backend=backend, head_dim=head_dim, @@ -203,6 +199,10 @@ def get_dense_compiled_flex_attention( triton_num_stages_2_head_dims=triton_num_stages_2_head_dims, ): return triton_num_stages_2_dense_compiled_flex_attention + if backend == _FORCED_FLEX_BACKEND: + return dense_compiled_flex_attention + if backend == "FLASH": + return flash_dense_compiled_flex_attention return triton_dense_compiled_flex_attention @@ -215,10 +215,6 @@ def get_sparse_compiled_flex_attention( triton_num_stages_2_head_dims: tuple[int, ...] = (), ) -> Any: del family_key - if backend == _FORCED_FLEX_BACKEND: - return sparse_compiled_flex_attention - if backend == "FLASH": - return flash_sparse_compiled_flex_attention if _needs_triton_num_stages_2( backend=backend, head_dim=head_dim, @@ -226,6 +222,10 @@ def get_sparse_compiled_flex_attention( triton_num_stages_2_head_dims=triton_num_stages_2_head_dims, ): return triton_num_stages_2_sparse_compiled_flex_attention + if backend == _FORCED_FLEX_BACKEND: + return sparse_compiled_flex_attention + if backend == "FLASH": + return flash_sparse_compiled_flex_attention return triton_sparse_compiled_flex_attention diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py index 123580113..33f57ece3 100644 --- a/src/art/megatron/model_support/handlers/gemma4.py +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -50,6 +50,9 @@ # google/gemma-4-31B-it: Triton flex attention raises "No valid triton # configs" for global attention head_dim=512 with backend-only options. ("dense", 60, 5376, 32, 256, 512, 4), + # google/gemma-4-26B-A4B-it hits the same Triton resource limit on global + # attention head_dim=512 with backend-only options. + ("moe", 30, 2816, 16, 256, 512, 2), } _ART_MOE_EXPERT_KEY_RE = re.compile( r"^(?P.*\.mlp\.experts)\.(?P\d+)\." @@ -130,7 +133,7 @@ def configure_provider_for_runtime(self, provider: Any) -> None: _patch_gemma4_rotary_for_hf_proportional() _patch_gemma4_qkv_for_hf_tied_value() window_size = int(getattr(provider, "window_size", 1024)) - provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper + _install_gemma4_flex_core_attention_wrapper(provider) provider.art_flex_sliding_windows = (window_size,) provider.art_flex_head_dims_by_window = { None: int(getattr(provider, "global_head_dim", provider.kv_channels)), @@ -349,6 +352,19 @@ def compile_workaround_config( disable_compile=False, ) + def flex_attention_compile_crash_config( + self, + provider: Any, + ) -> FlexAttentionCompileCrashConfig: + if ( + _gemma4_compile_crash_signature(provider) + in _GEMMA4_TRITON_NUM_STAGES_2_SIGNATURES + ): + return FlexAttentionCompileCrashConfig( + triton_num_stages_2_head_dims=(int(provider.global_head_dim),) + ) + return FlexAttentionCompileCrashConfig() + GEMMA4_MOE_HANDLER = Gemma4MoeHandler() @@ -364,7 +380,7 @@ def configure_provider_for_runtime(self, provider: Any) -> None: _patch_gemma4_rotary_for_hf_proportional() _patch_gemma4_qkv_for_hf_tied_value() window_size = int(getattr(provider, "window_size", 1024)) - provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper + _install_gemma4_flex_core_attention_wrapper(provider) provider.art_flex_sliding_windows = (window_size,) provider.art_flex_head_dims_by_window = { None: int(getattr(provider, "global_head_dim", provider.kv_channels)), @@ -1006,6 +1022,13 @@ def _gemma4_sliding_window_for_layer(provider: Any, layer_number: int) -> int | return int(provider.window_size) +def _install_gemma4_flex_core_attention_wrapper(provider: Any) -> None: + def _wrapper(_config: Any, base_cls: type[Any]) -> type[Any]: + return _gemma4_flex_core_attention_wrapper(provider, base_cls) + + provider.art_flex_core_attention_wrapper = _wrapper + + def _gemma4_flex_core_attention_wrapper( provider: Any, base_cls: type[Any] ) -> type[Any]: @@ -1017,6 +1040,11 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: + compile_crash_config = getattr( + provider, "art_flex_compile_crash_config", None + ) + if compile_crash_config is not None: + config.art_flex_compile_crash_config = compile_crash_config super().__init__(config, layer_number, *args, **kwargs) self.art_sliding_window = _gemma4_sliding_window_for_layer( provider, diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py index 15654fc09..3376473e0 100644 --- a/tests/integration/megatron/model_support/test_compile_flags.py +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -1,3 +1,11 @@ +from typing import Any + +from art.megatron.flex_attn import compiled as compiled_flex_attention +from art.megatron.model_support.handlers.gemma4 import ( + GEMMA4_DENSE_HANDLER, + GEMMA4_MOE_HANDLER, + _install_gemma4_flex_core_attention_wrapper, +) from art.megatron.model_support.handlers.qwen3_5 import QWEN3_5_MOE_HANDLER from art.megatron.model_support.handlers.qwen3_moe import QWEN3_MOE_HANDLER @@ -31,3 +39,97 @@ def test_qwen35_moe_compile_workarounds_cover_deepep_permute_restore() -> None: config = QWEN3_5_MOE_HANDLER.compile_workaround_config(provider) assert config.flags == _QWEN35_MOE_COMPILE_FLAGS assert config.unconditional_flags == () + + +def _gemma4_provider(**overrides: int) -> Any: + attrs = { + "num_moe_experts": 128, + "num_layers": 30, + "hidden_size": 2816, + "num_attention_heads": 16, + "kv_channels": 256, + "global_head_dim": 512, + "num_global_key_value_heads": 2, + } + attrs.update(overrides) + return type("Provider", (), attrs)() + + +def test_gemma4_known_wide_global_attention_signatures_use_lower_triton_stage_count() -> ( + None +): + dense_provider = _gemma4_provider( + num_moe_experts=0, + num_layers=60, + hidden_size=5376, + num_attention_heads=32, + num_global_key_value_heads=4, + ) + moe_provider = _gemma4_provider() + + assert GEMMA4_DENSE_HANDLER.flex_attention_compile_crash_config( + dense_provider + ).triton_num_stages_2_head_dims == (512,) + assert GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config( + moe_provider + ).triton_num_stages_2_head_dims == (512,) + + +def test_gemma4_unlisted_wide_global_attention_signature_keeps_default_stage_count() -> ( + None +): + provider = _gemma4_provider(hidden_size=2817) + + assert ( + GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config( + provider + ).triton_num_stages_2_head_dims + == () + ) + + +def test_gemma4_flex_attention_wrapper_carries_provider_compile_crash_config() -> None: + provider = _gemma4_provider() + provider.art_flex_compile_crash_config = ( + GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config(provider) + ) + _install_gemma4_flex_core_attention_wrapper(provider) + + class BaseAttention: + def __init__(self, config: Any, layer_number: int) -> None: + del layer_number + self.head_dims = ( + config.art_flex_compile_crash_config.triton_num_stages_2_head_dims + ) + + copied_config = type("CopiedConfig", (), {})() + wrapped_cls = provider.art_flex_core_attention_wrapper(copied_config, BaseAttention) + wrapped = wrapped_cls(type("LayerConfig", (), {})(), 1) + + assert wrapped.head_dims == (512,) + + +def test_triton_num_stages_2_selection_overrides_forced_triton_backend( + monkeypatch: Any, +) -> None: + monkeypatch.setattr(compiled_flex_attention, "_FORCED_FLEX_BACKEND", "TRITON") + + assert ( + compiled_flex_attention.get_dense_compiled_flex_attention( + backend="TRITON", + head_dim=512, + head_dim_v=512, + triton_num_stages_2_head_dims=(512,), + ) + is compiled_flex_attention.triton_num_stages_2_dense_compiled_flex_attention + ) + assert ( + compiled_flex_attention.get_sparse_compiled_flex_attention( + family_key="test", + backend="TRITON", + head_dim=512, + head_dim_v=512, + triton_num_stages_2_head_dims=(512,), + ) + is compiled_flex_attention.triton_num_stages_2_sparse_compiled_flex_attention + ) From 77d2b04488d826a9c07c851e605305411b44ab49 Mon Sep 17 00:00:00 2001 From: FurtherAI Date: Mon, 29 Jun 2026 00:18:25 +0000 Subject: [PATCH 2/2] Remove Gemma4 compile flag tests --- .../model_support/test_compile_flags.py | 102 ------------------ 1 file changed, 102 deletions(-) diff --git a/tests/integration/megatron/model_support/test_compile_flags.py b/tests/integration/megatron/model_support/test_compile_flags.py index 3376473e0..15654fc09 100644 --- a/tests/integration/megatron/model_support/test_compile_flags.py +++ b/tests/integration/megatron/model_support/test_compile_flags.py @@ -1,11 +1,3 @@ -from typing import Any - -from art.megatron.flex_attn import compiled as compiled_flex_attention -from art.megatron.model_support.handlers.gemma4 import ( - GEMMA4_DENSE_HANDLER, - GEMMA4_MOE_HANDLER, - _install_gemma4_flex_core_attention_wrapper, -) from art.megatron.model_support.handlers.qwen3_5 import QWEN3_5_MOE_HANDLER from art.megatron.model_support.handlers.qwen3_moe import QWEN3_MOE_HANDLER @@ -39,97 +31,3 @@ def test_qwen35_moe_compile_workarounds_cover_deepep_permute_restore() -> None: config = QWEN3_5_MOE_HANDLER.compile_workaround_config(provider) assert config.flags == _QWEN35_MOE_COMPILE_FLAGS assert config.unconditional_flags == () - - -def _gemma4_provider(**overrides: int) -> Any: - attrs = { - "num_moe_experts": 128, - "num_layers": 30, - "hidden_size": 2816, - "num_attention_heads": 16, - "kv_channels": 256, - "global_head_dim": 512, - "num_global_key_value_heads": 2, - } - attrs.update(overrides) - return type("Provider", (), attrs)() - - -def test_gemma4_known_wide_global_attention_signatures_use_lower_triton_stage_count() -> ( - None -): - dense_provider = _gemma4_provider( - num_moe_experts=0, - num_layers=60, - hidden_size=5376, - num_attention_heads=32, - num_global_key_value_heads=4, - ) - moe_provider = _gemma4_provider() - - assert GEMMA4_DENSE_HANDLER.flex_attention_compile_crash_config( - dense_provider - ).triton_num_stages_2_head_dims == (512,) - assert GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config( - moe_provider - ).triton_num_stages_2_head_dims == (512,) - - -def test_gemma4_unlisted_wide_global_attention_signature_keeps_default_stage_count() -> ( - None -): - provider = _gemma4_provider(hidden_size=2817) - - assert ( - GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config( - provider - ).triton_num_stages_2_head_dims - == () - ) - - -def test_gemma4_flex_attention_wrapper_carries_provider_compile_crash_config() -> None: - provider = _gemma4_provider() - provider.art_flex_compile_crash_config = ( - GEMMA4_MOE_HANDLER.flex_attention_compile_crash_config(provider) - ) - _install_gemma4_flex_core_attention_wrapper(provider) - - class BaseAttention: - def __init__(self, config: Any, layer_number: int) -> None: - del layer_number - self.head_dims = ( - config.art_flex_compile_crash_config.triton_num_stages_2_head_dims - ) - - copied_config = type("CopiedConfig", (), {})() - wrapped_cls = provider.art_flex_core_attention_wrapper(copied_config, BaseAttention) - wrapped = wrapped_cls(type("LayerConfig", (), {})(), 1) - - assert wrapped.head_dims == (512,) - - -def test_triton_num_stages_2_selection_overrides_forced_triton_backend( - monkeypatch: Any, -) -> None: - monkeypatch.setattr(compiled_flex_attention, "_FORCED_FLEX_BACKEND", "TRITON") - - assert ( - compiled_flex_attention.get_dense_compiled_flex_attention( - backend="TRITON", - head_dim=512, - head_dim_v=512, - triton_num_stages_2_head_dims=(512,), - ) - is compiled_flex_attention.triton_num_stages_2_dense_compiled_flex_attention - ) - assert ( - compiled_flex_attention.get_sparse_compiled_flex_attention( - family_key="test", - backend="TRITON", - head_dim=512, - head_dim_v=512, - triton_num_stages_2_head_dims=(512,), - ) - is compiled_flex_attention.triton_num_stages_2_sparse_compiled_flex_attention - )