[Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions

This commit is contained in:
Tri Dao
2026-02-08 22:10:05 +07:00
parent deb183092b
commit 17d29436b8
2 changed files with 53 additions and 72 deletions
+8 -20
View File
@@ -1250,10 +1250,10 @@ def consume_block_sparse_mma_bwd_sm90(
is_causal: cutlass.Constexpr,
is_local: cutlass.Constexpr,
thr_mma_SdP,
softmax_scale,
seqlen,
subtile_factor: cutlass.Constexpr,
m_block_max: int,
score_mod_fn=None,
score_mod_bwd_fn=None,
subtile_factor: cutlass.Constexpr = 1,
m_block_max: int = 0,
aux_tensors=None,
fastdiv_mods=(None, None),
):
@@ -1315,15 +1315,9 @@ def consume_block_sparse_mma_bwd_sm90(
consumer_state_Q,
consumer_state_dO,
mask_fn=mask_fn_partial,
score_mod_fn=score_mod_fn,
score_mod_bwd_fn=score_mod_bwd_fn,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True
@@ -1339,15 +1333,9 @@ def consume_block_sparse_mma_bwd_sm90(
consumer_state_Q,
consumer_state_dO,
mask_fn=mask_fn_full,
score_mod_fn=score_mod_fn,
score_mod_bwd_fn=score_mod_bwd_fn,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True
+45 -52
View File
@@ -1089,6 +1089,24 @@ class FlashAttentionBackwardSm90:
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
PdS_barrier = cutlass.pipeline.NamedBarrier(
barrier_id=int(NamedBarrierBwd.PdS), num_threads=self.num_mma_threads
)
score_mod_fn = partial(
self.apply_score_mod,
thr_mma_SdP=thr_mma_SdP,
softmax_scale=softmax_scale,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
score_mod_bwd_fn = partial(
self.apply_score_mod_bwd,
thr_mma_SdP=thr_mma_SdP,
softmax_scale=softmax_scale,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
mma_one_m_block_all = partial(
self.mma_one_m_block,
warp_group_idx=warp_group_idx,
@@ -1107,6 +1125,7 @@ class FlashAttentionBackwardSm90:
smem_thr_copy_PdS=smem_thr_copy_PdS,
smem_thr_copy_dQaccum=smem_thr_copy_dQaccum,
softmax_scale_log2=softmax_scale_log2,
PdS_barrier=PdS_barrier,
# acc_dV=acc_dV,
# acc_dK=acc_dK,
)
@@ -1123,6 +1142,20 @@ class FlashAttentionBackwardSm90:
n_block, head_idx, batch_idx, _ = work_tile.tile_idx
seqlen = SeqlenInfoCls(batch_idx)
mask = AttentionMaskCls(seqlen)
score_mod_fn_cur = partial(
score_mod_fn,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
seqlen_info=seqlen,
)
score_mod_bwd_fn_cur = partial(
score_mod_bwd_fn,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
seqlen_info=seqlen,
)
m_block_min, m_block_max = block_info.get_m_block_min_max(seqlen, n_block)
if const_expr(not self.use_block_sparsity):
@@ -1160,15 +1193,9 @@ class FlashAttentionBackwardSm90:
consumer_state_Q,
consumer_state_dO,
mask_fn=mask_fn,
score_mod_fn=score_mod_fn_cur,
score_mod_bwd_fn=score_mod_bwd_fn_cur,
dKV_accumulate=dKV_accumulate,
thr_mma_SdP=thr_mma_SdP,
batch_idx=batch_idx,
head_idx=head_idx,
n_block=n_block,
softmax_scale=softmax_scale,
seqlen=seqlen,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
)
dKV_accumulate = True
else:
@@ -1185,8 +1212,8 @@ class FlashAttentionBackwardSm90:
is_causal=self.is_causal,
is_local=self.is_local,
thr_mma_SdP=thr_mma_SdP,
softmax_scale=softmax_scale,
seqlen=seqlen,
score_mod_fn=score_mod_fn_cur,
score_mod_bwd_fn=score_mod_bwd_fn_cur,
subtile_factor=self.subtile_factor,
m_block_max=m_block_max,
aux_tensors=aux_tensors,
@@ -1266,16 +1293,11 @@ class FlashAttentionBackwardSm90:
smem_thr_copy_PdS: cute.TiledCopy,
smem_thr_copy_dQaccum: cute.TiledCopy,
softmax_scale_log2: Float32,
PdS_barrier: cutlass.pipeline.NamedBarrier,
mask_fn: Optional[Callable] = None,
score_mod_fn: Optional[Callable] = None,
score_mod_bwd_fn: Optional[Callable] = None,
dKV_accumulate: Boolean = True,
thr_mma_SdP: Optional[cute.core.ThrMma] = None,
batch_idx: Int32 = 0,
head_idx: Int32 = 0,
n_block: Int32 = 0,
softmax_scale: Float32 = 1.0,
seqlen: Optional[SeqlenInfoQK] = None,
aux_tensors: Optional[list] = None,
fastdiv_mods=(None, None),
):
consumer_state_dO_cur = (
consumer_state_dO if const_expr(self.Q_stage == self.dO_stage) else consumer_state_Q
@@ -1298,18 +1320,7 @@ class FlashAttentionBackwardSm90:
cute.autovec_copy(acc_S, acc_S_pre)
if const_expr(self.score_mod is not None):
self.apply_score_mod(
acc_S,
thr_mma_SdP,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)
score_mod_fn(acc_S, m_block=m_block)
# (3) [Pointwise 1] P = exp(S - LSE)
if cutlass.const_expr(mask_fn is not None):
@@ -1328,9 +1339,7 @@ class FlashAttentionBackwardSm90:
if const_expr(not self.mma_dkv_is_rs):
# sync to ensure P has already been used in the previous iteration before overwriting
if const_expr(self.PdS_stage == 1):
cute.arch.barrier(
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
)
PdS_barrier.arrive_and_wait()
tPrP = smem_thr_copy_PdS.retile(tdVrP)
cute.copy(smem_thr_copy_PdS, tPrP, tPsP[None, None, None, smem_idx_PdS])
@@ -1342,19 +1351,7 @@ class FlashAttentionBackwardSm90:
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
if const_expr(self.score_mod_bwd is not None):
self.apply_score_mod_bwd(
acc_dP,
acc_S_pre,
thr_mma_SdP,
batch_idx,
head_idx,
m_block,
n_block,
softmax_scale,
seqlen,
aux_tensors,
fastdiv_mods,
)
score_mod_bwd_fn(acc_dP, acc_S_pre, m_block=m_block)
# Convert dS from f32 -> f16
tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)
@@ -1367,9 +1364,7 @@ class FlashAttentionBackwardSm90:
# (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
)
PdS_barrier.arrive_and_wait()
# R2S for dS
tdSrdS = smem_thr_copy_PdS.retile(tdKrdS)
@@ -1385,9 +1380,7 @@ class FlashAttentionBackwardSm90:
# smem fence to make sure sdS is written before it's read by WGMMA
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
)
PdS_barrier.arrive_and_wait()
# (6) [GEMM 4] dQ = dS @ K
acc_dQ = mma_dsk_fn(A_idx=smem_idx_PdS, wg_wait=1)
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(acc_dV)