Merge pull request #2185 from henrylhtsang/test_local_r2p

[Cute,Fwd,Sm100] Add r2p for local mask
This commit is contained in:
Markus Hoehnerbach
2026-01-16 16:06:04 -08:00
committed by GitHub
+49 -8
View File
@@ -68,6 +68,43 @@ def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> N
# cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
@cute.jit
def mask_r2p_dual_bound(
X: cute.Tensor,
col_limit_left: Int32, # Inclusive lower bound
col_limit_right: Int32, # Exclusive upper bound
) -> None:
"""
Dual-bound masking using two bitmasks for SM100, following mask_r2p.
Masks elements where: NOT (col_limit_left <= col < col_limit_right)
Uses bit manipulation to create a range mask:
mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1
mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1
mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
"""
ncol = const_expr(cute.size(X.shape))
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
right_s = max(col_limit_right - s * 24, 0)
left_s = max(col_limit_left - s * 24, 0)
# otherwise cute dsl complains about python int too large to convert into c long
right_s = min(right_s, 24)
left_s = min(left_s, 24)
# bits (right-1)..left are 1
mask_right = (1 << right_s) - 1
mask_left = (1 << left_s) - 1
mask_range = mask_right & ~mask_left
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
in_bound = cutlass.Boolean(mask_range & (1 << i))
c = s * 24 + i
X[c] = X[c] if in_bound else -Float32.inf
@dataclass(frozen=True)
class AttentionMask:
tile_m: cutlass.Constexpr[int]
@@ -444,14 +481,18 @@ class AttentionMask:
if const_expr(self.window_size_left is not None)
else 0
)
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
col_idx = tScS_t2r[i][1]
acc_S[i] = (
-Float32.inf
if col_idx >= col_limit_right or col_idx < col_limit_left
else acc_S[i]
)
if const_expr(not r2p):
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
col_idx = tScS_t2r[i][1]
acc_S[i] = (
-Float32.inf
if col_idx >= col_limit_right or col_idx < col_limit_left
else acc_S[i]
)
else:
# XOR-based R2P dual bound masking
mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right)
@cute.jit
def apply_mask_sm100_transposed(