mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-06-30 21:07:55 -04:00
Fix Hopper tests (#2242)
This commit is contained in:
@@ -604,6 +604,13 @@ def _flash_attn_bwd(
|
||||
AtomLayoutMdQ = 1
|
||||
cluster_size = 1
|
||||
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
|
||||
is_varlen = (
|
||||
cu_seqlens_q is not None
|
||||
or cu_seqlens_k is not None
|
||||
or seqused_q is not None
|
||||
or seqused_k is not None
|
||||
)
|
||||
assert not is_varlen, "varlen backward is not yet supported on sm90"
|
||||
else:
|
||||
m_block_size = 128
|
||||
n_block_size = 128
|
||||
|
||||
@@ -709,6 +709,7 @@ def test_flash_attn_varlen_output(
|
||||
and not attention_chunk != 0
|
||||
and dv == d
|
||||
and not has_learnable_sink
|
||||
and not IS_SM90
|
||||
# and False
|
||||
):
|
||||
g_unpad = torch.randn_like(out_unpad)
|
||||
|
||||
@@ -26,6 +26,7 @@ from flash_attn.cute.interface import (
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_combine,
|
||||
_flash_attn_bwd,
|
||||
_get_device_capability,
|
||||
)
|
||||
|
||||
|
||||
@@ -407,6 +408,11 @@ def test_flash_attn_varlen_output(
|
||||
local = local_enum > 0
|
||||
if local and causal:
|
||||
pytest.skip()
|
||||
is_sm90 = _get_device_capability() == 9
|
||||
if is_sm90 and local:
|
||||
pytest.xfail("bwd local attention not supported on sm90")
|
||||
if is_sm90 and deterministic:
|
||||
pytest.xfail("bwd deterministic not supported on sm90")
|
||||
if (
|
||||
causal or local
|
||||
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
|
||||
@@ -645,6 +651,7 @@ def test_flash_attn_varlen_output(
|
||||
and not attention_chunk != 0
|
||||
and dv == d
|
||||
and not has_learnable_sink
|
||||
and not is_sm90
|
||||
# and False
|
||||
):
|
||||
g_unpad = torch.randn_like(out_unpad)
|
||||
|
||||
@@ -277,6 +277,7 @@ def _run_mask_test(
|
||||
|
||||
# SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling.
|
||||
sparse_tile_m_bwd = sparse_tile_m
|
||||
tile_n_bwd = tile_n
|
||||
if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128):
|
||||
bm_bwd = create_block_mask(
|
||||
mask_mod_flex,
|
||||
@@ -301,6 +302,7 @@ def _run_mask_test(
|
||||
*_,
|
||||
) = bm_bwd.as_tuple()
|
||||
sparse_tile_m_bwd = 128
|
||||
tile_n_bwd = 128
|
||||
|
||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||
|
||||
@@ -323,7 +325,7 @@ def _run_mask_test(
|
||||
mask_block_idx=q_mask_idx,
|
||||
full_block_cnt=full_q_cnt,
|
||||
full_block_idx=full_q_idx,
|
||||
block_size=(sparse_tile_m_bwd, tile_n),
|
||||
block_size=(sparse_tile_m_bwd, tile_n_bwd),
|
||||
)
|
||||
if use_block_sparsity
|
||||
else None
|
||||
|
||||
Reference in New Issue
Block a user