mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[Cute, Bwd, Sm100] Add varlen for sm100 bwd (#2150)
* varlen bwd with rounded padded offsets * fix mha * change offset mode to round down multiple * enable varlen bwd tests * enable deterministic mode * fix deadlock and switch mha to no postprocess * reenable tests * fix lint error * use head swizzle/spt for deterministic, update tests * change padding offset based on arch * rebase and update interface, tests * add arch dispatch for padded offset q to postprocess * address comments * remove tile sizes from seqlen info class vars
This commit is contained in:
@@ -325,9 +325,9 @@ for headdim in [128]:
|
||||
else:
|
||||
page_table = None
|
||||
|
||||
# for causal in [False, True]:
|
||||
for causal in [True]:
|
||||
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
|
||||
for causal in [False, True]:
|
||||
# for causal in [True]:
|
||||
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###")
|
||||
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
|
||||
if cudnn is not None:
|
||||
# if False:
|
||||
@@ -395,7 +395,10 @@ for headdim in [128]:
|
||||
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
|
||||
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
|
||||
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav2 python')
|
||||
if not varlen:
|
||||
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
|
||||
else:
|
||||
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
|
||||
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
|
||||
# if False:
|
||||
|
||||
@@ -123,6 +123,7 @@ def cute_compile_patched(*args, **kwargs):
|
||||
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
|
||||
return output
|
||||
|
||||
|
||||
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
|
||||
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
|
||||
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
|
||||
|
||||
@@ -233,13 +233,15 @@ class FlashAttentionBackwardPostprocess:
|
||||
TileScheduler = SingleTileVarlenScheduler
|
||||
num_head = mdQ.shape[1]
|
||||
num_batch = mCuSeqlensQ.shape[0] - 1
|
||||
num_block = cute.ceil_div(mdQ.shape[0], self.tile_m)
|
||||
else:
|
||||
TileScheduler = SingleTileScheduler
|
||||
num_head = mdQ.shape[2]
|
||||
num_batch = mdQ.shape[0]
|
||||
num_block = cute.ceil_div(mdQ.shape[1], self.tile_m)
|
||||
|
||||
tile_sched_args = TileSchedulerArguments(
|
||||
num_block=cute.ceil_div(mdQ.shape[1], self.tile_m),
|
||||
num_block=num_block,
|
||||
num_head=num_head,
|
||||
num_batch=num_batch,
|
||||
num_splits=1,
|
||||
@@ -318,7 +320,7 @@ class FlashAttentionBackwardPostprocess:
|
||||
tile_scheduler = TileScheduler.create(tile_sched_params)
|
||||
work_tile = tile_scheduler.initial_work_tile_info()
|
||||
|
||||
m_block, num_head, batch_size, _ = work_tile.tile_idx
|
||||
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
||||
|
||||
if work_tile.is_valid_tile:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -326,7 +328,7 @@ class FlashAttentionBackwardPostprocess:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
seqlen = SeqlenInfoQK.create(
|
||||
batch_size,
|
||||
batch_idx,
|
||||
mdQ.shape[1],
|
||||
0,
|
||||
mCuSeqlensQ=mCuSeqlensQ,
|
||||
@@ -335,14 +337,16 @@ class FlashAttentionBackwardPostprocess:
|
||||
mSeqUsedK=None,
|
||||
)
|
||||
if const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mdQ_cur = mdQ[batch_size, None, num_head, None]
|
||||
mdQaccum_cur = mdQaccum[batch_size, num_head, None]
|
||||
mdQ_cur = mdQ[batch_idx, None, head_idx, None]
|
||||
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
||||
head_dim = mdQ.shape[3]
|
||||
else:
|
||||
padded_offset_q = seqlen.offset_q + batch_size * self.tile_m
|
||||
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, num_head, None])
|
||||
padded_offset_q = seqlen.offset_q + batch_idx * self.tile_m
|
||||
if cutlass.const_expr(self.arch >= 90):
|
||||
padded_offset_q = padded_offset_q // self.tile_m * self.tile_m
|
||||
mdQ_cur = cute.domain_offset((seqlen.offset_q, 0), mdQ[None, head_idx, None])
|
||||
mdQaccum_cur = cute.domain_offset(
|
||||
(padded_offset_q * self.tile_hdim,), mdQaccum[num_head, None]
|
||||
(padded_offset_q * self.tile_hdim,), mdQaccum[head_idx, None]
|
||||
)
|
||||
head_dim = mdQ.shape[2]
|
||||
|
||||
@@ -457,273 +461,3 @@ class FlashAttentionBackwardPostprocess:
|
||||
tdQgdQ[None, rest_m, None],
|
||||
pred=tdQpdQ[None, rest_m, None],
|
||||
)
|
||||
|
||||
|
||||
class FlashAttentionBackwardPostprocess_sm100(FlashAttentionBackwardPostprocess):
|
||||
def __init__(
|
||||
self,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
head_dim: int,
|
||||
tile_m: int = 128,
|
||||
num_threads: int = 256,
|
||||
AtomLayoutMdQ: int = 1,
|
||||
dQ_swapAB: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
dtype=dtype,
|
||||
head_dim=head_dim,
|
||||
arch=90, # tmp dummy placement for now
|
||||
tile_m=tile_m,
|
||||
num_threads=num_threads,
|
||||
AtomLayoutMdQ=AtomLayoutMdQ,
|
||||
dQ_swapAB=dQ_swapAB,
|
||||
)
|
||||
|
||||
def _setup_attributes(self):
|
||||
self.num_stages = self.tile_hdim // 32 # 2 for D=64, 4 for D=128
|
||||
|
||||
self.sdQaccum_layout = cute.make_layout(
|
||||
shape=(self.tile_m * 32, 2), stride=(1, self.tile_m * 32)
|
||||
)
|
||||
self.epi_tile_q = (self.tile_m, self.tile_hdim)
|
||||
self.sdQ_layout = sm100_utils_basic.make_smem_layout_epi(
|
||||
self.dtype,
|
||||
LayoutEnum.ROW_MAJOR,
|
||||
self.epi_tile_q,
|
||||
1,
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
mdQaccum: cute.Tensor,
|
||||
mdQ: cute.Tensor,
|
||||
scale: cutlass.Float32,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
# Assume all strides are divisible by 128 bits except the last stride
|
||||
new_stride = lambda t: (
|
||||
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
||||
t.stride[-1],
|
||||
)
|
||||
mdQaccum, mdQ = [
|
||||
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
||||
for t in (mdQaccum, mdQ)
|
||||
]
|
||||
# (b, h, s*d) -> (s*d, h, b)
|
||||
mdQaccum = cute.make_tensor(mdQaccum.iterator, cute.select(mdQaccum.layout, mode=[2, 1, 0]))
|
||||
# (b, s, h, d) -> (s, d, h, b)
|
||||
mdQ = cute.make_tensor(mdQ.iterator, cute.select(mdQ.layout, mode=[1, 3, 2, 0]))
|
||||
|
||||
self._setup_attributes()
|
||||
|
||||
grid_dim = [
|
||||
cute.ceil_div(mdQ.shape[0], self.tile_m),
|
||||
cute.size(mdQ.shape[2]),
|
||||
cute.size(mdQ.shape[3]),
|
||||
]
|
||||
|
||||
cta_group = tcgen05.CtaGroup.ONE
|
||||
self.mma_tiler_dsk = (self.tile_m, self.tile_hdim)
|
||||
|
||||
dS_major_mode = tcgen05.OperandMajorMode.MN
|
||||
kt_major_mode_dsq = tcgen05.OperandMajorMode.MN
|
||||
|
||||
tiled_mma_dsk = sm100_utils_basic.make_trivial_tiled_mma(
|
||||
cutlass.BFloat16,
|
||||
dS_major_mode,
|
||||
kt_major_mode_dsq,
|
||||
cutlass.Float32,
|
||||
cta_group,
|
||||
self.mma_tiler_dsk,
|
||||
)
|
||||
|
||||
dQ_cta_v_layout = cute.composition(cute.make_identity_layout(mdQ.shape), self.mma_tiler_dsk)
|
||||
tma_store_op = cpasync.CopyBulkTensorTileS2GOp()
|
||||
tma_atom_dQ, tma_tensor_dQ = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
tma_store_op,
|
||||
mdQ,
|
||||
cute.select(self.sdQ_layout, mode=[0, 1]),
|
||||
dQ_cta_v_layout,
|
||||
)
|
||||
|
||||
buffer_align_bytes = 1024
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
sdQaccum: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Float32, cute.cosize(self.sdQaccum_layout)],
|
||||
128,
|
||||
]
|
||||
|
||||
sdQ: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype, cute.cosize(self.sdQ_layout)],
|
||||
buffer_align_bytes,
|
||||
]
|
||||
|
||||
self.shared_storage = SharedStorage
|
||||
|
||||
self.kernel(
|
||||
mdQaccum,
|
||||
tma_tensor_dQ,
|
||||
tma_atom_dQ,
|
||||
self.sdQaccum_layout,
|
||||
self.sdQ_layout,
|
||||
tiled_mma_dsk,
|
||||
scale,
|
||||
).launch(
|
||||
grid=grid_dim,
|
||||
block=[self.num_threads, 1, 1],
|
||||
smem=self.shared_storage.size_in_bytes(),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
mdQaccum: cute.Tensor,
|
||||
mdQ: cute.Tensor,
|
||||
tma_atom_dQ: cute.CopyAtom,
|
||||
sdQaccum_layout: cute.Layout,
|
||||
sdQ_layout: cute.ComposedLayout,
|
||||
tiled_mma_dsk: cute.TiledMma,
|
||||
scale: cutlass.Float32,
|
||||
):
|
||||
tidx = cute.arch.thread_idx()[0]
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
m_block, head_idx, batch_idx = cute.arch.block_idx()
|
||||
|
||||
# SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(self.shared_storage)
|
||||
swz128 = cute.make_swizzle(3, 4, 3)
|
||||
sdQaccum = storage.sdQaccum.get_tensor(sdQaccum_layout, swizzle=swz128)
|
||||
|
||||
sdQ = storage.sdQ.get_tensor(sdQ_layout.outer, swizzle=sdQ_layout.inner)
|
||||
|
||||
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
|
||||
mdQ_cur = mdQ[None, None, head_idx, batch_idx]
|
||||
|
||||
thr_mma_dsk = tiled_mma_dsk.get_slice(tidx)
|
||||
dQacc_shape = thr_mma_dsk.partition_shape_C(self.mma_tiler_dsk[:2])
|
||||
tdQtdQ = thr_mma_dsk.make_fragment_C(dQacc_shape)
|
||||
tdQtdQ = cute.make_tensor(tdQtdQ.iterator, tdQtdQ.layout)
|
||||
|
||||
tmem_ld_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32
|
||||
)
|
||||
tiled_tmem_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tdQtdQ)
|
||||
thr_tmem_ld = tiled_tmem_ld.get_slice(tidx)
|
||||
|
||||
cdQ = cute.make_identity_tensor((self.mma_tiler_dsk[0], self.mma_tiler_dsk[1]))
|
||||
tdQcdQ = thr_mma_dsk.partition_C(cdQ)
|
||||
tdQcdQ_tensor = cute.make_tensor(tdQcdQ.iterator, tdQcdQ.layout)
|
||||
tdQrdQ = thr_tmem_ld.partition_D(tdQcdQ_tensor)
|
||||
|
||||
gdQaccum = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (m_block,))
|
||||
|
||||
num_reduce_warps = 4
|
||||
num_reduce_threads = cute.arch.WARP_SIZE * num_reduce_warps
|
||||
|
||||
atom_universal_copy = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(), cutlass.Float32, num_bits_per_copy=128
|
||||
)
|
||||
tiler_mn, layout_tv = cute.make_layout_tv(
|
||||
thr_layout=cute.make_layout(shape=num_reduce_threads, stride=1),
|
||||
val_layout=cute.make_layout(shape=4, stride=1),
|
||||
)
|
||||
G2S_tiled_copy_dQaccum = cute.make_tiled_copy(
|
||||
atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn
|
||||
)
|
||||
|
||||
smem_thr_copy_g2s = G2S_tiled_copy_dQaccum.get_slice(tidx)
|
||||
|
||||
# S->R
|
||||
tdQrdQ_t2r = cute.make_fragment(tdQrdQ.shape, cutlass.Float32)
|
||||
tiled_smem_store_s2r = cute.make_tiled_copy(
|
||||
atom_universal_copy, layout_tv=layout_tv, tiler_mn=tiler_mn
|
||||
)
|
||||
|
||||
s2r_thr_copy_dQaccum = tiled_smem_store_s2r.get_slice(tidx)
|
||||
tdQsdQ_s2r = s2r_thr_copy_dQaccum.partition_S(sdQaccum)
|
||||
tdQrdQ_s2r = cute.make_tensor(tdQrdQ_t2r.iterator, tdQrdQ_t2r.shape)
|
||||
|
||||
# R->S
|
||||
smem_copy_atom = sm100_utils_basic.get_smem_store_op(
|
||||
LayoutEnum.ROW_MAJOR, self.dtype, cutlass.Float32, tiled_tmem_ld
|
||||
)
|
||||
tiled_smem_store_r2s = cute.make_tiled_copy(
|
||||
smem_copy_atom,
|
||||
layout_tv=tiled_tmem_ld.layout_dst_tv_tiled,
|
||||
tiler_mn=tiled_tmem_ld.tiler_mn,
|
||||
)
|
||||
tdQsdQ_r2s = thr_tmem_ld.partition_D(thr_mma_dsk.partition_C(sdQ))
|
||||
tdQrdQ_r2s = cute.make_fragment(tdQsdQ_r2s.shape, self.dtype)
|
||||
|
||||
num_stages = cute.size(tdQrdQ_t2r, mode=[1])
|
||||
for stage in cutlass.range_constexpr(num_stages):
|
||||
# G->S
|
||||
gdQaccum_stage = cute.local_tile(
|
||||
gdQaccum,
|
||||
(self.tile_m * 32,),
|
||||
(stage,),
|
||||
)
|
||||
|
||||
gdQaccum_layout_g2s = cute.make_layout(shape=(self.tile_m * 32, 1), stride=(1, 0))
|
||||
gdQaccum_stage_g2s = cute.make_tensor(
|
||||
cute.recast_ptr(gdQaccum_stage.iterator, swizzle_=swz128), gdQaccum_layout_g2s
|
||||
)
|
||||
|
||||
tdQgdQ = smem_thr_copy_g2s.partition_S(gdQaccum_stage_g2s)
|
||||
tdQsdQ = smem_thr_copy_g2s.partition_D(sdQaccum)
|
||||
|
||||
cute.copy(smem_thr_copy_g2s, tdQgdQ[None, None, 0], tdQsdQ[None, None, 0])
|
||||
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
||||
)
|
||||
cute.arch.barrier(barrier_id=6, number_of_threads=num_reduce_threads)
|
||||
|
||||
# S -> R
|
||||
tdQrdQ_s2r_cpy = tdQrdQ_s2r[None, stage, None, None]
|
||||
tdQsdQ_s2r_p = tdQsdQ_s2r[None, None, 0]
|
||||
tdQrdQ_r2s_cpy = cute.make_tensor(
|
||||
tdQrdQ_s2r_cpy.iterator, cute.make_layout(tdQsdQ_s2r_p.shape)
|
||||
)
|
||||
|
||||
cute.copy(s2r_thr_copy_dQaccum, tdQsdQ_s2r_p, tdQrdQ_r2s_cpy)
|
||||
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
||||
)
|
||||
cute.arch.barrier(barrier_id=7, number_of_threads=num_reduce_threads)
|
||||
|
||||
# R->S
|
||||
tdQrdQ_r2s_cpy = cute.make_tensor(
|
||||
cute.recast_ptr(tdQrdQ_r2s_cpy.iterator),
|
||||
tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].shape,
|
||||
)
|
||||
dQ_vec = tdQrdQ_r2s_cpy.load() * scale
|
||||
tdQrdQ_r2s[((None, 0), stage, 0, 0, 0)].store(dQ_vec.to(self.dtype))
|
||||
|
||||
cute.copy(
|
||||
tiled_smem_store_r2s,
|
||||
tdQrdQ_r2s[None, None, None, None, 0],
|
||||
tdQsdQ_r2s[None, None, None, None, 0],
|
||||
)
|
||||
cute.arch.fence_proxy(
|
||||
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
|
||||
)
|
||||
cute.arch.barrier(barrier_id=8, number_of_threads=num_reduce_threads)
|
||||
|
||||
# S-> G
|
||||
gdQ = cute.local_tile(mdQ_cur, (self.tile_m, self.tile_hdim), (None, 0))
|
||||
tdQsdQ, tdQgdQ = cpasync.tma_partition(
|
||||
tma_atom_dQ,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sdQ, 0, 2),
|
||||
cute.group_modes(gdQ, 0, 2),
|
||||
)
|
||||
|
||||
cute.copy(tma_atom_dQ, tdQsdQ[None, 0], tdQgdQ[None, m_block])
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# from Cutlass C++ to Cute-DSL.
|
||||
import math
|
||||
import operator
|
||||
from typing import Callable, Type, Optional
|
||||
from typing import Callable, Type, Optional, Literal
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
@@ -27,6 +27,7 @@ class FlashAttentionBackwardPreprocess:
|
||||
self,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
head_dim: int,
|
||||
arch: Literal[80, 90, 100],
|
||||
m_block_size: int = 128,
|
||||
num_threads: int = 128,
|
||||
):
|
||||
@@ -43,6 +44,7 @@ class FlashAttentionBackwardPreprocess:
|
||||
"""
|
||||
self.dtype = dtype
|
||||
self.m_block_size = m_block_size
|
||||
self.arch = arch
|
||||
# padding head_dim to a multiple of 32 as k_block_size
|
||||
hdim_multiple_of = 32
|
||||
self.head_dim_padded = int(math.ceil(head_dim / hdim_multiple_of) * hdim_multiple_of)
|
||||
@@ -213,14 +215,14 @@ class FlashAttentionBackwardPreprocess:
|
||||
|
||||
tile_scheduler = TileScheduler.create(tile_sched_params)
|
||||
work_tile = tile_scheduler.initial_work_tile_info()
|
||||
m_block, num_head, batch_size, _ = work_tile.tile_idx
|
||||
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
||||
|
||||
if work_tile.is_valid_tile:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Get the appropriate tiles for this thread block.
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
seqlen = SeqlenInfoQK.create(
|
||||
batch_size,
|
||||
batch_idx,
|
||||
mO.shape[1],
|
||||
0,
|
||||
mCuSeqlensQ=mCuSeqlensQ,
|
||||
@@ -230,16 +232,18 @@ class FlashAttentionBackwardPreprocess:
|
||||
)
|
||||
|
||||
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mO_cur = mO[batch_size, None, num_head, None]
|
||||
mdO_cur = mdO[batch_size, None, num_head, None]
|
||||
mdPsum_cur = mdPsum[batch_size, num_head, None]
|
||||
mO_cur = mO[batch_idx, None, head_idx, None]
|
||||
mdO_cur = mdO[batch_idx, None, head_idx, None]
|
||||
mdPsum_cur = mdPsum[batch_idx, head_idx, None]
|
||||
headdim_v = mO.shape[3]
|
||||
else:
|
||||
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, num_head, None])
|
||||
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, num_head, None])
|
||||
mO_cur = cute.domain_offset((seqlen.offset_q, 0), mO[None, head_idx, None])
|
||||
mdO_cur = cute.domain_offset((seqlen.offset_q, 0), mdO[None, head_idx, None])
|
||||
|
||||
padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size
|
||||
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[num_head, None])
|
||||
padded_offset_q = seqlen.offset_q + batch_idx * self.m_block_size
|
||||
if cutlass.const_expr(self.arch >= 90):
|
||||
padded_offset_q = padded_offset_q // self.m_block_size * self.m_block_size
|
||||
mdPsum_cur = cute.domain_offset((padded_offset_q,), mdPsum[head_idx, None])
|
||||
headdim_v = mO.shape[2]
|
||||
|
||||
blkOdO_shape = (self.m_block_size, self.head_dim_padded)
|
||||
@@ -268,9 +272,9 @@ class FlashAttentionBackwardPreprocess:
|
||||
|
||||
if cutlass.const_expr(mLSE is not None):
|
||||
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mLSE_cur = mLSE[batch_size, num_head, None]
|
||||
mLSE_cur = mLSE[batch_idx, head_idx, None]
|
||||
else:
|
||||
mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[num_head, None])
|
||||
mLSE_cur = cute.domain_offset((seqlen.offset_q,), mLSE[head_idx, None])
|
||||
|
||||
gLSE = cute.local_tile(mLSE_cur, (self.m_block_size,), (m_block,))
|
||||
lse = Float32.inf
|
||||
@@ -323,11 +327,10 @@ class FlashAttentionBackwardPreprocess:
|
||||
# Clear dQaccum
|
||||
if cutlass.const_expr(mdQaccum is not None):
|
||||
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mdQaccum_cur = mdQaccum[batch_size, num_head, None]
|
||||
mdQaccum_cur = mdQaccum[batch_idx, head_idx, None]
|
||||
else:
|
||||
padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size
|
||||
mdQaccum_cur = cute.domain_offset(
|
||||
(padded_offset_q * self.head_dim_padded,), mdQaccum[num_head, None]
|
||||
(padded_offset_q * self.head_dim_padded,), mdQaccum[head_idx, None]
|
||||
)
|
||||
|
||||
# HACK: Compiler doesn't seem to recognize that padding
|
||||
@@ -352,10 +355,9 @@ class FlashAttentionBackwardPreprocess:
|
||||
|
||||
if cutlass.const_expr(mLSE is not None):
|
||||
if cutlass.const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mLSElog2_cur = mLSElog2[batch_size, num_head, None]
|
||||
mLSElog2_cur = mLSElog2[batch_idx, head_idx, None]
|
||||
else:
|
||||
padded_offset_q = seqlen.offset_q + batch_size * self.m_block_size
|
||||
mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[num_head, None])
|
||||
mLSElog2_cur = cute.domain_offset((padded_offset_q,), mLSElog2[head_idx, None])
|
||||
|
||||
gLSElog2 = cute.local_tile(mLSElog2_cur, (self.m_block_size,), (m_block,))
|
||||
LOG2_E = math.log2(math.e)
|
||||
|
||||
@@ -25,6 +25,7 @@ from flash_attn.cute.tile_scheduler import (
|
||||
TileSchedulerArguments,
|
||||
SingleTileScheduler,
|
||||
SingleTileLPTBwdScheduler, # noqa
|
||||
SingleTileVarlenScheduler,
|
||||
ParamsBase,
|
||||
)
|
||||
|
||||
@@ -78,7 +79,7 @@ class FlashAttentionBackwardSm100:
|
||||
self.tile_n = tile_n
|
||||
|
||||
# CTA tiler
|
||||
self.cta_tiler = (tile_m, tile_n, self.tile_hdim)
|
||||
self.cta_tiler = (tile_n, tile_m, self.tile_hdim)
|
||||
# S = K @ Q.T
|
||||
self.mma_tiler_kq = (tile_n, tile_m, self.tile_hdim)
|
||||
# dP = V @ dO.T
|
||||
@@ -99,7 +100,6 @@ class FlashAttentionBackwardSm100:
|
||||
self.is_local = is_local
|
||||
self.qhead_per_kvhead = qhead_per_kvhead
|
||||
self.pack_gqa = False
|
||||
self.use_tma_store = True
|
||||
self.deterministic = deterministic
|
||||
|
||||
# Score mod and mask mod support
|
||||
@@ -353,7 +353,7 @@ class FlashAttentionBackwardSm100:
|
||||
self.num_epi_stages = max(1, (self.tile_hdim // 2) // self.sdKV_epi_tile[1])
|
||||
self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages
|
||||
# TODO: dK and dV could have different shapes
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
self.sdKV_layout = sm100_utils_basic.make_smem_layout_epi(
|
||||
self.dk_dtype,
|
||||
LayoutEnum.ROW_MAJOR,
|
||||
@@ -391,9 +391,6 @@ class FlashAttentionBackwardSm100:
|
||||
# Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
|
||||
blocksparse_tensors: Optional[BlockSparseTensors] = None,
|
||||
):
|
||||
assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
|
||||
"Variable sequence length is not supported yet in FlashAttentionBackwardSm100"
|
||||
)
|
||||
self.q_dtype = mQ.element_type
|
||||
self.k_dtype = mK.element_type
|
||||
self.v_dtype = mV.element_type
|
||||
@@ -405,7 +402,12 @@ class FlashAttentionBackwardSm100:
|
||||
self.dv_dtype = mdV.element_type
|
||||
self.ds_dtype = self.q_dtype
|
||||
|
||||
if const_expr(self.qhead_per_kvhead > 1):
|
||||
self.is_varlen_k = mCuSeqlensK is not None or mSeqUsedK is not None
|
||||
self.is_varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None
|
||||
self.use_tma_store = not (self.qhead_per_kvhead == 1 and mCuSeqlensK is not None)
|
||||
self.dKV_postprocess = self.qhead_per_kvhead > 1
|
||||
|
||||
if const_expr(self.dKV_postprocess):
|
||||
assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
|
||||
assert self.dv_dtype.width == 32, "Must accumulate dV in float precision for GQA"
|
||||
|
||||
@@ -429,21 +431,30 @@ class FlashAttentionBackwardSm100:
|
||||
)
|
||||
]
|
||||
|
||||
layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b)
|
||||
mQ, mK, mV, mdO = [utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)]
|
||||
LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b)
|
||||
# (b, s, n, h) --> (s, h, n, b) or (t, n, h) -> (t, h, n)
|
||||
QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1]
|
||||
mQ, mdO = [utils.select(t, mode=QO_layout_transpose) for t in (mQ, mdO)]
|
||||
|
||||
KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1]
|
||||
mK, mV = [utils.select(t, mode=KV_layout_transpose) for t in (mK, mV)]
|
||||
|
||||
# (b, n, s) --> (s, n, b) or (n, t) --> (t, n)
|
||||
LSE_dPsum_dQaccum_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0]
|
||||
mLSE, mdPsum, mdQaccum = [
|
||||
utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum)
|
||||
]
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
layout_dKV_transpose = layout_transpose
|
||||
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
layout_dKV_transpose = KV_layout_transpose
|
||||
else:
|
||||
layout_dKV_transpose = LSE_dPsum_dQaccum_transpose
|
||||
mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)]
|
||||
dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, b)
|
||||
# (s, h, n, b) --> (h, s, n, b) or (t, h, n) -> (h, t, b)
|
||||
dO_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensQ is None) else [1, 0, 2]
|
||||
mdO = utils.select(mdO, mode=dO_transpose)
|
||||
|
||||
semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b)
|
||||
# (b, n, block, stage) -> (block, stage, n, b)
|
||||
semaphore_transpose = [2, 3, 1, 0]
|
||||
if const_expr(self.deterministic):
|
||||
assert mdQ_semaphore is not None
|
||||
mdQ_semaphore = utils.select(mdQ_semaphore, mode=semaphore_transpose)
|
||||
@@ -478,7 +489,7 @@ class FlashAttentionBackwardSm100:
|
||||
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
||||
self.is_q_do_mcast = self.num_mcast_ctas_b > 1
|
||||
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
self.mdK_layout_enum = LayoutEnum.from_tensor(mdK)
|
||||
self.mdV_layout_enum = LayoutEnum.from_tensor(mdV)
|
||||
dK_major_mode = self.mdK_layout_enum.mma_major_mode()
|
||||
@@ -488,7 +499,7 @@ class FlashAttentionBackwardSm100:
|
||||
if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of mdV is wrong")
|
||||
|
||||
if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1):
|
||||
if const_expr(self.use_tma_store and not self.dKV_postprocess):
|
||||
tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp()
|
||||
tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom(
|
||||
tma_copy_op_dKV,
|
||||
@@ -510,7 +521,7 @@ class FlashAttentionBackwardSm100:
|
||||
tma_atom_dV = None
|
||||
tma_atom_dK = None
|
||||
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
thr_layout_r2s_dKV = cute.make_ordered_layout((128, 1), order=(1, 0)) # 128 threads
|
||||
val_layout_r2s_dKV = cute.make_ordered_layout(
|
||||
(1, 128 // self.dk_dtype.width), order=(1, 0)
|
||||
@@ -589,29 +600,36 @@ class FlashAttentionBackwardSm100:
|
||||
self.tma_copy_bytes["dKacc"] = self.tile_n * self.dK_reduce_ncol * Float32.width // 8
|
||||
|
||||
# TileScheduler = SingleTileScheduler
|
||||
if const_expr(self.deterministic):
|
||||
if const_expr(self.is_varlen_k):
|
||||
TileScheduler = SingleTileVarlenScheduler
|
||||
elif const_expr(self.deterministic):
|
||||
TileScheduler = SingleTileLPTBwdScheduler
|
||||
else:
|
||||
TileScheduler = SingleTileScheduler
|
||||
# reads n_blocks right-to-left
|
||||
self.spt = (self.is_causal or self.is_local) and self.deterministic
|
||||
tile_sched_args = TileSchedulerArguments(
|
||||
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]),
|
||||
cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks
|
||||
cute.size(mQ.shape[2]), # num_heads = num_query_heads
|
||||
cute.size(mK.shape[3]),
|
||||
cute.size(mK.shape[3])
|
||||
if const_expr(mCuSeqlensK is None)
|
||||
else cute.size(mCuSeqlensK.shape[0] - 1), # num_batches
|
||||
1, # num_splits
|
||||
cute.size(mQ.shape[0]), # pass seqlen_q for seqlen_k
|
||||
mQ.shape[1],
|
||||
mV.shape[1],
|
||||
total_q=cute.size(mQ.shape[0]),
|
||||
tile_shape_mn=self.cta_tiler[:2],
|
||||
cute.size(mQ.shape[0]), # pass seqlen_q or total_q for seqlen_k
|
||||
mQ.shape[1], # headdim
|
||||
mV.shape[1], # headdim_v
|
||||
total_q=cute.size(mK.shape[0]) # pass total_k for total_q
|
||||
if const_expr(mCuSeqlensK is not None)
|
||||
else cute.size(mK.shape[0]) * cute.size(mK.shape[3]),
|
||||
tile_shape_mn=self.cta_tiler[:2], # (tile_n, tile_m)
|
||||
cluster_shape_mn=self.cluster_shape_mnk[:2],
|
||||
mCuSeqlensQ=None,
|
||||
mSeqUsedQ=None,
|
||||
qhead_per_kvhead_packgqa=1,
|
||||
mCuSeqlensQ=mCuSeqlensK,
|
||||
mSeqUsedQ=mSeqUsedK,
|
||||
qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd
|
||||
element_size=self.k_dtype.width // 8,
|
||||
is_persistent=self.is_persistent,
|
||||
is_persistent=self.is_persistent, # persistent mode not tested
|
||||
lpt=self.spt,
|
||||
head_swizzle=self.deterministic,
|
||||
)
|
||||
|
||||
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
|
||||
@@ -718,6 +736,11 @@ class FlashAttentionBackwardSm100:
|
||||
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
|
||||
self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None)
|
||||
|
||||
if const_expr(self.use_block_sparsity or aux_tensors is not None):
|
||||
assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
|
||||
"Variable sequence length is not supported yet for blocksparse or aux tensors in bwd"
|
||||
)
|
||||
|
||||
self.kernel(
|
||||
tma_tensor_Q,
|
||||
tma_tensor_K,
|
||||
@@ -733,6 +756,10 @@ class FlashAttentionBackwardSm100:
|
||||
mdQ_semaphore,
|
||||
mdK_semaphore,
|
||||
mdV_semaphore,
|
||||
mCuSeqlensQ,
|
||||
mCuSeqlensK,
|
||||
mSeqUsedQ,
|
||||
mSeqUsedK,
|
||||
tma_atom_Q,
|
||||
tma_atom_K,
|
||||
tma_atom_V,
|
||||
@@ -794,6 +821,10 @@ class FlashAttentionBackwardSm100:
|
||||
mdQ_semaphore: Optional[cute.Tensor],
|
||||
mdK_semaphore: Optional[cute.Tensor],
|
||||
mdV_semaphore: Optional[cute.Tensor],
|
||||
mCuSeqlensQ: Optional[cute.Tensor],
|
||||
mCuSeqlensK: Optional[cute.Tensor],
|
||||
mSeqUsedQ: Optional[cute.Tensor],
|
||||
mSeqUsedK: Optional[cute.Tensor],
|
||||
tma_atom_Q: cute.CopyAtom,
|
||||
tma_atom_K: cute.CopyAtom,
|
||||
tma_atom_V: cute.CopyAtom,
|
||||
@@ -986,7 +1017,7 @@ class FlashAttentionBackwardSm100:
|
||||
)
|
||||
sLSE = storage.sLSE.get_tensor(sLSE_layout)
|
||||
sdPsum = storage.sdPsum.get_tensor(sdPsum_layout)
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
sdV = storage.sdO.get_tensor(
|
||||
sdKV_layout.outer, swizzle=sdKV_layout.inner, dtype=self.dv_dtype
|
||||
)
|
||||
@@ -1054,10 +1085,12 @@ class FlashAttentionBackwardSm100:
|
||||
SeqlenInfoQK.create,
|
||||
seqlen_q_static=mQ.shape[0],
|
||||
seqlen_k_static=mK.shape[0],
|
||||
mCuSeqlensQ=None,
|
||||
mCuSeqlensK=None,
|
||||
mSeqUsedQ=None,
|
||||
mSeqUsedK=None,
|
||||
mCuSeqlensQ=mCuSeqlensQ,
|
||||
mCuSeqlensK=mCuSeqlensK,
|
||||
mSeqUsedQ=mSeqUsedQ,
|
||||
mSeqUsedK=mSeqUsedK,
|
||||
tile_m=self.tile_m,
|
||||
tile_n=self.tile_n,
|
||||
)
|
||||
TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params)
|
||||
|
||||
@@ -1294,12 +1327,17 @@ class FlashAttentionBackwardSm100:
|
||||
seqlen, n_block // self.cluster_shape_mnk[0]
|
||||
)
|
||||
head_idx_kv = head_idx // self.qhead_per_kvhead
|
||||
mQ_cur = mQ[None, None, head_idx, batch_idx]
|
||||
mK_cur = mK[None, None, head_idx_kv, batch_idx]
|
||||
mV_cur = mV[None, None, head_idx_kv, batch_idx]
|
||||
mdO_cur = mdO[None, None, head_idx, batch_idx]
|
||||
mLSE_cur = mLSE[None, head_idx, batch_idx]
|
||||
mPsum_cur = mdPsum[None, head_idx, batch_idx]
|
||||
mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx]
|
||||
mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv]
|
||||
mV_cur = seqlen.offset_batch_K(mV, batch_idx, dim=3)[None, None, head_idx_kv]
|
||||
if const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mdO_cur = mdO[None, None, head_idx, batch_idx]
|
||||
else:
|
||||
mdO_cur = cute.domain_offset((0, seqlen.offset_q), mdO[None, None, head_idx])
|
||||
mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2, padded=True)[None, head_idx]
|
||||
mdPsum_cur = seqlen.offset_batch_Q(mdPsum, batch_idx, dim=2, padded=True)[
|
||||
None, head_idx
|
||||
]
|
||||
|
||||
gK = cute.local_tile(mK_cur, cute.select(self.mma_tiler_kq, mode=[0, 2]), (n_block, 0))
|
||||
tSgK = thr_mma_S.partition_A(gK)
|
||||
@@ -1308,7 +1346,7 @@ class FlashAttentionBackwardSm100:
|
||||
gQ = cute.local_tile(mQ_cur, cute.select(self.mma_tiler_kq, mode=[1, 2]), (None, 0))
|
||||
tSgQ = thr_mma_S.partition_B(gQ)
|
||||
gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (None,))
|
||||
gdPsum = cute.local_tile(mPsum_cur, (self.tile_m,), (None,))
|
||||
gdPsum = cute.local_tile(mdPsum_cur, (self.tile_m,), (None,))
|
||||
gdO = cute.local_tile(mdO_cur, cute.select(self.mma_tiler_pdo, mode=[1, 2]), (0, None))
|
||||
tdPgdO = thr_mma_dV.partition_B(gdO)
|
||||
|
||||
@@ -1363,7 +1401,10 @@ class FlashAttentionBackwardSm100:
|
||||
)
|
||||
process_tile = total_m_block_cnt > Int32(0)
|
||||
else:
|
||||
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
||||
process_tile = (
|
||||
const_expr(not self.is_local and not self.is_varlen_q)
|
||||
or m_block_min < m_block_max
|
||||
)
|
||||
|
||||
if process_tile:
|
||||
if const_expr(self.use_block_sparsity):
|
||||
@@ -1616,7 +1657,10 @@ class FlashAttentionBackwardSm100:
|
||||
process_tile = block_iter_count > Int32(0)
|
||||
else:
|
||||
block_iter_count = m_block_max - m_block_min
|
||||
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
||||
process_tile = (
|
||||
const_expr(not self.is_local and not self.is_varlen_q)
|
||||
or m_block_min < m_block_max
|
||||
)
|
||||
|
||||
if process_tile:
|
||||
accumulate_dK = False
|
||||
@@ -2055,7 +2099,10 @@ class FlashAttentionBackwardSm100:
|
||||
)
|
||||
process_tile = loop_count > Int32(0)
|
||||
else:
|
||||
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
||||
process_tile = (
|
||||
const_expr(not self.is_local and not self.is_varlen_q)
|
||||
or m_block_min < m_block_max
|
||||
)
|
||||
loop_count = m_block_max - m_block_min
|
||||
|
||||
# Mainloop
|
||||
@@ -2271,6 +2318,7 @@ class FlashAttentionBackwardSm100:
|
||||
batch_idx,
|
||||
head_idx,
|
||||
n_block,
|
||||
seqlen,
|
||||
thr_mma_dV,
|
||||
thr_mma_dK,
|
||||
tdVtdV,
|
||||
@@ -2289,6 +2337,7 @@ class FlashAttentionBackwardSm100:
|
||||
batch_idx,
|
||||
head_idx,
|
||||
n_block,
|
||||
seqlen,
|
||||
thr_mma_dV,
|
||||
tdVtdV,
|
||||
mdV_tma_tensor,
|
||||
@@ -2307,6 +2356,7 @@ class FlashAttentionBackwardSm100:
|
||||
batch_idx,
|
||||
head_idx,
|
||||
n_block,
|
||||
seqlen,
|
||||
thr_mma_dK,
|
||||
tdKtdK,
|
||||
mdK_tma_tensor,
|
||||
@@ -2315,15 +2365,15 @@ class FlashAttentionBackwardSm100:
|
||||
thr_copy_r2s_dKV,
|
||||
pipeline_dKV,
|
||||
consumer_state_dKV,
|
||||
softmax_scale if const_expr(self.qhead_per_kvhead == 1) else None,
|
||||
softmax_scale if const_expr(not self.dKV_postprocess) else None,
|
||||
int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id
|
||||
mdK_semaphore,
|
||||
)
|
||||
# Zero dK/dV for empty tiles (local attention or block sparsity)
|
||||
# When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
should_zero_dKV = False
|
||||
if const_expr(self.is_local):
|
||||
if const_expr(self.is_local or seqlen.has_cu_seqlens_q):
|
||||
should_zero_dKV = m_block_min >= m_block_max
|
||||
if const_expr(self.use_block_sparsity):
|
||||
# For block sparsity, zero when no m_blocks contribute to this n_block
|
||||
@@ -2338,8 +2388,8 @@ class FlashAttentionBackwardSm100:
|
||||
128, # num_threads
|
||||
)
|
||||
gmem_thr_copy_zero_dKV = gmem_tiled_copy_zero_dKV.get_slice(dp_idx)
|
||||
mdV_cur = mdV[None, None, head_idx, batch_idx]
|
||||
mdK_cur = mdK[None, None, head_idx, batch_idx]
|
||||
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
|
||||
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
|
||||
gdK = cute.local_tile(mdK_cur, (self.tile_n, self.tile_hdim), (n_block, 0))
|
||||
gdV = cute.local_tile(mdV_cur, (self.tile_n, self.tile_hdimv), (n_block, 0))
|
||||
tdKgdK = gmem_thr_copy_zero_dKV.partition_D(gdK)
|
||||
@@ -2415,7 +2465,12 @@ class FlashAttentionBackwardSm100:
|
||||
m_block_min, m_block_max = block_info.get_m_block_min_max(
|
||||
seqlen, n_block // self.cluster_shape_mnk[0]
|
||||
)
|
||||
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
|
||||
if const_expr(not seqlen.has_cu_seqlens_q):
|
||||
mdQaccum_cur = mdQaccum[None, head_idx, batch_idx]
|
||||
else:
|
||||
mdQaccum_cur = cute.domain_offset(
|
||||
(seqlen.padded_offset_q * self.tile_hdim,), mdQaccum[None, head_idx]
|
||||
)
|
||||
gdQaccum_ = cute.local_tile(mdQaccum_cur, (self.tile_m * self.tile_hdim,), (None,))
|
||||
# (M * K / STAGE, STAGE, _)
|
||||
gdQaccum = cute.flat_divide(
|
||||
@@ -2446,7 +2501,10 @@ class FlashAttentionBackwardSm100:
|
||||
)
|
||||
process_tile = loop_count > Int32(0)
|
||||
else:
|
||||
process_tile = const_expr(not self.is_local) or m_block_min < m_block_max
|
||||
process_tile = (
|
||||
const_expr(not self.is_local and not self.is_varlen_q)
|
||||
or m_block_min < m_block_max
|
||||
)
|
||||
loop_count = m_block_max - m_block_min
|
||||
|
||||
# dQacc_reduce mainloop
|
||||
@@ -2580,6 +2638,7 @@ class FlashAttentionBackwardSm100:
|
||||
batch_idx: Int32,
|
||||
head_idx: Int32,
|
||||
n_block: Int32,
|
||||
seqlen,
|
||||
thr_mma_dV: cute.core.ThrMma,
|
||||
thr_mma_dK: cute.core.ThrMma,
|
||||
tdVtdV: cute.Tensor,
|
||||
@@ -2596,8 +2655,8 @@ class FlashAttentionBackwardSm100:
|
||||
num_wg = cute.arch.WARP_SIZE * len(self.compute_warp_ids) // 128
|
||||
|
||||
assert self.qhead_per_kvhead == 1, "This epilogue path is only for MHA"
|
||||
mdV_cur = mdV[None, None, head_idx, batch_idx]
|
||||
mdK_cur = mdK[None, None, head_idx, batch_idx]
|
||||
mdV_cur = seqlen.offset_batch_K(mdV, batch_idx, dim=3)[None, None, head_idx]
|
||||
mdK_cur = seqlen.offset_batch_K(mdK, batch_idx, dim=3)[None, None, head_idx]
|
||||
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(16)), Float32
|
||||
@@ -2647,7 +2706,8 @@ class FlashAttentionBackwardSm100:
|
||||
tdVgdV_r2g_p = thr_tmem_ld_dV.partition_D(tdVgdV)
|
||||
tdVgdV_r2g = self.split_wg(tdVgdV_r2g_p, wg_idx, num_wg)
|
||||
|
||||
cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)
|
||||
if tidx < seqlen.seqlen_k - self.tile_n * n_block:
|
||||
cute.copy(tiled_gmem_store_dV, tdVrdV_r2s, tdVgdV_r2g)
|
||||
|
||||
cute.arch.sync_warp()
|
||||
with cute.arch.elect_one():
|
||||
@@ -2700,7 +2760,8 @@ class FlashAttentionBackwardSm100:
|
||||
tdKgdK_r2g_p = thr_tmem_ld_dK.partition_D(tdKgdK)
|
||||
tdKgdK_r2g = self.split_wg(tdKgdK_r2g_p, wg_idx, num_wg)
|
||||
|
||||
cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)
|
||||
if tidx < seqlen.seqlen_k - self.tile_n * n_block:
|
||||
cute.copy(tiled_gmem_store_dK, tdKrdK_r2s, tdKgdK_r2g)
|
||||
|
||||
cute.arch.sync_warp()
|
||||
with cute.arch.elect_one():
|
||||
@@ -2715,6 +2776,7 @@ class FlashAttentionBackwardSm100:
|
||||
batch_idx: Int32,
|
||||
head_idx: Int32,
|
||||
n_block: Int32,
|
||||
seqlen,
|
||||
thr_mma: cute.core.ThrMma,
|
||||
tdKVtdKV: cute.Tensor,
|
||||
mdKV: cute.Tensor,
|
||||
@@ -2734,7 +2796,7 @@ class FlashAttentionBackwardSm100:
|
||||
num_wg = num_compute_threads // 128
|
||||
leader_warp = (cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4) == 0
|
||||
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16
|
||||
else:
|
||||
sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32
|
||||
@@ -2743,7 +2805,8 @@ class FlashAttentionBackwardSm100:
|
||||
tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV)
|
||||
|
||||
head_idx_kv = head_idx // self.qhead_per_kvhead
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
assert not seqlen.has_cu_seqlens_k, "varlen uses non tma store path"
|
||||
mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim)
|
||||
gdKV_p = cute.local_tile(
|
||||
mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0)
|
||||
@@ -2753,7 +2816,12 @@ class FlashAttentionBackwardSm100:
|
||||
gdKV, self.sdKV_epi_tile, (0, None)
|
||||
) # (tile_n, 64, epi_stage = (hdim / 2) / 64)
|
||||
else:
|
||||
mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
|
||||
if const_expr(not seqlen.has_cu_seqlens_k):
|
||||
mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim)
|
||||
else:
|
||||
mdKV_cur = cute.domain_offset(
|
||||
(seqlen.padded_offset_k * self.tile_hdim,), mdKV[None, head_idx_kv]
|
||||
)
|
||||
gdKV_p = cute.local_tile(
|
||||
mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,)
|
||||
) # (tile_n * hdim)
|
||||
@@ -2768,7 +2836,7 @@ class FlashAttentionBackwardSm100:
|
||||
if const_expr(deterministic_KV):
|
||||
mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx]
|
||||
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
tdKVsdKV, tdKVgdKV = cpasync.tma_partition(
|
||||
tma_atom_dKV,
|
||||
0, # no multicast
|
||||
@@ -2842,7 +2910,7 @@ class FlashAttentionBackwardSm100:
|
||||
|
||||
# SMEM -> GMEM
|
||||
if leader_warp:
|
||||
if const_expr(self.qhead_per_kvhead == 1):
|
||||
if const_expr(not self.dKV_postprocess):
|
||||
cute.copy(tma_atom_dKV, tdKVsdKV, tdKVgdKV[None, epi_stage])
|
||||
else:
|
||||
with cute.arch.elect_one():
|
||||
|
||||
@@ -92,6 +92,8 @@ def _flash_attn_fwd(
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
@@ -115,8 +117,6 @@ def _flash_attn_fwd(
|
||||
out: Optional[torch.Tensor] = None,
|
||||
lse: Optional[torch.Tensor] = None,
|
||||
aux_tensors: Optional[list[torch.Tensor]] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for FlashAttention.
|
||||
|
||||
@@ -569,6 +569,8 @@ def _flash_attn_bwd(
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
deterministic: bool = False,
|
||||
dq: Optional[torch.Tensor] = None,
|
||||
dk: Optional[torch.Tensor] = None,
|
||||
@@ -615,16 +617,19 @@ def _flash_attn_bwd(
|
||||
total_q = batch_size * seqlen_q
|
||||
else:
|
||||
batch_size = cu_seqlens_q.shape[0] - 1
|
||||
seqlen_q = None
|
||||
total_q = q.shape[0]
|
||||
seqlen_q = max_seqlen_q if max_seqlen_q is not None else total_q
|
||||
|
||||
if cu_seqlens_k is None:
|
||||
batch_size, seqlen_k = k.shape[:2]
|
||||
total_k = batch_size * seqlen_k
|
||||
else:
|
||||
batch_size = cu_seqlens_k.shape[0] - 1
|
||||
seqlen_k = None
|
||||
total_k = k.shape[0]
|
||||
seqlen_k = max_seqlen_k if max_seqlen_k is not None else total_k
|
||||
|
||||
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
|
||||
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
|
||||
|
||||
num_head_kv = k.shape[-2]
|
||||
head_dim_v = v.shape[-1]
|
||||
@@ -724,7 +729,6 @@ def _flash_attn_bwd(
|
||||
head_dim_rounded = (head_dim + 32 - 1) // 32 * 32
|
||||
|
||||
if cu_seqlens_q is None:
|
||||
seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size
|
||||
dq_accum = torch.empty(
|
||||
batch_size,
|
||||
num_head,
|
||||
@@ -748,10 +752,10 @@ def _flash_attn_bwd(
|
||||
dpsum = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
||||
lse_log2 = torch.empty(num_head, total_q_rounded_padded, dtype=torch.float32, device=device)
|
||||
|
||||
if qhead_per_kvhead > 1:
|
||||
dKV_postprocess = qhead_per_kvhead > 1
|
||||
if dKV_postprocess:
|
||||
head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32
|
||||
if cu_seqlens_k is None:
|
||||
seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size
|
||||
num_n_blocks = seqlen_k_rounded // n_block_size
|
||||
if cluster_size == 2 and num_n_blocks % cluster_size != 0:
|
||||
seqlen_k_rounded = seqlen_k_rounded + n_block_size
|
||||
@@ -805,7 +809,15 @@ def _flash_attn_bwd(
|
||||
dV_semaphore = None
|
||||
|
||||
# Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum.
|
||||
compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads)
|
||||
compile_key_pre = (
|
||||
compute_capability,
|
||||
dtype,
|
||||
head_dim_v,
|
||||
m_block_size,
|
||||
num_threads,
|
||||
cu_seqlens_q is None,
|
||||
seqused_q is None,
|
||||
)
|
||||
if compile_key_pre not in _flash_attn_bwd.compile_cache_pre:
|
||||
o_tensor, do_tensor = [to_cute_tensor(t) for t in (out, dout)]
|
||||
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
|
||||
@@ -816,9 +828,11 @@ def _flash_attn_bwd(
|
||||
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
||||
for t in (cu_seqlens_q, seqused_q)
|
||||
]
|
||||
arch = compute_capability * 10
|
||||
fa_bwd_pre = FlashAttentionBackwardPreprocess(
|
||||
dtype,
|
||||
head_dim_v,
|
||||
arch,
|
||||
m_block_size,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
@@ -871,6 +885,10 @@ def _flash_attn_bwd(
|
||||
AtomLayoutNdKV,
|
||||
AtomLayoutMdQ,
|
||||
V_in_regs,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
)
|
||||
cute_aux_tensors = None
|
||||
else:
|
||||
@@ -904,6 +922,10 @@ def _flash_attn_bwd(
|
||||
mask_mod_hash,
|
||||
num_aux_tensors,
|
||||
use_block_sparsity,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
)
|
||||
num_threads = 384
|
||||
if compile_key not in _flash_attn_bwd.compile_cache:
|
||||
@@ -913,7 +935,7 @@ def _flash_attn_bwd(
|
||||
dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [
|
||||
to_cute_tensor(t) for t in (dq_accum, dpsum, lse_log2)
|
||||
]
|
||||
if qhead_per_kvhead > 1:
|
||||
if dKV_postprocess:
|
||||
dk_accum_tensor, dv_accum_tensor = [
|
||||
to_cute_tensor(t) for t in (dk_accum, dv_accum)
|
||||
]
|
||||
@@ -1011,8 +1033,8 @@ def _flash_attn_bwd(
|
||||
lse_log2_tensor,
|
||||
dpsum_tensor,
|
||||
dq_accum_tensor,
|
||||
dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor,
|
||||
dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor,
|
||||
dk_tensor if not dKV_postprocess else dk_accum_tensor,
|
||||
dv_tensor if not dKV_postprocess else dv_accum_tensor,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
cu_seqlens_q_tensor,
|
||||
@@ -1049,8 +1071,8 @@ def _flash_attn_bwd(
|
||||
lse_log2,
|
||||
dpsum,
|
||||
dq_accum,
|
||||
dk if qhead_per_kvhead == 1 else dk_accum,
|
||||
dv if qhead_per_kvhead == 1 else dv_accum,
|
||||
dk if not dKV_postprocess else dk_accum,
|
||||
dv if not dKV_postprocess else dv_accum,
|
||||
softmax_scale,
|
||||
current_stream,
|
||||
cu_seqlens_q,
|
||||
@@ -1069,7 +1091,19 @@ def _flash_attn_bwd(
|
||||
|
||||
num_threads = 256 if compute_capability == 9 else 128
|
||||
# Postprocess kernel: convert dq_accum from float32 to dq in bf16/fp16
|
||||
compile_key_post = (dtype, head_dim, m_block_size, num_threads, AtomLayoutMdQ, dQ_swapAB)
|
||||
compile_key_post = (
|
||||
compute_capability,
|
||||
dtype,
|
||||
head_dim,
|
||||
m_block_size,
|
||||
num_threads,
|
||||
AtomLayoutMdQ,
|
||||
dQ_swapAB,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
)
|
||||
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
||||
dq_accum_tensor = to_cute_tensor(dq_accum)
|
||||
dq_tensor = to_cute_tensor(dq)
|
||||
@@ -1101,9 +1135,21 @@ def _flash_attn_bwd(
|
||||
current_stream,
|
||||
)
|
||||
|
||||
if qhead_per_kvhead > 1:
|
||||
if dKV_postprocess:
|
||||
# Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16
|
||||
compile_key_post = (dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB)
|
||||
compile_key_post = (
|
||||
compute_capability,
|
||||
dtype,
|
||||
head_dim,
|
||||
n_block_size,
|
||||
num_threads,
|
||||
AtomLayoutNdKV,
|
||||
dKV_swapAB,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
)
|
||||
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
||||
dk_accum_tensor = to_cute_tensor(dk_accum)
|
||||
dk_tensor = to_cute_tensor(dk)
|
||||
@@ -1111,8 +1157,9 @@ def _flash_attn_bwd(
|
||||
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
||||
for t in (cu_seqlens_k, seqused_k)
|
||||
]
|
||||
arch = compute_capability * 10
|
||||
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
||||
dtype, head_dim, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
|
||||
dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
|
||||
)
|
||||
# TODO: check @can_implement
|
||||
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
||||
@@ -1134,12 +1181,17 @@ def _flash_attn_bwd(
|
||||
current_stream,
|
||||
)
|
||||
compile_key_post = (
|
||||
compute_capability,
|
||||
dtype,
|
||||
head_dim_v,
|
||||
n_block_size,
|
||||
num_threads,
|
||||
AtomLayoutNdKV,
|
||||
dKV_swapAB,
|
||||
cu_seqlens_q is None,
|
||||
cu_seqlens_k is None,
|
||||
seqused_q is None,
|
||||
seqused_k is None,
|
||||
)
|
||||
if compile_key_post not in _flash_attn_bwd.compile_cache_post:
|
||||
dv_accum_tensor = to_cute_tensor(dv_accum)
|
||||
@@ -1148,8 +1200,9 @@ def _flash_attn_bwd(
|
||||
to_cute_tensor(t, assumed_align=4) if t is not None else None
|
||||
for t in (cu_seqlens_k, seqused_k)
|
||||
]
|
||||
arch = compute_capability * 10
|
||||
fa_bwd_post = FlashAttentionBackwardPostprocess(
|
||||
dtype, head_dim_v, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
|
||||
dtype, head_dim_v, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB
|
||||
)
|
||||
# TODO: check @can_implement
|
||||
_flash_attn_bwd.compile_cache_post[compile_key_post] = cute.compile(
|
||||
@@ -1263,6 +1316,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
cu_seqlens_k: Optional[torch.Tensor],
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
@@ -1274,8 +1329,6 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
deterministic: bool = False,
|
||||
score_mod: Optional[Callable] = None,
|
||||
aux_tensors: Optional[list] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
):
|
||||
out, lse = _flash_attn_fwd(
|
||||
q,
|
||||
@@ -1285,6 +1338,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
page_table=page_table,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
@@ -1296,8 +1351,6 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
pack_gqa=pack_gqa,
|
||||
score_mod=score_mod,
|
||||
aux_tensors=aux_tensors,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
|
||||
ctx.softmax_scale = softmax_scale
|
||||
@@ -1305,12 +1358,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.deterministic = deterministic
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
return out, lse
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
||||
assert seqused_q == seqused_k == None
|
||||
assert ctx.softcap == 0.0
|
||||
dq, dk, dv = _flash_attn_bwd(
|
||||
q,
|
||||
@@ -1322,10 +1376,14 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
ctx.softcap,
|
||||
window_size_left=ctx.window_size[0],
|
||||
window_size_right=ctx.window_size[1],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_q=ctx.max_seqlen_q,
|
||||
max_seqlen_k=ctx.max_seqlen_k,
|
||||
deterministic=ctx.deterministic,
|
||||
)
|
||||
|
||||
@@ -1376,6 +1434,8 @@ def flash_attn_varlen_func(
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: Optional[torch.Tensor] = None,
|
||||
cu_seqlens_k: Optional[torch.Tensor] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
seqused_q: Optional[torch.Tensor] = None,
|
||||
seqused_k: Optional[torch.Tensor] = None,
|
||||
page_table: Optional[torch.Tensor] = None,
|
||||
@@ -1389,8 +1449,6 @@ def flash_attn_varlen_func(
|
||||
deterministic: bool = False,
|
||||
score_mod: Optional[Callable] = None,
|
||||
aux_tensors: Optional[list] = None,
|
||||
max_seqlen_q: Optional[int] = None,
|
||||
max_seqlen_k: Optional[int] = None,
|
||||
):
|
||||
return FlashAttnVarlenFunc.apply(
|
||||
q,
|
||||
@@ -1400,6 +1458,8 @@ def flash_attn_varlen_func(
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
page_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
@@ -1411,8 +1471,6 @@ def flash_attn_varlen_func(
|
||||
deterministic,
|
||||
score_mod,
|
||||
aux_tensors,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ class SeqlenInfo:
|
||||
class SeqlenInfoQK:
|
||||
offset_q: cutlass.Int32
|
||||
offset_k: cutlass.Int32
|
||||
padded_offset_q: cutlass.Int32
|
||||
padded_offset_k: cutlass.Int32
|
||||
seqlen_q: cutlass.Int32
|
||||
seqlen_k: cutlass.Int32
|
||||
has_cu_seqlens_q: cutlass.Constexpr[bool]
|
||||
@@ -54,9 +56,21 @@ class SeqlenInfoQK:
|
||||
mCuSeqlensK: Optional[cute.Tensor] = None,
|
||||
mSeqUsedQ: Optional[cute.Tensor] = None,
|
||||
mSeqUsedK: Optional[cute.Tensor] = None,
|
||||
tile_m: cutlass.Constexpr[cutlass.Int32] = 128,
|
||||
tile_n: cutlass.Constexpr[cutlass.Int32] = 128,
|
||||
):
|
||||
offset_q = 0 if const_expr(mCuSeqlensQ is None) else mCuSeqlensQ[batch_idx]
|
||||
offset_k = 0 if const_expr(mCuSeqlensK is None) else mCuSeqlensK[batch_idx]
|
||||
padded_offset_q = (
|
||||
0
|
||||
if const_expr(mCuSeqlensQ is None)
|
||||
else (offset_q + batch_idx * tile_m) // tile_m * tile_m
|
||||
)
|
||||
padded_offset_k = (
|
||||
0
|
||||
if const_expr(mCuSeqlensK is None)
|
||||
else (offset_k + batch_idx * tile_n) // tile_n * tile_n
|
||||
)
|
||||
if const_expr(mSeqUsedQ is not None):
|
||||
seqlen_q = mSeqUsedQ[batch_idx]
|
||||
else:
|
||||
@@ -80,6 +94,8 @@ class SeqlenInfoQK:
|
||||
return SeqlenInfoQK(
|
||||
offset_q,
|
||||
offset_k,
|
||||
padded_offset_q,
|
||||
padded_offset_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
has_cu_seqlens_q,
|
||||
@@ -88,23 +104,35 @@ class SeqlenInfoQK:
|
||||
has_seqused_k,
|
||||
)
|
||||
|
||||
def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor:
|
||||
def offset_batch_Q(
|
||||
self,
|
||||
mQ: cute.Tensor,
|
||||
batch_idx: Int32,
|
||||
dim: int,
|
||||
padded: cutlass.Constexpr[bool] = False,
|
||||
) -> cute.Tensor:
|
||||
"""Seqlen must be the first dimension of mQ"""
|
||||
if const_expr(not self.has_cu_seqlens_q):
|
||||
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mQ) - 1 - dim)
|
||||
return mQ[idx]
|
||||
else:
|
||||
offset = (
|
||||
self.offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, self.offset_q)
|
||||
)
|
||||
offset_q = self.offset_q if const_expr(not padded) else self.padded_offset_q
|
||||
offset = offset_q if const_expr(cute.rank(mQ.shape[0]) == 1) else (0, offset_q)
|
||||
idx = (offset,) + (0,) * (cute.rank(mQ) - 1)
|
||||
return cute.domain_offset(idx, mQ)
|
||||
|
||||
def offset_batch_K(self, mK: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor:
|
||||
def offset_batch_K(
|
||||
self,
|
||||
mK: cute.Tensor,
|
||||
batch_idx: Int32,
|
||||
dim: int,
|
||||
padded: cutlass.Constexpr[bool] = False,
|
||||
) -> cute.Tensor:
|
||||
"""Seqlen must be the first dimension of mK"""
|
||||
if const_expr(not self.has_cu_seqlens_k):
|
||||
idx = (None,) * dim + (batch_idx,) + (None,) * (cute.rank(mK) - 1 - dim)
|
||||
return mK[idx]
|
||||
else:
|
||||
idx = (self.offset_k,) + (0,) * (cute.rank(mK) - 1)
|
||||
offset_k = self.offset_k if const_expr(not padded) else self.padded_offset_k
|
||||
idx = (offset_k,) + (0,) * (cute.rank(mK) - 1)
|
||||
return cute.domain_offset(idx, mK)
|
||||
|
||||
@@ -92,7 +92,12 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random",
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
||||
lengths = torch.randint(
|
||||
max(0 if zero_lengths else 1, max_seqlen // 3),
|
||||
max_seqlen + 1,
|
||||
(batch_size, 1),
|
||||
device=device,
|
||||
)
|
||||
|
||||
if zero_lengths:
|
||||
for i in range(batch_size):
|
||||
|
||||
@@ -72,6 +72,7 @@ class TileSchedulerArguments(ParamsBase):
|
||||
is_persistent: cutlass.Constexpr[bool] = False
|
||||
lpt: cutlass.Constexpr[bool] = False
|
||||
is_split_kv: cutlass.Constexpr[bool] = False
|
||||
head_swizzle: cutlass.Constexpr[bool] = False
|
||||
|
||||
|
||||
class SingleTileScheduler:
|
||||
@@ -512,6 +513,7 @@ class SingleTileVarlenScheduler:
|
||||
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
|
||||
lpt: cutlass.Constexpr[bool] = False
|
||||
is_split_kv: cutlass.Constexpr[bool] = False
|
||||
head_swizzle: cutlass.Constexpr[bool] = False
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
@@ -537,6 +539,7 @@ class SingleTileVarlenScheduler:
|
||||
qhead_per_kvhead_packgqa=args.qhead_per_kvhead_packgqa,
|
||||
lpt=args.lpt,
|
||||
is_split_kv=args.is_split_kv,
|
||||
head_swizzle=args.head_swizzle,
|
||||
)
|
||||
|
||||
def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None):
|
||||
@@ -638,7 +641,7 @@ class SingleTileVarlenScheduler:
|
||||
)
|
||||
num_m_blocks = cute.arch.shuffle_sync(num_m_blocks, batch_idx_in_group)
|
||||
mh_block = next_tile_idx - group_start_tile - num_m_blocks_prev_lane * params.num_head
|
||||
if cutlass.const_expr(params.lpt):
|
||||
if cutlass.const_expr(params.lpt or params.head_swizzle):
|
||||
# This is a version of the SingleTileLPTScheduler, complicated by the fact that
|
||||
# the seqlen can vary per batch.
|
||||
# TODO: is there any case where num_m_blocks is 0?
|
||||
@@ -677,7 +680,8 @@ class SingleTileVarlenScheduler:
|
||||
block = l2_mod // nheads_in_this_section
|
||||
head_idx_residual = l2_mod - block * nheads_in_this_section
|
||||
head_idx = section_idx * nheads_in_l2 + head_idx_residual
|
||||
block = num_m_blocks - 1 - block
|
||||
if cutlass.const_expr(params.lpt):
|
||||
block = num_m_blocks - 1 - block
|
||||
else:
|
||||
head_idx = mh_block // num_m_blocks
|
||||
block = mh_block - head_idx * num_m_blocks
|
||||
|
||||
@@ -50,7 +50,7 @@ VERBOSE = True
|
||||
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
|
||||
# @pytest.mark.parametrize("local_enum", [0])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
@@ -58,9 +58,9 @@ VERBOSE = True
|
||||
# @pytest.mark.parametrize("d", [64, 128, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [64, 96, 128, 192])
|
||||
# @pytest.mark.parametrize("d", [64, 128])
|
||||
# @pytest.mark.parametrize("d", [128, 192])
|
||||
@pytest.mark.parametrize("d", [64, 128])
|
||||
# @pytest.mark.parametrize("d", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
@@ -113,7 +113,7 @@ def test_flash_attn_output(
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
batch_size = 9 if seqlen_k <= 2048 else 2
|
||||
# batch_size = 1
|
||||
# batch_size = 2
|
||||
nheads = 6
|
||||
# nheads = 1
|
||||
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
|
||||
@@ -236,7 +236,7 @@ def test_flash_attn_output(
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
# num_splits_vals = [1, 3]
|
||||
pack_gqa_vals = [False, True, None]
|
||||
pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False]
|
||||
# SplitKV is not supported for hdim >= 192
|
||||
# pack_gqa_vals = [False]
|
||||
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
|
||||
@@ -371,17 +371,17 @@ def test_flash_attn_output(
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mqa"])
|
||||
@pytest.mark.parametrize("has_learnable_sink", [False, True])
|
||||
# @pytest.mark.parametrize("has_learnable_sink", [False])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||
# @pytest.mark.parametrize("has_learnable_sink", [False, True])
|
||||
@pytest.mark.parametrize("has_learnable_sink", [False])
|
||||
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||
@pytest.mark.parametrize("has_qv", [False])
|
||||
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
|
||||
@pytest.mark.parametrize("softcap", [0.0])
|
||||
# @pytest.mark.parametrize("local", [False, True])
|
||||
@pytest.mark.parametrize("local", [False])
|
||||
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
|
||||
# @pytest.mark.parametrize("local_enum", [0])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||
@@ -393,7 +393,7 @@ def test_flash_attn_output(
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [64, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [128, 192])
|
||||
@pytest.mark.parametrize("d", [128])
|
||||
@pytest.mark.parametrize("d", [64, 128])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
@@ -419,20 +419,37 @@ def test_flash_attn_output(
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"])
|
||||
# @pytest.mark.parametrize("varlen_mode", ["full"])
|
||||
@pytest.mark.parametrize(
|
||||
"zero_lengths_q, zero_lengths_k",
|
||||
[
|
||||
(False, False),
|
||||
(True, False),
|
||||
(False, True),
|
||||
(True, True),
|
||||
],
|
||||
)
|
||||
def test_flash_attn_varlen_output(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
d,
|
||||
add_unused_qkv,
|
||||
causal,
|
||||
local,
|
||||
local_enum,
|
||||
softcap,
|
||||
deterministic,
|
||||
has_qv,
|
||||
has_learnable_sink,
|
||||
mha_type,
|
||||
dtype,
|
||||
varlen_mode,
|
||||
zero_lengths_q,
|
||||
zero_lengths_k,
|
||||
):
|
||||
local = local_enum > 0
|
||||
if local and causal:
|
||||
pytest.skip()
|
||||
if (
|
||||
causal or local
|
||||
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
|
||||
@@ -442,13 +459,12 @@ def test_flash_attn_varlen_output(
|
||||
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
|
||||
batch_size = 49 if seqlen_q <= 1024 else 7
|
||||
nheads = 6
|
||||
# batch_size = 1
|
||||
# nheads = 1
|
||||
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
|
||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY:
|
||||
dv_vals = [d]
|
||||
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]
|
||||
attention_chunk_vals = [0]
|
||||
@@ -490,6 +506,12 @@ def test_flash_attn_varlen_output(
|
||||
window_size = (
|
||||
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
|
||||
)
|
||||
if local_enum == 2:
|
||||
window_size = (None, window_size[1])
|
||||
elif local_enum == 3:
|
||||
window_size = (window_size[0], None)
|
||||
if local:
|
||||
print("window size = ", window_size)
|
||||
if has_learnable_sink:
|
||||
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
@@ -505,18 +527,19 @@ def test_flash_attn_varlen_output(
|
||||
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
|
||||
qv = qv_ref.detach() if has_qv else None
|
||||
query_padding_mask = generate_random_padding_mask(
|
||||
seqlen_q, batch_size, device, mode="random", zero_lengths=False
|
||||
seqlen_q,
|
||||
batch_size,
|
||||
device,
|
||||
mode=varlen_mode,
|
||||
zero_lengths=zero_lengths_q,
|
||||
)
|
||||
# TODO: test zero_lengths
|
||||
key_padding_mask = generate_random_padding_mask(
|
||||
# seqlen_k, batch_size, device, mode="random", zero_lengths=True
|
||||
seqlen_k,
|
||||
batch_size,
|
||||
device,
|
||||
mode="random",
|
||||
zero_lengths=False,
|
||||
mode=varlen_mode,
|
||||
zero_lengths=zero_lengths_k,
|
||||
)
|
||||
|
||||
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
|
||||
if add_unused:
|
||||
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
|
||||
@@ -570,6 +593,8 @@ def test_flash_attn_varlen_output(
|
||||
query_unused_mask=query_unused_mask,
|
||||
key_unused_mask=key_unused_mask,
|
||||
)
|
||||
print("cu_seqlens_q = ", cu_seqlens_q)
|
||||
print("cu_seqlens_k = ", cu_seqlens_k)
|
||||
q_unpad, k_unpad, v_unpad = [
|
||||
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
|
||||
]
|
||||
@@ -619,11 +644,11 @@ def test_flash_attn_varlen_output(
|
||||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
|
||||
rtol = 2 if softcap == 0.0 else 3
|
||||
|
||||
pack_gqa_vals = [False, True, None]
|
||||
pack_gqa_vals = [False, True, None] if not TEST_BWD_ONLY else [False]
|
||||
# pack_gqa_vals = [False]
|
||||
# num_splits_vals = [1, 3]
|
||||
# SplitKV is not supported for hdim >= 192
|
||||
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]
|
||||
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
|
||||
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
||||
# SplitKV not supported on SM90 - skip this iteration
|
||||
if IS_SM90 and num_splits > 1:
|
||||
@@ -634,7 +659,8 @@ def test_flash_attn_varlen_output(
|
||||
v_unpad,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
# max_seqlen_k,
|
||||
max_seqlen_q=seqlen_q,
|
||||
max_seqlen_k=seqlen_k,
|
||||
# seqused_q=seqused_q,
|
||||
# seqused_k=seqused_k,
|
||||
causal=causal,
|
||||
@@ -647,6 +673,7 @@ def test_flash_attn_varlen_output(
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if query_unused_mask is not None:
|
||||
@@ -670,10 +697,10 @@ def test_flash_attn_varlen_output(
|
||||
and not attention_chunk != 0
|
||||
and dv == d
|
||||
and not has_learnable_sink
|
||||
and False
|
||||
# and False
|
||||
):
|
||||
g_unpad = torch.randn_like(out_unpad)
|
||||
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
||||
# do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
||||
# import flash_attn_3_cuda
|
||||
# dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
|
||||
# g_unpad,
|
||||
|
||||
@@ -31,7 +31,7 @@ from flash_attn.cute.interface import (
|
||||
|
||||
DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE"
|
||||
IS_SM90 = torch.cuda.get_device_capability()[0] == 9
|
||||
|
||||
INCREASED_TRIALS = False
|
||||
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@@ -304,7 +304,7 @@ def test_flash_attn_output(
|
||||
dv_pt - dv_ref
|
||||
).abs().max().item() + dv_atol
|
||||
|
||||
num_iters = 20_000
|
||||
num_iters = 10_000 if INCREASED_TRIALS else 1000
|
||||
for i in range(num_iters):
|
||||
dq2, dk2, dv2, = _flash_attn_bwd(
|
||||
q, k, v, out, g, lse,
|
||||
@@ -342,3 +342,435 @@ def test_flash_attn_output(
|
||||
|
||||
print(f"✅ Iteration {i} passed!")
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["gqa"])
|
||||
# @pytest.mark.parametrize("has_learnable_sink", [False, True])
|
||||
@pytest.mark.parametrize("has_learnable_sink", [False])
|
||||
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||
@pytest.mark.parametrize("has_qv", [False])
|
||||
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||
@pytest.mark.parametrize("deterministic", [True])
|
||||
# @pytest.mark.parametrize("softcap", [0.0, 15.0])
|
||||
@pytest.mark.parametrize("softcap", [0.0])
|
||||
@pytest.mark.parametrize("local_enum", [0, 1, 2, 3])
|
||||
# @pytest.mark.parametrize("local_enum", [0, 1])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
# @pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||
@pytest.mark.parametrize("add_unused_qkv", [False])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [64, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [128, 192])
|
||||
@pytest.mark.parametrize("d", [64, 128])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1024, 1024),
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("varlen_mode", ["random", "third", "full"])
|
||||
# @pytest.mark.parametrize("varlen_mode", ["random"])
|
||||
@pytest.mark.parametrize(
|
||||
"zero_lengths_q, zero_lengths_k",
|
||||
[
|
||||
(False, False),
|
||||
(True, False),
|
||||
(False, True),
|
||||
(True, True),
|
||||
],
|
||||
)
|
||||
def test_flash_attn_varlen_output(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
d,
|
||||
add_unused_qkv,
|
||||
causal,
|
||||
local_enum,
|
||||
softcap,
|
||||
deterministic,
|
||||
has_qv,
|
||||
has_learnable_sink,
|
||||
mha_type,
|
||||
dtype,
|
||||
varlen_mode,
|
||||
zero_lengths_q,
|
||||
zero_lengths_k,
|
||||
):
|
||||
local = local_enum > 0
|
||||
if local and causal:
|
||||
pytest.skip()
|
||||
if (
|
||||
causal or local
|
||||
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
|
||||
seqlen_k = seqlen_q
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
|
||||
batch_size = 49 if seqlen_q <= 1024 else 7
|
||||
nheads = 6
|
||||
# nheads = 1
|
||||
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
|
||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
|
||||
dv_vals = [d] # override
|
||||
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k else [0]
|
||||
attention_chunk_vals = [0]
|
||||
for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):
|
||||
q_ref = torch.randn(
|
||||
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
|
||||
)
|
||||
if softcap > 0.0:
|
||||
# Ensure the values of qk are at least within softcap range.
|
||||
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
|
||||
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
|
||||
k_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
.requires_grad_()
|
||||
)
|
||||
v_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
.requires_grad_()
|
||||
)
|
||||
if has_qv:
|
||||
qv_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
)
|
||||
else:
|
||||
qv_ref = None
|
||||
# Put window_size after QKV randn so that window_size changes from test to test
|
||||
window_size = (
|
||||
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
|
||||
)
|
||||
if local_enum == 2:
|
||||
window_size = (None, window_size[1])
|
||||
elif local_enum == 3:
|
||||
window_size = (window_size[0], None)
|
||||
if local:
|
||||
print("window size = ", window_size)
|
||||
if has_learnable_sink:
|
||||
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
learnable_sink = None
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
q_descale, k_descale, v_descale = [
|
||||
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
||||
* 2
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
q_descale, k_descale, v_descale = None, None, None
|
||||
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
|
||||
qv = qv_ref.detach() if has_qv else None
|
||||
query_padding_mask = generate_random_padding_mask(
|
||||
seqlen_q,
|
||||
batch_size,
|
||||
device,
|
||||
mode=varlen_mode,
|
||||
zero_lengths=zero_lengths_q,
|
||||
)
|
||||
key_padding_mask = generate_random_padding_mask(
|
||||
seqlen_k,
|
||||
batch_size,
|
||||
device,
|
||||
mode=varlen_mode,
|
||||
zero_lengths=zero_lengths_k,
|
||||
)
|
||||
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
|
||||
if add_unused:
|
||||
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
|
||||
attn_mask = torch.logical_and(padding_mask, another_mask)
|
||||
unused_mask = torch.logical_xor(
|
||||
torch.logical_or(padding_mask, another_mask), attn_mask
|
||||
)
|
||||
else:
|
||||
attn_mask = padding_mask
|
||||
unused_mask = None
|
||||
return attn_mask, unused_mask
|
||||
|
||||
query_padding_mask, query_unused_mask = _gen_unused_masks(
|
||||
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
|
||||
)
|
||||
# query_padding_mask[:] = True
|
||||
# query_unused_mask = None
|
||||
key_padding_mask, key_unused_mask = _gen_unused_masks(
|
||||
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
|
||||
)
|
||||
|
||||
if causal or local:
|
||||
key_padding_mask = query_padding_mask
|
||||
|
||||
(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
qv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
qv,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
) = generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
qv=qv,
|
||||
kvpacked=False,
|
||||
query_unused_mask=query_unused_mask,
|
||||
key_unused_mask=key_unused_mask,
|
||||
)
|
||||
print("cu_seqlens_q = ", cu_seqlens_q)
|
||||
print("cu_seqlens_k = ", cu_seqlens_k)
|
||||
q_unpad, k_unpad, v_unpad = [
|
||||
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
|
||||
]
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=causal,
|
||||
qv=qv_ref,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=causal,
|
||||
qv=qv_ref,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||
)
|
||||
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
if query_unused_mask is not None:
|
||||
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
||||
|
||||
# Numerical error if we just do any arithmetic on out_ref
|
||||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
|
||||
rtol = 2 if softcap == 0.0 else 3
|
||||
|
||||
out_unpad, lse = flash_attn_varlen_func(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
# max_seqlen_k,
|
||||
# seqused_q=seqused_q,
|
||||
# seqused_k=seqused_k,
|
||||
max_seqlen_q=seqlen_q,
|
||||
max_seqlen_k=seqlen_k,
|
||||
causal=causal,
|
||||
# qv=qv_unpad,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale, v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
# attention_chunk=attention_chunk,
|
||||
learnable_sink=learnable_sink,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=False,
|
||||
deterministic=deterministic,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if query_unused_mask is not None:
|
||||
out.masked_fill_(q_zero_masking, 0.0)
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
# if not causal:
|
||||
# print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
|
||||
# breakpoint()
|
||||
|
||||
# Check that FlashAttention's numerical error is at most 3x the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= rtol * (
|
||||
out_pt - out_ref
|
||||
).abs().max().item() + fwd_atol
|
||||
|
||||
if (
|
||||
dtype != torch.float8_e4m3fn
|
||||
and not has_qv
|
||||
and not dv > 256
|
||||
and not attention_chunk != 0
|
||||
and dv == d
|
||||
and not has_learnable_sink
|
||||
# and False
|
||||
):
|
||||
g_unpad = torch.randn_like(out_unpad)
|
||||
# do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
||||
# import flash_attn_3_cuda
|
||||
# dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen(
|
||||
# g_unpad,
|
||||
# q_unpad,
|
||||
# k_unpad,
|
||||
# v_unpad,
|
||||
# out_unpad,
|
||||
# lse,
|
||||
# None,
|
||||
# None,
|
||||
# None,
|
||||
# cu_seqlens_q,
|
||||
# cu_seqlens_k,
|
||||
# None, None,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
# d ** (-0.5),
|
||||
# causal,
|
||||
# window_size[0], window_size[1],
|
||||
# softcap,
|
||||
# deterministic,
|
||||
# 0, # sm_margin
|
||||
# )
|
||||
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
|
||||
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
|
||||
)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
dk = dk_pad_fn(dk_unpad)
|
||||
dv = dk_pad_fn(dv_unpad)
|
||||
if key_unused_mask is not None:
|
||||
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
||||
dk.masked_fill_(k_zero_masking, 0.0)
|
||||
dv.masked_fill_(k_zero_masking, 0.0)
|
||||
if query_unused_mask is not None:
|
||||
dq.masked_fill_(q_zero_masking, 0.0)
|
||||
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
|
||||
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
|
||||
# assert dq_accum.abs().max().item() == 0.0
|
||||
g = output_pad_fn(g_unpad)
|
||||
|
||||
# qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float()
|
||||
# qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
# dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float())
|
||||
# P = torch.softmax(qk, -1)
|
||||
# dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1))
|
||||
# dQ = torch.einsum('bhts,bshd->bthd', dP, k.float())
|
||||
# dV = torch.einsum('bhts,bthd->bshd', P, g.float())
|
||||
# dK = torch.einsum('bhts,bthd->bshd', dP, q.float())
|
||||
|
||||
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref = torch.autograd.grad(
|
||||
out_ref, (q_ref, k_ref, v_ref), g
|
||||
)
|
||||
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
# breakpoint()
|
||||
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dq - dq_ref).abs().max().item() <= rtol * (
|
||||
dq_pt - dq_ref
|
||||
).abs().max().item() + dq_atol
|
||||
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dk - dk_ref).abs().max().item() <= rtol * (
|
||||
dk_pt - dk_ref
|
||||
).abs().max().item() + dk_atol
|
||||
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dv - dv_ref).abs().max().item() <= rtol * (
|
||||
dv_pt - dv_ref
|
||||
).abs().max().item() + dv_atol
|
||||
|
||||
num_iters = 10_000 if INCREASED_TRIALS else 1000
|
||||
|
||||
for i in range(num_iters):
|
||||
dq_unpad2, dk_unpad2, dv_unpad2 = _flash_attn_bwd(
|
||||
q_unpad, k_unpad, v_unpad, out_unpad, g_unpad, lse,
|
||||
causal=causal,
|
||||
window_size_left=window_size[0],
|
||||
window_size_right=window_size[1],
|
||||
deterministic=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=seqlen_q,
|
||||
max_seqlen_k=seqlen_k,
|
||||
)
|
||||
|
||||
diff_dq = (dq_unpad - dq_unpad2).abs()
|
||||
max_idx = diff_dq.argmax()
|
||||
if i % 100 == 0:
|
||||
print(f"dQ max diff: {diff_dq.max().item()}")
|
||||
print(f" at index {max_idx.item()}: dQ={dq_unpad.flatten()[max_idx].item()}, dQ2={dq_unpad2.flatten()[max_idx].item()}")
|
||||
|
||||
diff_dk = (dk_unpad - dk_unpad2).abs()
|
||||
max_idx = diff_dk.argmax()
|
||||
if i % 100 == 0:
|
||||
print(f"dK max diff: {diff_dk.max().item()}")
|
||||
print(f" at index {max_idx.item()}: dK={dk_unpad.flatten()[max_idx].item()}, dK2={dk_unpad2.flatten()[max_idx].item()}")
|
||||
|
||||
diff_dv = (dv_unpad - dv_unpad2).abs()
|
||||
max_idx = diff_dv.argmax()
|
||||
if i % 100 == 0:
|
||||
print(f"dV max diff: {diff_dv.max().item()}")
|
||||
print(f" at index {max_idx.item()}: dV={dv_unpad.flatten()[max_idx].item()}, dV2={dv_unpad2.flatten()[max_idx].item()}")
|
||||
|
||||
assert torch.equal(dq_unpad, dq_unpad2)
|
||||
assert torch.equal(dk_unpad, dk_unpad2)
|
||||
assert torch.equal(dv_unpad, dv_unpad2)
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f"✅ Iteration {i} passed!")
|
||||
@@ -43,8 +43,8 @@ def test_varlen(
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
# SM90/SM100 backward pass doesn't support varlen yet
|
||||
skip_backward = IS_SM90 or torch.cuda.get_device_capability()[0] == 10
|
||||
# SM90 backward pass doesn't support varlen yet
|
||||
skip_backward = IS_SM90
|
||||
|
||||
ok = check_varlen_vs_torch_flash(
|
||||
q, k, v,
|
||||
@@ -128,7 +128,7 @@ def check_varlen_vs_torch_flash(
|
||||
if not ok_fwd:
|
||||
return False
|
||||
|
||||
# Skip backward if not supported (e.g., SM100 varlen)
|
||||
# Skip backward if not supported (e.g., SM90 varlen)
|
||||
if skip_backward:
|
||||
return True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user