[DSL] Use cute.math.{exp2,log2,log}

This commit is contained in:
Tri Dao
2026-02-08 18:02:19 +07:00
parent 8dd8019cef
commit 90f10faafd
6 changed files with 41 additions and 70 deletions
+2 -3
View File
@@ -14,7 +14,6 @@ from cutlass import Float32, Int32, const_expr
# Import data structures from block_sparsity
from flash_attn.cute.block_sparsity import BlockSparseTensors
from flash_attn.cute import utils
from flash_attn.cute import copy_utils
from flash_attn.cute.named_barrier import NamedBarrierBwd
@@ -698,8 +697,8 @@ def handle_block_sparse_empty_tile_correction_sm100(
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
row_sum_value = Float32(1.0)
else:
row_sum_value = row_sum_value + utils.exp2f(
sink_val * LOG2_E - row_max_value * softmax_scale_log2
row_sum_value = row_sum_value + cute.math.exp2(
sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True
)
if tidx < m_block_size:
scale_row_idx = tidx + stage * m_block_size
+1 -1
View File
@@ -882,7 +882,7 @@ class FlashAttentionBackwardSm80:
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)
for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r]))
acc_S_mn[r, None].store(cute.math.exp2(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True))
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
# MMA dP
+4 -2
View File
@@ -540,13 +540,15 @@ class FlashAttentionForwardCombine:
LOG2_E = math.log2(math.e)
lse_sum_cur = 0.0
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E))
scale = cute.math.exp2(
ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True
)
lse_sum_cur += scale
ts2rrLSE[0, s, m] = scale # Store scale for later use
lse_sum_cur = cute.arch.warp_reduction_sum(
lse_sum_cur, threads_in_group=threads_per_col
)
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
# Normalize scales
inv_sum = (
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
+5 -5
View File
@@ -1863,7 +1863,7 @@ class FlashAttentionForwardSm100:
# )
# LN2 = math.log(2.0)
# lse = (
# (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2
# (softmax.row_max[0] * softmax.scale_log2 + cute.math.log2(softmax.row_sum[0], fastmath=True)) * LN2
# if not acc_O_mn_row_is_zero_or_nan else -Float32.inf
# )
# if const_expr(not seqlen.has_cu_seqlens_q):
@@ -2004,7 +2004,7 @@ class FlashAttentionForwardSm100:
mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
)
softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first)
# acc_scale = cute.arch.exp2(acc_scale_)
# acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1
@cute.jit
@@ -2170,8 +2170,8 @@ class FlashAttentionForwardSm100:
row_max = sink_val * (LOG2_E / softmax_scale_log2)
row_sum = Float32(1.0)
else:
row_sum += utils.exp2f(
sink_val * LOG2_E - row_max * softmax_scale_log2
row_sum += cute.math.exp2(
sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True
)
acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum
stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
@@ -2276,7 +2276,7 @@ class FlashAttentionForwardSm100:
# cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
LN2 = math.log(2.0)
lse = (
(row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2
(row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2
if not acc_O_mn_row_is_zero_or_nan
else -Float32.inf
)
+29 -21
View File
@@ -92,16 +92,20 @@ class Softmax(ParamsBase):
if cutlass.const_expr(is_first):
row_max_cur_scaled = row_max_cur * scale_log2
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
acc_S_row_exp = cute.math.exp2(
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
)
acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
row_scale[r] = 1.0
else:
row_max_cur_scaled = row_max_cur * scale_log2
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
# row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2)
acc_S_row_exp = cute.math.exp2(
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
)
# row_scale[r] = cute.math.exp2(row_max_prev * self.scale_log2 - row_max_cur_scaled)
row_scale[r] = cute.math.exp2(
(row_max_prev - row_max_cur) * scale_log2, fastmath=True
)
acc_S_row_sum = utils.fadd_reduce(
acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
)
@@ -130,7 +134,9 @@ class Softmax(ParamsBase):
if cutlass.const_expr(sink_val is not None):
sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
LOG2_E = math.log2(math.e)
row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2)
row_sum[r] += cute.math.exp2(
sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True
)
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
@@ -140,7 +146,7 @@ class Softmax(ParamsBase):
row_sum_cur = row_sum[r]
LN2 = math.log(2.0)
row_sum[r] = (
(row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2
(row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2
if not acc_O_mn_row_is_zero_or_nan
else -Float32.inf
)
@@ -195,7 +201,7 @@ class SoftmaxSm100(Softmax):
row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
acc_scale = utils.exp2f(acc_scale_)
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if cutlass.const_expr(self.rescale_threshold > 0.0):
if acc_scale_ >= -self.rescale_threshold:
row_max_new = row_max_old
@@ -249,17 +255,19 @@ class SoftmaxSm100(Softmax):
)
for j in cutlass.range_constexpr(frg_cnt):
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
# acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
# acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
# acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
# acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
if cutlass.const_expr(not e2e):
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
else:
if cutlass.const_expr(
k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit
):
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
acc_S_row_frg[k + 1, j] = cute.math.exp2(
acc_S_row_frg[k + 1, j], fastmath=True
)
else:
# acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
@@ -291,8 +299,8 @@ class SoftmaxSm100(Softmax):
# (self.scale_log2, self.scale_log2),
# (minus_row_max_scaled, minus_row_max_scaled),
# )
# acc_S_row[i] = cute.arch.exp2(acc_S_row[i])
# acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1])
# acc_S_row[i] = cute.math.exp2(acc_S_row[i], fastmath=True)
# acc_S_row[i + 1] = cute.math.exp2(acc_S_row[i + 1], fastmath=True)
frg_tile = 32
assert frg_tile % 2 == 0
@@ -311,10 +319,10 @@ class SoftmaxSm100(Softmax):
# (minus_row_max_scaled, minus_row_max_scaled),
# )
# )
# acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
# acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
# acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
# acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
acc_S_row_converted_frg[None, j].store(
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
)
-38
View File
@@ -200,44 +200,6 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
@cute.jit
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
"""exp2f calculation for both vector and scalar.
:param x: input value
:type x: cute.TensorSSA or Float32
:return: exp2 value
:rtype: cute.TensorSSA or Float32
"""
if const_expr(isinstance(x, cute.TensorSSA)):
res = cute.make_fragment(x.shape, Float32)
res.store(x)
for i in cutlass.range_constexpr(cute.size(x.shape)):
res[i] = cute.arch.exp2(res[i])
return res.load()
else:
return cute.arch.exp2(x)
@dsl_user_op
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
return Float32(
llvm.inline_asm(
T.f32(),
[Float32(a).ir_value(loc=loc, ip=ip)],
"lg2.approx.ftz.f32 $0, $1;",
"=f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def logf(a: float | Float32, *, loc=None, ip=None) -> Float32:
return log2f(a, loc=loc, ip=ip) * math.log(2.0)
@dsl_user_op
def fmax(
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None