Skip to content
Open
19 changes: 13 additions & 6 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
fused_attn,
run_length_fill,
make_swa_mask,
check_set_window_size,
SequenceDescriptor,
CPStrategy,
ReorderStrategy,
Expand Down Expand Up @@ -1065,10 +1066,13 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu

cuDNN < 9.2: skip (no SWA support).
cuDNN >= 9.2: left-only window (s_kv // 10, 0).
cuDNN >= 9.6: bidirectional window (s_kv // 10, s_kv // 10 + 5) for the mask types whose
bidirectional fused dispatch is meaningful here (NO_MASK, PADDING_MASK).
Other mask types keep the left-only window: causal-family masks would
collapse (W, W) -> (W, 0), hence not tested here.
cuDNN >= 9.6: bidirectional asymmetric window (s_kv // 10, s_kv // 10 + 5) for the mask
types whose bidirectional fused dispatch is meaningful here (NO_MASK,
PADDING_MASK). Other mask types keep the left-only window: causal-family
masks would collapse (W, W) -> (W, 0), hence not tested here.

The chosen ``(left, right)`` is then routed through :func:`check_set_window_size`, which
is the same canonicalizer the production modules call at construction time.
"""
cudnn_version = get_cudnn_version()
if cudnn_version < 90200:
Expand All @@ -1080,8 +1084,11 @@ def _get_swa_window_size_for_test(s_kv: int, attn_mask_type: AttnMaskType) -> Tu
AttnMaskType.NO_MASK,
AttnMaskType.PADDING_MASK,
):
return (left_window_size, right_window_size)
return (left_window_size, 0)
candidate = (left_window_size, right_window_size)
else:
candidate = (left_window_size, 0)
# Validate the window size against the contract and return the canonicalized value.
return check_set_window_size(attn_mask_type, candidate)


@pytest.mark.parametrize(
Expand Down
128 changes: 112 additions & 16 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,88 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
)


def check_set_window_size(
attn_mask_type: Union[str, AttnMaskType],
window_size: Optional[Tuple[int, int]] = None,
*,
warn: bool = True,
) -> Tuple[int, int]:
"""Check if sliding window size is compliant with attention mask type.
If not, set it to the appropriate size.

attn_mask_type | window_size
----------------------------------------------------------------------------
no_mask, padding | (-1, -1) or (>=0, >=0)
causal, padding_causal | (-1, 0) or (>=0, 0)
causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0)

``(-1, -1)`` and ``(-1, 0)`` are sentinels meaning "no window" (full attention) and
"infinite-left causal" respectively. Negative entries are otherwise rejected.

Args:
attn_mask_type: Either a canonical ``attn_mask_type`` string (e.g. ``"no_mask"``,
``"padding"``, ``"causal"``, ``"padding_causal"``, ``"causal_bottom_right"``,
``"padding_causal_bottom_right"``) or an :class:`AttnMaskType` enum value.
window_size: ``(left, right)`` tuple, or ``None`` to use the natural default for the
given mask type.
warn: When ``True`` (default), emit a :class:`UserWarning` whenever the supplied
``window_size`` is silently coerced to the canonical form for ``attn_mask_type``
Set to ``False`` for internal call sites that do not need to emit warnings.
Hard-error branches (negative bounds outside the recognized sentinels) are not gated by this flag
and always raise.

Returns:
The canonicalized ``(left, right)`` tuple.
"""
if isinstance(attn_mask_type, str):
attn_mask_type_enum = canonicalize_attn_mask_type(attn_mask_type)
attn_mask_type_str = attn_mask_type
else:
attn_mask_type_enum = attn_mask_type
attn_mask_type_str = attn_mask_type.name

orig_window_size = window_size
if attn_mask_type_enum.is_causal():
if orig_window_size is None:
window_size = (-1, 0)
# Coerce the right side window to 0.
elif orig_window_size == (-1, -1) or (orig_window_size[0] >= 0 and orig_window_size[1] > 0):
window_size = (orig_window_size[0], 0)
if warn:
warnings.warn(
"window_size should be (-1, 0) or (>=0, 0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; "
f"coercing to {window_size}."
)
# Assert if invalid window size is provided.
elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0):
raise AssertionError(
"window_size should be (-1, 0) or (>=0, 0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}."
)
elif attn_mask_type_enum in (AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK):
if orig_window_size is None:
window_size = (-1, -1)
# Coerce the right side window to -1.
elif orig_window_size == (-1, 0):
window_size = (-1, -1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should do this, or go the other direction and change the mask? Technically, this could be a valid combination right? no_mask/padding + swa(left, 0) -> essentially causal + swa(left,0)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cyanguwa I did think a bit about this especially since I had seen it in the PyT check_set_window_size() too.

Because there is a lot of downstream branching on the mask type in the primitives, and none really on the SWA window size, I'd prefer to not coerce the mask and instead coerce the SWA window size instead with a warning. This can also make debugging difficult (because we do change the masks for some of the CP patterns internally without the user being aware and so it just increases the chances of something going wrong when the mask has the ability to be changed in multiple places and often)

Also, smaller concern but if the mask is indeed coerced, it can give the user an incorrect understanding of the support when is_fused_attn_kernel_available() is called by them. They maybe asking about padding masks and may get an answer for padding_causal instead now (which can also be argued for the SWA window but I believe the mask just has more ramifications in general)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make "no_mask/padding + swa(left,0)" officially supported then, as a first-class citizen, just like any other combination, including "causal/padding_causal/BRCM/PBRCM + swa(left,0)", which is what it's equivalent to anyway.

if warn:
warnings.warn(
"window_size should be (-1, -1) or (>=0, >=0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}; "
f"coercing to {window_size}."
)
# Assert if invalid window size is provided.
elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0):
raise AssertionError(
"window_size should be (-1, -1) or (>=0, >=0) for "
f"attn_mask_type={attn_mask_type_str}, got {orig_window_size}."
)
else:
raise AssertionError(f"Invalid attn_mask_type: {attn_mask_type_str}")
return window_size


def is_fused_attn_kernel_available(
is_training,
q_dtype,
Expand All @@ -343,7 +425,9 @@ def is_fused_attn_kernel_available(
"""
To check whether the fused attention kernel is supported
"""
window_size_tuple = (-1, -1) if window_size is None else window_size
# Canonicalize at the CPP-extension boundary so direct callers see the same
# canonical encoding as users of DPA/MHA API to ensure consistency.
window_size_tuple = check_set_window_size(attn_mask_type, window_size)
Comment thread
KshitijLakhani marked this conversation as resolved.

def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
Expand Down Expand Up @@ -688,9 +772,9 @@ def _segment_ids_pos_to_seqlens_offsets(
segment_ids_kv,
segment_pos_q,
segment_pos_kv,
attn_mask_type,
window_size,
max_segments_per_seq,
attn_mask_type: AttnMaskType,
window_size: Tuple[int, int],
max_segments_per_seq: int,
):
"""Compute per-segment seqlens and start offsets(currently only used for THD)
Given segment-id and segment-position tensors for Q and KV,
Expand All @@ -708,21 +792,24 @@ def _segment_ids_pos_to_seqlens_offsets(
attn_mask_type: AttnMaskType. Selects the mask predicate used to decide
which positions are valid (top-left causal vs
bottom-right causal vs. padding-only)
window_size: Optional sliding-window tuple ``(left, right)`` or None
Used here only as a fast-path eligibility hint
window_size: Sliding-window tuple ``(left, right)``. Required (not
Optional): Tuple[int, int]. Window size received should be
already canonicalized by check_set_window_size.
max_segments_per_seq: maximum number of segments expected per row
Used to size the bincount / argwhere outputs

Routing (only invoked for THD qkv_layout):
1. Fast path -- ``_segment_ids_pos_to_seqlens_offsets_fast_causal_path``.
O(T) per row. Counts all segment tokens via bincount on
segment_ids and trims at most one token per segment at the
boundary. Used for:
- top-left CAUSAL / PADDING_CAUSAL with ``window_size is None``
- SWA with ``window_size == (-1, -1)`` and not bottom-right
Bottom-right causal cross-attention is excluded: the boundary
trim leaves kv_seqlen short by one per active segment, which
shifts the BRCM bottom-right alignment by one KV per Q row.
boundary. Used for any non-bottom-right mask with no finite
sliding window, i.e. ``window_size`` in
``{(-1, -1), (-1, 0)}``. ``window_size`` is guaranteed to be
non-``None`` here because it is already canonicalized by check_set_window_size.
Bottom-right causal cross-attention is excluded:
the boundary trim leaves kv_seqlen short by one per active
segment, which shifts the BRCM bottom-right alignment by one KV
per Q row.

2. Slow path -- ``_get_seqlens_offsets_thd``.
O(T * max_segments_per_seq) per row. Per-segment min/max
Expand Down Expand Up @@ -755,9 +842,11 @@ def _segment_ids_pos_to_seqlens_offsets(
# must route bottom-right masks to the slow path.

# Fast path: O(T) per row.
if (
attn_mask_type.is_causal() and not attn_mask_type.is_bottom_right() and window_size is None
) or (window_size == (-1, -1) and not attn_mask_type.is_bottom_right()):
# "No finite window" is encoded as (-1, -1) for non-causal masks and (-1, 0) for
# causal-family masks; both share window_size[0] == -1, which is therefore the
# mask-type-agnostic SWA-presence sentinel.
no_finite_window = window_size[0] == -1
if no_finite_window and not attn_mask_type.is_bottom_right():
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
)
Expand Down Expand Up @@ -825,10 +914,17 @@ def tree_unflatten(cls, aux_data, children):
return cls(*children)

def get_seqlens_and_offsets(
self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq
self,
attn_mask_type: "AttnMaskType",
qkv_layout: "QKVLayout",
window_size: Tuple[int, int],
max_segments_per_seq: int,
):
"""
Acquire the seqlens/offsets for cuDNN backend.

``window_size`` must be a ``Tuple[int, int]`` (never ``None``)
and already canonicalized by check_set_window_size.
"""
q_segment_ids, kv_segment_ids = self.segment_ids
q_segment_pos, kv_segment_pos = self.segment_pos
Expand Down
38 changes: 32 additions & 6 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
QKVFormat,
CPStrategy,
SequenceDescriptor,
check_set_window_size,
)
from ..sharding import with_sharding_constraint_by_logical_axes, HEAD_AXES, is_mesh_available

Expand Down Expand Up @@ -2479,7 +2480,19 @@ def check_supported(self):
)

def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
"""Returns a _FusedAttnConfig for single CP step call to fused attention."""
"""Returns a _FusedAttnConfig for single CP step call to fused attention.

Ring CP overrides ``attn_mask_type`` per step (e.g. ``CAUSAL_MASK`` -> ``NO_MASK``
for off-diagonal steps where the kv chunk is fully past or fully future of the
local q chunk; see ``ring_attn_fwd_impl`` / ``ring_attn_bwd_impl``). The user's
``window_size`` is the canonical no-SWA form for the *original* mask, so we
re-canonicalize it for the per-step mask.

``warn=False`` because the user's ``window_size`` was already canonicalized
by check_set_window_size upstream. This per-step coercion is an internal mask
switch (for ring P2P) which if reported, may confuse the user.
"""
per_step_window = check_set_window_size(attn_mask_type, self.config.window_size, warn=False)
return _FusedAttnConfig(
attn_bias_type=self.config.attn_bias_type,
attn_mask_type=attn_mask_type,
Expand All @@ -2489,7 +2502,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig:
dropout_probability=self.config.dropout_probability,
is_training=self.config.is_training,
max_segments_per_seq=self.config.max_segments_per_seq,
window_size=self.config.window_size,
window_size=per_step_window,
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=self.config.context_parallel_load_balanced,
cp_axis=self.config.cp_axis,
Expand Down Expand Up @@ -3149,7 +3162,11 @@ def compute(config):
config=config,
)

if config.window_size != (-1, -1):
# Trigger striped-window adjustment only when there is a finite SWA.
# window_size[0] == -1 is the unified "no finite window" sentinel that
# covers both the non-causal (-1, -1) form and the causal-family (-1, 0)
# form produced by check_set_window_size
if config.window_size[0] != -1:
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
Expand Down Expand Up @@ -3302,7 +3319,10 @@ def compute(config):
)
return dq_per_step, dkv_per_step, dbias_per_step

