diff --git a/src/art/megatron/flex_attn/compiled.py b/src/art/megatron/flex_attn/compiled.py index 9b8c93a88..b6bd8dfe2 100644 --- a/src/art/megatron/flex_attn/compiled.py +++ b/src/art/megatron/flex_attn/compiled.py @@ -192,10 +192,6 @@ def get_dense_compiled_flex_attention( head_dim_v: int, triton_num_stages_2_head_dims: tuple[int, ...] = (), ) -> Any: - if backend == _FORCED_FLEX_BACKEND: - return dense_compiled_flex_attention - if backend == "FLASH": - return flash_dense_compiled_flex_attention if _needs_triton_num_stages_2( backend=backend, head_dim=head_dim, @@ -203,6 +199,10 @@ def get_dense_compiled_flex_attention( triton_num_stages_2_head_dims=triton_num_stages_2_head_dims, ): return triton_num_stages_2_dense_compiled_flex_attention + if backend == _FORCED_FLEX_BACKEND: + return dense_compiled_flex_attention + if backend == "FLASH": + return flash_dense_compiled_flex_attention return triton_dense_compiled_flex_attention @@ -215,10 +215,6 @@ def get_sparse_compiled_flex_attention( triton_num_stages_2_head_dims: tuple[int, ...] = (), ) -> Any: del family_key - if backend == _FORCED_FLEX_BACKEND: - return sparse_compiled_flex_attention - if backend == "FLASH": - return flash_sparse_compiled_flex_attention if _needs_triton_num_stages_2( backend=backend, head_dim=head_dim, @@ -226,6 +222,10 @@ def get_sparse_compiled_flex_attention( triton_num_stages_2_head_dims=triton_num_stages_2_head_dims, ): return triton_num_stages_2_sparse_compiled_flex_attention + if backend == _FORCED_FLEX_BACKEND: + return sparse_compiled_flex_attention + if backend == "FLASH": + return flash_sparse_compiled_flex_attention return triton_sparse_compiled_flex_attention diff --git a/src/art/megatron/model_support/handlers/gemma4.py b/src/art/megatron/model_support/handlers/gemma4.py index 123580113..33f57ece3 100644 --- a/src/art/megatron/model_support/handlers/gemma4.py +++ b/src/art/megatron/model_support/handlers/gemma4.py @@ -50,6 +50,9 @@ # google/gemma-4-31B-it: Triton flex attention raises "No valid triton # configs" for global attention head_dim=512 with backend-only options. ("dense", 60, 5376, 32, 256, 512, 4), + # google/gemma-4-26B-A4B-it hits the same Triton resource limit on global + # attention head_dim=512 with backend-only options. + ("moe", 30, 2816, 16, 256, 512, 2), } _ART_MOE_EXPERT_KEY_RE = re.compile( r"^(?P.*\.mlp\.experts)\.(?P\d+)\." @@ -130,7 +133,7 @@ def configure_provider_for_runtime(self, provider: Any) -> None: _patch_gemma4_rotary_for_hf_proportional() _patch_gemma4_qkv_for_hf_tied_value() window_size = int(getattr(provider, "window_size", 1024)) - provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper + _install_gemma4_flex_core_attention_wrapper(provider) provider.art_flex_sliding_windows = (window_size,) provider.art_flex_head_dims_by_window = { None: int(getattr(provider, "global_head_dim", provider.kv_channels)), @@ -349,6 +352,19 @@ def compile_workaround_config( disable_compile=False, ) + def flex_attention_compile_crash_config( + self, + provider: Any, + ) -> FlexAttentionCompileCrashConfig: + if ( + _gemma4_compile_crash_signature(provider) + in _GEMMA4_TRITON_NUM_STAGES_2_SIGNATURES + ): + return FlexAttentionCompileCrashConfig( + triton_num_stages_2_head_dims=(int(provider.global_head_dim),) + ) + return FlexAttentionCompileCrashConfig() + GEMMA4_MOE_HANDLER = Gemma4MoeHandler() @@ -364,7 +380,7 @@ def configure_provider_for_runtime(self, provider: Any) -> None: _patch_gemma4_rotary_for_hf_proportional() _patch_gemma4_qkv_for_hf_tied_value() window_size = int(getattr(provider, "window_size", 1024)) - provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper + _install_gemma4_flex_core_attention_wrapper(provider) provider.art_flex_sliding_windows = (window_size,) provider.art_flex_head_dims_by_window = { None: int(getattr(provider, "global_head_dim", provider.kv_channels)), @@ -1006,6 +1022,13 @@ def _gemma4_sliding_window_for_layer(provider: Any, layer_number: int) -> int | return int(provider.window_size) +def _install_gemma4_flex_core_attention_wrapper(provider: Any) -> None: + def _wrapper(_config: Any, base_cls: type[Any]) -> type[Any]: + return _gemma4_flex_core_attention_wrapper(provider, base_cls) + + provider.art_flex_core_attention_wrapper = _wrapper + + def _gemma4_flex_core_attention_wrapper( provider: Any, base_cls: type[Any] ) -> type[Any]: @@ -1017,6 +1040,11 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: + compile_crash_config = getattr( + provider, "art_flex_compile_crash_config", None + ) + if compile_crash_config is not None: + config.art_flex_compile_crash_config = compile_crash_config super().__init__(config, layer_number, *args, **kwargs) self.art_sliding_window = _gemma4_sliding_window_for_layer( provider,