Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ class _HubKernelConfig:
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_varlen_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_varlen_backward",
wrapped_forward_attr="flash_attn_interface._flash_attn_varlen_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_varlen_backward",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
Expand Down Expand Up @@ -1325,8 +1325,8 @@ def _flash_varlen_attention_hub_forward_op(
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_forward` and "
"`_wrapped_flash_attn_varlen_backward` for context parallel execution."
"Flash attention varlen hub kernels must expose `_flash_attn_varlen_forward` and "
"`_flash_attn_varlen_backward` for context parallel execution."
)

if scale is None:
Expand Down Expand Up @@ -1419,7 +1419,7 @@ def _flash_varlen_attention_hub_backward_op(
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention varlen hub kernels must expose `_wrapped_flash_attn_varlen_backward` "
"Flash attention varlen hub kernels must expose `_flash_attn_varlen_backward` "
"for context parallel execution."
)

Expand Down
Loading