mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
remove 24 clamp
This commit is contained in:
@@ -86,13 +86,10 @@ def mask_r2p_dual_bound(
|
||||
ncol = const_expr(cute.size(X.shape))
|
||||
|
||||
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
|
||||
# Don't need to clamp to 32 since the shr.u32 instruction does that already
|
||||
right_s = max(col_limit_right - s * 24, 0)
|
||||
left_s = max(col_limit_left - s * 24, 0)
|
||||
|
||||
# Clamp to chunk size
|
||||
right_s = min(right_s, 24)
|
||||
left_s = min(left_s, 24)
|
||||
|
||||
# XOR creates range mask: bits left_s..(right_s-1) are 1
|
||||
mask_right = (1 << right_s) - 1
|
||||
mask_left = (1 << left_s) - 1
|
||||
@@ -100,9 +97,9 @@ def mask_r2p_dual_bound(
|
||||
|
||||
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
|
||||
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
|
||||
out_bound = cutlass.Boolean(mask_range & (1 << i))
|
||||
in_bound = cutlass.Boolean(mask_range & (1 << i))
|
||||
c = s * 24 + i
|
||||
X[c] = -Float32.inf if not out_bound else X[c]
|
||||
X[c] = X[c] if in_bound else -Float32.inf
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user