mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
[DSL] Use cute.math.{exp2,log2,log}
This commit is contained in:
@@ -14,7 +14,6 @@ from cutlass import Float32, Int32, const_expr
|
||||
|
||||
# Import data structures from block_sparsity
|
||||
from flash_attn.cute.block_sparsity import BlockSparseTensors
|
||||
from flash_attn.cute import utils
|
||||
from flash_attn.cute import copy_utils
|
||||
from flash_attn.cute.named_barrier import NamedBarrierBwd
|
||||
|
||||
@@ -698,8 +697,8 @@ def handle_block_sparse_empty_tile_correction_sm100(
|
||||
row_max_value = sink_val * (LOG2_E / softmax_scale_log2)
|
||||
row_sum_value = Float32(1.0)
|
||||
else:
|
||||
row_sum_value = row_sum_value + utils.exp2f(
|
||||
sink_val * LOG2_E - row_max_value * softmax_scale_log2
|
||||
row_sum_value = row_sum_value + cute.math.exp2(
|
||||
sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True
|
||||
)
|
||||
if tidx < m_block_size:
|
||||
scale_row_idx = tidx + stage * m_block_size
|
||||
|
||||
@@ -882,7 +882,7 @@ class FlashAttentionBackwardSm80:
|
||||
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE)
|
||||
assert cute.size(acc_S_mn, mode=[0]) == cute.size(tLSErLSE)
|
||||
for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True):
|
||||
acc_S_mn[r, None].store(utils.exp2f(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r]))
|
||||
acc_S_mn[r, None].store(cute.math.exp2(acc_S_mn[r, None].load() * softmax_scale_log2 - tLSErLSE[r], fastmath=True))
|
||||
# if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn)
|
||||
|
||||
# MMA dP
|
||||
|
||||
@@ -540,13 +540,15 @@ class FlashAttentionForwardCombine:
|
||||
LOG2_E = math.log2(math.e)
|
||||
lse_sum_cur = 0.0
|
||||
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
||||
scale = utils.exp2f(ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E))
|
||||
scale = cute.math.exp2(
|
||||
ts2rrLSE[0, s, m] * LOG2_E - (lse_max_cur * LOG2_E), fastmath=True
|
||||
)
|
||||
lse_sum_cur += scale
|
||||
ts2rrLSE[0, s, m] = scale # Store scale for later use
|
||||
lse_sum_cur = cute.arch.warp_reduction_sum(
|
||||
lse_sum_cur, threads_in_group=threads_per_col
|
||||
)
|
||||
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
|
||||
lse_sum[m] = cute.math.log(lse_sum_cur, fastmath=True) + lse_max
|
||||
# Normalize scales
|
||||
inv_sum = (
|
||||
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
|
||||
|
||||
@@ -1863,7 +1863,7 @@ class FlashAttentionForwardSm100:
|
||||
# )
|
||||
# LN2 = math.log(2.0)
|
||||
# lse = (
|
||||
# (softmax.row_max[0] * softmax.scale_log2 + utils.log2f(softmax.row_sum[0])) * LN2
|
||||
# (softmax.row_max[0] * softmax.scale_log2 + cute.math.log2(softmax.row_sum[0], fastmath=True)) * LN2
|
||||
# if not acc_O_mn_row_is_zero_or_nan else -Float32.inf
|
||||
# )
|
||||
# if const_expr(not seqlen.has_cu_seqlens_q):
|
||||
@@ -2004,7 +2004,7 @@ class FlashAttentionForwardSm100:
|
||||
mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase
|
||||
)
|
||||
softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first)
|
||||
# acc_scale = cute.arch.exp2(acc_scale_)
|
||||
# acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
return mma_si_consumer_phase ^ 1, si_corr_producer_phase ^ 1, s0_s1_sequence_phase ^ 1
|
||||
|
||||
@cute.jit
|
||||
@@ -2170,8 +2170,8 @@ class FlashAttentionForwardSm100:
|
||||
row_max = sink_val * (LOG2_E / softmax_scale_log2)
|
||||
row_sum = Float32(1.0)
|
||||
else:
|
||||
row_sum += utils.exp2f(
|
||||
sink_val * LOG2_E - row_max * softmax_scale_log2
|
||||
row_sum += cute.math.exp2(
|
||||
sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True
|
||||
)
|
||||
acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum
|
||||
stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
|
||||
@@ -2276,7 +2276,7 @@ class FlashAttentionForwardSm100:
|
||||
# cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan)
|
||||
LN2 = math.log(2.0)
|
||||
lse = (
|
||||
(row_max * softmax_scale_log2 + utils.log2f(row_sum)) * LN2
|
||||
(row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2
|
||||
if not acc_O_mn_row_is_zero_or_nan
|
||||
else -Float32.inf
|
||||
)
|
||||
|
||||
+29
-21
@@ -92,16 +92,20 @@ class Softmax(ParamsBase):
|
||||
|
||||
if cutlass.const_expr(is_first):
|
||||
row_max_cur_scaled = row_max_cur * scale_log2
|
||||
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
|
||||
|
||||
acc_S_row_exp = cute.math.exp2(
|
||||
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
|
||||
)
|
||||
acc_S_row_sum = utils.fadd_reduce(acc_S_row_exp, init_val=None, arch=arch)
|
||||
row_scale[r] = 1.0
|
||||
else:
|
||||
row_max_cur_scaled = row_max_cur * scale_log2
|
||||
acc_S_row_exp = utils.exp2f(acc_S_row * scale_log2 - row_max_cur_scaled)
|
||||
# row_scale[r] = utils.exp2f(row_max_prev * self.scale_log2 - row_max_cur_scaled)
|
||||
row_scale[r] = utils.exp2f((row_max_prev - row_max_cur) * scale_log2)
|
||||
|
||||
acc_S_row_exp = cute.math.exp2(
|
||||
acc_S_row * scale_log2 - row_max_cur_scaled, fastmath=True
|
||||
)
|
||||
# row_scale[r] = cute.math.exp2(row_max_prev * self.scale_log2 - row_max_cur_scaled)
|
||||
row_scale[r] = cute.math.exp2(
|
||||
(row_max_prev - row_max_cur) * scale_log2, fastmath=True
|
||||
)
|
||||
acc_S_row_sum = utils.fadd_reduce(
|
||||
acc_S_row_exp, init_val=row_sum[r] * row_scale[r], arch=arch
|
||||
)
|
||||
@@ -130,7 +134,9 @@ class Softmax(ParamsBase):
|
||||
if cutlass.const_expr(sink_val is not None):
|
||||
sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r]
|
||||
LOG2_E = math.log2(math.e)
|
||||
row_sum[r] += utils.exp2f(sink_val_cur * LOG2_E - row_max[r] * scale_log2)
|
||||
row_sum[r] += cute.math.exp2(
|
||||
sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True
|
||||
)
|
||||
|
||||
# if row_sum is zero or nan, set acc_O_mn_row to 1.0
|
||||
acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r]
|
||||
@@ -140,7 +146,7 @@ class Softmax(ParamsBase):
|
||||
row_sum_cur = row_sum[r]
|
||||
LN2 = math.log(2.0)
|
||||
row_sum[r] = (
|
||||
(row_max[r] * scale_log2 + utils.log2f(row_sum_cur)) * LN2
|
||||
(row_max[r] * scale_log2 + cute.math.log2(row_sum_cur, fastmath=True)) * LN2
|
||||
if not acc_O_mn_row_is_zero_or_nan
|
||||
else -Float32.inf
|
||||
)
|
||||
@@ -195,7 +201,7 @@ class SoftmaxSm100(Softmax):
|
||||
row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old)
|
||||
row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0
|
||||
acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2
|
||||
acc_scale = utils.exp2f(acc_scale_)
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
if cutlass.const_expr(self.rescale_threshold > 0.0):
|
||||
if acc_scale_ >= -self.rescale_threshold:
|
||||
row_max_new = row_max_old
|
||||
@@ -249,17 +255,19 @@ class SoftmaxSm100(Softmax):
|
||||
)
|
||||
for j in cutlass.range_constexpr(frg_cnt):
|
||||
for k in cutlass.range_constexpr(0, cute.size(acc_S_row_frg, mode=[0]), 2):
|
||||
# acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
|
||||
# acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
|
||||
# acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
||||
# acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
||||
if cutlass.const_expr(not e2e):
|
||||
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
|
||||
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
|
||||
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
||||
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
||||
else:
|
||||
if cutlass.const_expr(
|
||||
k % e2e_freq < e2e_freq - e2e_res or j >= frg_cnt - e2e_frg_limit
|
||||
):
|
||||
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
|
||||
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
|
||||
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
||||
acc_S_row_frg[k + 1, j] = cute.math.exp2(
|
||||
acc_S_row_frg[k + 1, j], fastmath=True
|
||||
)
|
||||
else:
|
||||
# acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.e2e_asm2(acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j])
|
||||
acc_S_row_frg[k, j], acc_S_row_frg[k + 1, j] = utils.ex2_emulation_2(
|
||||
@@ -291,8 +299,8 @@ class SoftmaxSm100(Softmax):
|
||||
# (self.scale_log2, self.scale_log2),
|
||||
# (minus_row_max_scaled, minus_row_max_scaled),
|
||||
# )
|
||||
# acc_S_row[i] = cute.arch.exp2(acc_S_row[i])
|
||||
# acc_S_row[i + 1] = cute.arch.exp2(acc_S_row[i + 1])
|
||||
# acc_S_row[i] = cute.math.exp2(acc_S_row[i], fastmath=True)
|
||||
# acc_S_row[i + 1] = cute.math.exp2(acc_S_row[i + 1], fastmath=True)
|
||||
|
||||
frg_tile = 32
|
||||
assert frg_tile % 2 == 0
|
||||
@@ -311,10 +319,10 @@ class SoftmaxSm100(Softmax):
|
||||
# (minus_row_max_scaled, minus_row_max_scaled),
|
||||
# )
|
||||
# )
|
||||
# acc_S_row_frg[k, j] = utils.exp2f(acc_S_row_frg[k, j])
|
||||
# acc_S_row_frg[k + 1, j] = utils.exp2f(acc_S_row_frg[k + 1, j])
|
||||
acc_S_row_frg[k, j] = cute.arch.exp2(acc_S_row_frg[k, j])
|
||||
acc_S_row_frg[k + 1, j] = cute.arch.exp2(acc_S_row_frg[k + 1, j])
|
||||
# acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
||||
# acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
||||
acc_S_row_frg[k, j] = cute.math.exp2(acc_S_row_frg[k, j], fastmath=True)
|
||||
acc_S_row_frg[k + 1, j] = cute.math.exp2(acc_S_row_frg[k + 1, j], fastmath=True)
|
||||
acc_S_row_converted_frg[None, j].store(
|
||||
acc_S_row_frg[None, j].load().to(acc_S_row_converted.element_type)
|
||||
)
|
||||
|
||||
@@ -200,44 +200,6 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
|
||||
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")
|
||||
|
||||
|
||||
@cute.jit
|
||||
def exp2f(x: cute.TensorSSA | Float32) -> cute.TensorSSA | Float32:
|
||||
"""exp2f calculation for both vector and scalar.
|
||||
:param x: input value
|
||||
:type x: cute.TensorSSA or Float32
|
||||
:return: exp2 value
|
||||
:rtype: cute.TensorSSA or Float32
|
||||
"""
|
||||
if const_expr(isinstance(x, cute.TensorSSA)):
|
||||
res = cute.make_fragment(x.shape, Float32)
|
||||
res.store(x)
|
||||
for i in cutlass.range_constexpr(cute.size(x.shape)):
|
||||
res[i] = cute.arch.exp2(res[i])
|
||||
return res.load()
|
||||
else:
|
||||
return cute.arch.exp2(x)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
||||
return Float32(
|
||||
llvm.inline_asm(
|
||||
T.f32(),
|
||||
[Float32(a).ir_value(loc=loc, ip=ip)],
|
||||
"lg2.approx.ftz.f32 $0, $1;",
|
||||
"=f,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def logf(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
||||
return log2f(a, loc=loc, ip=ip) * math.log(2.0)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def fmax(
|
||||
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
|
||||
|
||||
Reference in New Issue
Block a user