[DSL] Replace old fence with cute.arch.fence_view_async_shared()

This commit is contained in:
Tri Dao
2026-02-08 10:48:54 +07:00
parent 48af662c53
commit a804a5a3ef
4 changed files with 19 additions and 35 deletions
+4 -12
View File
@@ -2277,9 +2277,7 @@ class FlashAttentionBackwardSm100:
if const_expr(not self.use_smem_dS_for_mma_dK):
cute.arch.fence_view_async_tmem_store()
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.fence_view_async_shared()
self.compute_sync_barrier.arrive_and_wait()
# with cute.arch.elect_one():
@@ -2528,9 +2526,7 @@ class FlashAttentionBackwardSm100:
)
cute.copy(thr_copy_dQaccum_r2s, tdQrdQ_r2s, tdQsdQ_r2s)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.fence_view_async_shared()
# semaphore acquire
if const_expr(self.deterministic and stage == 0):
if const_expr(self.spt):
@@ -2886,9 +2882,7 @@ class FlashAttentionBackwardSm100:
# RMEM -> SMEM -- copy, fence and barrier
tdKVrdKV_r2s = cute.make_tensor(tdKVrdKV.iterator, tdKVsdKV_r2s.shape)
cute.copy(thr_copy_r2s_dKV, tdKVrdKV_r2s, tdKVsdKV_r2s)
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.fence_view_async_shared()
cute.arch.barrier(barrier_id=barrier_id + wg_idx, number_of_threads=128)
# SMEM -> GMEM
@@ -2910,9 +2904,7 @@ class FlashAttentionBackwardSm100:
)
# Barrier since all warps need to wait for SMEM to be freed
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=barrier_id + wg_idx, number_of_threads=128 + cute.arch.WARP_SIZE
)
+7 -8
View File
@@ -8,7 +8,6 @@ import cutlass
import cutlass.cute as cute
import cutlass.utils.hopper_helpers as sm90_utils_basic
from cutlass.cute.nvgpu import cpasync, warpgroup
from cutlass.cute.arch import ProxyKind, SharedSpace
from cutlass.cute import FastDivmodDivisor
from cutlass import Float32, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
@@ -1409,7 +1408,7 @@ class FlashAttentionBackwardSm90:
# This sync is to ensure (1) P is written in case of !mma_dkv_is_rs and
# (2) dS is already read by the Mma in the previous iteration in case of mma_dkv_is_rs.
if const_expr(not self.mma_dkv_is_rs or (self.PdS_stage == 1 and self.mma_dkv_is_rs)):
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
)
@@ -1427,7 +1426,7 @@ class FlashAttentionBackwardSm90:
mma_pdo_fn(tCrA=tdVrP, B_idx=smem_idx_dO, zero_init=not dKV_accumulate, wg_wait=-1)
# smem fence to make sure sdS is written before it's read by WGMMA
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierBwd.PdS), number_of_threads=self.num_mma_threads
)
@@ -1451,7 +1450,7 @@ class FlashAttentionBackwardSm90:
)
tdQrdQaccum_flat = cute.make_tensor(acc_dQ.iterator, cute.make_layout(tdQsdQaccum.shape))
cute.autovec_copy(tdQrdQaccum_flat, tdQsdQaccum)
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
number_of_threads=self.num_threads_per_warp_group + cute.arch.WARP_SIZE,
@@ -1524,7 +1523,7 @@ class FlashAttentionBackwardSm90:
sdV = sV if const_expr(not self.dKV_swapAB) else utils.transpose_view(sV)
taccdVsdV = smem_thr_copy_dV.partition_D(sdV)
cute.copy(smem_copy_atom_dKV, taccdVrdV, taccdVsdV)
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
)
@@ -1534,7 +1533,7 @@ class FlashAttentionBackwardSm90:
sdK = sK if const_expr(not self.dKV_swapAB) else utils.transpose_view(sK)
taccdKsdK = smem_thr_copy_dK.partition_D(sdK)
cute.copy(smem_copy_atom_dKV, taccdKrdK, taccdKsdK)
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
)
@@ -1573,7 +1572,7 @@ class FlashAttentionBackwardSm90:
acc_dK.iterator, cute.make_layout(tdKsdKVaccum.shape)
)
cute.autovec_copy(tdKrdKaccum_flat, tdKsdKVaccum)
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
)
@@ -1597,7 +1596,7 @@ class FlashAttentionBackwardSm90:
acc_dV.iterator, cute.make_layout(tdKsdKVaccum.shape)
)
cute.autovec_copy(tdVrdVaccum_flat, tdKsdKVaccum)
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier(
barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_mma_threads
)
+7 -11
View File
@@ -16,18 +16,17 @@ import cutlass
import cutlass.cute as cute
from cutlass import Constexpr, Float32, Int32, const_expr, Boolean
from cutlass.cute.nvgpu import cpasync, warp, warpgroup
from cutlass.cute.arch import ProxyKind, SharedSpace
import cutlass.utils as utils_basic
from cutlass.utils import LayoutEnum
import cutlass.utils.hopper_helpers as sm90_utils_basic
from quack import copy_utils as quack_copy_utils
from quack import copy_utils
from quack import sm90_utils
from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import hopper_helpers as sm90_utils
from flash_attn.cute import utils
from flash_attn.cute import copy_utils
from flash_attn.cute.mask import AttentionMask
from flash_attn.cute.softmax import Softmax, apply_score_mod_inner
from flash_attn.cute.seqlen_info import SeqlenInfoQK
@@ -357,7 +356,7 @@ class FlashAttentionForwardBase:
smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx)
taccOrO = smem_thr_copy_O.retile(rO)
taccOsO = smem_thr_copy_O.partition_D(sO)
# taccOsO = quack_copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
# taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO)
# copy acc O from rmem to smem with the smem copy atom
cute.copy(smem_copy_atom_O, taccOrO, taccOsO)
@@ -406,7 +405,7 @@ class FlashAttentionForwardBase:
# sync to make sure all smem stores are done
if const_expr(self.use_tma_O):
# ensure smem writes are visible to TMA
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierFwd.Epilogue),
number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE,
@@ -1220,7 +1219,6 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
return tiled_mma_qk, tiled_mma_pv, tiled_mma_pv_rs
def _get_shared_storage_cls(self):
# If we use cp.async to load Q, we want sQ to align to 1024 bytes
sQ_struct, sK_struct, sV_struct = [
cute.struct.Align[cute.struct.MemRange[self.dtype, cute.cosize(layout)], self.buffer_align_bytes]
for layout in (self.sQ_layout, self.sK_layout, self.sV_layout)
@@ -2247,9 +2245,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
tPrP = smem_copy_params.smem_thr_copy_P.retile(tOrP_cur)
cute.copy(smem_copy_params.smem_thr_copy_P, tPrP, smem_copy_params.tPsP)
# Fence and barrier to make smem store visible to WGMMA
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta
)
cute.arch.fence_view_async_shared()
cute.arch.sync_warp()
return kv_consumer_state
@@ -2320,7 +2316,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
softmax.rescale_O(acc_O, row_scale)
if const_expr(not self.mma_pv_is_rs):
# Fence and barrier to make sure smem store is visible to WGMMA
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
pipeline_v.consumer_wait(smem_pipe_read, pipeline_v.consumer_try_wait(smem_pipe_read))
self.warp_scheduler_barrier_sync()
@@ -2387,7 +2383,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
softmax.rescale_O(acc_O, row_scale)
if const_expr(not self.mma_pv_is_rs):
# Fence and barrier to make sure smem store is visible to WGMMA
cute.arch.fence_proxy(ProxyKind.async_shared, space=SharedSpace.shared_cta)
cute.arch.fence_view_async_shared()
cute.arch.sync_warp() # Only need syncwarp since each warp is using its own P values for MmaPV
return smem_pipe_read
+1 -4
View File
@@ -2428,10 +2428,7 @@ class FlashAttentionForwardSm100:
tOrO_frg_cvt.store(tOrO_frg.load().to(self.o_dtype))
cute.copy(tiled_smem_store, tOrO_frg_cvt, tOsO_r2s_i)
# fence view async shared
cute.arch.fence_proxy(
cute.arch.ProxyKind.async_shared,
space=cute.arch.SharedSpace.shared_cta,
)
cute.arch.fence_view_async_shared()
if const_expr(self.use_correction_warps_for_epi):
assert(not self.use_tma_O)