ruff all the smaller files (#2040)

This commit is contained in:
Driss Guessous
2025-12-02 13:43:24 -08:00
committed by GitHub
parent 672381f72c
commit 91ba87d759
7 changed files with 193 additions and 90 deletions
-9
View File
@@ -7,19 +7,10 @@ repos:
files: ^flash_attn/cute/.*\.py$
exclude: &cute_exclude |
(?x)^flash_attn/cute/(
__init__|
copy_utils|
cute_dsl_utils|
fast_math|
flash_bwd|
flash_fwd|
flash_fwd_combine|
flash_fwd_sm100|
hopper_helpers|
interface|
pack_gqa|
testing|
utils
)\.py$
- id: ruff-format
files: ^flash_attn/cute/.*\.py$
+3 -3
View File
@@ -1,11 +1,11 @@
# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao.
import math
from typing import Optional, Type, Tuple, Callable
from typing import Optional, Type, Callable
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, Boolean, const_expr
from cutlass import Float32, Int32, const_expr
from cutlass.cute.nvgpu import cpasync
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cutlass_dsl import T, dsl_user_op
@@ -279,7 +279,7 @@ def cpasync_bulk_get_copy_fn(
dst[None, dst_idx].iterator,
size=size,
**new_kwargs,
**kwargs
**kwargs,
)
def copy_bulk_single_stage(**new_kwargs):
+107 -47
View File
@@ -55,8 +55,13 @@ class FlashAttentionForwardCombine:
@staticmethod
def can_implement(
dtype, dtype_partial, head_dim, m_block_size, k_block_size,
log_max_splits, num_threads,
dtype,
dtype_partial,
head_dim,
m_block_size,
k_block_size,
log_max_splits,
num_threads,
) -> bool:
"""Check if the kernel can be implemented with the given parameters."""
if dtype not in [cutlass.Float16, cutlass.BFloat16, cutlass.Float32]:
@@ -83,8 +88,7 @@ class FlashAttentionForwardCombine:
assert self.k_block_size % async_copy_elems == 0
k_block_gmem = (
128 if self.k_block_size % 128 == 0 else
(64 if self.k_block_size % 64 == 0 else 32)
128 if self.k_block_size % 128 == 0 else (64 if self.k_block_size % 64 == 0 else 32)
)
gmem_threads_per_row = k_block_gmem // async_copy_elems
assert self.num_threads % gmem_threads_per_row == 0
@@ -111,16 +115,25 @@ class FlashAttentionForwardCombine:
num_bits_per_copy=async_copy_elems * self.dtype.width,
)
self.gmem_tiled_copy_O = cute.make_tiled_copy_tv(
atom_universal_copy, tOpartial_layout, vOpartial_layout # 4 vals per store
atom_universal_copy,
tOpartial_layout,
vOpartial_layout, # 4 vals per store
)
# LSE copy setup with async copy (alignment = 1)
lse_copy_bits = Float32.width # 1 element per copy, width is in bits
m_block_smem = (
128 if self.m_block_size % 128 == 0 else
(64 if self.m_block_size % 64 == 0 else
(32 if self.m_block_size % 32 == 0 else
(16 if self.m_block_size % 16 == 0 else 8)))
128
if self.m_block_size % 128 == 0
else (
64
if self.m_block_size % 64 == 0
else (
32
if self.m_block_size % 32 == 0
else (16 if self.m_block_size % 16 == 0 else 8)
)
)
)
gmem_threads_per_row_lse = m_block_smem
assert self.num_threads % gmem_threads_per_row_lse == 0
@@ -167,9 +180,7 @@ class FlashAttentionForwardCombine:
else:
smem_lse_swizzle = cute.make_swizzle(3, 2, 3)
smem_layout_atom_lse = cute.make_composed_layout(
smem_lse_swizzle,
0,
cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
smem_lse_swizzle, 0, cute.make_ordered_layout((8, m_block_smem), order=(1, 0))
)
self.smem_layout_lse = cute.tile_to_shape(
smem_layout_atom_lse, (self.max_splits, self.m_block_size), (0, 1)
@@ -177,11 +188,9 @@ class FlashAttentionForwardCombine:
# O partial shared memory layout (simple layout for pipeline stages)
self.smem_layout_o = cute.make_ordered_layout(
(self.m_block_size, self.k_block_size, self.stages),
order=(1, 0, 2)
(self.m_block_size, self.k_block_size, self.stages), order=(1, 0, 2)
)
@cute.jit
def __call__(
self,
@@ -200,38 +209,63 @@ class FlashAttentionForwardCombine:
raise TypeError("O partial tensor must match dtype_partial")
if const_expr(not (mO.element_type == self.dtype)):
raise TypeError("O tensor must match dtype")
if const_expr(not mLSE_partial.element_type in [Float32]):
if const_expr(mLSE_partial.element_type not in [Float32]):
raise TypeError("LSE partial tensor must be Float32")
if const_expr(mLSE is not None and not mLSE.element_type in [Float32]):
if const_expr(mLSE is not None and mLSE.element_type not in [Float32]):
raise TypeError("LSE tensor must be Float32")
# Shape validation - input tensors are in user format, need to be converted to kernel format
if const_expr(len(mO_partial.shape) not in [4, 5]):
raise ValueError("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)")
raise ValueError(
"O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
)
if const_expr(len(mLSE_partial.shape) not in [3, 4]):
raise ValueError("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)")
raise ValueError(
"LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
)
if const_expr(len(mO.shape) not in [3, 4]):
raise ValueError("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)")
raise ValueError(
"O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
)
if const_expr(mLSE is not None and len(mLSE.shape) not in [2, 3]):
raise ValueError("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)")
raise ValueError(
"LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
)
# Assume all strides are divisible by 128 bits except the last stride
new_stride = lambda t: (*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]), t.stride[-1])
mO_partial, mO = [cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) for t in (mO_partial, mO)]
new_stride = lambda t: (
*(cute.assume(s, divby=128 // t.element_type.width) for s in t.stride[:-1]),
t.stride[-1],
)
mO_partial, mO = [
cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t)))
for t in (mO_partial, mO)
]
# (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
# or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
O_partial_layout_transpose = [2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
O_partial_layout_transpose = (
[2, 4, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 3, 0, 2]
)
# (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
mO_partial = cute.make_tensor(mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose))
mO_partial = cute.make_tensor(
mO_partial.iterator, cute.select(mO_partial.layout, mode=O_partial_layout_transpose)
)
O_layout_transpose = [1, 3, 2, 0] if const_expr(cu_seqlens is None) else [0, 2, 1]
mO = cute.make_tensor(mO.iterator, cute.select(mO.layout, mode=O_layout_transpose))
# (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
# or (num_splits, total_q, h) -> (total_q, num_splits, h)
LSE_partial_layout_transpose = [2, 0, 3, 1] if const_expr(cu_seqlens is None) else [1, 0, 2]
mLSE_partial = cute.make_tensor(mLSE_partial.iterator, cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose))
mLSE_partial = cute.make_tensor(
mLSE_partial.iterator,
cute.select(mLSE_partial.layout, mode=LSE_partial_layout_transpose),
)
# (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
LSE_layout_transpose = [1, 2, 0] if const_expr(cu_seqlens is None) else [0, 1]
mLSE = cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose)) if mLSE is not None else None
mLSE = (
cute.make_tensor(mLSE.iterator, cute.select(mLSE.layout, mode=LSE_layout_transpose))
if mLSE is not None
else None
)
# Determine if we have variable length sequences
varlen = const_expr(cu_seqlens is not None or seqused is not None)
@@ -243,9 +277,7 @@ class FlashAttentionForwardCombine:
sLSE: cute.struct.Align[
cute.struct.MemRange[Float32, cute.cosize(self.smem_layout_lse)], 128
]
sMaxValidSplit: cute.struct.Align[
cute.struct.MemRange[Int32, self.m_block_size], 128
]
sMaxValidSplit: cute.struct.Align[cute.struct.MemRange[Int32, self.m_block_size], 128]
sO: cute.struct.Align[
cute.struct.MemRange[self.dtype_partial, cute.cosize(self.smem_layout_o)], 128
]
@@ -255,7 +287,11 @@ class FlashAttentionForwardCombine:
# Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
seqlen = mO_partial.shape[0]
num_head = mO_partial.shape[3]
batch_size = mO_partial.shape[4] if const_expr(cu_seqlens is None) else Int32(cu_seqlens.shape[0] - 1)
batch_size = (
mO_partial.shape[4]
if const_expr(cu_seqlens is None)
else Int32(cu_seqlens.shape[0] - 1)
)
# Create FastDivmodDivisor objects for efficient division
seqlen_divmod = FastDivmodDivisor(seqlen)
@@ -330,14 +366,18 @@ class FlashAttentionForwardCombine:
# Handle semaphore reset
if const_expr(semaphore_to_reset is not None):
if (tidx == 0 and m_block == cute.arch.grid_dim()[0] - 1 and
k_block == cute.arch.grid_dim()[1] - 1 and
batch_idx == cute.arch.grid_dim()[2] - 1):
if (
tidx == 0
and m_block == cute.arch.grid_dim()[0] - 1
and k_block == cute.arch.grid_dim()[1] - 1
and batch_idx == cute.arch.grid_dim()[2] - 1
):
semaphore_to_reset[0] = 0
# Get number of splits
num_splits = (
num_splits_dynamic_ptr[batch_idx] if const_expr(num_splits_dynamic_ptr is not None)
num_splits_dynamic_ptr[batch_idx]
if const_expr(num_splits_dynamic_ptr is not None)
else mLSE_partial.shape[1]
)
# Handle variable length sequences using SeqlenInfo
@@ -345,7 +385,7 @@ class FlashAttentionForwardCombine:
batch_idx=batch_idx,
seqlen_static=mO_partial.shape[0],
cu_seqlens=cu_seqlens,
seqused=seqused
seqused=seqused,
)
seqlen, offset = seqlen_info.seqlen, seqlen_info.offset
@@ -354,8 +394,9 @@ class FlashAttentionForwardCombine:
max_idx = seqlen * num_head
# Early exit for single split if dynamic
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (const_expr(not varlen) or m_block * self.m_block_size < max_idx):
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
const_expr(not varlen) or m_block * self.m_block_size < max_idx
):
# ===============================
# Step 1: Load LSE_partial from gmem to shared memory
# ===============================
@@ -390,7 +431,11 @@ class FlashAttentionForwardCombine:
for s in cutlass.range(cute.size(tLSEcLSE, mode=[1]), unroll_full=True):
si = tLSEcLSE[0, s, 0][0] # Get split coordinate
if si < num_splits:
cute.copy(gmem_thr_copy_LSE, mLSE_partial_cur_copy[None, si], tLSEsLSE[None, s, m])
cute.copy(
gmem_thr_copy_LSE,
mLSE_partial_cur_copy[None, si],
tLSEsLSE[None, s, m],
)
else:
tLSEsLSE[None, s, m].fill(-Float32.inf)
# Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
@@ -424,7 +469,9 @@ class FlashAttentionForwardCombine:
else:
tOhidx[m] = idx // seqlen
tOmidx[m] = idx - tOhidx[m] * seqlen
tOrOptr[m] = utils.elem_pointer_i64(mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])).toint()
tOrOptr[m] = utils.elem_pointer_i64(
mO_partial_cur, (tOmidx[m], k_block * self.k_block_size, 0, tOhidx[m])
).toint()
if idx >= max_idx:
tOhidx[m] = -1
@@ -483,7 +530,9 @@ class FlashAttentionForwardCombine:
# Find max LSE value across splits
threads_per_col = const_expr(self.smem_threads_per_col_lse)
lse_max = utils.warp_reduce(
ts2rrLSE[None, None, m].load().reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
ts2rrLSE[None, None, m]
.load()
.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
op=cute.arch.fmax,
width=threads_per_col,
)
@@ -496,7 +545,9 @@ class FlashAttentionForwardCombine:
# if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
max_valid_split[m] = utils.warp_reduce(max_valid_idx, max, width=threads_per_col)
# Compute exp scales and sum
lse_max_cur = 0.0 if lse_max == -Float32.inf else lse_max # In case all local LSEs are -inf
lse_max_cur = (
0.0 if lse_max == -Float32.inf else lse_max
) # In case all local LSEs are -inf
LOG2_E = math.log2(math.e)
lse_sum_cur = 0.0
for s in cutlass.range(cute.size(ts2rrLSE, mode=[1]), unroll_full=True):
@@ -506,7 +557,9 @@ class FlashAttentionForwardCombine:
lse_sum_cur = utils.warp_reduce(lse_sum_cur, operator.add, width=threads_per_col)
lse_sum[m] = utils.logf(lse_sum_cur) + lse_max
# Normalize scales
inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
inv_sum = (
0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur) else 1.0 / lse_sum_cur
)
ts2rrLSE[None, None, m].store(ts2rrLSE[None, None, m].load() * inv_sum)
# Store the scales exp(lse - lse_logsum) back to smem
cute.copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE)
@@ -584,7 +637,10 @@ class FlashAttentionForwardCombine:
# Accumulate scaled partial results
for m in cutlass.range(num_rows, unroll_full=True):
if tOhidx[m] >= 0 and scale[m] > 0.0:
tOrO[None, m, None].store(tOrO[None, m, None].load() + scale[m] * tOrO_partial[None, m, None].load().to(Float32))
tOrO[None, m, None].store(
tOrO[None, m, None].load()
+ scale[m] * tOrO_partial[None, m, None].load().to(Float32)
)
# ===============================
# Step 7: Write final O to gmem
@@ -605,7 +661,9 @@ class FlashAttentionForwardCombine:
# Write final results
for m in cutlass.range(num_rows, unroll_full=True):
if tOhidx[m] >= 0:
mO_cur_copy = cute.tiled_divide(mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,))
mO_cur_copy = cute.tiled_divide(
mO_cur[tOmidx[m], None, tOhidx[m]], (elems_per_store,)
)
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
k_idx = tOcO[0, 0, k][1] // elems_per_store
if const_expr(self.is_even_k) or tOpO[k]:
@@ -631,7 +689,9 @@ class FlashAttentionForwardCombine:
o_gmem_ptr = cute.make_ptr(
tOsO_partial.element_type, tOrOptr[m], cute.AddressSpace.gmem, assumed_align=16
)
mO_partial_cur = cute.make_tensor(o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0)))
mO_partial_cur = cute.make_tensor(
o_gmem_ptr, cute.slice_(mO_cur_partial_layout, (0, None, None, 0))
)
mO_partial_cur_copy = cute.tiled_divide(mO_partial_cur, (elems_per_load,))
for k in cutlass.range(cute.size(tOcO, mode=[2]), unroll_full=True):
k_idx = tOcO[0, 0, k][1] // elems_per_load
@@ -640,5 +700,5 @@ class FlashAttentionForwardCombine:
gmem_tiled_copy_O_partial,
# mO_partial_cur_copy[None, k_idx, split],
utils.coord_offset_i64(mO_partial_cur_copy, split, dim=2)[None, k_idx],
tOsO_partial_cur[None, m, k]
tOsO_partial_cur[None, m, k],
)
-1
View File
@@ -4,7 +4,6 @@ import cutlass
import cutlass.cute as cute
from cutlass import Int32, Float32, Boolean, const_expr
from cutlass.cute.nvgpu import warpgroup
from cutlass._mlir.dialects import llvm
from cutlass.cutlass_dsl import Numeric, dsl_user_op
from cutlass.utils import LayoutEnum
import cutlass.utils.hopper_helpers as sm90_utils_og
-2
View File
@@ -1,7 +1,5 @@
# Copyright (c) 2025, Tri Dao.
import math
import operator
import cutlass
import cutlass.cute as cute
+15 -5
View File
@@ -99,7 +99,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random",
if i % 5 == 0:
lengths[i] = 0
lengths[-1] = 0
padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
return padding_mask
@@ -129,7 +131,9 @@ def generate_qkv(
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
q, query_padding_mask, query_unused_mask
)
output_pad_fn = lambda output_unpad: pad_input(output_unpad, indices_q, batch_size, seqlen_q)
output_pad_fn = lambda output_unpad: pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
@@ -138,7 +142,9 @@ def generate_qkv(
)
seqused_q = None
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size)
output_pad_fn = lambda output_unpad: rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None
if key_padding_mask is not None:
@@ -256,7 +262,9 @@ def construct_local_mask(
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length),
torch.logical_and(
col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length
),
)
@@ -368,7 +376,9 @@ def attention_ref(
key_leftpad=key_leftpad,
device=q.device,
)
local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
local_mask = (
torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask
)
if local_mask is not None:
scores.masked_fill_(local_mask, float("-inf"))
if attn_bias is not None:
+68 -23
View File
@@ -10,7 +10,7 @@ from functools import partial
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, const_expr
from cutlass import Float32, const_expr
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import nvvm, llvm
from cutlass.cute.runtime import from_dlpack
@@ -24,9 +24,10 @@ sub_packed_f32x2 = partial(
cute.arch.calc_packed_f32x2_op,
src_c=None,
calc_func=nvvm.sub_packed_f32x2,
rnd=nvvm.RoundingModeKind.RN
rnd=nvvm.RoundingModeKind.RN,
)
def hash_callable(func: Callable) -> str:
"""Hash a callable based on the source code or bytecode and closure values."""
if hasattr(func, "__wrapped__"):
@@ -62,6 +63,7 @@ def create_softcap_scoremod(softcap_val):
return scoremod_premask_fn
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
return (
from_dlpack(x, assumed_align=alignment)
@@ -71,7 +73,10 @@ def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Te
)
)
def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_modes=None, stride_order=None) -> cute.Tensor:
def convert_from_dlpack_leading_static(
x, leading_dim, alignment=16, static_modes=None, stride_order=None
) -> cute.Tensor:
if stride_order is None:
stride_order = x.dim_order()
x_ = from_dlpack(x, assumed_align=alignment)
@@ -80,6 +85,7 @@ def convert_from_dlpack_leading_static(x, leading_dim, alignment=16, static_mode
x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
return x_
def make_tiled_copy_A(
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.TiledCopy:
@@ -258,7 +264,7 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
# the string here.
swizzle_str = str(ptr.type.swizzle_type)
# Extract the inner part "S<b,m,s>"
match = re.search(r'S<(\d+),(\d+),(\d+)>', swizzle_str)
match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str)
if match:
b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3))
return cute.make_swizzle(b, m, s)
@@ -298,6 +304,7 @@ def log2f(a: float | Float32, *, loc=None, ip=None) -> Float32:
)
)
@dsl_user_op
def logf(a: float | Float32, *, loc=None, ip=None) -> Float32:
return log2f(a, loc=loc, ip=ip) * math.log(2.0)
@@ -350,7 +357,11 @@ def fmax_reduce(
# We instead force the 3-input max.
res = cute.make_fragment(x.shape, Float32)
res.store(x)
local_max_0 = fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1])
local_max_0 = (
fmax(init_val, res[0], res[1])
if const_expr(init_val is not None)
else fmax(res[0], res[1])
)
local_max = [
local_max_0,
fmax(res[2], res[3]),
@@ -438,7 +449,9 @@ def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cut
def elem_pointer_i64(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
flat_stride = cute.flatten_to_tuple(x.stride)
assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length"
assert len(flat_coord_i64) == len(flat_stride), (
"Coordinate and stride must have the same length"
)
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
# HACK: we assume that applying the offset does not change the pointer alignment
byte_offset = offset * x.element_type.width // 8
@@ -517,7 +530,10 @@ def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) ->
return cutlass.Uint32(
llvm.inline_asm(
T.i32(),
[cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip)],
[
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
],
"shr.s32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
@@ -543,7 +559,9 @@ def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) ->
@dsl_user_op
def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None) -> cutlass.Int32:
def cvt_f16x2_f32(
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
) -> cutlass.Int32:
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
return cutlass.Int32(
llvm.inline_asm(
@@ -561,9 +579,11 @@ def cvt_f16x2_f32(a: float | Float32, b: float | Float32, to_dtype: Type, *, loc
@overload
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
@overload
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
@cute.jit
def cvt_f16(src: cute.Tensor, dst_or_dtype):
"""Convert Float32 tensor to Float16/BFloat16.
@@ -586,7 +606,9 @@ def cvt_f16(src: cute.Tensor, dst_or_dtype):
dst = dst_or_dtype
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], "dst must be BFloat16 or Float16"
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
"dst must be BFloat16 or Float16"
)
assert src.element_type is Float32, "src must be Float32"
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
@@ -606,7 +628,9 @@ def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=N
@dsl_user_op
@cute.jit
def evaluate_polynomial_2(x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Tuple[Float32, Float32]:
def evaluate_polynomial_2(
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
) -> Tuple[Float32, Float32]:
deg = len(poly) - 1
out = (poly[deg], poly[deg])
for i in cutlass.range_constexpr(deg - 1, -1, -1):
@@ -621,7 +645,7 @@ def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None)
llvm.inline_asm(
T.f32(),
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
f"add.rm.ftz.f32 $0, $1, $2;",
"add.rm.ftz.f32 $0, $1, $2;",
"=f,f,f",
has_side_effects=False,
is_align_stack=False,
@@ -635,7 +659,10 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=
return cutlass.Float32(
llvm.inline_asm(
T.f32(),
[Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip)],
[
Float32(x_rounded).ir_value(loc=loc, ip=ip),
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
],
"{\n\t"
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
"mov.b32 x_rounded_i, $1;\n\t"
@@ -657,7 +684,12 @@ def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=
@dsl_user_op
def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32:
# We assume x <= 127.0
poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625)
poly_ex2_deg3 = (
1.0,
0.695146143436431884765625,
0.227564394474029541015625,
0.077119089663028717041015625,
)
fp32_round_int = float(2**23 + 2**22)
x_clamped = cute.arch.fmax(x, -127.0)
# We want to round down here, so that the fractional part is in [0, 1)
@@ -674,11 +706,18 @@ def ex2_emulation(x: Float32, *, loc=None, ip=None) -> Float32:
@dsl_user_op
def ex2_emulation_2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
# We assume x <= 127.0 and y <= 127.0
poly_ex2_deg3 = (1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625)
poly_ex2_deg3 = (
1.0,
0.695146143436431884765625,
0.227564394474029541015625,
0.077119089663028717041015625,
)
fp32_round_int = float(2**23 + 2**22)
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
# We want to round down here, so that the fractional part is in [0, 1)
xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM)
xy_rounded = cute.arch.add_packed_f32x2(
xy_clamped, (fp32_round_int, fp32_round_int), rnd=nvvm.RoundingModeKind.RM
)
# The integer floor of x & y are now in the last 8 bits of xy_rounded
# We want the next 2 ops to round to nearest even. The rounding mode is important.
xy_rounded_back = sub_packed_f32x2(xy_rounded, (fp32_round_int, fp32_round_int))
@@ -734,8 +773,12 @@ def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Flo
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
return out0, out1
@dsl_user_op
def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
def domain_offset_aligned(
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
) -> cute.Tensor:
assert isinstance(tensor.iterator, cute.Pointer)
# We assume that applying the offset does not change the pointer alignment
new_ptr = cute.make_ptr(
@@ -751,9 +794,9 @@ def domain_offset_aligned(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, i
def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
flat_stride = cute.flatten_to_tuple(tensor.stride)
assert len(flat_coord_i64) == len(
flat_stride
), "Coordinate and stride must have the same length"
assert len(flat_coord_i64) == len(flat_stride), (
"Coordinate and stride must have the same length"
)
offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
assert isinstance(tensor.iterator, cute.Pointer)
# HACK: we assume that applying the offset does not change the pointer alignment
@@ -779,18 +822,20 @@ def coord_offset_i64(
tensor.memspace,
assumed_align=tensor.iterator.max_alignment,
)
new_layout = cute.slice_(tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1)))
new_layout = cute.slice_(
tensor.layout, (*[None] * dim, 0, *[None] * (cute.rank(tensor) - dim - 1))
)
return cute.make_tensor(new_ptr, new_layout)
@cute.jit
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
""" Convert a scalar to a cute TensorSSA of shape (1,) and given dtype """
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
vec = cute.make_fragment(1, dtype)
vec[0] = a
return vec.load()
def ssa_to_scalar(val):
""" Could inline but nice for reflecting the above api """
return val[0]
"""Could inline but nice for reflecting the above api"""
return val[0]