[Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack

This commit is contained in:
Tri Dao
2026-02-08 16:48:16 +07:00
parent 7edcf59c9e
commit b735ef24c2
9 changed files with 37 additions and 110 deletions
+7 -6
View File
@@ -14,6 +14,7 @@ from cutlass.cute.nvgpu import cpasync, warp
from cutlass import Float32, Int32
import cutlass.utils as utils_basic
from quack import layout_utils
from flash_attn.cute import ampere_helpers as sm80_utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
from flash_attn.cute import utils
@@ -630,8 +631,8 @@ class FlashAttentionBackwardSm80:
tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)
LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)
tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
tSsLSEMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
tSsdPsumMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
# ///////////////////////////////////////////////////////////////////////////////
# Smem copy atom tiling
@@ -875,7 +876,7 @@ class FlashAttentionBackwardSm80:
)
if cutlass.const_expr(mask_fn is not None):
mask_fn(acc_S, m_block=m_block)
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
bidx = 0
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
@@ -901,7 +902,7 @@ class FlashAttentionBackwardSm80:
cute.autovec_copy(
smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum
)
acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP)
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP)
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
@@ -921,7 +922,7 @@ class FlashAttentionBackwardSm80:
tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)
cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)
if cutlass.const_expr(self.Mma_dKV_is_RS):
tdVrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
tdVrP = layout_utils.reshape_acc_to_frgA(rP)
else:
tdVrP = mma_params.tdVrP
@@ -966,7 +967,7 @@ class FlashAttentionBackwardSm80:
# MMA dK
if cutlass.const_expr(self.Mma_dKV_is_RS):
tdKrdS = cute.make_tensor(rdS.iterator, utils.convert_layout_acc_frgA(rdS.layout))
tdVrP = layout_utils.reshape_acc_to_frgA(rdS)
else:
tdKrdS = mma_params.tdKrdS
sm80_utils.gemm(
+1 -1
View File
@@ -14,7 +14,7 @@ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
from cutlass import Float32, const_expr
from cutlass.utils import LayoutEnum
import quack.sm90_utils as sm90_utils
from quack import sm90_utils
from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
+8 -7
View File
@@ -12,7 +12,8 @@ from cutlass.cute import FastDivmodDivisor
from cutlass import Float32, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
import quack.sm90_utils as sm90_utils
from quack import layout_utils
from quack import sm90_utils
from quack.sm90_utils import gemm_zero_init, gemm_w_idx
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
@@ -1100,8 +1101,8 @@ class FlashAttentionBackwardSm90:
sLSE_mma = utils.transpose_view(sLSE_mma)
sdPsum_mma = utils.transpose_view(sdPsum_mma)
LSEslice = (None, 0, None) if const_expr(not self.SdP_swapAB) else (0, None, None)
tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice]
tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice]
tLSEsLSE = layout_utils.reshape_acc_to_mn(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice]
tLSEsdPsum = layout_utils.reshape_acc_to_mn(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice]
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
@@ -1331,7 +1332,7 @@ class FlashAttentionBackwardSm90:
# (3) [Pointwise 1] P = exp(S - LSE)
if cutlass.const_expr(mask_fn is not None):
mask_fn(acc_S, m_block=m_block)
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB)
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
acc_S_mn[r, c] = cute.math.exp2(
@@ -1340,7 +1341,7 @@ class FlashAttentionBackwardSm90:
tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
# Convert P from f32 -> f16
tdVrP = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_S), self.dtype)
tdVrP = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_S), self.dtype)
# R2S for P
if const_expr(not self.mma_dkv_is_rs):
# sync to ensure P has already been used in the previous iteration before overwriting
@@ -1353,7 +1354,7 @@ class FlashAttentionBackwardSm90:
# (4) [Pointwise 2] dS = P*(dP-dPsum)
warpgroup.wait_group(0)
acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB)
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
@@ -1374,7 +1375,7 @@ class FlashAttentionBackwardSm90:
)
# Convert dS from f32 -> f16
tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype)
tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)
# If there's double buffering on dS, we don't need to sync here.
# Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
+11 -10
View File
@@ -21,6 +21,7 @@ from cutlass.utils import LayoutEnum
import cutlass.utils.hopper_helpers as sm90_utils_basic
from quack import copy_utils
from quack import layout_utils
from quack import sm90_utils
from flash_attn.cute import ampere_helpers as sm80_utils
@@ -378,10 +379,10 @@ class FlashAttentionForwardBase:
)
gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout)
thr_mma = tiled_mma.get_slice(tidx)
taccOgLSE = utils.make_acc_tensor_mn_view(thr_mma.partition_C(gLSE_expanded))
taccOgLSE = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gLSE_expanded))
assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse)
taccOcO = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cO))
t0accOcO = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cO))
taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO))
t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))
# Only the thread corresponding to column 0 writes out the lse to gmem
if taccOcO[0][1] == 0:
for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])):
@@ -1125,7 +1126,7 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
softmax.rescale_O(mma_params.acc_O, row_scale)
rP = cute.make_fragment_like(acc_S, self.dtype)
rP.store(acc_S.load().to(self.dtype))
tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
tOrP = layout_utils.reshape_acc_to_frgA(rP)
if const_expr(self.num_stages > 1):
sync()
load_K_next()
@@ -2140,7 +2141,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
else: # Each thread might have a different sink value due to different q_head
sink_val = cute.make_fragment_like(softmax.row_max, Float32)
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS))
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))
for r in cutlass.range(cute.size(sink_val), unroll_full=True):
row = m_block * self.tile_m + tScS_mn[r][0]
q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
@@ -2205,7 +2206,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
softmax.online_softmax(acc_S, is_first=is_first_block)
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
tOrP_cur = (
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
)
@@ -2270,8 +2271,8 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
mask_fn(acc_S=acc_S, n_block=n_block)
row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
tOrP_cur = (
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
)
@@ -2332,12 +2333,12 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
if const_expr(mask_fn is not None):
mask_fn(acc_S=acc_S, n_block=n_block)
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
warpgroup.wait_group(0)
pipeline_v.consumer_release(smem_pipe_read_v)
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
tOrP_cur = (
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
)
+4 -3
View File
@@ -7,6 +7,7 @@ import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, const_expr
from quack import layout_utils
import flash_attn.cute.utils as utils
from flash_attn.cute.seqlen_info import SeqlenInfoQK
@@ -140,13 +141,13 @@ class AttentionMask:
fastdiv_mods=(None, None),
) -> None:
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB)
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB)
acc_shape = (self.tile_m, self.tile_n)
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB)
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB)
# We use t0ScS as these indices are known at compile time. We then must subtract the
# column limit by the thread column offset.
t0ScS_mn = utils.make_acc_tensor_mn_view(
t0ScS_mn = layout_utils.reshape_acc_to_mn(
thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
)
ROW = 0 if const_expr(not self.swap_AB) else 1
+2 -1
View File
@@ -4,6 +4,7 @@
import cutlass
import cutlass.cute as cute
from quack import layout_utils
import flash_attn.cute.utils as utils
@@ -98,7 +99,7 @@ class PackGQA:
thr_mma = tiled_mma.get_slice(tidx)
caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
taccOcO = thr_mma.partition_C(caccO)
taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0]
taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]
assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
+1 -1
View File
@@ -28,7 +28,7 @@ dependencies = [
"typing_extensions",
"apache-tvm-ffi>=0.1.5,<0.2",
"torch-c-dlpack-ext",
"quack-kernels>=0.2.7",
"quack-kernels>=0.2.8",
]
[project.optional-dependencies]
+3 -2
View File
@@ -9,6 +9,7 @@ import cutlass
import cutlass.cute as cute
from cutlass import Float32
from quack import layout_utils
import flash_attn.cute.utils as utils
from flash_attn.cute.cute_dsl_utils import ParamsBase
from flash_attn.cute.seqlen_info import SeqlenInfoQK
@@ -63,7 +64,7 @@ class Softmax(ParamsBase):
:type is_first: cutlass.Constexpr
"""
# Change acc_S to M,N layout view.
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
row_scale = cute.make_fragment_like(self.row_max, Float32)
row_max = self.row_max
@@ -153,7 +154,7 @@ class Softmax(ParamsBase):
:param row_scale: row_scale tensor
:type row_scale: cute.Tensor
"""
acc_O_mn = utils.make_acc_tensor_mn_view(acc_O)
acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)
assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
for r in cutlass.range(cute.size(row_scale), unroll_full=True):
acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
-79
View File
@@ -163,85 +163,6 @@ def warp_reduce(
return val
def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
"""
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
"""
acc_layout_col_major = cute.make_layout(acc_layout.shape)
shape = (
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
(
acc_layout_col_major.shape[0][0],
*acc_layout_col_major.shape[0][2:],
acc_layout_col_major.shape[2],
), # MMA_N
*acc_layout_col_major.shape[3:],
)
stride = (
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
(
acc_layout_col_major.stride[0][0],
*acc_layout_col_major.stride[0][2:],
acc_layout_col_major.stride[2],
), # MMA_N
*acc_layout_col_major.stride[3:],
)
if const_expr(transpose):
shape = (shape[1], shape[0], *shape[2:])
stride = (stride[1], stride[0], *stride[2:])
acc_layout_mn = cute.make_layout(shape, stride=stride)
return cute.composition(acc_layout, acc_layout_mn)
def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
@cute.jit
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
# TODO: Sm90 FP8
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
l = cute.logical_divide(
acc_layout, ((None, None, 2), None, None)
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
rA_mma_view = cute.make_layout(
(
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
l.shape[1],
(l.shape[0][2][1], l.shape[2]),
),
stride=(
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
l.stride[1],
(l.stride[0][2][1], l.stride[2]),
),
)
else: # Sm80
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
l = cute.logical_divide(acc_layout, (None, None, 2))
rA_mma_view = cute.make_layout(
(
(l.shape[0], l.shape[2][0]),
l.shape[1],
l.shape[2][1],
),
stride=(
(l.stride[0], l.stride[2][0]),
l.stride[1],
l.stride[2][1],
),
)
return rA_mma_view
def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor:
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))