[DSL] warpgroup_reg_alloc -> setmaxregister_increase

This commit is contained in:
Tri Dao
2026-02-08 22:17:36 +07:00
parent 17d29436b8
commit 2a8d39c540
4 changed files with 16 additions and 16 deletions
+6 -6
View File
@@ -1093,18 +1093,18 @@ class FlashAttentionBackwardSm100:
# EMPTY
# (15)
if warp_idx == self.empty_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
cute.arch.setmaxregister_decrease(self.num_regs_empty)
# EPI
# (14)
if warp_idx == self.epi_warp_id:
# currently no-op, could use for tma store/reduce
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
cute.arch.setmaxregister_decrease(self.num_regs_empty)
# LOAD
# (13)
if warp_idx == self.load_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
cute.arch.setmaxregister_decrease(self.num_regs_other)
self.load(
thr_mma_S,
thr_mma_dP,
@@ -1141,7 +1141,7 @@ class FlashAttentionBackwardSm100:
# MMA
# (12)
if warp_idx == self.mma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
cute.arch.setmaxregister_decrease(self.num_regs_other)
# Alloc tmem buffer
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
@@ -1194,7 +1194,7 @@ class FlashAttentionBackwardSm100:
# Compute
# (4, 5, 6, 7, 8, 9, 10, 11) --> 8 warps
if warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1]:
cute.arch.warpgroup_reg_alloc(self.num_regs_compute) # 8 warps
cute.arch.setmaxregister_increase(self.num_regs_compute) # 8 warps
self.compute_loop(
thr_mma_S,
thr_mma_dP,
@@ -1239,7 +1239,7 @@ class FlashAttentionBackwardSm100:
# Reduce
# (0, 1, 2, 3) - dQ
if warp_idx >= self.reduce_warp_ids[0] and warp_idx <= self.reduce_warp_ids[-1]:
cute.arch.warpgroup_reg_alloc(self.num_regs_reduce)
cute.arch.setmaxregister_increase(self.num_regs_reduce)
self.dQacc_reduce(
mdQaccum,
sdQaccum,
+2 -2
View File
@@ -640,7 +640,7 @@ class FlashAttentionBackwardSm90:
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
if warp_idx < 4:
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
cute.arch.setmaxregister_decrease(self.num_producer_regs)
if warp_idx == 0:
self.load(
mQ,
@@ -682,7 +682,7 @@ class FlashAttentionBackwardSm90:
blocksparse_tensors,
)
else:
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
cute.arch.setmaxregister_increase(self.num_mma_regs)
tidx, _, _ = cute.arch.thread_idx()
tidx = tidx - 128
self.mma(
+2 -2
View File
@@ -1659,7 +1659,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
TileSchedulerCls = partial(TileScheduler.create, tile_sched_params)
if warp_idx < 4: # Producer
cute.arch.warpgroup_reg_dealloc(self.num_producer_regs)
cute.arch.setmaxregister_decrease(self.num_producer_regs)
self.load(
mQ,
mK,
@@ -1680,7 +1680,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
)
else: # Consumer
cute.arch.warpgroup_reg_alloc(self.num_mma_regs)
cute.arch.setmaxregister_increase(self.num_mma_regs)
# ///////////////////////////////////////////////////////////////////////////////
# Tile MMA compute thread partitions and allocate accumulators
# ///////////////////////////////////////////////////////////////////////////////
+6 -6
View File
@@ -951,13 +951,13 @@ class FlashAttentionForwardSm100:
# ///////////////////////////////////////////////////////////////////////////////
for i in cutlass.range_constexpr(len(self.empty_warp_ids)):
if warp_idx == self.empty_warp_ids[i]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_empty)
cute.arch.setmaxregister_decrease(self.num_regs_empty)
# ///////////////////////////////////////////////////////////////////////////////
# LOAD
# ///////////////////////////////////////////////////////////////////////////////
if warp_idx >= self.load_warp_ids[0] and warp_idx <= self.load_warp_ids[-1]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
cute.arch.setmaxregister_decrease(self.num_regs_other)
self.load(
thr_mma_qk,
thr_mma_pv,
@@ -985,7 +985,7 @@ class FlashAttentionForwardSm100:
# ///////////////////////////////////////////////////////////////////////////////
if warp_idx == self.mma_warp_id:
# if warp_idx == self.mma_warp_id or warp_idx == self.empty_warp_ids:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
cute.arch.setmaxregister_decrease(self.num_regs_other)
# Alloc tmem buffer
tmem_alloc_cols = Int32(self.tmem_alloc_cols)
if warp_idx == self.mma_warp_id:
@@ -1028,7 +1028,7 @@ class FlashAttentionForwardSm100:
# ///////////////////////////////////////////////////////////////////////////////
if const_expr(not self.use_correction_warps_for_epi):
if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]:
cute.arch.warpgroup_reg_dealloc(self.num_regs_other)
cute.arch.setmaxregister_decrease(self.num_regs_other)
self.epilogue_s2g(
mO,
sO,
@@ -1049,7 +1049,7 @@ class FlashAttentionForwardSm100:
(const_expr(self.q_stage == 1) and warp_idx <= self.softmax0_warp_ids[-1])
):
# increase register after decreasing
cute.arch.warpgroup_reg_alloc(self.num_regs_softmax)
cute.arch.setmaxregister_increase(self.num_regs_softmax)
softmax_loop = partial(
self.softmax_loop,
softmax_scale_log2=softmax_scale_log2,
@@ -1096,7 +1096,7 @@ class FlashAttentionForwardSm100:
# Correction
# ///////////////////////////////////////////////////////////////////////////////
if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id:
cute.arch.warpgroup_reg_dealloc(self.num_regs_correction)
cute.arch.setmaxregister_decrease(self.num_regs_correction)
self.correction_loop(
thr_mma_qk,
thr_mma_pv,