[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)

* varlen bwd with rounded padded offsets

* fix mha

* change offset mode to round down multiple

* enable varlen bwd tests

* enable deterministic mode

* fix deadlock and switch mha to no postprocess

* reenable tests

* fix lint error

* use head swizzle/spt for deterministic, update tests

* change padding offset based on arch

* rebase and update interface, tests

* add arch dispatch for padded offset q to postprocess

* address comments

* remove tile sizes from seqlen info class vars
This commit is contained in:
jayhshah
2026-01-09 15:24:29 -08:00
committed by GitHub
parent 6dd7e742df
commit ed6a82f050
12 changed files with 787 additions and 425 deletions
+7 -4
View File
@@ -325,9 +325,9 @@ for headdim in [128]:
else:
page_table = None
# for causal in [False, True]:
for causal in [True]:
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
for causal in [False, True]:
# for causal in [True]:
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###")
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
if cudnn is not None:
# if False:
@@ -395,7 +395,10 @@ for headdim in [128]:
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python')
if not varlen:
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
else:
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
# if False:
+1
View File
@@ -123,6 +123,7 @@ def cute_compile_patched(*args, **kwargs):
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
return output
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
+12 -278
View File
@@ -233,13 +233,15 @@ class FlashAttentionBackwardPostprocess:
TileScheduler = SingleTileVarlenScheduler
num_head = mdQ.shape[1]
num_batch = mCuSeqlensQ.shape[0] - 1
num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
else:
TileScheduler = SingleTileScheduler
num_head = mdQ.shape[2]
num_batch = mdQ.shape[0]
num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)
tile_sched_args = TileSchedulerArguments(
num_block=cute.ceil_div(mdQ.shape[1], self.tile_m),
num_block=num_block,
num_head=num_head,
num_batch=num_batch,
num_splits=1,
@@ -318,7 +320,7 @@ class FlashAttentionBackwardPostprocess:
tile_scheduler = TileScheduler.create(tile_sched_params)
work_tile = tile_scheduler.initial_work_tile_info()
m_block, num_head, batch_size, _ = work_tile.tile_idx
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
if work_tile.is_valid_tile:
# ///////////////////////////////////////////////////////////////////////////////
@@ -326,7 +328,7 @@ class FlashAttentionBackwardPostprocess:
# ///////////////////////////////////////////////////////////////////////////////
seqlen = SeqlenInfoQK.create(
batch_size,
batch_idx,
mdQ.shape[1],
0,
mCuSeqlensQ=mCuSeqlensQ,
@@ -335,14 +337,16 @@ class FlashAttentionBackwardPostprocess:
mSeqUsedK=None,
)
if const_expr(not seqlen.has_cu_seqlens_q):
mdQ_cur = mdQ[batch_size, None, num_head, None]
mdQaccum_cur = mdQaccum[batch_size, num_head, None]
mdQ_cur = mdQ[batch_idx, None, head_idx, None]
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
head_dim = mdQ.shape[3]
else:
padded_offset_q = seqlen.offset_q + batch_size * self.tile_m
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None])
padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
if cutlass.const_expr(self.arch >= 90):
padded_offset_q = padded_offset_q // self.tile_m * self.tile_m
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
mdQaccum_cur = cute.domain_offset(
(padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None]
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
)
head_dim = mdQ.shape[2]
@@ -457,273 +461,3 @@ class FlashAttentionBackwardPostprocess:
tdQgdQ[None, rest_m, None],
pred=tdQpdQ[None, rest_m, None],
)
class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess):
def __init__(
self,
dtype: Type[cutlass.Numeric],
head_dim: int,
tile_m: int = 128,
num_threads: int = 256,
AtomLayoutMdQ: int = 1,
dQ_swapAB: bool = False,
):
super().__init__(
dtype=dtype,
head_dim=head_dim,
arch=90, # tmp dummy placement for now
tile_m=tile_m,
num_threads=num_threads,
AtomLayoutMdQ=AtomLayoutMdQ,
dQ_swapAB=dQ_swapAB,
)
def _setup_attributes(self):
self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128
self.sdQaccum_layout = cute.make_layout(
shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32)
)
self.epi_tile_q = (self.tile_m, self.tile_hdim)
self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
self.dtype,
LayoutEnum.ROW_MAJOR,
self.epi_tile_q,
1,
)
@cute.jit
def __call__(
self,
mdQaccum: cute.Tensor,
mdQ: cute.Tensor,
scale: cutlass.Float32,
stream: cuda.CUstream,
):
# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mdQaccum, mdQ = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mdQaccum, mdQ)
]
# (b, h, s*d) -> (s*d, h, b)
mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0]))
# (b, s, h, d) -> (s, d, h, b)
mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1, 3, 2, 0]))
self._setup_attributes()
grid_dim = [
cute.ceil_div(mdQ.shape[0], self.tile_m),
cute.size(mdQ.shape[2]),
cute.size(mdQ.shape[3]),
]
cta_group = tcgen05.CtaGroup.ONE
self.mma_tiler_dsk = (self.tile_m, self.tile_hdim)
dS_major_mode = tcgen05.OperandMajorMode.MN
kt_major_mode_dsq = tcgen05.OperandMajorMode.MN
tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma(
cutlass.BFloat16,
dS_major_mode,
kt_major_mode_dsq,
cutlass.Float32,
cta_group,
self.mma_tiler_dsk,
)
dQ_cta_v_layout = cute.composition(cute.make_identity_layout(mdQ.shape), self.mma_tiler_dsk)
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
tma_atom_dQ, tma_tensor_dQ = cute.nvgpu.cpasync.make_tiled_tma_atom(
tma_store_op,
mdQ,
cute.select(self.sdQ_layout, mode=[0, 1]),
dQ_cta_v_layout,
)
buffer_align_bytes = 1024
@cute.struct
class SharedStorage:
sdQaccum: cute.struct.Align[
cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)],
128,
]
sdQ: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)],
buffer_align_bytes,
]
self.shared_storage = SharedStorage
self.kernel(
mdQaccum,
tma_tensor_dQ,
tma_atom_dQ,
self.sdQaccum_layout,
self.sdQ_layout,
tiled_mma_dsk,
scale,
).launch(
grid=grid_dim,
block=[self.num_threads, 1, 1],
smem=self.shared_storage.size_in_bytes(),
stream=stream,
)
@cute.kernel
def kernel(
self,
mdQaccum: cute.Tensor,
mdQ: cute.Tensor,
tma_atom_dQ: cute.CopyAtom,
sdQaccum_layout: cute.Layout,
sdQ_layout: cute.ComposedLayout,
tiled_mma_dsk: cute.TiledMma,
scale: cutlass.Float32,
):
tidx = cute.arch.thread_idx()[0]
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
m_block, head_idx, batch_idx = cute.arch.block_idx()
# SMEM
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
swz128 = cute.make_swizzle(3, 4, 3)
sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128)
sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner)
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
mdQ_cur = mdQ[None, None, head_idx, batch_idx]
thr_mma_dsk = tiled_mma_dsk.get_slice(tidx)
dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2])
tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape)
tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout)
tmem_ld_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32
)
tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ)
thr_tmem_ld = tiled_tmem_ld.get_slice(tidx)
cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1]))
tdQcdQ = thr_mma_dsk.partition_C(cdQ)
tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout)
tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor)
gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))
num_reduce_warps = 4
num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps
atom_universal_copy = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128
)
tiler_mn, layout_tv = cute.make_layout_tv(
thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1),
val_layout=cute.make_layout(shape=4, stride=1),
)
G2S_tiled_copy_dQaccum = cute.make_tiled_copy(
atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn
)
smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx)
# S->R
tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32)
tiled_smem_store_s2r = cute.make_tiled_copy(
atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn
)
s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx)
tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
tdQrdQ_s2r = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_t2r.shape)
# R->S
smem_copy_atom = sm100_utils_basic.get_smem_store_op(
LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld
)
tiled_smem_store_r2s = cute.make_tiled_copy(
smem_copy_atom,
layout_tv=tiled_tmem_ld.layout_dst_tv_tiled,
tiler_mn=tiled_tmem_ld.tiler_mn,
)
tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ))
tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype)
num_stages = cute.size(tdQrdQ_t2r, mode=[1])
for stage in cutlass.range_constexpr(num_stages):
# G->S
gdQaccum_stage = cute.local_tile(
gdQaccum,
(self.tile_m * 32,),
(stage,),
)
gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0))
gdQaccum_stage_g2s = cute.make_tensor(
cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s
)
tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s)
tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum)
cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0])
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads)
# S -> R
tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None]
tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0]
tdQrdQ_r2s_cpy = cute.make_tensor(
tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape)
)
cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads)
# R->S
tdQrdQ_r2s_cpy = cute.make_tensor(
cute.recast_ptr(tdQrdQ_r2s_cpy.iterator),
tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape,
)
dQ_vec = tdQrdQ_r2s_cpy.load() * scale
tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype))
cute.copy(
tiled_smem_store_r2s,
tdQrdQ_r2s[None, None, None, None, 0],
tdQsdQ_r2s[None, None, None, None, 0],
)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads)
# S-> G
gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
tdQsdQ, tdQgdQ = cpasync.tma_partition(
tma_atom_dQ,
0,
cute.make_layout(1),
cute.group_modes(sdQ, 0, 2),
cute.group_modes(gdQ, 0, 2),
)
cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block])
+20 -18
View File
@@ -3,7 +3,7 @@
# from Cutlass C++ to Cute-DSL.
import math
import operator
from typing import Callable, Type, Optional
from typing import Callable, Type, Optional, Literal
import cuda.bindings.driver as cuda
@@ -27,6 +27,7 @@ class FlashAttentionBackwardPreprocess:
self,
dtype: Type[cutlass.Numeric],
head_dim: int,
arch: Literal[80, 90, 100],
m_block_size: int = 128,
num_threads: int = 128,
):
@@ -43,6 +44,7 @@ class FlashAttentionBackwardPreprocess:
"""
self.dtype = dtype
self.m_block_size = m_block_size
self.arch = arch
# padding head_dim to a multiple of 32 as k_block_size
hdim_multiple_of = 32
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
@@ -213,14 +215,14 @@ class FlashAttentionBackwardPreprocess:
tile_scheduler = TileScheduler.create(tile_sched_params)
work_tile = tile_scheduler.initial_work_tile_info()
m_block, num_head, batch_size, _ = work_tile.tile_idx
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
if work_tile.is_valid_tile:
# ///////////////////////////////////////////////////////////////////////////////
# Get the appropriate tiles for this thread block.
# ///////////////////////////////////////////////////////////////////////////////
seqlen = SeqlenInfoQK.create(
batch_size,
batch_idx,
mO.shape[1],
0,
mCuSeqlensQ=mCuSeqlensQ,
@@ -230,16 +232,18 @@ class FlashAttentionBackwardPreprocess:
)
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
mO_cur = mO[batch_size, None, num_head, None]
mdO_cur = mdO[batch_size, None, num_head, None]
mdPsum_cur = mdPsum[batch_size, num_head, None]
mO_cur = mO[batch_idx, None, head_idx, None]
mdO_cur = mdO[batch_idx, None, head_idx, None]
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
headdim_v = mO.shape[3]
else:
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, num_head, None])
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, num_head, None])
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None])
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
if cutlass.const_expr(self.arch >= 90):
padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
headdim_v = mO.shape[2]
blkOdO_shape = (self.m_block_size, self.head_dim_padded)
@@ -268,9 +272,9 @@ class FlashAttentionBackwardPreprocess:
if cutlass.const_expr(mLSE is not None):
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
mLSE_cur = mLSE[batch_size, num_head, None]
mLSE_cur = mLSE[batch_idx, head_idx, None]
else:
mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[num_head, None])
mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
lse = Float32.inf
@@ -323,11 +327,10 @@ class FlashAttentionBackwardPreprocess:
# Clear dQaccum
if cutlass.const_expr(mdQaccum is not None):
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
mdQaccum_cur = mdQaccum[batch_size, num_head, None]
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
else:
padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size
mdQaccum_cur = cute.domain_offset(
(padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None]
(padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
)
# HACK: Compiler doesn't seem to recognize that padding
@@ -352,10 +355,9 @@ class FlashAttentionBackwardPreprocess:
if cutlass.const_expr(mLSE is not None):
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
mLSElog2_cur = mLSElog2[batch_size, num_head, None]
mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
else:
padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size
mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None])
mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
LOG2_E = math.log2(math.e)
+128 -60
View File
@@ -25,6 +25,7 @@ from flash_attn.cute.tile_scheduler import (
TileSchedulerArguments,
SingleTileScheduler,
SingleTileLPTBwdScheduler, # noqa
SingleTileVarlenScheduler,
ParamsBase,
)
@@ -78,7 +79,7 @@ class FlashAttentionBackwardSm100:
self.tile_n = tile_n
# CTA tiler
self.cta_tiler = (tile_m, tile_n, self.tile_hdim)
self.cta_tiler = (tile_n, tile_m, self.tile_hdim)
# S = K @ Q.T
self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim)
# dP = V @ dO.T
@@ -99,7 +100,6 @@ class FlashAttentionBackwardSm100:
self.is_local = is_local
self.qhead_per_kvhead = qhead_per_kvhead
self.pack_gqa = False
self.use_tma_store = True
self.deterministic = deterministic
# Score mod and mask mod support
@@ -353,7 +353,7 @@ class FlashAttentionBackwardSm100:
self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1])
self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages
# TODO: dK and dV could have different shapes
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi(
self.dk_dtype,
LayoutEnum.ROW_MAJOR,
@@ -391,9 +391,6 @@ class FlashAttentionBackwardSm100:
# Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
blocksparse_tensors: Optional[BlockSparseTensors] = None,
):
assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
"Variable sequence length is not supported yet in FlashAttentionBackwardSm100"
)
self.q_dtype = mQ.element_type
self.k_dtype = mK.element_type
self.v_dtype = mV.element_type
@@ -405,7 +402,12 @@ class FlashAttentionBackwardSm100:
self.dv_dtype = mdV.element_type
self.ds_dtype = self.q_dtype
if const_expr(self.qhead_per_kvhead > 1):
self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None)
self.dKV_postprocess = self.qhead_per_kvhead > 1
if const_expr(self.dKV_postprocess):
assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA"
@@ -429,21 +431,30 @@ class FlashAttentionBackwardSm100:
)
]
layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
mQ, mK, mV, mdO = [utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)]
LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b)
# (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
mQ, mdO = [utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)]
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
mK, mV = [utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)]
# (b, n, s) --> (s, n, b) or (n, t) --> (t, n)
LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
mLSE, mdPsum, mdQaccum = [
utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
]
if const_expr(self.qhead_per_kvhead == 1):
layout_dKV_transpose = layout_transpose
if const_expr(not self.dKV_postprocess):
layout_dKV_transpose = KV_layout_transpose
else:
layout_dKV_transpose = LSE_dPsum_dQaccum_transpose
mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)]
dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, b)
# (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b)
dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2]
mdO = utils.select(mdO, mode=dO_transpose)
semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b)
# (b, n, block, stage) -> (block, stage, n, b)
semaphore_transpose = [2, 3, 1, 0]
if const_expr(self.deterministic):
assert mdQ_semaphore is not None
mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose)
@@ -478,7 +489,7 @@ class FlashAttentionBackwardSm100:
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_q_do_mcast = self.num_mcast_ctas_b > 1
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)
self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)
dK_major_mode = self.mdK_layout_enum.mma_major_mode()
@@ -488,7 +499,7 @@ class FlashAttentionBackwardSm100:
if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):
raise RuntimeError("The layout of mdV is wrong")
if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1):
if const_expr(self.use_tma_store and not self.dKV_postprocess):
tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()
tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom(
tma_copy_op_dKV,
@@ -510,7 +521,7 @@ class FlashAttentionBackwardSm100:
tma_atom_dV = None
tma_atom_dK = None
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads
val_layout_r2s_dKV = cute.make_ordered_layout(
(1, 128 // self.dk_dtype.width), order=(1, 0)
@@ -589,29 +600,36 @@ class FlashAttentionBackwardSm100:
self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8
# TileScheduler = SingleTileScheduler
if const_expr(self.deterministic):
if const_expr(self.is_varlen_k):
TileScheduler = SingleTileVarlenScheduler
elif const_expr(self.deterministic):
TileScheduler = SingleTileLPTBwdScheduler
else:
TileScheduler = SingleTileScheduler
# reads n_blocks right-to-left
self.spt = (self.is_causal or self.is_local) and self.deterministic
tile_sched_args = TileSchedulerArguments(
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]),
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks
cute.size(mQ.shape[2]), # num_heads = num_query_heads
cute.size(mK.shape[3]),
cute.size(mK.shape[3])
if const_expr(mCuSeqlensK is None)
else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches
1, # num_splits
cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k
mQ.shape[1],
mV.shape[1],
total_q=cute.size(mQ.shape[0]),
tile_shape_mn=self.cta_tiler[:2],
cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k
mQ.shape[1], # headdim
mV.shape[1], # headdim_v
total_q=cute.size(mK.shape[0]) # pass total_k for total_q
if const_expr(mCuSeqlensK is not None)
else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),
tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m)
cluster_shape_mn=self.cluster_shape_mnk[:2],
mCuSeqlensQ=None,
mSeqUsedQ=None,
qhead_per_kvhead_packgqa=1,
mCuSeqlensQ=mCuSeqlensK,
mSeqUsedQ=mSeqUsedK,
qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd
element_size=self.k_dtype.width // 8,
is_persistent=self.is_persistent,
is_persistent=self.is_persistent, # persistent mode not tested
lpt=self.spt,
head_swizzle=self.deterministic,
)
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
@@ -718,6 +736,11 @@ class FlashAttentionBackwardSm100:
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
if const_expr(self.use_block_sparsity or aux_tensors is not None):
assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
"Variable sequence length is not supported yet for blocksparse or aux tensors in bwd"
)
self.kernel(
tma_tensor_Q,
tma_tensor_K,
@@ -733,6 +756,10 @@ class FlashAttentionBackwardSm100:
mdQ_semaphore,
mdK_semaphore,
mdV_semaphore,
mCuSeqlensQ,
mCuSeqlensK,
mSeqUsedQ,
mSeqUsedK,
tma_atom_Q,
tma_atom_K,
tma_atom_V,
@@ -794,6 +821,10 @@ class FlashAttentionBackwardSm100:
mdQ_semaphore: Optional[cute.Tensor],
mdK_semaphore: Optional[cute.Tensor],
mdV_semaphore: Optional[cute.Tensor],
mCuSeqlensQ: Optional[cute.Tensor],
mCuSeqlensK: Optional[cute.Tensor],
mSeqUsedQ: Optional[cute.Tensor],
mSeqUsedK: Optional[cute.Tensor],
tma_atom_Q: cute.CopyAtom,
tma_atom_K: cute.CopyAtom,
tma_atom_V: cute.CopyAtom,
@@ -986,7 +1017,7 @@ class FlashAttentionBackwardSm100:
)
sLSE = storage.sLSE.get_tensor(sLSE_layout)
sdPsum = storage.sdPsum.get_tensor(sdPsum_layout)
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
sdV = storage.sdO.get_tensor(
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype
)
@@ -1054,10 +1085,12 @@ class FlashAttentionBackwardSm100:
SeqlenInfoQK.create,
seqlen_q_static=mQ.shape[0],
seqlen_k_static=mK.shape[0],
mCuSeqlensQ=None,
mCuSeqlensK=None,
mSeqUsedQ=None,
mSeqUsedK=None,
mCuSeqlensQ=mCuSeqlensQ,
mCuSeqlensK=mCuSeqlensK,
mSeqUsedQ=mSeqUsedQ,
mSeqUsedK=mSeqUsedK,
tile_m=self.tile_m,
tile_n=self.tile_n,
)
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
@@ -1294,12 +1327,17 @@ class FlashAttentionBackwardSm100:
seqlen, n_block // self.cluster_shape_mnk[0]
)
head_idx_kv = head_idx // self.qhead_per_kvhead
mQ_cur = mQ[None, None, head_idx, batch_idx]
mK_cur = mK[None, None, head_idx_kv, batch_idx]
mV_cur = mV[None, None, head_idx_kv, batch_idx]
mdO_cur = mdO[None, None, head_idx, batch_idx]
mLSE_cur = mLSE[None, head_idx, batch_idx]
mPsum_cur = mdPsum[None, head_idx, batch_idx]
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
if const_expr(not seqlen.has_cu_seqlens_q):
mdO_cur = mdO[None, None, head_idx, batch_idx]
else:
mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx])
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx]
mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[
None, head_idx
]
gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0))
tSgK = thr_mma_S.partition_A(gK)
@@ -1308,7 +1346,7 @@ class FlashAttentionBackwardSm100:
gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0))
tSgQ = thr_mma_S.partition_B(gQ)
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
gdPsum = cute.local_tile(mPsum_cur, (self.tile_m,), (None,))
gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None))
tdPgdO = thr_mma_dV.partition_B(gdO)
@@ -1363,7 +1401,10 @@ class FlashAttentionBackwardSm100:
)
process_tile = total_m_block_cnt > Int32(0)
else:
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
process_tile = (
const_expr(not self.is_local and not self.is_varlen_q)
or m_block_min < m_block_max
)
if process_tile:
if const_expr(self.use_block_sparsity):
@@ -1616,7 +1657,10 @@ class FlashAttentionBackwardSm100:
process_tile = block_iter_count > Int32(0)
else:
block_iter_count = m_block_max - m_block_min
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
process_tile = (
const_expr(not self.is_local and not self.is_varlen_q)
or m_block_min < m_block_max
)
if process_tile:
accumulate_dK = False
@@ -2055,7 +2099,10 @@ class FlashAttentionBackwardSm100:
)
process_tile = loop_count > Int32(0)
else:
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
process_tile = (
const_expr(not self.is_local and not self.is_varlen_q)
or m_block_min < m_block_max
)
loop_count = m_block_max - m_block_min
# Mainloop
@@ -2271,6 +2318,7 @@ class FlashAttentionBackwardSm100:
batch_idx,
head_idx,
n_block,
seqlen,
thr_mma_dV,
thr_mma_dK,
tdVtdV,
@@ -2289,6 +2337,7 @@ class FlashAttentionBackwardSm100:
batch_idx,
head_idx,
n_block,
seqlen,
thr_mma_dV,
tdVtdV,
mdV_tma_tensor,
@@ -2307,6 +2356,7 @@ class FlashAttentionBackwardSm100:
batch_idx,
head_idx,
n_block,
seqlen,
thr_mma_dK,
tdKtdK,
mdK_tma_tensor,
@@ -2315,15 +2365,15 @@ class FlashAttentionBackwardSm100:
thr_copy_r2s_dKV,
pipeline_dKV,
consumer_state_dKV,
softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None,
softmax_scale if const_expr(not self.dKV_postprocess) else None,
int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
mdK_semaphore,
)
# Zero dK/dV for empty tiles (local attention or block sparsity)
# When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
should_zero_dKV = False
if const_expr(self.is_local):
if const_expr(self.is_local or seqlen.has_cu_seqlens_q):
should_zero_dKV = m_block_min >= m_block_max
if const_expr(self.use_block_sparsity):
# For block sparsity, zero when no m_blocks contribute to this n_block
@@ -2338,8 +2388,8 @@ class FlashAttentionBackwardSm100:
128, # num_threads
)
gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx)
mdV_cur = mdV[None, None, head_idx, batch_idx]
mdK_cur = mdK[None, None, head_idx, batch_idx]
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK)
@@ -2415,7 +2465,12 @@ class FlashAttentionBackwardSm100:
m_block_min, m_block_max = block_info.get_m_block_min_max(
seqlen, n_block // self.cluster_shape_mnk[0]
)
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
if const_expr(not seqlen.has_cu_seqlens_q):
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
else:
mdQaccum_cur = cute.domain_offset(
(seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]
)
gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
# (M * K / STAGE, STAGE, _)
gdQaccum = cute.flat_divide(
@@ -2446,7 +2501,10 @@ class FlashAttentionBackwardSm100:
)
process_tile = loop_count > Int32(0)
else:
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
process_tile = (
const_expr(not self.is_local and not self.is_varlen_q)
or m_block_min < m_block_max
)
loop_count = m_block_max - m_block_min
# dQacc_reduce mainloop
@@ -2580,6 +2638,7 @@ class FlashAttentionBackwardSm100:
batch_idx: Int32,
head_idx: Int32,
n_block: Int32,
seqlen,
thr_mma_dV: cute.core.ThrMma,
thr_mma_dK: cute.core.ThrMma,
tdVtdV: cute.Tensor,
@@ -2596,8 +2655,8 @@ class FlashAttentionBackwardSm100:
num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128
assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA"
mdV_cur = mdV[None, None, head_idx, batch_idx]
mdK_cur = mdK[None, None, head_idx, batch_idx]
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32
@@ -2647,7 +2706,8 @@ class FlashAttentionBackwardSm100:
tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV)
tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg)
cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)
if tidx < seqlen.seqlen_k - self.tile_n * n_block:
cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)
cute.arch.sync_warp()
with cute.arch.elect_one():
@@ -2700,7 +2760,8 @@ class FlashAttentionBackwardSm100:
tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK)
tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg)
cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)
if tidx < seqlen.seqlen_k - self.tile_n * n_block:
cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)
cute.arch.sync_warp()
with cute.arch.elect_one():
@@ -2715,6 +2776,7 @@ class FlashAttentionBackwardSm100:
batch_idx: Int32,
head_idx: Int32,
n_block: Int32,
seqlen,
thr_mma: cute.core.ThrMma,
tdKVtdKV: cute.Tensor,
mdKV: cute.Tensor,
@@ -2734,7 +2796,7 @@ class FlashAttentionBackwardSm100:
num_wg = num_compute_threads // 128
leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16
else:
sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32
@@ -2743,7 +2805,8 @@ class FlashAttentionBackwardSm100:
tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)
head_idx_kv = head_idx // self.qhead_per_kvhead
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path"
mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim)
gdKV_p = cute.local_tile(
mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)
@@ -2753,7 +2816,12 @@ class FlashAttentionBackwardSm100:
gdKV, self.sdKV_epi_tile, (0, None)
) # (tile_n, 64, epi_stage = (hdim / 2) / 64)
else:
mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
if const_expr(not seqlen.has_cu_seqlens_k):
mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
else:
mdKV_cur = cute.domain_offset(
(seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv]
)
gdKV_p = cute.local_tile(
mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,)
) # (tile_n * hdim)
@@ -2768,7 +2836,7 @@ class FlashAttentionBackwardSm100:
if const_expr(deterministic_KV):
mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
tdKVsdKV, tdKVgdKV = cpasync.tma_partition(
tma_atom_dKV,
0, # no multicast
@@ -2842,7 +2910,7 @@ class FlashAttentionBackwardSm100:
# SMEM -> GMEM
if leader_warp:
if const_expr(self.qhead_per_kvhead == 1):
if const_expr(not self.dKV_postprocess):
cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage])
else:
with cute.arch.elect_one():
+85 -27
View File
@@ -92,6 +92,8 @@ def _flash_attn_fwd(
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
@@ -115,8 +117,6 @@ def _flash_attn_fwd(
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
aux_tensors: Optional[list[torch.Tensor]] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for FlashAttention.
@@ -569,6 +569,8 @@ def _flash_attn_bwd(
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
deterministic: bool = False,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
@@ -615,16 +617,19 @@ def _flash_attn_bwd(
total_q = batch_size * seqlen_q
else:
batch_size = cu_seqlens_q.shape[0] - 1
seqlen_q = None
total_q = q.shape[0]
seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q
if cu_seqlens_k is None:
batch_size, seqlen_k = k.shape[:2]
total_k = batch_size * seqlen_k
else:
batch_size = cu_seqlens_k.shape[0] - 1
seqlen_k = None
total_k = k.shape[0]
seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
num_head_kv = k.shape[-2]
head_dim_v = v.shape[-1]
@@ -724,7 +729,6 @@ def _flash_attn_bwd(
head_dim_rounded = (head_dim + 32 - 1) // 32 * 32
if cu_seqlens_q is None:
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
dq_accum = torch.empty(
batch_size,
num_head,
@@ -748,10 +752,10 @@ def _flash_attn_bwd(
dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
if qhead_per_kvhead > 1:
dKV_postprocess = qhead_per_kvhead > 1
if dKV_postprocess:
head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
if cu_seqlens_k is None:
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
num_n_blocks = seqlen_k_rounded // n_block_size
if cluster_size == 2 and num_n_blocks % cluster_size != 0:
seqlen_k_rounded = seqlen_k_rounded + n_block_size
@@ -805,7 +809,15 @@ def _flash_attn_bwd(
dV_semaphore = None
# Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads)
compile_key_pre = (
compute_capability,
dtype,
head_dim_v,
m_block_size,
num_threads,
cu_seqlens_q is None,
seqused_q is None,
)
if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
@@ -816,9 +828,11 @@ def _flash_attn_bwd(
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_q, seqused_q)
]
arch = compute_capability * 10
fa_bwd_pre = FlashAttentionBackwardPreprocess(
dtype,
head_dim_v,
arch,
m_block_size,
num_threads=num_threads,
)
@@ -871,6 +885,10 @@ def _flash_attn_bwd(
AtomLayoutNdKV,
AtomLayoutMdQ,
V_in_regs,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
cute_aux_tensors = None
else:
@@ -904,6 +922,10 @@ def _flash_attn_bwd(
mask_mod_hash,
num_aux_tensors,
use_block_sparsity,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
num_threads = 384
if compile_key not in _flash_attn_bwd.compile_cache:
@@ -913,7 +935,7 @@ def _flash_attn_bwd(
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
]
if qhead_per_kvhead > 1:
if dKV_postprocess:
dk_accum_tensor, dv_accum_tensor = [
to_cute_tensor(t) for t in (dk_accum, dv_accum)
]
@@ -1011,8 +1033,8 @@ def _flash_attn_bwd(
lse_log2_tensor,
dpsum_tensor,
dq_accum_tensor,
dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor,
dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor,
dk_tensor if not dKV_postprocess else dk_accum_tensor,
dv_tensor if not dKV_postprocess else dv_accum_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
@@ -1049,8 +1071,8 @@ def _flash_attn_bwd(
lse_log2,
dpsum,
dq_accum,
dk if qhead_per_kvhead == 1 else dk_accum,
dv if qhead_per_kvhead == 1 else dv_accum,
dk if not dKV_postprocess else dk_accum,
dv if not dKV_postprocess else dv_accum,
softmax_scale,
current_stream,
cu_seqlens_q,
@@ -1069,7 +1091,19 @@ def _flash_attn_bwd(
num_threads = 256 if compute_capability == 9 else 128
# Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB)
compile_key_post = (
compute_capability,
dtype,
head_dim,
m_block_size,
num_threads,
AtomLayoutMdQ,
dQ_swapAB,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dq_accum_tensor = to_cute_tensor(dq_accum)
dq_tensor = to_cute_tensor(dq)
@@ -1101,9 +1135,21 @@ def _flash_attn_bwd(
current_stream,
)
if qhead_per_kvhead > 1:
if dKV_postprocess:
# Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
compile_key_post = (dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB)
compile_key_post = (
compute_capability,
dtype,
head_dim,
n_block_size,
num_threads,
AtomLayoutNdKV,
dKV_swapAB,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dk_accum_tensor = to_cute_tensor(dk_accum)
dk_tensor = to_cute_tensor(dk)
@@ -1111,8 +1157,9 @@ def _flash_attn_bwd(
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_k, seqused_k)
]
arch = compute_capability * 10
fa_bwd_post = FlashAttentionBackwardPostprocess(
dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
)
# TODO: check @can_implement
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
@@ -1134,12 +1181,17 @@ def _flash_attn_bwd(
current_stream,
)
compile_key_post = (
compute_capability,
dtype,
head_dim_v,
n_block_size,
num_threads,
AtomLayoutNdKV,
dKV_swapAB,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
)
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
dv_accum_tensor = to_cute_tensor(dv_accum)
@@ -1148,8 +1200,9 @@ def _flash_attn_bwd(
to_cute_tensor(t, assumed_align=4) if t is not None else None
for t in (cu_seqlens_k, seqused_k)
]
arch = compute_capability * 10
fa_bwd_post = FlashAttentionBackwardPostprocess(
dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
)
# TODO: check @can_implement
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
@@ -1263,6 +1316,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
cu_seqlens_k: Optional[torch.Tensor],
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
@@ -1274,8 +1329,6 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
deterministic: bool = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
):
out, lse = _flash_attn_fwd(
q,
@@ -1285,6 +1338,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
page_table=page_table,
softmax_scale=softmax_scale,
causal=causal,
@@ -1296,8 +1351,6 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
pack_gqa=pack_gqa,
score_mod=score_mod,
aux_tensors=aux_tensors,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.softmax_scale = softmax_scale
@@ -1305,12 +1358,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
return out, lse
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
assert seqused_q == seqused_k == None
assert ctx.softcap == 0.0
dq, dk, dv = _flash_attn_bwd(
q,
@@ -1322,10 +1376,14 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.softcap,
window_size_left=ctx.window_size[0],
window_size_right=ctx.window_size[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_k=ctx.max_seqlen_k,
deterministic=ctx.deterministic,
)
@@ -1376,6 +1434,8 @@ def flash_attn_varlen_func(
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
@@ -1389,8 +1449,6 @@ def flash_attn_varlen_func(
deterministic: bool = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
):
return FlashAttnVarlenFunc.apply(
q,
@@ -1400,6 +1458,8 @@ def flash_attn_varlen_func(
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
page_table,
softmax_scale,
causal,
@@ -1411,8 +1471,6 @@ def flash_attn_varlen_func(
deterministic,
score_mod,
aux_tensors,
max_seqlen_q,
max_seqlen_k,
)
+34 -6
View File
@@ -38,6 +38,8 @@ class SeqlenInfo:
class SeqlenInfoQK:
offset_q: cutlass.Int32
offset_k: cutlass.Int32
padded_offset_q: cutlass.Int32
padded_offset_k: cutlass.Int32
seqlen_q: cutlass.Int32
seqlen_k: cutlass.Int32
has_cu_seqlens_q: cutlass.Constexpr[bool]
@@ -54,9 +56,21 @@ class SeqlenInfoQK:
mCuSeqlensK: Optional[cute.Tensor] = None,
mSeqUsedQ: Optional[cute.Tensor] = None,
mSeqUsedK: Optional[cute.Tensor] = None,
tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
):
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
padded_offset_q = (
0
if const_expr(mCuSeqlensQ is None)
else (offset_q + batch_idx * tile_m) // tile_m * tile_m
)
padded_offset_k = (
0
if const_expr(mCuSeqlensK is None)
else (offset_k + batch_idx * tile_n) // tile_n * tile_n
)
if const_expr(mSeqUsedQ is not None):
seqlen_q = mSeqUsedQ[batch_idx]
else:
@@ -80,6 +94,8 @@ class SeqlenInfoQK:
return SeqlenInfoQK(
offset_q,
offset_k,
padded_offset_q,
padded_offset_k,
seqlen_q,
seqlen_k,
has_cu_seqlens_q,
@@ -88,23 +104,35 @@ class SeqlenInfoQK:
has_seqused_k,
)
def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor:
def offset_batch_Q(
self,
mQ: cute.Tensor,
batch_idx: Int32,
dim: int,
padded: cutlass.Constexpr[bool] = False,
) -> cute.Tensor:
"""Seqlen must be the first dimension of mQ"""
if const_expr(not self.has_cu_seqlens_q):
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
return mQ[idx]
else:
offset = (
self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q)
)
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
return cute.domain_offset(idx, mQ)
def offset_batch_K(self, mK: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor:
def offset_batch_K(
self,
mK: cute.Tensor,
batch_idx: Int32,
dim: int,
padded: cutlass.Constexpr[bool] = False,
) -> cute.Tensor:
"""Seqlen must be the first dimension of mK"""
if const_expr(not self.has_cu_seqlens_k):
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
return mK[idx]
else:
idx = (self.offset_k,) + (0,) * (cute.rank(mK) - 1)
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
return cute.domain_offset(idx, mK)
+6 -1
View File
@@ -92,7 +92,12 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random",
device=device,
)
else:
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
lengths = torch.randint(
max(0 if zero_lengths else 1, max_seqlen // 3),
max_seqlen + 1,
(batch_size, 1),
device=device,
)
if zero_lengths:
for i in range(batch_size):
+6 -2
View File
@@ -72,6 +72,7 @@ class TileSchedulerArguments(ParamsBase):
is_persistent: cutlass.Constexpr[bool] = False
lpt: cutlass.Constexpr[bool] = False
is_split_kv: cutlass.Constexpr[bool] = False
head_swizzle: cutlass.Constexpr[bool] = False
class SingleTileScheduler:
@@ -512,6 +513,7 @@ class SingleTileVarlenScheduler:
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
lpt: cutlass.Constexpr[bool] = False
is_split_kv: cutlass.Constexpr[bool] = False
head_swizzle: cutlass.Constexpr[bool] = False
@staticmethod
@cute.jit
@@ -537,6 +539,7 @@ class SingleTileVarlenScheduler:
qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
lpt=args.lpt,
is_split_kv=args.is_split_kv,
head_swizzle=args.head_swizzle,
)
def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
@@ -638,7 +641,7 @@ class SingleTileVarlenScheduler:
)
num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
if cutlass.const_expr(params.lpt):
if cutlass.const_expr(params.lpt or params.head_swizzle):
# This is a version of the SingleTileLPTScheduler, complicated by the fact that
# the seqlen can vary per batch.
# TODO: is there any case where num_m_blocks is 0?
@@ -677,7 +680,8 @@ class SingleTileVarlenScheduler:
block = l2_mod // nheads_in_this_section
head_idx_residual = l2_mod - block * nheads_in_this_section
head_idx = section_idx * nheads_in_l2 + head_idx_residual
block = num_m_blocks - 1 - block
if cutlass.const_expr(params.lpt):
block = num_m_blocks - 1 - block
else:
head_idx = mh_block // num_m_blocks
block = mh_block - head_idx * num_m_blocks
+51 -24
View File
@@ -50,7 +50,7 @@ VERBOSE = True
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
# @pytest.mark.parametrize("local_enum", [0])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@@ -58,9 +58,9 @@ VERBOSE = True
# @pytest.mark.parametrize("d", [64, 128, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128, 192])
# @pytest.mark.parametrize("d", [64, 128])
# @pytest.mark.parametrize("d", [128, 192])
@pytest.mark.parametrize("d", [64, 128])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
@@ -113,7 +113,7 @@ def test_flash_attn_output(
torch.cuda.empty_cache()
torch.cuda.synchronize()
batch_size = 9 if seqlen_k <= 2048 else 2
# batch_size = 1
# batch_size = 2
nheads = 6
# nheads = 1
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
@@ -236,7 +236,7 @@ def test_flash_attn_output(
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# num_splits_vals = [1, 3]
pack_gqa_vals = [False, True, None]
pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False]
# SplitKV is not supported for hdim >= 192
# pack_gqa_vals = [False]
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
@@ -371,17 +371,17 @@ def test_flash_attn_output(
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mqa"])
@pytest.mark.parametrize("has_learnable_sink", [False, True])
# @pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("mha_type", ["mha"])
# @pytest.mark.parametrize("has_learnable_sink", [False, True])
@pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("has_qv", [False, True])
@pytest.mark.parametrize("has_qv", [False])
# @pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("deterministic", [False])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
# @pytest.mark.parametrize("local_enum", [0])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
@@ -393,7 +393,7 @@ def test_flash_attn_output(
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
# @pytest.mark.parametrize("d", [128, 192])
@pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
@@ -419,20 +419,37 @@ def test_flash_attn_output(
(2048, 2048),
],
)
@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"])
# @pytest.mark.parametrize("varlen_mode", ["full"])
@pytest.mark.parametrize(
"zero_lengths_q, zero_lengths_k",
[
(False, False),
(True, False),
(False, True),
(True, True),
],
)
def test_flash_attn_varlen_output(
seqlen_q,
seqlen_k,
d,
add_unused_qkv,
causal,
local,
local_enum,
softcap,
deterministic,
has_qv,
has_learnable_sink,
mha_type,
dtype,
varlen_mode,
zero_lengths_q,
zero_lengths_k,
):
local = local_enum > 0
if local and causal:
pytest.skip()
if (
causal or local
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
@@ -442,13 +459,12 @@ def test_flash_attn_varlen_output(
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
batch_size = 49 if seqlen_q <= 1024 else 7
nheads = 6
# batch_size = 1
# nheads = 1
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
if dtype == torch.float8_e4m3fn:
if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY:
dv_vals = [d]
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]
attention_chunk_vals = [0]
@@ -490,6 +506,12 @@ def test_flash_attn_varlen_output(
window_size = (
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
)
if local_enum == 2:
window_size = (None, window_size[1])
elif local_enum == 3:
window_size = (window_size[0], None)
if local:
print("window size = ", window_size)
if has_learnable_sink:
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
@@ -505,18 +527,19 @@ def test_flash_attn_varlen_output(
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
qv = qv_ref.detach() if has_qv else None
query_padding_mask = generate_random_padding_mask(
seqlen_q, batch_size, device, mode="random", zero_lengths=False
seqlen_q,
batch_size,
device,
mode=varlen_mode,
zero_lengths=zero_lengths_q,
)
# TODO: test zero_lengths
key_padding_mask = generate_random_padding_mask(
# seqlen_k, batch_size, device, mode="random", zero_lengths=True
seqlen_k,
batch_size,
device,
mode="random",
zero_lengths=False,
mode=varlen_mode,
zero_lengths=zero_lengths_k,
)
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
if add_unused:
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
@@ -570,6 +593,8 @@ def test_flash_attn_varlen_output(
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask,
)
print("cu_seqlens_q = ", cu_seqlens_q)
print("cu_seqlens_k = ", cu_seqlens_k)
q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
]
@@ -619,11 +644,11 @@ def test_flash_attn_varlen_output(
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2 if softcap == 0.0 else 3
pack_gqa_vals = [False, True, None]
pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False]
# pack_gqa_vals = [False]
# num_splits_vals = [1, 3]
# SplitKV is not supported for hdim >= 192
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
# SplitKV not supported on SM90 - skip this iteration
if IS_SM90 and num_splits > 1:
@@ -634,7 +659,8 @@ def test_flash_attn_varlen_output(
v_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
# max_seqlen_k,
max_seqlen_q=seqlen_q,
max_seqlen_k=seqlen_k,
# seqused_q=seqused_q,
# seqused_k=seqused_k,
causal=causal,
@@ -647,6 +673,7 @@ def test_flash_attn_varlen_output(
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
)
out = output_pad_fn(out_unpad)
if query_unused_mask is not None:
@@ -670,10 +697,10 @@ def test_flash_attn_varlen_output(
and not attention_chunk != 0
and dv == d
and not has_learnable_sink
and False
# and False
):
g_unpad = torch.randn_like(out_unpad)
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
# do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
# import flash_attn_3_cuda
# dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
# g_unpad,
+434 -2
View File
@@ -31,7 +31,7 @@ from flash_attn.cute.interface import (
DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
IS_SM90 = torch.cuda.get_device_capability()[0] == 9
INCREASED_TRIALS = False
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@@ -304,7 +304,7 @@ def test_flash_attn_output(
dv_pt - dv_ref
).abs().max().item() + dv_atol
num_iters = 20_000
num_iters = 10_000 if INCREASED_TRIALS else 1000
for i in range(num_iters):
dq2, dk2, dv2, = _flash_attn_bwd(
q, k, v, out, g, lse,
@@ -342,3 +342,435 @@ def test_flash_attn_output(
print(f"✅ Iteration {i} passed!")
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["gqa"])
# @pytest.mark.parametrize("has_learnable_sink", [False, True])
@pytest.mark.parametrize("has_learnable_sink", [False])
# @pytest.mark.parametrize("has_qv", [False, True])
@pytest.mark.parametrize("has_qv", [False])
# @pytest.mark.parametrize("deterministic", [False, True])
@pytest.mark.parametrize("deterministic", [True])
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
@pytest.mark.parametrize("softcap", [0.0])
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
# @pytest.mark.parametrize("local_enum", [0, 1])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
@pytest.mark.parametrize("add_unused_qkv", [False])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128])
# @pytest.mark.parametrize("d", [128, 192])
@pytest.mark.parametrize("d", [64, 128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1024, 1024),
(2048, 2048),
],
)
@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"])
# @pytest.mark.parametrize("varlen_mode", ["random"])
@pytest.mark.parametrize(
"zero_lengths_q, zero_lengths_k",
[
(False, False),
(True, False),
(False, True),
(True, True),
],
)
def test_flash_attn_varlen_output(
seqlen_q,
seqlen_k,
d,
add_unused_qkv,
causal,
local_enum,
softcap,
deterministic,
has_qv,
has_learnable_sink,
mha_type,
dtype,
varlen_mode,
zero_lengths_q,
zero_lengths_k,
):
local = local_enum > 0
if local and causal:
pytest.skip()
if (
causal or local
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
seqlen_k = seqlen_q
device = "cuda"
# set seed
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
batch_size = 49 if seqlen_q <= 1024 else 7
nheads = 6
# nheads = 1
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
dv_vals = [d] # override
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]
attention_chunk_vals = [0]
for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):
q_ref = torch.randn(
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
)
if softcap > 0.0:
# Ensure the values of qk are at least within softcap range.
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
k_ref = (
torch.randn(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
.requires_grad_()
)
v_ref = (
torch.randn(
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
.requires_grad_()
)
if has_qv:
qv_ref = (
torch.randn(
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
)
.to(dtype)
.to(dtype_ref)
)
else:
qv_ref = None
# Put window_size after QKV randn so that window_size changes from test to test
window_size = (
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
)
if local_enum == 2:
window_size = (None, window_size[1])
elif local_enum == 3:
window_size = (window_size[0], None)
if local:
print("window size = ", window_size)
if has_learnable_sink:
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
else:
learnable_sink = None
if dtype == torch.float8_e4m3fn:
q_descale, k_descale, v_descale = [
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
* 2
for _ in range(3)
]
else:
q_descale, k_descale, v_descale = None, None, None
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
qv = qv_ref.detach() if has_qv else None
query_padding_mask = generate_random_padding_mask(
seqlen_q,
batch_size,
device,
mode=varlen_mode,
zero_lengths=zero_lengths_q,
)
key_padding_mask = generate_random_padding_mask(
seqlen_k,
batch_size,
device,
mode=varlen_mode,
zero_lengths=zero_lengths_k,
)
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
if add_unused:
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
attn_mask = torch.logical_and(padding_mask, another_mask)
unused_mask = torch.logical_xor(
torch.logical_or(padding_mask, another_mask), attn_mask
)
else:
attn_mask = padding_mask
unused_mask = None
return attn_mask, unused_mask
query_padding_mask, query_unused_mask = _gen_unused_masks(
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
)
# query_padding_mask[:] = True
# query_unused_mask = None
key_padding_mask, key_unused_mask = _gen_unused_masks(
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
)
if causal or local:
key_padding_mask = query_padding_mask
(
q_unpad,
k_unpad,
v_unpad,
qv_unpad,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
qv,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(
q,
k,
v,
query_padding_mask,
key_padding_mask,
qv=qv,
kvpacked=False,
query_unused_mask=query_unused_mask,
key_unused_mask=key_unused_mask,
)
print("cu_seqlens_q = ", cu_seqlens_q)
print("cu_seqlens_k = ", cu_seqlens_k)
q_unpad, k_unpad, v_unpad = [
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
]
out_ref, attn_ref = attention_ref(
q_ref,
k_ref,
v_ref,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv_ref,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
)
out_pt, attn_pt = attention_ref(
q_ref,
k_ref,
v_ref,
query_padding_mask,
key_padding_mask,
causal=causal,
qv=qv_ref,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
upcast=False,
reorder_ops=True,
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
)
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
if query_unused_mask is not None:
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2 if softcap == 0.0 else 3
out_unpad, lse = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
# max_seqlen_k,
# seqused_q=seqused_q,
# seqused_k=seqused_k,
max_seqlen_q=seqlen_q,
max_seqlen_k=seqlen_k,
causal=causal,
# qv=qv_unpad,
# q_descale=q_descale,
# k_descale=k_descale, v_descale=v_descale,
window_size=window_size,
# attention_chunk=attention_chunk,
learnable_sink=learnable_sink,
softcap=softcap,
num_splits=1,
pack_gqa=False,
deterministic=deterministic,
)
out = output_pad_fn(out_unpad)
if query_unused_mask is not None:
out.masked_fill_(q_zero_masking, 0.0)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
# if not causal:
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# breakpoint()
# Check that FlashAttention's numerical error is at most 3x the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= rtol * (
out_pt - out_ref
).abs().max().item() + fwd_atol
if (
dtype != torch.float8_e4m3fn
and not has_qv
and not dv > 256
and not attention_chunk != 0
and dv == d
and not has_learnable_sink
# and False
):
g_unpad = torch.randn_like(out_unpad)
# do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
# import flash_attn_3_cuda
# dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
# g_unpad,
# q_unpad,
# k_unpad,
# v_unpad,
# out_unpad,
# lse,
# None,
# None,
# None,
# cu_seqlens_q,
# cu_seqlens_k,
# None, None,
# max_seqlen_q,
# max_seqlen_k,
# d ** (-0.5),
# causal,
# window_size[0], window_size[1],
# softcap,
# deterministic,
# 0, # sm_margin
# )
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
)
dq = dq_pad_fn(dq_unpad)
dk = dk_pad_fn(dk_unpad)
dv = dk_pad_fn(dv_unpad)
if key_unused_mask is not None:
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
dk.masked_fill_(k_zero_masking, 0.0)
dv.masked_fill_(k_zero_masking, 0.0)
if query_unused_mask is not None:
dq.masked_fill_(q_zero_masking, 0.0)
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
# assert dq_accum.abs().max().item() == 0.0
g = output_pad_fn(g_unpad)
# qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
# qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
# dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
# P = torch.softmax(qk, -1)
# dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))
# dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
# dV = torch.einsum('bhts,bthd->bshd', P, g.float())
# dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
dq_ref, dk_ref, dv_ref = torch.autograd.grad(
out_ref, (q_ref, k_ref, v_ref), g
)
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# breakpoint()
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dq - dq_ref).abs().max().item() <= rtol * (
dq_pt - dq_ref
).abs().max().item() + dq_atol
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dk - dk_ref).abs().max().item() <= rtol * (
dk_pt - dk_ref
).abs().max().item() + dk_atol
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
0 if softcap == 0 else 3e-4
)
assert (dv - dv_ref).abs().max().item() <= rtol * (
dv_pt - dv_ref
).abs().max().item() + dv_atol
num_iters = 10_000 if INCREASED_TRIALS else 1000
for i in range(num_iters):
dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd(
q_unpad, k_unpad, v_unpad, out_unpad, g_unpad, lse,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
deterministic=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen_q,
max_seqlen_k=seqlen_k,
)
diff_dq = (dq_unpad - dq_unpad2).abs()
max_idx = diff_dq.argmax()
if i % 100 == 0:
print(f"dQ max diff: {diff_dq.max().item()}")
print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}")
diff_dk = (dk_unpad - dk_unpad2).abs()
max_idx = diff_dk.argmax()
if i % 100 == 0:
print(f"dK max diff: {diff_dk.max().item()}")
print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}")
diff_dv = (dv_unpad - dv_unpad2).abs()
max_idx = diff_dv.argmax()
if i % 100 == 0:
print(f"dV max diff: {diff_dv.max().item()}")
print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}")
assert torch.equal(dq_unpad, dq_unpad2)
assert torch.equal(dk_unpad, dk_unpad2)
assert torch.equal(dv_unpad, dv_unpad2)
if i % 100 == 0:
print(f"✅ Iteration {i} passed!")
+3 -3
View File
@@ -43,8 +43,8 @@ def test_varlen(
dtype=dtype
)
# SM90/SM100 backward pass doesn't support varlen yet
skip_backward = IS_SM90 or torch.cuda.get_device_capability()[0] == 10
# SM90 backward pass doesn't support varlen yet
skip_backward = IS_SM90
ok = check_varlen_vs_torch_flash(
q, k, v,
@@ -128,7 +128,7 @@ def check_varlen_vs_torch_flash(
if not ok_fwd:
return False
# Skip backward if not supported (e.g., SM100 varlen)
# Skip backward if not supported (e.g., SM90 varlen)
if skip_backward:
return True