mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
Merge pull request #2185 from henrylhtsang/test_local_r2p
[Cute,Fwd,Sm100] Add r2p for local mask
This commit is contained in:
+49
-8
@@ -68,6 +68,43 @@ def mask_r2p_transposed(X: cute.Tensor, row_limit_top: Int32, num_rep: int) -> N
|
||||
# cute.printf("tidx = {}, s = {}, i = {}, row_limit_top = {}, row_limit_top_s = {}, mask = {}, out_bound = {}", tidx, s, i, row_limit_top, row_limit_top_s, mask, out_bound)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def mask_r2p_dual_bound(
|
||||
X: cute.Tensor,
|
||||
col_limit_left: Int32, # Inclusive lower bound
|
||||
col_limit_right: Int32, # Exclusive upper bound
|
||||
) -> None:
|
||||
"""
|
||||
Dual-bound masking using two bitmasks for SM100, following mask_r2p.
|
||||
Masks elements where: NOT (col_limit_left <= col < col_limit_right)
|
||||
|
||||
Uses bit manipulation to create a range mask:
|
||||
mask_right = (1 << right) - 1 -> bits (right-1)..0 are 1
|
||||
mask_left = (1 << left) - 1 -> bits (left-1)..0 are 1
|
||||
mask_range = mask_range = mask_right & ~ mask_left -> bits (right-1)..left are 1
|
||||
"""
|
||||
ncol = const_expr(cute.size(X.shape))
|
||||
|
||||
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
||||
right_s = max(col_limit_right - s * 24, 0)
|
||||
left_s = max(col_limit_left - s * 24, 0)
|
||||
|
||||
# otherwise cute dsl complains about python int too large to convert into c long
|
||||
right_s = min(right_s, 24)
|
||||
left_s = min(left_s, 24)
|
||||
|
||||
# bits (right-1)..left are 1
|
||||
mask_right = (1 << right_s) - 1
|
||||
mask_left = (1 << left_s) - 1
|
||||
mask_range = mask_right & ~mask_left
|
||||
|
||||
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
||||
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
||||
in_bound = cutlass.Boolean(mask_range & (1 << i))
|
||||
c = s * 24 + i
|
||||
X[c] = X[c] if in_bound else -Float32.inf
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttentionMask:
|
||||
tile_m: cutlass.Constexpr[int]
|
||||
@@ -444,14 +481,18 @@ class AttentionMask:
|
||||
if const_expr(self.window_size_left is not None)
|
||||
else 0
|
||||
)
|
||||
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
|
||||
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
|
||||
col_idx = tScS_t2r[i][1]
|
||||
acc_S[i] = (
|
||||
-Float32.inf
|
||||
if col_idx >= col_limit_right or col_idx < col_limit_left
|
||||
else acc_S[i]
|
||||
)
|
||||
if const_expr(not r2p):
|
||||
# if cute.arch.thread_idx()[0] == 0 or cute.arch.thread_idx()[0] == 128: cute.printf("m_block = {}, n_block = {}, row_idx = {}, causal_row_offset = {}, col_limit_right = {}, col_limit_left = {}", m_block, n_block, row_idx, causal_row_offset, col_limit_right, col_limit_left)
|
||||
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
|
||||
col_idx = tScS_t2r[i][1]
|
||||
acc_S[i] = (
|
||||
-Float32.inf
|
||||
if col_idx >= col_limit_right or col_idx < col_limit_left
|
||||
else acc_S[i]
|
||||
)
|
||||
else:
|
||||
# XOR-based R2P dual bound masking
|
||||
mask_r2p_dual_bound(acc_S, col_limit_left, col_limit_right)
|
||||
|
||||
@cute.jit
|
||||
def apply_mask_sm100_transposed(
|
||||
|
||||
Reference in New Issue
Block a user