mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions
This commit is contained in:
@@ -1250,10 +1250,10 @@ def consume_block_sparse_mma_bwd_sm90(
|
||||
is_causal: cutlass.Constexpr,
|
||||
is_local: cutlass.Constexpr,
|
||||
thr_mma_SdP,
|
||||
softmax_scale,
|
||||
seqlen,
|
||||
subtile_factor: cutlass.Constexpr,
|
||||
m_block_max: int,
|
||||
score_mod_fn=None,
|
||||
score_mod_bwd_fn=None,
|
||||
subtile_factor: cutlass.Constexpr = 1,
|
||||
m_block_max: int = 0,
|
||||
aux_tensors=None,
|
||||
fastdiv_mods=(None, None),
|
||||
):
|
||||
@@ -1315,15 +1315,9 @@ def consume_block_sparse_mma_bwd_sm90(
|
||||
consumer_state_Q,
|
||||
consumer_state_dO,
|
||||
mask_fn=mask_fn_partial,
|
||||
score_mod_fn=score_mod_fn,
|
||||
score_mod_bwd_fn=score_mod_bwd_fn,
|
||||
dKV_accumulate=dKV_accumulate,
|
||||
thr_mma_SdP=thr_mma_SdP,
|
||||
batch_idx=batch_idx,
|
||||
head_idx=head_idx,
|
||||
n_block=n_block,
|
||||
softmax_scale=softmax_scale,
|
||||
seqlen=seqlen,
|
||||
aux_tensors=aux_tensors,
|
||||
fastdiv_mods=fastdiv_mods,
|
||||
)
|
||||
dKV_accumulate = True
|
||||
|
||||
@@ -1339,15 +1333,9 @@ def consume_block_sparse_mma_bwd_sm90(
|
||||
consumer_state_Q,
|
||||
consumer_state_dO,
|
||||
mask_fn=mask_fn_full,
|
||||
score_mod_fn=score_mod_fn,
|
||||
score_mod_bwd_fn=score_mod_bwd_fn,
|
||||
dKV_accumulate=dKV_accumulate,
|
||||
thr_mma_SdP=thr_mma_SdP,
|
||||
batch_idx=batch_idx,
|
||||
head_idx=head_idx,
|
||||
n_block=n_block,
|
||||
softmax_scale=softmax_scale,
|
||||
seqlen=seqlen,
|
||||
aux_tensors=aux_tensors,
|
||||
fastdiv_mods=fastdiv_mods,
|
||||
)
|
||||
dKV_accumulate = True
|
||||
|
||||
|
||||
@@ -1089,6 +1089,24 @@ class FlashAttentionBackwardSm90:
|
||||
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
|
||||
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
|
||||
|
||||
PdS_barrier = cutlass.pipeline.NamedBarrier(
|
||||
barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
|
||||
)
|
||||
score_mod_fn = partial(
|
||||
self.apply_score_mod,
|
||||
thr_mma_SdP=thr_mma_SdP,
|
||||
softmax_scale=softmax_scale,
|
||||
aux_tensors=aux_tensors,
|
||||
fastdiv_mods=fastdiv_mods,
|
||||
)
|
||||
score_mod_bwd_fn = partial(
|
||||
self.apply_score_mod_bwd,
|
||||
thr_mma_SdP=thr_mma_SdP,
|
||||
softmax_scale=softmax_scale,
|
||||
aux_tensors=aux_tensors,
|
||||
fastdiv_mods=fastdiv_mods,
|
||||
)
|
||||
|
||||
mma_one_m_block_all = partial(
|
||||
self.mma_one_m_block,
|
||||
warp_group_idx=warp_group_idx,
|
||||
@@ -1107,6 +1125,7 @@ class FlashAttentionBackwardSm90:
|
||||
smem_thr_copy_PdS=smem_thr_copy_PdS,
|
||||
smem_thr_copy_dQaccum=smem_thr_copy_dQaccum,
|
||||
softmax_scale_log2=softmax_scale_log2,
|
||||
PdS_barrier=PdS_barrier,
|
||||
# acc_dV=acc_dV,
|
||||
# acc_dK=acc_dK,
|
||||
)
|
||||
@@ -1123,6 +1142,20 @@ class FlashAttentionBackwardSm90:
|
||||
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
|
||||
seqlen = SeqlenInfoCls(batch_idx)
|
||||
mask = AttentionMaskCls(seqlen)
|
||||
score_mod_fn_cur = partial(
|
||||
score_mod_fn,
|
||||
batch_idx=batch_idx,
|
||||
head_idx=head_idx,
|
||||
n_block=n_block,
|
||||
seqlen_info=seqlen,
|
||||
)
|
||||
score_mod_bwd_fn_cur = partial(
|
||||
score_mod_bwd_fn,
|
||||
batch_idx=batch_idx,
|
||||
head_idx=head_idx,
|
||||
n_block=n_block,
|
||||
seqlen_info=seqlen,
|
||||
)
|
||||
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
|
||||
|
||||
if const_expr(not self.use_block_sparsity):
|
||||
@@ -1160,15 +1193,9 @@ class FlashAttentionBackwardSm90:
|
||||
consumer_state_Q,
|
||||
consumer_state_dO,
|
||||
mask_fn=mask_fn,
|
||||
score_mod_fn=score_mod_fn_cur,
|
||||
score_mod_bwd_fn=score_mod_bwd_fn_cur,
|
||||
dKV_accumulate=dKV_accumulate,
|
||||
thr_mma_SdP=thr_mma_SdP,
|
||||
batch_idx=batch_idx,
|
||||
head_idx=head_idx,
|
||||
n_block=n_block,
|
||||
softmax_scale=softmax_scale,
|
||||
seqlen=seqlen,
|
||||
aux_tensors=aux_tensors,
|
||||
fastdiv_mods=fastdiv_mods,
|
||||
)
|
||||
dKV_accumulate = True
|
||||
else:
|
||||
@@ -1185,8 +1212,8 @@ class FlashAttentionBackwardSm90:
|
||||
is_causal=self.is_causal,
|
||||
is_local=self.is_local,
|
||||
thr_mma_SdP=thr_mma_SdP,
|
||||
softmax_scale=softmax_scale,
|
||||
seqlen=seqlen,
|
||||
score_mod_fn=score_mod_fn_cur,
|
||||
score_mod_bwd_fn=score_mod_bwd_fn_cur,
|
||||
subtile_factor=self.subtile_factor,
|
||||
m_block_max=m_block_max,
|
||||
aux_tensors=aux_tensors,
|
||||
@@ -1266,16 +1293,11 @@ class FlashAttentionBackwardSm90:
|
||||
smem_thr_copy_PdS: cute.TiledCopy,
|
||||
smem_thr_copy_dQaccum: cute.TiledCopy,
|
||||
softmax_scale_log2: Float32,
|
||||
PdS_barrier: cutlass.pipeline.NamedBarrier,
|
||||
mask_fn: Optional[Callable] = None,
|
||||
score_mod_fn: Optional[Callable] = None,
|
||||
score_mod_bwd_fn: Optional[Callable] = None,
|
||||
dKV_accumulate: Boolean = True,
|
||||
thr_mma_SdP: Optional[cute.core.ThrMma] = None,
|
||||
batch_idx: Int32 = 0,
|
||||
head_idx: Int32 = 0,
|
||||
n_block: Int32 = 0,
|
||||
softmax_scale: Float32 = 1.0,
|
||||
seqlen: Optional[SeqlenInfoQK] = None,
|
||||
aux_tensors: Optional[list] = None,
|
||||
fastdiv_mods=(None, None),
|
||||
):
|
||||
consumer_state_dO_cur = (
|
||||
consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
|
||||
@@ -1298,18 +1320,7 @@ class FlashAttentionBackwardSm90:
|
||||
cute.autovec_copy(acc_S, acc_S_pre)
|
||||
|
||||
if const_expr(self.score_mod is not None):
|
||||
self.apply_score_mod(
|
||||
acc_S,
|
||||
thr_mma_SdP,
|
||||
batch_idx,
|
||||
head_idx,
|
||||
m_block,
|
||||
n_block,
|
||||
softmax_scale,
|
||||
seqlen,
|
||||
aux_tensors,
|
||||
fastdiv_mods,
|
||||
)
|
||||
score_mod_fn(acc_S, m_block=m_block)
|
||||
|
||||
# (3) [Pointwise 1] P = exp(S - LSE)
|
||||
if cutlass.const_expr(mask_fn is not None):
|
||||
@@ -1328,9 +1339,7 @@ class FlashAttentionBackwardSm90:
|
||||
if const_expr(not self.mma_dkv_is_rs):
|
||||
# sync to ensure P has already been used in the previous iteration before overwriting
|
||||
if const_expr(self.PdS_stage == 1):
|
||||
cute.arch.barrier(
|
||||
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
|
||||
)
|
||||
PdS_barrier.arrive_and_wait()
|
||||
tPrP = smem_thr_copy_PdS.retile(tdVrP)
|
||||
cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS])
|
||||
|
||||
@@ -1342,19 +1351,7 @@ class FlashAttentionBackwardSm90:
|
||||
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
|
||||
|
||||
if const_expr(self.score_mod_bwd is not None):
|
||||
self.apply_score_mod_bwd(
|
||||
acc_dP,
|
||||
acc_S_pre,
|
||||
thr_mma_SdP,
|
||||
batch_idx,
|
||||
head_idx,
|
||||
m_block,
|
||||
n_block,
|
||||
softmax_scale,
|
||||
seqlen,
|
||||
aux_tensors,
|
||||
fastdiv_mods,
|
||||
)
|
||||
score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
|
||||
|
||||
# Convert dS from f32 -> f16
|
||||
tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)
|
||||
@@ -1367,9 +1364,7 @@ class FlashAttentionBackwardSm90:
|
||||
# (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
|
||||
if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
|
||||
cute.arch.fence_view_async_shared()
|
||||
cute.arch.barrier(
|
||||
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
|
||||
)
|
||||
PdS_barrier.arrive_and_wait()
|
||||
|
||||
# R2S for dS
|
||||
tdSrdS = smem_thr_copy_PdS.retile(tdKrdS)
|
||||
@@ -1385,9 +1380,7 @@ class FlashAttentionBackwardSm90:
|
||||
|
||||
# smem fence to make sure sdS is written before it's read by WGMMA
|
||||
cute.arch.fence_view_async_shared()
|
||||
cute.arch.barrier(
|
||||
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
|
||||
)
|
||||
PdS_barrier.arrive_and_wait()
|
||||
# (6) [GEMM 4] dQ = dS @ K
|
||||
acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
|
||||
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)
|
||||
|
||||
Reference in New Issue
Block a user