if config.window_size != (-1, -1):
# See fwd path above: window_size[0] != -1 is the unified "finite SWA"
# sentinel that handles both the non-causal (-1, -1) and causal-family
# (-1, 0) canonical forms produced by check_set_window_size.
if config.window_size[0] != -1:
kv_src_rank = (cp_size + cp_rank - idx) % cp_size
# Note: all inputs of adjust_cp_striped_window_size should be host values
cp_striped_window_size = adjust_cp_striped_window_size(
Expand Down Expand Up @@ -3486,7 +3506,10 @@ def fused_attn_fwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
# Canonicalize at the CPP-extension boundary so every downstream primitive this
# function dispatches to (default fused-attn and the CP all-gather/ring variants)
# sees the same canonical encoding as the DPA/MHA API.
window_size=check_set_window_size(attn_mask_type, window_size),
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
Expand Down Expand Up @@ -3661,7 +3684,10 @@ def fused_attn_bwd(
dropout_probability=dropout_probability,
is_training=is_training,
max_segments_per_seq=max_segments_per_seq,
window_size=(-1, -1) if window_size is None else window_size,
# Canonicalize at the CPP-extension boundary so every downstream primitive this
# function dispatches to (default fused-attn and the CP all-gather/ring variants)
# sees the same canonical encoding as the DPA/MHA API.
window_size=check_set_window_size(attn_mask_type, window_size),
bottom_right_diagonal=attn_mask_type.is_bottom_right(),
context_parallel_load_balanced=context_parallel_causal_load_balanced,
cp_axis=_maybe_context_parallel_axis(context_parallel_axis),
Expand Down
Loading
Loading