Fix Hopper tests (#2242)

This commit is contained in:
Driss Guessous
2026-02-08 09:25:01 -08:00
committed by GitHub
parent 2a8d39c540
commit 72c7ba484d
4 changed files with 18 additions and 1 deletions
+7
View File
@@ -604,6 +604,13 @@ def _flash_attn_bwd(
AtomLayoutMdQ = 1
cluster_size = 1
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
is_varlen = (
cu_seqlens_q is not None
or cu_seqlens_k is not None
or seqused_q is not None
or seqused_k is not None
)
assert not is_varlen, "varlen backward is not yet supported on sm90"
else:
m_block_size = 128
n_block_size = 128
+1
View File
@@ -709,6 +709,7 @@ def test_flash_attn_varlen_output(
and not attention_chunk != 0
and dv == d
and not has_learnable_sink
and not IS_SM90
# and False
):
g_unpad = torch.randn_like(out_unpad)
@@ -26,6 +26,7 @@ from flash_attn.cute.interface import (
flash_attn_varlen_func,
flash_attn_combine,
_flash_attn_bwd,
_get_device_capability,
)
@@ -407,6 +408,11 @@ def test_flash_attn_varlen_output(
local = local_enum > 0
if local and causal:
pytest.skip()
is_sm90 = _get_device_capability() == 9
if is_sm90 and local:
pytest.xfail("bwd local attention not supported on sm90")
if is_sm90 and deterministic:
pytest.xfail("bwd deterministic not supported on sm90")
if (
causal or local
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
@@ -645,6 +651,7 @@ def test_flash_attn_varlen_output(
and not attention_chunk != 0
and dv == d
and not has_learnable_sink
and not is_sm90
# and False
):
g_unpad = torch.randn_like(out_unpad)
+3 -1
View File
@@ -277,6 +277,7 @@ def _run_mask_test(
# SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling.
sparse_tile_m_bwd = sparse_tile_m
tile_n_bwd = tile_n
if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128):
bm_bwd = create_block_mask(
mask_mod_flex,
@@ -301,6 +302,7 @@ def _run_mask_test(
*_,
) = bm_bwd.as_tuple()
sparse_tile_m_bwd = 128
tile_n_bwd = 128
softmax_scale = 1.0 / math.sqrt(headdim)
@@ -323,7 +325,7 @@ def _run_mask_test(
mask_block_idx=q_mask_idx,
full_block_cnt=full_q_cnt,
full_block_idx=full_q_idx,
block_size=(sparse_tile_m_bwd, tile_n),
block_size=(sparse_tile_m_bwd, tile_n_bwd),
)
if use_block_sparsity
else None