remove 24 clamp

This commit is contained in:
Henry Tsang
2026-01-15 20:27:15 -08:00
parent e34d84057d
commit 08e65188b5
+3 -6
View File
@@ -86,13 +86,10 @@ def mask_r2p_dual_bound(
ncol = const_expr(cute.size(X.shape))
for s in cutlass.range_constexpr(cute.ceil_div(ncol, 24)):
# Don't need to clamp to 32 since the shr.u32 instruction does that already
right_s = max(col_limit_right - s * 24, 0)
left_s = max(col_limit_left - s * 24, 0)
# Clamp to chunk size
right_s = min(right_s, 24)
left_s = min(left_s, 24)
# XOR creates range mask: bits left_s..(right_s-1) are 1
mask_right = (1 << right_s) - 1
mask_left = (1 << left_s) - 1
@@ -100,9 +97,9 @@ def mask_r2p_dual_bound(
# This needs to be range_constexpr, o/w the compiler can't generate the R2P instruction
for i in cutlass.range_constexpr(min(24, ncol - s * 24)):
out_bound = cutlass.Boolean(mask_range & (1 << i))
in_bound = cutlass.Boolean(mask_range & (1 << i))
c = s * 24 + i
X[c] = -Float32.inf if not out_bound else X[c]
X[c] = X[c] if in_bound else -Float32.inf
@dataclass(frozen=True)