Files
flash-attention/hopper/flash_attn_interface.py
Michael Melesse 701ebe0578 [AMD] Triton Backend for ROCm #3 (#2178)
* Fused Bwd (#137)

* Fused with Good perf and stride fixed

Fix fused bugs

isolate failing case

fix bug

bring back test cases

rm split impl in fused

use exp2 is global variable now

try oom fix

save

make fused the default

limit to reproduce failure

return default to split

fix head size bug

use exp2 back to true

* new grid

* BLK_SLICE_FACTOR = 1

* add tflops

* new commit

* test in parrallel

* strides added by jusson

* disable alibi

* fix bugs again

* default to fused

* add bwd options for varlen

* backend filter

* default to jingning and batch 4

* best fwd config

* fix TRITON_PRINT_AUTOTUNING flag bug

* tune

* Tuning fwd prefill

* add if else

* use flag

* Minor mask fix

* FLIP GRID

* use best config for default

* print when autotuning

* test bfloat16

* fix k and v stride bugs

* skip bfloat16

* test kvpacked

* disable internal tests

* pick default config based on arch

* Add alibi in the new bwd kernel (#139)

* enable alibi for jinging kernel

enable alibi for jinging kernel

match

* save bad configs

* fix alibi and causal bug

* disable autotune by default

* auto tune when benching is good

* set best config

* remove env var

* Update amd_tests.yml

* upgrad to triton==3.3.0

* increase shm

* use 64 x 64 for now

* save

* handle 1d alibi

* Add fp8 to fused kernel (#140)

* fp8 stuff

find test case

compute delta fp8

basic fp8 config passing

non causal path works

* isolate bad case

* fix fp8 bug

* didnot fix fp8 bug

* back to failing test

* fp8 tests passing

* skip

* skip ref tests

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>

* head, seq, batch (#141)

* Fix keys (#144)

* save

* rm keys

* fix keys

* use GHA_RENDER_DEVICES

* normal docker

* Pad LSE (#148)

* add round multiple

* fix fwd

* backward fix

* use rounded lse flag

* passing ROUNDED_LSE

* default is new rounded mode

* rename to fused_atmoics and fused_no_atomics

* add test for torch_compile

* add varlen torch compile test

* add old one kernel for ref

* fix varlen mismatch bug

* fix shape issue in varlen but mismatch

* sync torch compile kernel launch

* simple varlen test

* add debug code

* rm old

* ignore old impls

* DEBUG flag works in interface only

* ref uses the righ shape for lse

* rm oldest bwd kernel

* fix typo

* fix varlen bug

* fix bug. Get info from q for now

* simple shape and stride checkout

* add more tests

* test kvcache

* kvcache safe

* match case

* fix segfault due to bad return_softmax

* run bench

* run seperate for the main functions

* just output benchmark

* default csv format and time stamp files

* non verbsoe bench

* Sliding Window Forward (#151)

* Compress SWA work

test case

set up debug inputs

add fwd ref

one mask ref

fwd first pass

save

ref doesnot work for bigger seqlens

save new version

some causal cases failing

found bad cases

working new attn

new atten works

new attn_fwd works

reorg n_extra_tokens

use seqlen_delta_qk

ref fwd works

add sliding window to bwd ref

test kvcache

decode ref work with everything except sliding window

add debug code for 12 failing sliding window cases for decode

attention_decode_forward_ref_impl mostly works except for alibi

fix alibi in attention_decode_forward_ref_impl

ref works with normal, varlen & kvcache

move stuff around

figure out masking

old attn inner

two inner functions

remove load_fn

do Lk - Lq like ref

unify IS_CAUSAL code in epilogue

clean up

add args

rm inference stuff

simplify compute_masking

simpler compute mask

stub out returning front masking variables

remove pointer pass

compute ptrs inloop

compute block min and max

window stub inside inner mask loop

trying to use attn_fwd_mask causes issues

fix compiler bug when front masking

gen specifc types

add sliding window and debug statements

use identity for v

add more taste cases

add comments

save

use k_max_token for clarity

disable debug configs

basic NON-CAUSAL SLIDING WINDOW

non causal sliding window works on the all the shapes

non sliding window working in fwd

clean up fused bwd

seperate old fwd_prefill

move configs to utils.py

* fix bwd ref bug

* skip local cases so that fa output

* no sliding window causal green

* add backward test skip for sliding window

* clean reduce in fwd_kvcache. no is_CASUAL branching

* add kvcache masking

* kvcache working

* fix some bugs in test.py

* clean up

* Fix Device Segfault (#152)

* Compress segfault work

fix backward segfault

rework offset

ignore .profile

ignore .analysis

save

* assert the kernel launch device and tensor devices are the same

* fix failing asserts

* add asserts to fwd

* Fix SDMASK bug

* Log triton, torch and fa version

* Fix fp8 import issues

* fix docs (#154)

* Sliding Window block classification logic (#155)

* add aiter code

* remove aiter stuff

* sliding window non causal masking works

* causal and sliding window block masking

* extract common

* clean up typo

* helper for swa

* ignore .amd

* fix last block bug

* Enable FA V3 (#157)

* Compress PA work

narrow pa test

ref works on most cases

inplace ref with new_kv

inplace paged attention

add pa ref

save pa

basic  paged works

save

fix swa + causal in pa. Also new_kv only on pa path

passing

build fa v3

import interface from fa v3

copy fa tests

use v3 api

clean up

rename to match old test

support different head sizes

remove fp8

basisc passing v3 cases

test_flash_attn_varlen_output v3 working

isolate bad case for kvcache

case passing

save

use decode is seqused/ cacheseql is given

use decode if not varlen

basci kvcache v3 working

kvcache enable more cases

detect kvcache case if seqused_q is non and sequese_k is not None

skip failing test

find fp8 failing case

mha fp8 works

fix fp8 MQA/GQA bug

clean up

more clean up

clean up more

don't need fp8 dead code

remove train code with fp8 stuff

fp8 working in kvcache

paged + fp8 seems to be working

new_kv allowed

* clean up

* skip hopper race test

* clean up more

* fix paged + alibi

* similar inner paged api

* unify _attn_fwd_inner

* AITER integration (#159)

* clean up v2 interface

* assert fp8 scale shapes

* rotary working

* move rotary to impl layers

* remove einops

* enable rotarry in v3

* create interface

* fix descale assert

* unify bwd

* lint from aiter

* clean fp8 api

* add api change

* assert shapes for v2

* remove ref and bench.py

* remove metadata class and clean up

* bwd_prefill

* one bwd.py

* rename

* lint

* add bwd_change (#156)

* Tune FP8 Perf (#160)

* check cu count for gfx942

* create get_cu_count

* update repo root

* update forward tune

* clean up load

* use float8_e4m3fnuz

* save

* show bwd mode

* recommend fp8

* use torch.float32 for fp8 kernel

* add both best fp16 and fp8 config

* tune fp8 backward

* descale factors should be b, hk

* fp8 bwd working on all primus configs

* tune bwd configs

* fa v3 tests passing

* better warning

* clean up bwd launcher

* v3 passing

* tune more

* improve perf

* clean up

* lint

* clean

* start tuning gfx950

* tune non causal path

* fix bug

* save

* Skip configs where BLOCK_M2 % BLOCK_N2 != 0

* skip more

* stop tuning

* fix varlen bug

* fix dropout & causal/swa segfault

* update the to machine new changes

* save

* fix more bugs

* remove random seed

* clean up

* update readme

* print tensor stats for debug

* disable sliding window tests

* add rdna configs

* fix k partial bug

* fix block_size_n bug

* fix type check bug

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
2026-01-28 07:49:08 -08:00

1144 lines
41 KiB
Python
Executable File

# Copyright (c) 2023, Tri Dao.
from typing import Optional, Union, List, Tuple
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
repo_root = Path(__file__).resolve().parent.parent
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
from flash_attn.flash_attn_triton_amd import flash_attn_3 as flash_attn_3_gpu # type: ignore
else:
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_3._C # Registers operators with PyTorch
# isort: on
flash_attn_3_gpu = torch.ops.flash_attn_3
def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def round_multiple(x, m):
return (x + m - 1) // m * m
def round_up_headdim(head_size: int) -> int:
from flash_attn_config import CONFIG
if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM64"]:
if head_size <= 64:
return 64
if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM96"]:
if head_size <= 96:
return 96
if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM128"]:
if head_size <= 128:
return 128
if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM192"]:
if head_size <= 192:
return 192
if not CONFIG["build_flags"]["FLASHATTENTION_DISABLE_HDIM256"]:
if head_size <= 256:
return 256
return 256
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_new: Optional[torch.Tensor] = None,
v_new: Optional[torch.Tensor] = None,
qv: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
page_table: Optional[torch.Tensor] = None,
kv_batch_idx: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
seqlens_rotary: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
attention_chunk: int = 0,
softcap: float = 0.0,
rotary_interleaved: bool = True,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new)
]
seqused_q, seqused_k = [maybe_contiguous(x) for x in (seqused_q, seqused_k)]
page_table, kv_batch_idx, leftpad_k = [
maybe_contiguous(x) for x in (page_table, kv_batch_idx, leftpad_k)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
seqlens_rotary = maybe_contiguous(seqlens_rotary)
out, softmax_lse, out_accum, softmax_lse_accum = flash_attn_3_gpu.fwd(
q,
k,
v,
k_new,
v_new,
qv,
out_,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_k_new,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
page_table,
kv_batch_idx,
leftpad_k,
rotary_cos,
rotary_sin,
seqlens_rotary,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size_left,
window_size_right,
attention_chunk,
softcap,
rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
)
if out_accum is None:
out_accum = torch.tensor([], device=out.device)
if softmax_lse_accum is None:
softmax_lse_accum = torch.tensor([], device=out.device)
return out, softmax_lse, out_accum, softmax_lse_accum
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
def _flash_attn_forward_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
k_new: Optional[torch.Tensor] = None,
v_new: Optional[torch.Tensor] = None,
qv: Optional[torch.Tensor] = None,
out_: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
page_table: Optional[torch.Tensor] = None,
kv_batch_idx: Optional[torch.Tensor] = None,
leftpad_k: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
seqlens_rotary: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
attention_chunk: int = 0,
softcap: float = 0.0,
rotary_interleaved: bool = True,
scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Symbolic fake implementation of flash attention forward.
Returns tensors with the correct shapes and dtypes without actual computation.
"""
# Determine if we're in varlen mode
is_varlen_q = cu_seqlens_q is not None
# Get dimensions from query tensor
if is_varlen_q:
# varlen mode: q is (total_q, num_heads, head_size)
total_q, num_heads, head_size = q.shape
batch_size = cu_seqlens_q.shape[0] - 1
if max_seqlen_q is None:
raise ValueError("max_seqlen_q must be provided if cu_seqlens_q is provided")
seqlen_q = max_seqlen_q
else:
# batch mode: q is (batch_size, seqlen_q, num_heads, head_size)
batch_size, seqlen_q, num_heads, head_size = q.shape
total_q = batch_size * q.shape[1]
# Get value head dimension
head_size_v = v.shape[-1]
# Determine output dtype (FP8 inputs produce BF16 outputs)
q_type = q.dtype
if q_type == torch.float8_e4m3fn:
out_dtype = torch.bfloat16
else:
out_dtype = q_type
# Create output tensor
if out_ is not None:
# If out_ is provided, _flash_attn_forward becomes non-functional
raise TypeError("Tracing (torch.compile/torch.export) with pre-allocated output tensor is not supported.")
if is_varlen_q:
out = torch.empty((total_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
else:
out = torch.empty((batch_size, seqlen_q, num_heads, head_size_v), dtype=out_dtype, device=q.device)
# Create softmax_lse tensor
if is_varlen_q:
softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device)
else:
softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
# TODO(guilhermeleobas): Implement "get_num_splits"
# There's an heuristic to compute num_splits when "num_splits <= 0"
# assert that num_splits is > 0 for now
if num_splits <= 0:
raise ValueError(f"tracing (torch.compile/torch.export) with num_splits <= 0 not supported. Got {num_splits=}")
if num_splits > 1:
if is_varlen_q:
out_accum = torch.empty((num_splits, num_heads, total_q, head_size_v), dtype=torch.float32, device=q.device)
softmax_lse_accum = torch.empty((num_splits, num_heads, total_q), dtype=torch.float32, device=q.device)
else:
out_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q, head_size_v), dtype=torch.float32, device=q.device)
softmax_lse_accum = torch.empty((num_splits, batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device)
else:
# Tensors are not set when num_splits < 1
out_accum = torch.tensor([], device=out.device)
softmax_lse_accum = torch.tensor([], device=out.device)
return out, softmax_lse, out_accum, softmax_lse_accum
@torch.library.custom_op("flash_attn_3::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
def _flash_attn_backward(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
sequed_q: Optional[torch.Tensor] = None,
sequed_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
sm_margin: int = 0,
) -> torch.Tensor:
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
softmax_d, *rest = flash_attn_3_gpu.bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
sequed_q,
sequed_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
deterministic,
sm_margin,
)
return softmax_d
@torch.library.register_fake("flash_attn_3::_flash_attn_backward")
def _flash_attn_backward_fake(
dout: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
sequed_q: Optional[torch.Tensor] = None,
sequed_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
dq: Optional[torch.Tensor] = None,
dk: Optional[torch.Tensor] = None,
dv: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
is_causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
deterministic: bool = False,
sm_margin: int = 0,
) -> torch.Tensor:
is_varlen_q = cu_seqlens_q is not None
is_varlen_k = cu_seqlens_q is not None
is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None
if not is_varlen_q:
batch_size = q.size(0)
seqlen_q = q.size(1)
seqlen_k = k.size(1)
total_q = batch_size * q.size(1)
else:
batch_size = cu_seqlens_q.size(0) - 1
total_q = q.size(0)
seqlen_q = max_seqlen_q
seqlen_k = max_seqlen_k
if window_size_left >= seqlen_k - 1:
window_size_left = -1
if window_size_right >= seqlen_q - 1:
window_size_right = -1
if is_causal:
window_size_right = 0
is_causal = window_size_left < 0 and window_size_right == 0
head_size = q.size(-1)
head_size_v = v.size(-1)
head_size_rounded = round_up_headdim(max(head_size, head_size_v))
# Hopper gpus uses cuda compute capabilities 9.0
cap = torch.cuda.get_device_capability(q.device)
arch = cap[0] * 10 + cap[1]
is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal
if head_size_rounded <= 64:
kBlockM_sm90 = 96 if (is_causal and softcap > 0.0) else 128
elif head_size_rounded <= 96:
kBlockM_sm90 = 64
elif head_size_rounded <= 128:
kBlockM_sm90 = 64 if (is_causal or is_local or softcap > 0.0) else 80
else:
kBlockM_sm90 = 64
kBlockM_sm80 = 128 if head_size_rounded <= 64 else 64
kBlockM_sm86 = 64 if head_size_rounded <= 192 else 32
if arch >= 90:
kBlockM = kBlockM_sm90
elif arch == 86 or arch == 89:
kBlockM = kBlockM_sm86
else:
kBlockM = kBlockM_sm80
num_heads = q.shape[-2]
seqlen_q_rounded = round_multiple(seqlen_q, kBlockM)
total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM)
dq = torch.empty_like(q) if dq is None else dq
dk = torch.empty_like(k) if dk is None else dk
dv = torch.empty_like(v) if dv is None else dv
if not is_varlen:
softmax_d = torch.empty((batch_size, num_heads, seqlen_q_rounded), dtype=torch.float32, device=q.device)
else:
softmax_d = torch.empty((num_heads, total_q_padded_rounded), dtype=torch.float32, device=q.device)
return softmax_d
def setup_context(ctx, inputs, output):
q, k, v = inputs[:3]
out, softmax_lse, _, _ = output
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.softmax_scale = inputs[-11]
ctx.causal = inputs[-10]
ctx.window_size = [inputs[-9], inputs[-8]]
ctx.attention_chunk = inputs[-7]
ctx.softcap = inputs[-6]
ctx.sm_margin = inputs[-1]
def _backward(ctx, dout, *grads):
q, k, v, out, softmax_lse = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
None, None, # cu_seqlens_q, cu_seqlens_k,
None, None, # sequed_q, sequed_k,
None, None, # max_seqlen_q, max_seqlen_k,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
False, # deterministic
ctx.sm_margin,
)
return dq, dk, dv, *((None,) * 21)
_flash_attn_forward.register_autograd(_backward, setup_context=setup_context)
class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
softmax_scale,
causal,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
deterministic=False,
num_heads_q=None,
sm_margin=0,
return_softmax=False,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
if qkv.dim() == 5:
assert qkv.shape[-3] == 3
q, k, v = qkv.unbind(dim=-3)
else:
assert qkv.dim() == 4
assert num_heads_q is not None
num_heads_k = (qkv.shape[2] - num_heads_q) // 2
assert num_heads_k * 2 + num_heads_q == qkv.shape[2]
q, k, v = qkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
out, softmax_lse, *rest = _flash_attn_forward(
q,
k,
v,
None, None, # k_new, v_new
None, # qv
None, # out
None, None, None, # cu_seqlens_q/k/k_new
None, None, # seqused_q/k
None, None, # max_seqlen_q/k
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, None, # rotary_cos/sin, seqlens_rotary
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
sm_margin=sm_margin,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.attention_chunk = attention_chunk
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.ndim = qkv.dim()
ctx.sm_margin = sm_margin
return (out, softmax_lse) if return_softmax else out
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse = ctx.saved_tensors
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
if ctx.ndim == 5:
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
dq, dk, dv = dqkv.unbind(dim=-3)
else:
num_heads_q = q.shape[2]
num_heads_k = k.shape[2]
qkv_shape = q.shape[:-2] + (num_heads_q + num_heads_k * 2, *q.shape[-1:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
dq, dk, dv = dqkv.split([num_heads_q, num_heads_k, num_heads_k], dim=-2)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
None, None, # cu_seqlens_q, cu_seqlens_k,
None, None, # sequed_q, sequed_k,
None, None, # max_seqlen_q, max_seqlen_k,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
softmax_scale,
causal,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
return_softmax=False,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
# out, q, k, v, out_padded, softmax_lse = _flash_attn_forward(
out, softmax_lse, *rest = _flash_attn_forward(
q,
k,
v,
None, None, # k_new, v_new
qv, # qv
None, # out
None, None, None, # cu_seqlens_q/k/k_new
None, None, # seqused_q/k
None, None, # max_seqlen_q/k
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, None, # rotary_cos/sin, seqlens_rotary
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.attention_chunk = attention_chunk
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
return (out, softmax_lse) if return_softmax else out
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse = ctx.saved_tensors
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
None, None, # cu_seqlens_q, cu_seqlens_k,
None, None, # sequed_q, sequed_k,
None, None, # max_seqlen_q, max_seqlen_k,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
dk = dk[..., : k.shape[-1]]
dv = dv[..., : v.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
return_softmax=False,
):
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
# out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward(
out, softmax_lse, *rest = _flash_attn_forward(
q,
k,
v,
None, None, # k_new, v_new
qv, # qv
None, # out
cu_seqlens_q,
cu_seqlens_k,
None, # cu_seqlens_k_new
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
None, None, None, # page_table, kv_batch_idx, leftpad_k,
None, None, None, # rotary_cos/sin, seqlens_rotary
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.attention_chunk = attention_chunk
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
return (out, softmax_lse) if return_softmax else out
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
assert ctx.attention_chunk == 0, "FA3 backward does not support attention_chunk"
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
_flash_attn_backward(
dout,
q,
k,
v,
out,
softmax_lse,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
dq,
dk,
dv,
ctx.softmax_scale,
ctx.causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
dq = dq[..., : q.shape[-1]] # We could have padded the head dimension
dk = dk[..., : k.shape[-1]]
dv = dv[..., : v.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func(
qkv,
softmax_scale=None,
causal=False,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
deterministic=False,
num_heads_q=None,
sm_margin=0,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedFunc.apply(
qkv,
softmax_scale,
causal,
q_descale, k_descale, v_descale,
window_size,
attention_chunk,
softcap,
deterministic,
num_heads_q,
sm_margin,
return_attn_probs,
)
def flash_attn_func(
q,
k,
v,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
return FlashAttnFunc.apply(
q,
k,
v,
softmax_scale,
causal,
qv,
q_descale, k_descale, v_descale,
window_size,
attention_chunk,
softcap,
num_splits,
pack_gqa,
deterministic,
sm_margin,
return_attn_probs,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None, k_descale=None, v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
deterministic=False,
sm_margin=0,
return_attn_probs=False,
):
return FlashAttnVarlenFunc.apply(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqused_q,
seqused_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
qv,
q_descale, k_descale, v_descale,
window_size,
attention_chunk,
softcap,
num_splits,
pack_gqa,
deterministic,
sm_margin,
return_attn_probs,
)
def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None):
return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype)
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk=0,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
page_block_size can be arbitrary (e.g, 1, 2, 3, 64, etc.).
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full(
(q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
)
cache_seqlens = maybe_contiguous(cache_seqlens)
out, softmax_lse, *rest = _flash_attn_forward(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale, k_descale, v_descale,
softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=attention_chunk,
softcap=softcap,
rotary_interleaved=rotary_interleaved,
scheduler_metadata=scheduler_metadata,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)
# return (out, softmax_lse) if return_softmax_lse else out
return (out, softmax_lse, *rest) if return_softmax_lse else out
def get_scheduler_metadata(
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk=0,
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
):
cache_seqlens = maybe_contiguous(cache_seqlens)
if headdim_v is None:
headdim_v = headdim
scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata(
batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v,
qkv_dtype,
cache_seqlens,
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_leftpad,
page_size,
max_seqlen_k_new,
causal,
window_size[0], window_size[1],
attention_chunk,
has_softcap,
num_splits,
pack_gqa,
sm_margin,
)
return scheduler_metadata