mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-06-30 21:07:55 -04:00
ruff all the smaller files (#2040)
This commit is contained in:
@@ -7,19 +7,10 @@ repos:
|
||||
files: ^flash_attn/cute/.*\.py$
|
||||
exclude: &cute_exclude |
|
||||
(?x)^flash_attn/cute/(
|
||||
__init__|
|
||||
copy_utils|
|
||||
cute_dsl_utils|
|
||||
fast_math|
|
||||
flash_bwd|
|
||||
flash_fwd|
|
||||
flash_fwd_combine|
|
||||
flash_fwd_sm100|
|
||||
hopper_helpers|
|
||||
interface|
|
||||
pack_gqa|
|
||||
testing|
|
||||
utils
|
||||
)\.py$
|
||||
- id: ruff-format
|
||||
files: ^flash_attn/cute/.*\.py$
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
|
||||
|
||||
import math
|
||||
from typing import Optional, Type, Tuple, Callable
|
||||
from typing import Optional, Type, Callable
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass import Float32, Int32, Boolean, const_expr
|
||||
from cutlass import Float32, Int32, const_expr
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
@@ -279,7 +279,7 @@ def cpasync_bulk_get_copy_fn(
|
||||
dst[None, dst_idx].iterator,
|
||||
size=size,
|
||||
**new_kwargs,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def copy_bulk_single_stage(**new_kwargs):
|
||||
|
||||
@@ -55,8 +55,13 @@ class FlashAttentionForwardCombine:
|
||||
|
||||
@staticmethod
|
||||
def can_implement(
|
||||
dtype, dtype_partial, head_dim, m_block_size, k_block_size,
|
||||
log_max_splits, num_threads,
|
||||
dtype,
|
||||
dtype_partial,
|
||||
head_dim,
|
||||
m_block_size,
|
||||
k_block_size,
|
||||
log_max_splits,
|
||||
num_threads,
|
||||
) -> bool:
|
||||
"""Check if the kernel can be implemented with the given parameters."""
|
||||
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
|
||||
@@ -83,8 +88,7 @@ class FlashAttentionForwardCombine:
|
||||
assert self.k_block_size % async_copy_elems == 0
|
||||
|
||||
k_block_gmem = (
|
||||
128 if self.k_block_size % 128 == 0 else
|
||||
(64 if self.k_block_size % 64 == 0 else 32)
|
||||
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
|
||||
)
|
||||
gmem_threads_per_row = k_block_gmem // async_copy_elems
|
||||
assert self.num_threads % gmem_threads_per_row == 0
|
||||
@@ -111,16 +115,25 @@ class FlashAttentionForwardCombine:
|
||||
num_bits_per_copy=async_copy_elems * self.dtype.width,
|
||||
)
|
||||
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
|
||||
atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store
|
||||
atom_universal_copy,
|
||||
tOpartial_layout,
|
||||
vOpartial_layout, # 4 vals per store
|
||||
)
|
||||
|
||||
# LSE copy setup with async copy (alignment = 1)
|
||||
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
|
||||
m_block_smem = (
|
||||
128 if self.m_block_size % 128 == 0 else
|
||||
(64 if self.m_block_size % 64 == 0 else
|
||||
(32 if self.m_block_size % 32 == 0 else
|
||||
(16 if self.m_block_size % 16 == 0 else 8)))
|
||||
128
|
||||
if self.m_block_size % 128 == 0
|
||||
else (
|
||||
64
|
||||
if self.m_block_size % 64 == 0
|
||||
else (
|
||||
32
|
||||
if self.m_block_size % 32 == 0
|
||||
else (16 if self.m_block_size % 16 == 0 else 8)
|
||||
)
|
||||
)
|
||||
)
|
||||
gmem_threads_per_row_lse = m_block_smem
|
||||
assert self.num_threads % gmem_threads_per_row_lse == 0
|
||||
@@ -167,9 +180,7 @@ class FlashAttentionForwardCombine:
|
||||
else:
|
||||
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
|
||||
smem_layout_atom_lse = cute.make_composed_layout(
|
||||
smem_lse_swizzle,
|
||||
0,
|
||||
cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
|
||||
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
|
||||
)
|
||||
self.smem_layout_lse = cute.tile_to_shape(
|
||||
smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
|
||||
@@ -177,11 +188,9 @@ class FlashAttentionForwardCombine:
|
||||
|
||||
# O partial shared memory layout (simple layout for pipeline stages)
|
||||
self.smem_layout_o = cute.make_ordered_layout(
|
||||
(self.m_block_size, self.k_block_size, self.stages),
|
||||
order=(1, 0, 2)
|
||||
(self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
@@ -200,38 +209,63 @@ class FlashAttentionForwardCombine:
|
||||
raise TypeError("O partial tensor must match dtype_partial")
|
||||
if const_expr(not (mO.element_type == self.dtype)):
|
||||
raise TypeError("O tensor must match dtype")
|
||||
if const_expr(not mLSE_partial.element_type in [Float32]):
|
||||
if const_expr(mLSE_partial.element_type not in [Float32]):
|
||||
raise TypeError("LSE partial tensor must be Float32")
|
||||
if const_expr(mLSE is not None and not mLSE.element_type in [Float32]):
|
||||
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
|
||||
raise TypeError("LSE tensor must be Float32")
|
||||
|
||||
# Shape validation - input tensors are in user format, need to be converted to kernel format
|
||||
if const_expr(len(mO_partial.shape) not in [4, 5]):
|
||||
raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)")
|
||||
raise ValueError(
|
||||
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
|
||||
)
|
||||
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
|
||||
raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)")
|
||||
raise ValueError(
|
||||
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
|
||||
)
|
||||
if const_expr(len(mO.shape) not in [3, 4]):
|
||||
raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)")
|
||||
raise ValueError(
|
||||
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
|
||||
)
|
||||
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
|
||||
raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)")
|
||||
raise ValueError(
|
||||
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
|
||||
)
|
||||
|
||||
# Assume all strides are divisible by 128 bits except the last stride
|
||||
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
|
||||
mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)]
|
||||
new_stride = lambda t: (
|
||||
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
|
||||
t.stride[-1],
|
||||
)
|
||||
mO_partial, mO = [
|
||||
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
|
||||
for t in (mO_partial, mO)
|
||||
]
|
||||
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
|
||||
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
|
||||
O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
|
||||
O_partial_layout_transpose = (
|
||||
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
|
||||
)
|
||||
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
|
||||
mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose))
|
||||
mO_partial = cute.make_tensor(
|
||||
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
|
||||
)
|
||||
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
|
||||
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
|
||||
# (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
|
||||
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
|
||||
LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
|
||||
mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose))
|
||||
mLSE_partial = cute.make_tensor(
|
||||
mLSE_partial.iterator,
|
||||
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
|
||||
)
|
||||
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
|
||||
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
|
||||
mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None
|
||||
mLSE = (
|
||||
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
|
||||
if mLSE is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Determine if we have variable length sequences
|
||||
varlen = const_expr(cu_seqlens is not None or seqused is not None)
|
||||
@@ -243,9 +277,7 @@ class FlashAttentionForwardCombine:
|
||||
sLSE: cute.struct.Align[
|
||||
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
|
||||
]
|
||||
sMaxValidSplit: cute.struct.Align[
|
||||
cute.struct.MemRange[Int32, self.m_block_size], 128
|
||||
]
|
||||
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
|
||||
sO: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
|
||||
]
|
||||
@@ -255,7 +287,11 @@ class FlashAttentionForwardCombine:
|
||||
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
|
||||
seqlen = mO_partial.shape[0]
|
||||
num_head = mO_partial.shape[3]
|
||||
batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1)
|
||||
batch_size = (
|
||||
mO_partial.shape[4]
|
||||
if const_expr(cu_seqlens is None)
|
||||
else Int32(cu_seqlens.shape[0] - 1)
|
||||
)
|
||||
|
||||
# Create FastDivmodDivisor objects for efficient division
|
||||
seqlen_divmod = FastDivmodDivisor(seqlen)
|
||||
@@ -330,14 +366,18 @@ class FlashAttentionForwardCombine:
|
||||
|
||||
# Handle semaphore reset
|
||||
if const_expr(semaphore_to_reset is not None):
|
||||
if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and
|
||||
k_block == cute.arch.grid_dim()[1] - 1 and
|
||||
batch_idx == cute.arch.grid_dim()[2] - 1):
|
||||
if (
|
||||
tidx == 0
|
||||
and m_block == cute.arch.grid_dim()[0] - 1
|
||||
and k_block == cute.arch.grid_dim()[1] - 1
|
||||
and batch_idx == cute.arch.grid_dim()[2] - 1
|
||||
):
|
||||
semaphore_to_reset[0] = 0
|
||||
|
||||
# Get number of splits
|
||||
num_splits = (
|
||||
num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None)
|
||||
num_splits_dynamic_ptr[batch_idx]
|
||||
if const_expr(num_splits_dynamic_ptr is not None)
|
||||
else mLSE_partial.shape[1]
|
||||
)
|
||||
# Handle variable length sequences using SeqlenInfo
|
||||
@@ -345,7 +385,7 @@ class FlashAttentionForwardCombine:
|
||||
batch_idx=batch_idx,
|
||||
seqlen_static=mO_partial.shape[0],
|
||||
cu_seqlens=cu_seqlens,
|
||||
seqused=seqused
|
||||
seqused=seqused,
|
||||
)
|
||||
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
|
||||
|
||||
@@ -354,8 +394,9 @@ class FlashAttentionForwardCombine:
|
||||
max_idx = seqlen * num_head
|
||||
|
||||
# Early exit for single split if dynamic
|
||||
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx):
|
||||
|
||||
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
|
||||
const_expr(not varlen) or m_block * self.m_block_size < max_idx
|
||||
):
|
||||
# ===============================
|
||||
# Step 1: Load LSE_partial from gmem to shared memory
|
||||
# ===============================
|
||||
@@ -390,7 +431,11 @@ class FlashAttentionForwardCombine:
|
||||
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
|
||||
si = tLSEcLSE[0, s, 0][0] # Get split coordinate
|
||||
if si < num_splits:
|
||||
cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m])
|
||||
cute.copy(
|
||||
gmem_thr_copy_LSE,
|
||||
mLSE_partial_cur_copy[None, si],
|
||||
tLSEsLSE[None, s, m],
|
||||
)
|
||||
else:
|
||||
tLSEsLSE[None, s, m].fill(-Float32.inf)
|
||||
# Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
|
||||
@@ -424,7 +469,9 @@ class FlashAttentionForwardCombine:
|
||||
else:
|
||||
tOhidx[m] = idx // seqlen
|
||||
tOmidx[m] = idx - tOhidx[m] * seqlen
|
||||
tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint()
|
||||
tOrOptr[m] = utils.elem_pointer_i64(
|
||||
mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
|
||||
).toint()
|
||||
if idx >= max_idx:
|
||||
tOhidx[m] = -1
|
||||
|
||||
@@ -483,7 +530,9 @@ class FlashAttentionForwardCombine:
|
||||
# Find max LSE value across splits
|
||||
threads_per_col = const_expr(self.smem_threads_per_col_lse)
|
||||
lse_max = utils.warp_reduce(
|
||||
ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
||||
ts2rrLSE[None, None, m]
|
||||
.load()
|
||||
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
|
||||
op=cute.arch.fmax,
|
||||
width=threads_per_col,
|
||||
)
|
||||
@@ -496,7 +545,9 @@ class FlashAttentionForwardCombine:
|
||||
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
|
||||
max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col)
|
||||
# Compute exp scales and sum
|
||||
lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf
|
||||
lse_max_cur = (
|
||||
0.0 if lse_max == -Float32.inf else lse_max
|
||||
) # In case all local LSEs are -inf
|
||||
LOG2_E = math.log2(math.e)
|
||||
lse_sum_cur = 0.0
|
||||
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
|
||||
@@ -506,7 +557,9 @@ class FlashAttentionForwardCombine:
|
||||
lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col)
|
||||
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
|
||||
# Normalize scales
|
||||
inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
|
||||
inv_sum = (
|
||||
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
|
||||
)
|
||||
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
|
||||
# Store the scales exp(lse - lse_logsum) back to smem
|
||||
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
|
||||
@@ -584,7 +637,10 @@ class FlashAttentionForwardCombine:
|
||||
# Accumulate scaled partial results
|
||||
for m in cutlass.range(num_rows, unroll_full=True):
|
||||
if tOhidx[m] >= 0 and scale[m] > 0.0:
|
||||
tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32))
|
||||
tOrO[None, m, None].store(
|
||||
tOrO[None, m, None].load()
|
||||
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
|
||||
)
|
||||
|
||||
# ===============================
|
||||
# Step 7: Write final O to gmem
|
||||
@@ -605,7 +661,9 @@ class FlashAttentionForwardCombine:
|
||||
# Write final results
|
||||
for m in cutlass.range(num_rows, unroll_full=True):
|
||||
if tOhidx[m] >= 0:
|
||||
mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,))
|
||||
mO_cur_copy = cute.tiled_divide(
|
||||
mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
|
||||
)
|
||||
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
||||
k_idx = tOcO[0, 0, k][1] // elems_per_store
|
||||
if const_expr(self.is_even_k) or tOpO[k]:
|
||||
@@ -631,7 +689,9 @@ class FlashAttentionForwardCombine:
|
||||
o_gmem_ptr = cute.make_ptr(
|
||||
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
|
||||
)
|
||||
mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)))
|
||||
mO_partial_cur = cute.make_tensor(
|
||||
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
|
||||
)
|
||||
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
|
||||
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
|
||||
k_idx = tOcO[0, 0, k][1] // elems_per_load
|
||||
@@ -640,5 +700,5 @@ class FlashAttentionForwardCombine:
|
||||
gmem_tiled_copy_O_partial,
|
||||
# mO_partial_cur_copy[None, k_idx, split],
|
||||
utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx],
|
||||
tOsO_partial_cur[None, m, k]
|
||||
tOsO_partial_cur[None, m, k],
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass import Int32, Float32, Boolean, const_expr
|
||||
from cutlass.cute.nvgpu import warpgroup
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass.cutlass_dsl import Numeric, dsl_user_op
|
||||
from cutlass.utils import LayoutEnum
|
||||
import cutlass.utils.hopper_helpers as sm90_utils_og
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# Copyright (c) 2025, Tri Dao.
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
@@ -99,7 +99,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random",
|
||||
if i % 5 == 0:
|
||||
lengths[i] = 0
|
||||
lengths[-1] = 0
|
||||
padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
||||
padding_mask = (
|
||||
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
||||
)
|
||||
return padding_mask
|
||||
|
||||
|
||||
@@ -129,7 +131,9 @@ def generate_qkv(
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
||||
q, query_padding_mask, query_unused_mask
|
||||
)
|
||||
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
|
||||
output_pad_fn = lambda output_unpad: pad_input(
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
@@ -138,7 +142,9 @@ def generate_qkv(
|
||||
)
|
||||
seqused_q = None
|
||||
max_seqlen_q = seqlen_q
|
||||
output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size)
|
||||
output_pad_fn = lambda output_unpad: rearrange(
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
|
||||
|
||||
if key_padding_mask is not None:
|
||||
@@ -256,7 +262,9 @@ def construct_local_mask(
|
||||
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
||||
return torch.logical_or(
|
||||
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
|
||||
torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),
|
||||
torch.logical_and(
|
||||
col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -368,7 +376,9 @@ def attention_ref(
|
||||
key_leftpad=key_leftpad,
|
||||
device=q.device,
|
||||
)
|
||||
local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
|
||||
local_mask = (
|
||||
torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
|
||||
)
|
||||
if local_mask is not None:
|
||||
scores.masked_fill_(local_mask, float("-inf"))
|
||||
if attn_bias is not None:
|
||||
|
||||
+68
-23
@@ -10,7 +10,7 @@ from functools import partial
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from cutlass import Float32, Int32, const_expr
|
||||
from cutlass import Float32, const_expr
|
||||
from cutlass.cutlass_dsl import T, dsl_user_op
|
||||
from cutlass._mlir.dialects import nvvm, llvm
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
@@ -24,9 +24,10 @@ sub_packed_f32x2 = partial(
|
||||
cute.arch.calc_packed_f32x2_op,
|
||||
src_c=None,
|
||||
calc_func=nvvm.sub_packed_f32x2,
|
||||
rnd=nvvm.RoundingModeKind.RN
|
||||
rnd=nvvm.RoundingModeKind.RN,
|
||||
)
|
||||
|
||||
|
||||
def hash_callable(func: Callable) -> str:
|
||||
"""Hash a callable based on the source code or bytecode and closure values."""
|
||||
if hasattr(func, "__wrapped__"):
|
||||
@@ -62,6 +63,7 @@ def create_softcap_scoremod(softcap_val):
|
||||
|
||||
return scoremod_premask_fn
|
||||
|
||||
|
||||
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
|
||||
return (
|
||||
from_dlpack(x, assumed_align=alignment)
|
||||
@@ -71,7 +73,10 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
|
||||
)
|
||||
)
|
||||
|
||||
def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_modes=None, stride_order=None) -> cute.Tensor:
|
||||
|
||||
def convert_from_dlpack_leading_static(
|
||||
x, leading_dim, alignment=16, static_modes=None, stride_order=None
|
||||
) -> cute.Tensor:
|
||||
if stride_order is None:
|
||||
stride_order = x.dim_order()
|
||||
x_ = from_dlpack(x, assumed_align=alignment)
|
||||
@@ -80,6 +85,7 @@ def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_mode
|
||||
x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
|
||||
return x_
|
||||
|
||||
|
||||
def make_tiled_copy_A(
|
||||
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
|
||||
) -> cute.TiledCopy:
|
||||
@@ -258,7 +264,7 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
|
||||
# the string here.
|
||||
swizzle_str = str(ptr.type.swizzle_type)
|
||||
# Extract the inner part "S<b,m,s>"
|
||||
match = re.search(r'S<(\d+),(\d+),(\d+)>', swizzle_str)
|
||||
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
|
||||
if match:
|
||||
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
return cute.make_swizzle(b, m, s)
|
||||
@@ -298,6 +304,7 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def logf(a: float | Float32, *, loc=None, ip=None) -> Float32:
|
||||
return log2f(a, loc=loc, ip=ip) * math.log(2.0)
|
||||
@@ -350,7 +357,11 @@ def fmax_reduce(
|
||||
# We instead force the 3-input max.
|
||||
res = cute.make_fragment(x.shape, Float32)
|
||||
res.store(x)
|
||||
local_max_0 = fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1])
|
||||
local_max_0 = (
|
||||
fmax(init_val, res[0], res[1])
|
||||
if const_expr(init_val is not None)
|
||||
else fmax(res[0], res[1])
|
||||
)
|
||||
local_max = [
|
||||
local_max_0,
|
||||
fmax(res[2], res[3]),
|
||||
@@ -438,7 +449,9 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
|
||||
def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
|
||||
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
||||
flat_stride = cute.flatten_to_tuple(x.stride)
|
||||
assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length"
|
||||
assert len(flat_coord_i64) == len(flat_stride), (
|
||||
"Coordinate and stride must have the same length"
|
||||
)
|
||||
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
||||
# HACK: we assume that applying the offset does not change the pointer alignment
|
||||
byte_offset = offset * x.element_type.width // 8
|
||||
@@ -517,7 +530,10 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) ->
|
||||
return cutlass.Uint32(
|
||||
llvm.inline_asm(
|
||||
T.i32(),
|
||||
[cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip)],
|
||||
[
|
||||
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
|
||||
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
|
||||
],
|
||||
"shr.s32 $0, $1, $2;",
|
||||
"=r,r,r",
|
||||
has_side_effects=False,
|
||||
@@ -543,7 +559,9 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32:
|
||||
def cvt_f16x2_f32(
|
||||
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
|
||||
) -> cutlass.Int32:
|
||||
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
|
||||
return cutlass.Int32(
|
||||
llvm.inline_asm(
|
||||
@@ -561,9 +579,11 @@ def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc
|
||||
@overload
|
||||
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
|
||||
|
||||
|
||||
@cute.jit
|
||||
def cvt_f16(src: cute.Tensor, dst_or_dtype):
|
||||
"""Convert Float32 tensor to Float16/BFloat16.
|
||||
@@ -586,7 +606,9 @@ def cvt_f16(src: cute.Tensor, dst_or_dtype):
|
||||
dst = dst_or_dtype
|
||||
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
|
||||
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
|
||||
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16"
|
||||
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
|
||||
"dst must be BFloat16 or Float16"
|
||||
)
|
||||
assert src.element_type is Float32, "src must be Float32"
|
||||
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
|
||||
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
|
||||
@@ -606,7 +628,9 @@ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=N
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def evaluate_polynomial_2(x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
||||
def evaluate_polynomial_2(
|
||||
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
|
||||
) -> Tuple[Float32, Float32]:
|
||||
deg = len(poly) - 1
|
||||
out = (poly[deg], poly[deg])
|
||||
for i in cutlass.range_constexpr(deg - 1, -1, -1):
|
||||
@@ -621,7 +645,7 @@ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None)
|
||||
llvm.inline_asm(
|
||||
T.f32(),
|
||||
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
|
||||
f"add.rm.ftz.f32 $0, $1, $2;",
|
||||
"add.rm.ftz.f32 $0, $1, $2;",
|
||||
"=f,f,f",
|
||||
has_side_effects=False,
|
||||
is_align_stack=False,
|
||||
@@ -635,7 +659,10 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=
|
||||
return cutlass.Float32(
|
||||
llvm.inline_asm(
|
||||
T.f32(),
|
||||
[Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip)],
|
||||
[
|
||||
Float32(x_rounded).ir_value(loc=loc, ip=ip),
|
||||
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
|
||||
],
|
||||
"{\n\t"
|
||||
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
|
||||
"mov.b32 x_rounded_i, $1;\n\t"
|
||||
@@ -657,7 +684,12 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=
|
||||
@dsl_user_op
|
||||
def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32:
|
||||
# We assume x <= 127.0
|
||||
poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625)
|
||||
poly_ex2_deg3 = (
|
||||
1.0,
|
||||
0.695146143436431884765625,
|
||||
0.227564394474029541015625,
|
||||
0.077119089663028717041015625,
|
||||
)
|
||||
fp32_round_int = float(2**23 + 2**22)
|
||||
x_clamped = cute.arch.fmax(x, -127.0)
|
||||
# We want to round down here, so that the fractional part is in [0, 1)
|
||||
@@ -674,11 +706,18 @@ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32:
|
||||
@dsl_user_op
|
||||
def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
|
||||
# We assume x <= 127.0 and y <= 127.0
|
||||
poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625)
|
||||
poly_ex2_deg3 = (
|
||||
1.0,
|
||||
0.695146143436431884765625,
|
||||
0.227564394474029541015625,
|
||||
0.077119089663028717041015625,
|
||||
)
|
||||
fp32_round_int = float(2**23 + 2**22)
|
||||
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
|
||||
# We want to round down here, so that the fractional part is in [0, 1)
|
||||
xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM)
|
||||
xy_rounded = cute.arch.add_packed_f32x2(
|
||||
xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM
|
||||
)
|
||||
# The integer floor of x & y are now in the last 8 bits of xy_rounded
|
||||
# We want the next 2 ops to round to nearest even. The rounding mode is important.
|
||||
xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int))
|
||||
@@ -734,8 +773,12 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo
|
||||
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
|
||||
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
|
||||
return out0, out1
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
||||
def domain_offset_aligned(
|
||||
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
|
||||
) -> cute.Tensor:
|
||||
assert isinstance(tensor.iterator, cute.Pointer)
|
||||
# We assume that applying the offset does not change the pointer alignment
|
||||
new_ptr = cute.make_ptr(
|
||||
@@ -751,9 +794,9 @@ def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, i
|
||||
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
|
||||
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
|
||||
flat_stride = cute.flatten_to_tuple(tensor.stride)
|
||||
assert len(flat_coord_i64) == len(
|
||||
flat_stride
|
||||
), "Coordinate and stride must have the same length"
|
||||
assert len(flat_coord_i64) == len(flat_stride), (
|
||||
"Coordinate and stride must have the same length"
|
||||
)
|
||||
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
|
||||
assert isinstance(tensor.iterator, cute.Pointer)
|
||||
# HACK: we assume that applying the offset does not change the pointer alignment
|
||||
@@ -779,18 +822,20 @@ def coord_offset_i64(
|
||||
tensor.memspace,
|
||||
assumed_align=tensor.iterator.max_alignment,
|
||||
)
|
||||
new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1)))
|
||||
new_layout = cute.slice_(
|
||||
tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))
|
||||
)
|
||||
return cute.make_tensor(new_ptr, new_layout)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
|
||||
""" Convert a scalar to a cute TensorSSA of shape (1,) and given dtype """
|
||||
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
|
||||
vec = cute.make_fragment(1, dtype)
|
||||
vec[0] = a
|
||||
return vec.load()
|
||||
|
||||
|
||||
def ssa_to_scalar(val):
|
||||
""" Could inline but nice for reflecting the above api """
|
||||
return val[0]
|
||||
"""Could inline but nice for reflecting the above api"""
|
||||
return val[0]
|
||||
|
||||
Reference in New Issue
Block a user