mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack
This commit is contained in:
@@ -14,6 +14,7 @@ from cutlass.cute.nvgpu import cpasync, warp
|
||||
from cutlass import Float32, Int32
|
||||
import cutlass.utils as utils_basic
|
||||
|
||||
from quack import layout_utils
|
||||
from flash_attn.cute import ampere_helpers as sm80_utils
|
||||
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
|
||||
from flash_attn.cute import utils
|
||||
@@ -630,8 +631,8 @@ class FlashAttentionBackwardSm80:
|
||||
tdQrK = utils.mma_make_fragment_B(sKt, thr_mma_dq, swapAB=self.dQ_swapAB)
|
||||
|
||||
LSEslice = (None, 0, None) if cutlass.const_expr(not self.SdP_swapAB) else (0, None, None)
|
||||
tSsLSEMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
|
||||
tSsdPsumMma = utils.make_acc_tensor_mn_view(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
|
||||
tSsLSEMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sLSEMma))[LSEslice]
|
||||
tSsdPsumMma = layout_utils.reshape_acc_to_mn(thr_mma_sdp.partition_C(sdPsumMma))[LSEslice]
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Smem copy atom tiling
|
||||
@@ -875,7 +876,7 @@ class FlashAttentionBackwardSm80:
|
||||
)
|
||||
if cutlass.const_expr(mask_fn is not None):
|
||||
mask_fn(acc_S, m_block=m_block)
|
||||
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
|
||||
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
|
||||
bidx = 0
|
||||
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
|
||||
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
|
||||
@@ -901,7 +902,7 @@ class FlashAttentionBackwardSm80:
|
||||
cute.autovec_copy(
|
||||
smem_copy_params.tSsdPsumMma[None, smem_pipe_read_do if cutlass.const_expr(self.num_stages_dO > 1) else 0], tLSErdPsum
|
||||
)
|
||||
acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP)
|
||||
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP)
|
||||
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn)
|
||||
assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum)
|
||||
for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True):
|
||||
@@ -921,7 +922,7 @@ class FlashAttentionBackwardSm80:
|
||||
tdSrdS = smem_copy_params.r2s_thr_copy_PdS.retile(rdS)
|
||||
cute.copy(smem_copy_params.r2s_thr_copy_PdS, tdSrdS, smem_copy_params.tdSsdS)
|
||||
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
||||
tdVrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
|
||||
tdVrP = layout_utils.reshape_acc_to_frgA(rP)
|
||||
else:
|
||||
tdVrP = mma_params.tdVrP
|
||||
|
||||
@@ -966,7 +967,7 @@ class FlashAttentionBackwardSm80:
|
||||
|
||||
# MMA dK
|
||||
if cutlass.const_expr(self.Mma_dKV_is_RS):
|
||||
tdKrdS = cute.make_tensor(rdS.iterator, utils.convert_layout_acc_frgA(rdS.layout))
|
||||
tdVrP = layout_utils.reshape_acc_to_frgA(rdS)
|
||||
else:
|
||||
tdKrdS = mma_params.tdKrdS
|
||||
sm80_utils.gemm(
|
||||
|
||||
@@ -14,7 +14,7 @@ from cutlass.cute.nvgpu import cpasync, warp, warpgroup
|
||||
from cutlass import Float32, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
|
||||
import quack.sm90_utils as sm90_utils
|
||||
from quack import sm90_utils
|
||||
|
||||
from flash_attn.cute import utils
|
||||
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
|
||||
|
||||
@@ -12,7 +12,8 @@ from cutlass.cute import FastDivmodDivisor
|
||||
from cutlass import Float32, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
|
||||
import quack.sm90_utils as sm90_utils
|
||||
from quack import layout_utils
|
||||
from quack import sm90_utils
|
||||
from quack.sm90_utils import gemm_zero_init, gemm_w_idx
|
||||
|
||||
from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned
|
||||
@@ -1100,8 +1101,8 @@ class FlashAttentionBackwardSm90:
|
||||
sLSE_mma = utils.transpose_view(sLSE_mma)
|
||||
sdPsum_mma = utils.transpose_view(sdPsum_mma)
|
||||
LSEslice = (None, 0, None) if const_expr(not self.SdP_swapAB) else (0, None, None)
|
||||
tLSEsLSE = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice]
|
||||
tLSEsdPsum = utils.make_acc_tensor_mn_view(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice]
|
||||
tLSEsLSE = layout_utils.reshape_acc_to_mn(thr_mma_SdP.partition_C(sLSE_mma))[LSEslice]
|
||||
tLSEsdPsum = layout_utils.reshape_acc_to_mn(thr_mma_SdP.partition_C(sdPsum_mma))[LSEslice]
|
||||
|
||||
smem_thr_copy_dQaccum = r2s_tiled_copy_dQaccum.get_slice(tidx)
|
||||
tdQsdQaccum = smem_thr_copy_dQaccum.partition_D(sdQaccum)
|
||||
@@ -1331,7 +1332,7 @@ class FlashAttentionBackwardSm90:
|
||||
# (3) [Pointwise 1] P = exp(S - LSE)
|
||||
if cutlass.const_expr(mask_fn is not None):
|
||||
mask_fn(acc_S, m_block=m_block)
|
||||
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.SdP_swapAB)
|
||||
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.SdP_swapAB)
|
||||
for r in cutlass.range_constexpr(cute.size(acc_S_mn, mode=[0])):
|
||||
for c in cutlass.range(cute.size(acc_S_mn, mode=[1]), unroll_full=True):
|
||||
acc_S_mn[r, c] = cute.math.exp2(
|
||||
@@ -1340,7 +1341,7 @@ class FlashAttentionBackwardSm90:
|
||||
tLSErdPsum = copy_utils.load_s2r(tLSEsdPsum[None, smem_idx_dO])
|
||||
|
||||
# Convert P from f32 -> f16
|
||||
tdVrP = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_S), self.dtype)
|
||||
tdVrP = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_S), self.dtype)
|
||||
# R2S for P
|
||||
if const_expr(not self.mma_dkv_is_rs):
|
||||
# sync to ensure P has already been used in the previous iteration before overwriting
|
||||
@@ -1353,7 +1354,7 @@ class FlashAttentionBackwardSm90:
|
||||
|
||||
# (4) [Pointwise 2] dS = P*(dP-dPsum)
|
||||
warpgroup.wait_group(0)
|
||||
acc_dP_mn = utils.make_acc_tensor_mn_view(acc_dP, transpose=self.SdP_swapAB)
|
||||
acc_dP_mn = layout_utils.reshape_acc_to_mn(acc_dP, transpose=self.SdP_swapAB)
|
||||
for r in cutlass.range_constexpr(cute.size(acc_dP_mn, mode=[0])):
|
||||
for c in cutlass.range(cute.size(acc_dP_mn, mode=[1]), unroll_full=True):
|
||||
acc_dP_mn[r, c] = acc_S_mn[r, c] * (acc_dP_mn[r, c] - tLSErdPsum[r])
|
||||
@@ -1374,7 +1375,7 @@ class FlashAttentionBackwardSm90:
|
||||
)
|
||||
|
||||
# Convert dS from f32 -> f16
|
||||
tdKrdS = utils.cvt_f16(utils.make_acc_tensor_frgA_view(acc_dP), self.dtype)
|
||||
tdKrdS = utils.cvt_f16(layout_utils.reshape_acc_to_frgA(acc_dP), self.dtype)
|
||||
|
||||
# If there's double buffering on dS, we don't need to sync here.
|
||||
# Otherwise we might have WG1 writing to dS before WG2 is done reading from it during MmadQ.
|
||||
|
||||
@@ -21,6 +21,7 @@ from cutlass.utils import LayoutEnum
|
||||
import cutlass.utils.hopper_helpers as sm90_utils_basic
|
||||
|
||||
from quack import copy_utils
|
||||
from quack import layout_utils
|
||||
from quack import sm90_utils
|
||||
|
||||
from flash_attn.cute import ampere_helpers as sm80_utils
|
||||
@@ -378,10 +379,10 @@ class FlashAttentionForwardBase:
|
||||
)
|
||||
gLSE_expanded = cute.make_tensor(gLSE.iterator, gLSE_expanded_layout)
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
taccOgLSE = utils.make_acc_tensor_mn_view(thr_mma.partition_C(gLSE_expanded))
|
||||
taccOgLSE = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gLSE_expanded))
|
||||
assert cute.size(taccOgLSE, mode=[0]) == cute.size(lse)
|
||||
taccOcO = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cO))
|
||||
t0accOcO = utils.make_acc_tensor_mn_view(thr_mma.get_slice(0).partition_C(cO))
|
||||
taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO))
|
||||
t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO))
|
||||
# Only the thread corresponding to column 0 writes out the lse to gmem
|
||||
if taccOcO[0][1] == 0:
|
||||
for m in cutlass.range_constexpr(cute.size(taccOgLSE.shape[1])):
|
||||
@@ -1125,7 +1126,7 @@ class FlashAttentionForwardSm80(FlashAttentionForwardBase):
|
||||
softmax.rescale_O(mma_params.acc_O, row_scale)
|
||||
rP = cute.make_fragment_like(acc_S, self.dtype)
|
||||
rP.store(acc_S.load().to(self.dtype))
|
||||
tOrP = cute.make_tensor(rP.iterator, utils.convert_layout_acc_frgA(rP.layout))
|
||||
tOrP = layout_utils.reshape_acc_to_frgA(rP)
|
||||
if const_expr(self.num_stages > 1):
|
||||
sync()
|
||||
load_K_next()
|
||||
@@ -2140,7 +2141,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
||||
else: # Each thread might have a different sink value due to different q_head
|
||||
sink_val = cute.make_fragment_like(softmax.row_max, Float32)
|
||||
cS = cute.make_identity_tensor((self.tile_m, self.tile_n))
|
||||
tScS_mn = utils.make_acc_tensor_mn_view(thr_mma_qk.partition_C(cS))
|
||||
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma_qk.partition_C(cS))
|
||||
for r in cutlass.range(cute.size(sink_val), unroll_full=True):
|
||||
row = m_block * self.tile_m + tScS_mn[r][0]
|
||||
q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead
|
||||
@@ -2205,7 +2206,7 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
||||
|
||||
softmax.online_softmax(acc_S, is_first=is_first_block)
|
||||
|
||||
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
|
||||
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
||||
tOrP_cur = (
|
||||
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
||||
)
|
||||
@@ -2270,8 +2271,8 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
||||
mask_fn(acc_S=acc_S, n_block=n_block)
|
||||
|
||||
row_scale = softmax.online_softmax(acc_S, is_first=is_first_n_block, check_inf=check_inf)
|
||||
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
|
||||
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
|
||||
# if cute.arch.thread_idx()[0] == 0: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
|
||||
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
||||
tOrP_cur = (
|
||||
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
||||
)
|
||||
@@ -2332,12 +2333,12 @@ class FlashAttentionForwardSm90(FlashAttentionForwardBase):
|
||||
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
|
||||
if const_expr(mask_fn is not None):
|
||||
mask_fn(acc_S=acc_S, n_block=n_block)
|
||||
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
|
||||
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(layout_utils.reshape_acc_to_mn(acc_S))
|
||||
|
||||
row_scale = softmax.online_softmax(acc_S, check_inf=check_inf)
|
||||
warpgroup.wait_group(0)
|
||||
pipeline_v.consumer_release(smem_pipe_read_v)
|
||||
tOrP_acc = cute.make_tensor(acc_S.iterator, utils.convert_layout_acc_frgA(acc_S.layout))
|
||||
tOrP_acc = layout_utils.reshape_acc_to_frgA(acc_S)
|
||||
tOrP_cur = (
|
||||
tOrP if const_expr(self.mma_pv_is_rs) else cute.make_fragment_like(tOrP_acc, self.dtype)
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass import Float32, Int32, const_expr
|
||||
|
||||
from quack import layout_utils
|
||||
import flash_attn.cute.utils as utils
|
||||
from flash_attn.cute.seqlen_info import SeqlenInfoQK
|
||||
|
||||
@@ -140,13 +141,13 @@ class AttentionMask:
|
||||
fastdiv_mods=(None, None),
|
||||
) -> None:
|
||||
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
|
||||
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S, transpose=self.swap_AB)
|
||||
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S, transpose=self.swap_AB)
|
||||
acc_shape = (self.tile_m, self.tile_n)
|
||||
cS = cute.make_identity_tensor(acc_shape if not self.swap_AB else acc_shape[::-1])
|
||||
tScS_mn = utils.make_acc_tensor_mn_view(thr_mma.partition_C(cS), transpose=self.swap_AB)
|
||||
tScS_mn = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cS), transpose=self.swap_AB)
|
||||
# We use t0ScS as these indices are known at compile time. We then must subtract the
|
||||
# column limit by the thread column offset.
|
||||
t0ScS_mn = utils.make_acc_tensor_mn_view(
|
||||
t0ScS_mn = layout_utils.reshape_acc_to_mn(
|
||||
thr_mma.get_slice(0).partition_C(cS), transpose=self.swap_AB
|
||||
)
|
||||
ROW = 0 if const_expr(not self.swap_AB) else 1
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from quack import layout_utils
|
||||
import flash_attn.cute.utils as utils
|
||||
|
||||
|
||||
@@ -98,7 +99,7 @@ class PackGQA:
|
||||
thr_mma = tiled_mma.get_slice(tidx)
|
||||
caccO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded))
|
||||
taccOcO = thr_mma.partition_C(caccO)
|
||||
taccOcO_row = utils.make_acc_tensor_mn_view(taccOcO)[None, 0]
|
||||
taccOcO_row = layout_utils.reshape_acc_to_mn(taccOcO)[None, 0]
|
||||
assert cute.size(tLSErLSE) == cute.size(taccOcO_row)
|
||||
threads_per_row = tiled_mma.tv_layout_C.shape[0][0]
|
||||
assert cute.arch.WARP_SIZE % threads_per_row == 0, "threads_per_row must divide WARP_SIZE"
|
||||
|
||||
@@ -28,7 +28,7 @@ dependencies = [
|
||||
"typing_extensions",
|
||||
"apache-tvm-ffi>=0.1.5,<0.2",
|
||||
"torch-c-dlpack-ext",
|
||||
"quack-kernels>=0.2.7",
|
||||
"quack-kernels>=0.2.8",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -9,6 +9,7 @@ import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass import Float32
|
||||
|
||||
from quack import layout_utils
|
||||
import flash_attn.cute.utils as utils
|
||||
from flash_attn.cute.cute_dsl_utils import ParamsBase
|
||||
from flash_attn.cute.seqlen_info import SeqlenInfoQK
|
||||
@@ -63,7 +64,7 @@ class Softmax(ParamsBase):
|
||||
:type is_first: cutlass.Constexpr
|
||||
"""
|
||||
# Change acc_S to M,N layout view.
|
||||
acc_S_mn = utils.make_acc_tensor_mn_view(acc_S)
|
||||
acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S)
|
||||
row_scale = cute.make_fragment_like(self.row_max, Float32)
|
||||
|
||||
row_max = self.row_max
|
||||
@@ -153,7 +154,7 @@ class Softmax(ParamsBase):
|
||||
:param row_scale: row_scale tensor
|
||||
:type row_scale: cute.Tensor
|
||||
"""
|
||||
acc_O_mn = utils.make_acc_tensor_mn_view(acc_O)
|
||||
acc_O_mn = layout_utils.reshape_acc_to_mn(acc_O)
|
||||
assert cute.size(row_scale) == cute.size(acc_O_mn, mode=[0])
|
||||
for r in cutlass.range(cute.size(row_scale), unroll_full=True):
|
||||
acc_O_mn[r, None].store(acc_O_mn[r, None].load() * row_scale[r])
|
||||
|
||||
@@ -163,85 +163,6 @@ def warp_reduce(
|
||||
return val
|
||||
|
||||
|
||||
def convert_layout_acc_mn(acc_layout: cute.Layout, transpose: bool = False) -> cute.Layout:
|
||||
"""
|
||||
For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...).
|
||||
For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...).
|
||||
"""
|
||||
acc_layout_col_major = cute.make_layout(acc_layout.shape)
|
||||
shape = (
|
||||
(acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M
|
||||
(
|
||||
acc_layout_col_major.shape[0][0],
|
||||
*acc_layout_col_major.shape[0][2:],
|
||||
acc_layout_col_major.shape[2],
|
||||
), # MMA_N
|
||||
*acc_layout_col_major.shape[3:],
|
||||
)
|
||||
stride = (
|
||||
(acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M
|
||||
(
|
||||
acc_layout_col_major.stride[0][0],
|
||||
*acc_layout_col_major.stride[0][2:],
|
||||
acc_layout_col_major.stride[2],
|
||||
), # MMA_N
|
||||
*acc_layout_col_major.stride[3:],
|
||||
)
|
||||
if const_expr(transpose):
|
||||
shape = (shape[1], shape[0], *shape[2:])
|
||||
stride = (stride[1], stride[0], *stride[2:])
|
||||
acc_layout_mn = cute.make_layout(shape, stride=stride)
|
||||
return cute.composition(acc_layout, acc_layout_mn)
|
||||
|
||||
|
||||
def make_acc_tensor_mn_view(acc: cute.Tensor, transpose: bool = False) -> cute.Tensor:
|
||||
return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout, transpose=transpose))
|
||||
|
||||
|
||||
@cute.jit
|
||||
def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout:
|
||||
# For back to back gemm, convert layout of acc0 to gemm 1 accept layout.
|
||||
# For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
# For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
|
||||
# TODO: Sm90 FP8
|
||||
if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90
|
||||
l = cute.logical_divide(
|
||||
acc_layout, ((None, None, 2), None, None)
|
||||
) # ((2, 2, (2, N / 16)), MMA_M, MMA_N)
|
||||
rA_mma_view = cute.make_layout(
|
||||
(
|
||||
(l.shape[0][0], l.shape[0][1], l.shape[0][2][0]),
|
||||
l.shape[1],
|
||||
(l.shape[0][2][1], l.shape[2]),
|
||||
),
|
||||
stride=(
|
||||
(l.stride[0][0], l.stride[0][1], l.stride[0][2][0]),
|
||||
l.stride[1],
|
||||
(l.stride[0][2][1], l.stride[2]),
|
||||
),
|
||||
)
|
||||
else: # Sm80
|
||||
# (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2))
|
||||
l = cute.logical_divide(acc_layout, (None, None, 2))
|
||||
rA_mma_view = cute.make_layout(
|
||||
(
|
||||
(l.shape[0], l.shape[2][0]),
|
||||
l.shape[1],
|
||||
l.shape[2][1],
|
||||
),
|
||||
stride=(
|
||||
(l.stride[0], l.stride[2][0]),
|
||||
l.stride[1],
|
||||
l.stride[2][1],
|
||||
),
|
||||
)
|
||||
return rA_mma_view
|
||||
|
||||
|
||||
def make_acc_tensor_frgA_view(acc: cute.Tensor) -> cute.Tensor:
|
||||
return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout))
|
||||
|
||||
|
||||
def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor:
|
||||
return cute.make_tensor(a.iterator, cute.select(a.layout, mode))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user