[Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)

* fully shard paged KV address calculation across threads

* use t0 indices for static bound checking

* increase tiled copy to full KV row

* shrink predicate tensor

* clarify paged KV divisibility constraints

* increase load register allocation
This commit is contained in:
timmy-feng
2026-01-15 15:11:01 -05:00
committed by GitHub
parent 68649fb784
commit fffabc3de1
2 changed files with 57 additions and 18 deletions
+4 -4
View File
@@ -201,12 +201,12 @@ class FlashAttentionForwardSm100:
self.tmem_vec_offset = self.tmem_s_offset
if self.head_dim_padded < 96:
self.num_regs_softmax = 200
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
self.num_regs_correction = 64
self.num_regs_other = 48
self.num_regs_other = 48 if not paged_kv_non_tma else 80
else:
# self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184
self.num_regs_softmax = 200
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
# self.num_regs_softmax = 176
# self.num_regs_correction = 96
# self.num_regs_correction = 80
@@ -215,7 +215,7 @@ class FlashAttentionForwardSm100:
# self.num_regs_other = 32
# self.num_regs_other = 64
# self.num_regs_other = 80
self.num_regs_other = 48
self.num_regs_other = 48 if not paged_kv_non_tma else 80
# self.num_regs_other = 96 if self.is_causal or self.is_local else 80
# self.num_regs_other = 64 if self.is_causal or self.is_local else 80
self.num_regs_empty = 24
+53 -14
View File
@@ -10,6 +10,8 @@ from flash_attn.cute import utils
from flash_attn.cute.cute_dsl_utils import ParamsBase
from cutlass.cute import FastDivmodDivisor
import math
@dataclass
class PagedKVManager(ParamsBase):
@@ -55,8 +57,16 @@ class PagedKVManager(ParamsBase):
dtype: Type[cutlass.Numeric],
):
universal_copy_bits = 128
gmem_threads_per_row = 8 # 8 threads loading 128 bits = 128 bytes = 1 cache line
async_copy_elems = universal_copy_bits // dtype.width
dtype_bytes = dtype.width // 8
gmem_k_block_size = math.gcd(
head_dim_padded,
head_dim_v_padded,
128 // dtype_bytes,
)
assert gmem_k_block_size % async_copy_elems == 0
gmem_threads_per_row = gmem_k_block_size // async_copy_elems
assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0
atom_async_copy = cute.make_copy_atom(
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL),
dtype,
@@ -69,7 +79,7 @@ class PagedKVManager(ParamsBase):
val_layout = cute.make_layout((1, async_copy_elems))
gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout)
gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx)
page_entry_per_thread = n_block_size * gmem_threads_per_row // num_threads
page_entry_per_thread = n_block_size // num_threads
tPrPage = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
tPrPageOffset = cute.make_rmem_tensor((page_entry_per_thread,), Int32)
@@ -115,7 +125,12 @@ class PagedKVManager(ParamsBase):
@cute.jit
def load_page_table(self, n_block: Int32):
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
row = (i * self.num_threads + self.thread_idx) // self.gmem_threads_per_row
row = (
i * self.num_threads
+ (self.thread_idx % self.gmem_threads_per_row)
* (self.num_threads // self.gmem_threads_per_row)
+ (self.thread_idx // self.gmem_threads_per_row)
)
row_idx = n_block * self.n_block_size + row
page_idx, page_offset = divmod(row_idx + self.leftpad_k, self.page_size_divmod)
@@ -128,10 +143,24 @@ class PagedKVManager(ParamsBase):
self.tPrPage[i] = page
self.tPrPageOffset[i] = page_offset
@cute.jit
def compute_X_ptr(self, K_or_V: str):
tPrXPtr = cute.make_rmem_tensor((self.page_entry_per_thread,), cutlass.Int64)
for i in cutlass.range(self.page_entry_per_thread, unroll=1):
page = self.tPrPage[i]
page_offset = self.tPrPageOffset[i]
if const_expr(K_or_V == "K"):
tPrXPtr[i] = utils.elem_pointer(self.mK_paged, (page_offset, 0, page)).toint()
else:
tPrXPtr[i] = utils.elem_pointer(self.mV_paged, (0, page_offset, page)).toint()
return tPrXPtr
@cute.jit
def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
assert K_or_V in ("K", "V")
tPrXPtr = self.compute_X_ptr(K_or_V)
# Finesse sX layout to be (M, N).
sX_pi = cute.make_tensor(
sX.iterator,
@@ -149,27 +178,37 @@ class PagedKVManager(ParamsBase):
cX = cute.make_identity_tensor((self.n_block_size, head_dim))
tXsX = self.gmem_thr_copy_KV.partition_D(sX_pi)
tXcX = self.gmem_thr_copy_KV.partition_S(cX)
tXc0X = self.gmem_thr_copy_KV.get_slice(0).partition_S(cX)
seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0
seqlenk_row_limit = (
self.seqlen_k - n_block * self.n_block_size - tXcX[0][0] if n_block >= 0 else 0
)
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean)
row_valid = tXc0X[0, m, 0][0] < seqlenk_row_limit
should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], cute.Boolean)
should_load.fill(row_valid)
page = self.tPrPage[m]
page_offset = self.tPrPageOffset[m]
mX_paged_cur = (
self.mK_paged[page_offset, None, page]
if const_expr(K_or_V == "K")
else self.mV_paged[None, page_offset, page]
x_ptr_i64 = utils.shuffle_sync(
tPrXPtr[m // self.gmem_threads_per_row],
m % self.gmem_threads_per_row,
width=self.gmem_threads_per_row,
)
x_gmem_ptr = cute.make_ptr(
self.mK_paged.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16
)
mX_paged_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,)))
mX_paged_cur_copy = cute.tiled_divide(mX_paged_cur, (self.async_copy_elems,))
for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])):
ki = tXcX[0, 0, k][1] // self.async_copy_elems
mX_paged_cur_copy_ki = mX_paged_cur_copy[None, ki]
tXsX_k = tXsX[None, m, k]
mX_paged_cur_copy_ki = cute.make_tensor(
mX_paged_cur_copy_ki.iterator, tXsX_k.layout
)
cute.copy(
self.gmem_tiled_copy_KV,
mX_paged_cur_copy[None, ki],
tXsX[None, m, k],
mX_paged_cur_copy_ki,
tXsX_k,
pred=should_load,
)