diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 03d730e..8d93660 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -604,6 +604,13 @@ def _flash_attn_bwd( AtomLayoutMdQ = 1 cluster_size = 1 assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x" + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) + assert not is_varlen, "varlen backward is not yet supported on sm90" else: m_block_size = 128 n_block_size = 128 diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 1c2088d..c1f227d 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -709,6 +709,7 @@ def test_flash_attn_varlen_output( and not attention_chunk != 0 and dv == d and not has_learnable_sink + and not IS_SM90 # and False ): g_unpad = torch.randn_like(out_unpad) diff --git a/tests/cute/test_flash_attn_race_condition.py b/tests/cute/test_flash_attn_race_condition.py index c2a6490..cadb4a9 100644 --- a/tests/cute/test_flash_attn_race_condition.py +++ b/tests/cute/test_flash_attn_race_condition.py @@ -26,6 +26,7 @@ from flash_attn.cute.interface import ( flash_attn_varlen_func, flash_attn_combine, _flash_attn_bwd, + _get_device_capability, ) @@ -407,6 +408,11 @@ def test_flash_attn_varlen_output( local = local_enum > 0 if local and causal: pytest.skip() + is_sm90 = _get_device_capability() == 9 + if is_sm90 and local: + pytest.xfail("bwd local attention not supported on sm90") + if is_sm90 and deterministic: + pytest.xfail("bwd deterministic not supported on sm90") if ( causal or local ): # Right now reference only supports causal attention with seqlen_k == seqlen_q @@ -645,6 +651,7 @@ def test_flash_attn_varlen_output( and not attention_chunk != 0 and dv == d and not has_learnable_sink + and not is_sm90 # and False ): g_unpad = torch.randn_like(out_unpad) diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 37a68c3..438ac8a 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -277,6 +277,7 @@ def _run_mask_test( # SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling. sparse_tile_m_bwd = sparse_tile_m + tile_n_bwd = tile_n if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128): bm_bwd = create_block_mask( mask_mod_flex, @@ -301,6 +302,7 @@ def _run_mask_test( *_, ) = bm_bwd.as_tuple() sparse_tile_m_bwd = 128 + tile_n_bwd = 128 softmax_scale = 1.0 / math.sqrt(headdim) @@ -323,7 +325,7 @@ def _run_mask_test( mask_block_idx=q_mask_idx, full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, - block_size=(sparse_tile_m_bwd, tile_n), + block_size=(sparse_tile_m_bwd, tile_n_bwd), ) if use_block_sparsity else None