mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[DSL] warpgroup_reg_alloc -> setmaxregister_increase
This commit is contained in:
@@ -1093,18 +1093,18 @@ class FlashAttentionBackwardSm100:
|
||||
# EMPTY
|
||||
# (15)
|
||||
if warp_idx == self.empty_warp_id:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_empty)
|
||||
|
||||
# EPI
|
||||
# (14)
|
||||
if warp_idx == self.epi_warp_id:
|
||||
# currently no-op, could use for tma store/reduce
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_empty)
|
||||
|
||||
# LOAD
|
||||
# (13)
|
||||
if warp_idx == self.load_warp_id:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
||||
self.load(
|
||||
thr_mma_S,
|
||||
thr_mma_dP,
|
||||
@@ -1141,7 +1141,7 @@ class FlashAttentionBackwardSm100:
|
||||
# MMA
|
||||
# (12)
|
||||
if warp_idx == self.mma_warp_id:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
||||
|
||||
# Alloc tmem buffer
|
||||
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
|
||||
@@ -1194,7 +1194,7 @@ class FlashAttentionBackwardSm100:
|
||||
# Compute
|
||||
# (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps
|
||||
if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:
|
||||
cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps
|
||||
cute.arch.setmaxregister_increase(self.num_regs_compute) # 8 warps
|
||||
self.compute_loop(
|
||||
thr_mma_S,
|
||||
thr_mma_dP,
|
||||
@@ -1239,7 +1239,7 @@ class FlashAttentionBackwardSm100:
|
||||
# Reduce
|
||||
# (0, 1, 2, 3) - dQ
|
||||
if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]:
|
||||
cute.arch.warpgroup_reg_alloc(self.num_regs_reduce)
|
||||
cute.arch.setmaxregister_increase(self.num_regs_reduce)
|
||||
self.dQacc_reduce(
|
||||
mdQaccum,
|
||||
sdQaccum,
|
||||
|
||||
@@ -640,7 +640,7 @@ class FlashAttentionBackwardSm90:
|
||||
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
||||
|
||||
if warp_idx < 4:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
|
||||
cute.arch.setmaxregister_decrease(self.num_producer_regs)
|
||||
if warp_idx == 0:
|
||||
self.load(
|
||||
mQ,
|
||||
@@ -682,7 +682,7 @@ class FlashAttentionBackwardSm90:
|
||||
blocksparse_tensors,
|
||||
)
|
||||
else:
|
||||
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
|
||||
cute.arch.setmaxregister_increase(self.num_mma_regs)
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
tidx = tidx - 128
|
||||
self.mma(
|
||||
|
||||
@@ -1659,7 +1659,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
||||
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
|
||||
|
||||
if warp_idx < 4: # Producer
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
|
||||
cute.arch.setmaxregister_decrease(self.num_producer_regs)
|
||||
self.load(
|
||||
mQ,
|
||||
mK,
|
||||
@@ -1680,7 +1680,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
||||
)
|
||||
|
||||
else: # Consumer
|
||||
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
|
||||
cute.arch.setmaxregister_increase(self.num_mma_regs)
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -951,13 +951,13 @@ class FlashAttentionForwardSm100:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
|
||||
if warp_idx == self.empty_warp_ids[i]:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_empty)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# LOAD
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
||||
self.load(
|
||||
thr_mma_qk,
|
||||
thr_mma_pv,
|
||||
@@ -985,7 +985,7 @@ class FlashAttentionForwardSm100:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
if warp_idx == self.mma_warp_id:
|
||||
# if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
||||
# Alloc tmem buffer
|
||||
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
|
||||
if warp_idx == self.mma_warp_id:
|
||||
@@ -1028,7 +1028,7 @@ class FlashAttentionForwardSm100:
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
if const_expr(not self.use_correction_warps_for_epi):
|
||||
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_other)
|
||||
self.epilogue_s2g(
|
||||
mO,
|
||||
sO,
|
||||
@@ -1049,7 +1049,7 @@ class FlashAttentionForwardSm100:
|
||||
(const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1])
|
||||
):
|
||||
# increase register after decreasing
|
||||
cute.arch.warpgroup_reg_alloc(self.num_regs_softmax)
|
||||
cute.arch.setmaxregister_increase(self.num_regs_softmax)
|
||||
softmax_loop = partial(
|
||||
self.softmax_loop,
|
||||
softmax_scale_log2=softmax_scale_log2,
|
||||
@@ -1096,7 +1096,7 @@ class FlashAttentionForwardSm100:
|
||||
# Correction
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id:
|
||||
cute.arch.warpgroup_reg_dealloc(self.num_regs_correction)
|
||||
cute.arch.setmaxregister_decrease(self.num_regs_correction)
|
||||
self.correction_loop(
|
||||
thr_mma_qk,
|
||||
thr_mma_pv,
|
||||
|
||||
Reference in New Issue
Block a user