Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/art/megatron/flex_attn/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,17 @@ def get_dense_compiled_flex_attention(
head_dim_v: int,
triton_num_stages_2_head_dims: tuple[int, ...] = (),
) -> Any:
if backend == _FORCED_FLEX_BACKEND:
return dense_compiled_flex_attention
if backend == "FLASH":
return flash_dense_compiled_flex_attention
if _needs_triton_num_stages_2(
backend=backend,
head_dim=head_dim,
head_dim_v=head_dim_v,
triton_num_stages_2_head_dims=triton_num_stages_2_head_dims,
):
return triton_num_stages_2_dense_compiled_flex_attention
if backend == _FORCED_FLEX_BACKEND:
return dense_compiled_flex_attention
if backend == "FLASH":
return flash_dense_compiled_flex_attention
return triton_dense_compiled_flex_attention


Expand All @@ -215,17 +215,17 @@ def get_sparse_compiled_flex_attention(
triton_num_stages_2_head_dims: tuple[int, ...] = (),
) -> Any:
del family_key
if backend == _FORCED_FLEX_BACKEND:
return sparse_compiled_flex_attention
if backend == "FLASH":
return flash_sparse_compiled_flex_attention
if _needs_triton_num_stages_2(
backend=backend,
head_dim=head_dim,
head_dim_v=head_dim_v,
triton_num_stages_2_head_dims=triton_num_stages_2_head_dims,
):
return triton_num_stages_2_sparse_compiled_flex_attention
if backend == _FORCED_FLEX_BACKEND:
return sparse_compiled_flex_attention
if backend == "FLASH":
return flash_sparse_compiled_flex_attention
return triton_sparse_compiled_flex_attention


Expand Down
32 changes: 30 additions & 2 deletions src/art/megatron/model_support/handlers/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
# google/gemma-4-31B-it: Triton flex attention raises "No valid triton
# configs" for global attention head_dim=512 with backend-only options.
("dense", 60, 5376, 32, 256, 512, 4),
# google/gemma-4-26B-A4B-it hits the same Triton resource limit on global
# attention head_dim=512 with backend-only options.
("moe", 30, 2816, 16, 256, 512, 2),
}
_ART_MOE_EXPERT_KEY_RE = re.compile(
r"^(?P<prefix>.*\.mlp\.experts)\.(?P<expert>\d+)\."
Expand Down Expand Up @@ -130,7 +133,7 @@ def configure_provider_for_runtime(self, provider: Any) -> None:
_patch_gemma4_rotary_for_hf_proportional()
_patch_gemma4_qkv_for_hf_tied_value()
window_size = int(getattr(provider, "window_size", 1024))
provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper
_install_gemma4_flex_core_attention_wrapper(provider)
provider.art_flex_sliding_windows = (window_size,)
provider.art_flex_head_dims_by_window = {
None: int(getattr(provider, "global_head_dim", provider.kv_channels)),
Expand Down Expand Up @@ -349,6 +352,19 @@ def compile_workaround_config(
disable_compile=False,
)

def flex_attention_compile_crash_config(
self,
provider: Any,
) -> FlexAttentionCompileCrashConfig:
if (
_gemma4_compile_crash_signature(provider)
in _GEMMA4_TRITON_NUM_STAGES_2_SIGNATURES
):
return FlexAttentionCompileCrashConfig(
triton_num_stages_2_head_dims=(int(provider.global_head_dim),)
)
return FlexAttentionCompileCrashConfig()


GEMMA4_MOE_HANDLER = Gemma4MoeHandler()

Expand All @@ -364,7 +380,7 @@ def configure_provider_for_runtime(self, provider: Any) -> None:
_patch_gemma4_rotary_for_hf_proportional()
_patch_gemma4_qkv_for_hf_tied_value()
window_size = int(getattr(provider, "window_size", 1024))
provider.art_flex_core_attention_wrapper = _gemma4_flex_core_attention_wrapper
_install_gemma4_flex_core_attention_wrapper(provider)
provider.art_flex_sliding_windows = (window_size,)
provider.art_flex_head_dims_by_window = {
None: int(getattr(provider, "global_head_dim", provider.kv_channels)),
Expand Down Expand Up @@ -1006,6 +1022,13 @@ def _gemma4_sliding_window_for_layer(provider: Any, layer_number: int) -> int |
return int(provider.window_size)


def _install_gemma4_flex_core_attention_wrapper(provider: Any) -> None:
def _wrapper(_config: Any, base_cls: type[Any]) -> type[Any]:
return _gemma4_flex_core_attention_wrapper(provider, base_cls)

provider.art_flex_core_attention_wrapper = _wrapper


def _gemma4_flex_core_attention_wrapper(
provider: Any, base_cls: type[Any]
) -> type[Any]:
Expand All @@ -1017,6 +1040,11 @@ def __init__(
*args: Any,
**kwargs: Any,
) -> None:
compile_crash_config = getattr(
provider, "art_flex_compile_crash_config", None
)
if compile_crash_config is not None:
config.art_flex_compile_crash_config = compile_crash_config
super().__init__(config, layer_number, *args, **kwargs)
self.art_sliding_window = _gemma4_sliding_window_for_layer(
provider,
Expand Down
Loading