mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
* fully shard paged KV address calculation across threads * use t0 indices for static bound checking * increase tiled copy to full KV row * shrink predicate tensor * clarify paged KV divisibility constraints * increase load register allocation
This commit is contained in:
@@ -201,12 +201,12 @@ class FlashAttentionForwardSm100:
|
||||
self.tmem_vec_offset = self.tmem_s_offset
|
||||
|
||||
if self.head_dim_padded < 96:
|
||||
self.num_regs_softmax = 200
|
||||
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
|
||||
self.num_regs_correction = 64
|
||||
self.num_regs_other = 48
|
||||
self.num_regs_other = 48 if not paged_kv_non_tma else 80
|
||||
else:
|
||||
# self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184
|
||||
self.num_regs_softmax = 200
|
||||
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
|
||||
# self.num_regs_softmax = 176
|
||||
# self.num_regs_correction = 96
|
||||
# self.num_regs_correction = 80
|
||||
@@ -215,7 +215,7 @@ class FlashAttentionForwardSm100:
|
||||
# self.num_regs_other = 32
|
||||
# self.num_regs_other = 64
|
||||
# self.num_regs_other = 80
|
||||
self.num_regs_other = 48
|
||||
self.num_regs_other = 48 if not paged_kv_non_tma else 80
|
||||
# self.num_regs_other = 96 if self.is_causal or self.is_local else 80
|
||||
# self.num_regs_other = 64 if self.is_causal or self.is_local else 80
|
||||
self.num_regs_empty = 24
|
||||
|
||||
+53
-14
@@ -10,6 +10,8 @@ from flash_attn.cute import utils
|
||||
from flash_attn.cute.cute_dsl_utils import ParamsBase
|
||||
from cutlass.cute import FastDivmodDivisor
|
||||
|
||||
import math
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagedKVManager(ParamsBase):
|
||||
@@ -55,8 +57,16 @@ class PagedKVManager(ParamsBase):
|
||||
dtype: Type[cutlass.Numeric],
|
||||
):
|
||||
universal_copy_bits = 128
|
||||
gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line
|
||||
async_copy_elems = universal_copy_bits // dtype.width
|
||||
dtype_bytes = dtype.width // 8
|
||||
gmem_k_block_size = math.gcd(
|
||||
head_dim_padded,
|
||||
head_dim_v_padded,
|
||||
128 // dtype_bytes,
|
||||
)
|
||||
assert gmem_k_block_size % async_copy_elems == 0
|
||||
gmem_threads_per_row = gmem_k_block_size // async_copy_elems
|
||||
assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0
|
||||
atom_async_copy = cute.make_copy_atom(
|
||||
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
|
||||
dtype,
|
||||
@@ -69,7 +79,7 @@ class PagedKVManager(ParamsBase):
|
||||
val_layout = cute.make_layout((1, async_copy_elems))
|
||||
gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
|
||||
gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
|
||||
page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads
|
||||
page_entry_per_thread = n_block_size // num_threads
|
||||
|
||||
tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
|
||||
tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
|
||||
@@ -115,7 +125,12 @@ class PagedKVManager(ParamsBase):
|
||||
@cute.jit
|
||||
def load_page_table(self, n_block: Int32):
|
||||
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
||||
row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row
|
||||
row = (
|
||||
i * self.num_threads
|
||||
+ (self.thread_idx % self.gmem_threads_per_row)
|
||||
* (self.num_threads // self.gmem_threads_per_row)
|
||||
+ (self.thread_idx // self.gmem_threads_per_row)
|
||||
)
|
||||
row_idx = n_block * self.n_block_size + row
|
||||
|
||||
page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)
|
||||
@@ -128,10 +143,24 @@ class PagedKVManager(ParamsBase):
|
||||
self.tPrPage[i] = page
|
||||
self.tPrPageOffset[i] = page_offset
|
||||
|
||||
@cute.jit
|
||||
def compute_X_ptr(self, K_or_V: str):
|
||||
tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
|
||||
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
|
||||
page = self.tPrPage[i]
|
||||
page_offset = self.tPrPageOffset[i]
|
||||
if const_expr(K_or_V == "K"):
|
||||
tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint()
|
||||
else:
|
||||
tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint()
|
||||
return tPrXPtr
|
||||
|
||||
@cute.jit
|
||||
def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
|
||||
assert K_or_V in ("K", "V")
|
||||
|
||||
tPrXPtr = self.compute_X_ptr(K_or_V)
|
||||
|
||||
# Finesse sX layout to be (M, N).
|
||||
sX_pi = cute.make_tensor(
|
||||
sX.iterator,
|
||||
@@ -149,27 +178,37 @@ class PagedKVManager(ParamsBase):
|
||||
cX = cute.make_identity_tensor((self.n_block_size, head_dim))
|
||||
tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)
|
||||
tXcX = self.gmem_thr_copy_KV.partition_S(cX)
|
||||
tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX)
|
||||
|
||||
seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0
|
||||
seqlenk_row_limit = (
|
||||
self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0
|
||||
)
|
||||
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
|
||||
row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
|
||||
should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean)
|
||||
row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit
|
||||
should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean)
|
||||
should_load.fill(row_valid)
|
||||
|
||||
page = self.tPrPage[m]
|
||||
page_offset = self.tPrPageOffset[m]
|
||||
mX_paged_cur = (
|
||||
self.mK_paged[page_offset, None, page]
|
||||
if const_expr(K_or_V == "K")
|
||||
else self.mV_paged[None, page_offset, page]
|
||||
x_ptr_i64 = utils.shuffle_sync(
|
||||
tPrXPtr[m // self.gmem_threads_per_row],
|
||||
m % self.gmem_threads_per_row,
|
||||
width=self.gmem_threads_per_row,
|
||||
)
|
||||
x_gmem_ptr = cute.make_ptr(
|
||||
self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
|
||||
)
|
||||
mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,)))
|
||||
mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))
|
||||
|
||||
for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):
|
||||
ki = tXcX[0, 0, k][1] // self.async_copy_elems
|
||||
mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki]
|
||||
tXsX_k = tXsX[None, m, k]
|
||||
mX_paged_cur_copy_ki = cute.make_tensor(
|
||||
mX_paged_cur_copy_ki.iterator, tXsX_k.layout
|
||||
)
|
||||
cute.copy(
|
||||
self.gmem_tiled_copy_KV,
|
||||
mX_paged_cur_copy[None, ki],
|
||||
tXsX[None, m, k],
|
||||
mX_paged_cur_copy_ki,
|
||||
tXsX_k,
|
||||
pred=should_load,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user