mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
701ebe0578
* Fused Bwd (#137) * Fused with Good perf and stride fixed Fix fused bugs isolate failing case fix bug bring back test cases rm split impl in fused use exp2 is global variable now try oom fix save make fused the default limit to reproduce failure return default to split fix head size bug use exp2 back to true * new grid * BLK_SLICE_FACTOR = 1 * add tflops * new commit * test in parrallel * strides added by jusson * disable alibi * fix bugs again * default to fused * add bwd options for varlen * backend filter * default to jingning and batch 4 * best fwd config * fix TRITON_PRINT_AUTOTUNING flag bug * tune * Tuning fwd prefill * add if else * use flag * Minor mask fix * FLIP GRID * use best config for default * print when autotuning * test bfloat16 * fix k and v stride bugs * skip bfloat16 * test kvpacked * disable internal tests * pick default config based on arch * Add alibi in the new bwd kernel (#139) * enable alibi for jinging kernel enable alibi for jinging kernel match * save bad configs * fix alibi and causal bug * disable autotune by default * auto tune when benching is good * set best config * remove env var * Update amd_tests.yml * upgrad to triton==3.3.0 * increase shm * use 64 x 64 for now * save * handle 1d alibi * Add fp8 to fused kernel (#140) * fp8 stuff find test case compute delta fp8 basic fp8 config passing non causal path works * isolate bad case * fix fp8 bug * didnot fix fp8 bug * back to failing test * fp8 tests passing * skip * skip ref tests --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> * head, seq, batch (#141) * Fix keys (#144) * save * rm keys * fix keys * use GHA_RENDER_DEVICES * normal docker * Pad LSE (#148) * add round multiple * fix fwd * backward fix * use rounded lse flag * passing ROUNDED_LSE * default is new rounded mode * rename to fused_atmoics and fused_no_atomics * add test for torch_compile * add varlen torch compile test * add old one kernel for ref * fix varlen mismatch bug * fix shape issue in varlen but mismatch * sync torch compile kernel launch * simple varlen test * add debug code * rm old * ignore old impls * DEBUG flag works in interface only * ref uses the righ shape for lse * rm oldest bwd kernel * fix typo * fix varlen bug * fix bug. Get info from q for now * simple shape and stride checkout * add more tests * test kvcache * kvcache safe * match case * fix segfault due to bad return_softmax * run bench * run seperate for the main functions * just output benchmark * default csv format and time stamp files * non verbsoe bench * Sliding Window Forward (#151) * Compress SWA work test case set up debug inputs add fwd ref one mask ref fwd first pass save ref doesnot work for bigger seqlens save new version some causal cases failing found bad cases working new attn new atten works new attn_fwd works reorg n_extra_tokens use seqlen_delta_qk ref fwd works add sliding window to bwd ref test kvcache decode ref work with everything except sliding window add debug code for 12 failing sliding window cases for decode attention_decode_forward_ref_impl mostly works except for alibi fix alibi in attention_decode_forward_ref_impl ref works with normal, varlen & kvcache move stuff around figure out masking old attn inner two inner functions remove load_fn do Lk - Lq like ref unify IS_CAUSAL code in epilogue clean up add args rm inference stuff simplify compute_masking simpler compute mask stub out returning front masking variables remove pointer pass compute ptrs inloop compute block min and max window stub inside inner mask loop trying to use attn_fwd_mask causes issues fix compiler bug when front masking gen specifc types add sliding window and debug statements use identity for v add more taste cases add comments save use k_max_token for clarity disable debug configs basic NON-CAUSAL SLIDING WINDOW non causal sliding window works on the all the shapes non sliding window working in fwd clean up fused bwd seperate old fwd_prefill move configs to utils.py * fix bwd ref bug * skip local cases so that fa output * no sliding window causal green * add backward test skip for sliding window * clean reduce in fwd_kvcache. no is_CASUAL branching * add kvcache masking * kvcache working * fix some bugs in test.py * clean up * Fix Device Segfault (#152) * Compress segfault work fix backward segfault rework offset ignore .profile ignore .analysis save * assert the kernel launch device and tensor devices are the same * fix failing asserts * add asserts to fwd * Fix SDMASK bug * Log triton, torch and fa version * Fix fp8 import issues * fix docs (#154) * Sliding Window block classification logic (#155) * add aiter code * remove aiter stuff * sliding window non causal masking works * causal and sliding window block masking * extract common * clean up typo * helper for swa * ignore .amd * fix last block bug * Enable FA V3 (#157) * Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed * clean up * skip hopper race test * clean up more * fix paged + alibi * similar inner paged api * unify _attn_fwd_inner * AITER integration (#159) * clean up v2 interface * assert fp8 scale shapes * rotary working * move rotary to impl layers * remove einops * enable rotarry in v3 * create interface * fix descale assert * unify bwd * lint from aiter * clean fp8 api * add api change * assert shapes for v2 * remove ref and bench.py * remove metadata class and clean up * bwd_prefill * one bwd.py * rename * lint * add bwd_change (#156) * Tune FP8 Perf (#160) * check cu count for gfx942 * create get_cu_count * update repo root * update forward tune * clean up load * use float8_e4m3fnuz * save * show bwd mode * recommend fp8 * use torch.float32 for fp8 kernel * add both best fp16 and fp8 config * tune fp8 backward * descale factors should be b, hk * fp8 bwd working on all primus configs * tune bwd configs * fa v3 tests passing * better warning * clean up bwd launcher * v3 passing * tune more * improve perf * clean up * lint * clean * start tuning gfx950 * tune non causal path * fix bug * save * Skip configs where BLOCK_M2 % BLOCK_N2 != 0 * skip more * stop tuning * fix varlen bug * fix dropout & causal/swa segfault * update the to machine new changes * save * fix more bugs * remove random seed * clean up * update readme * print tensor stats for debug * disable sliding window tests * add rdna configs * fix k partial bug * fix block_size_n bug * fix type check bug --------- Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com> Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
1617 lines
60 KiB
Python
1617 lines
60 KiB
Python
# Copyright (c) 2023, Tri Dao.
|
|
|
|
from typing import Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
|
|
# isort: off
|
|
# We need to import the CUDA kernels after importing torch
|
|
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
|
|
if USE_TRITON_ROCM:
|
|
from .flash_attn_triton_amd import flash_attn_2 as flash_attn_gpu
|
|
else:
|
|
import flash_attn_2_cuda as flash_attn_gpu
|
|
|
|
# isort: on
|
|
|
|
def maybe_contiguous(x):
|
|
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
|
|
|
|
|
|
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
|
# This should match the block sizes in the CUDA kernel
|
|
assert head_dim <= 256
|
|
major, minor = torch.cuda.get_device_capability(device)
|
|
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
|
|
is_sm80 = major == 8 and minor == 0
|
|
is_sm90 = major == 9 and minor == 0
|
|
if head_dim <= 32:
|
|
return 128
|
|
if head_dim <= 64:
|
|
return 128 if not is_dropout else 64
|
|
elif head_dim <= 96:
|
|
return 64
|
|
elif head_dim <= 128:
|
|
if is_sm8x:
|
|
return 64 if (not is_dropout and is_causal) else 32
|
|
else:
|
|
return 64 if not is_dropout else 32
|
|
elif head_dim <= 192:
|
|
return 64
|
|
elif head_dim <= 224:
|
|
return 64
|
|
elif head_dim <= 256:
|
|
return 64
|
|
|
|
|
|
def round_multiple(x, m):
|
|
return (x + m - 1) // m * m
|
|
|
|
|
|
# torch.compile() support is only enabled for pytorch >= 2.4
|
|
# The reason for this is that we are using the new custom_op and register_fake
|
|
# APIs, which support inplace modification of inputs in the function itself
|
|
if torch.__version__ >= "2.4.0":
|
|
_torch_custom_op_wrapper = torch.library.custom_op
|
|
_torch_register_fake_wrapper = torch.library.register_fake
|
|
else:
|
|
def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
|
def wrap(func):
|
|
return func
|
|
if fn is None:
|
|
return wrap
|
|
return fn
|
|
def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1):
|
|
def wrap(func):
|
|
return func
|
|
if fn is None:
|
|
return wrap
|
|
return fn
|
|
_torch_custom_op_wrapper = noop_custom_op_wrapper
|
|
_torch_register_fake_wrapper = noop_register_fake_wrapper
|
|
|
|
|
|
@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
|
def _flash_attn_forward(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int,
|
|
window_size_right: int,
|
|
softcap: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
return_softmax: bool
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
None,
|
|
alibi_slopes,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size_left,
|
|
window_size_right,
|
|
softcap,
|
|
return_softmax,
|
|
None,
|
|
)
|
|
return out, softmax_lse, S_dmask, rng_state
|
|
|
|
|
|
@_torch_register_fake_wrapper("flash_attn::_flash_attn_forward")
|
|
def _flash_attn_forward_fake(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int,
|
|
window_size_right: int,
|
|
softcap: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
return_softmax: bool
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
batch_size, seqlen_q, num_heads, head_size = q.shape
|
|
seqlen_k = k.shape[1]
|
|
out = torch.empty_like(q)
|
|
softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)
|
|
p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
if return_softmax:
|
|
if torch.cuda.is_available() and torch.version.hip:
|
|
p = torch.empty((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
else:
|
|
p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
|
|
|
|
return out, softmax_lse, p, rng_state
|
|
|
|
|
|
if torch.__version__ >= "2.4.0":
|
|
_wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
|
|
else:
|
|
_wrapped_flash_attn_forward = _flash_attn_forward
|
|
|
|
|
|
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_forward", mutates_args=(), device_types="cuda")
|
|
def _flash_attn_varlen_forward(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
cu_seqlens_q: torch.Tensor,
|
|
cu_seqlens_k: torch.Tensor,
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int = -1,
|
|
window_size_right: int = -1,
|
|
softcap: float = 0.0,
|
|
alibi_slopes: Optional[torch.Tensor] = None,
|
|
return_softmax: bool = False,
|
|
block_table: Optional[torch.Tensor] = None,
|
|
leftpad_k: Optional[torch.Tensor] = None,
|
|
seqused_k: Optional[torch.Tensor] = None,
|
|
zero_tensors: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
|
|
q,
|
|
k,
|
|
v,
|
|
None,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
seqused_k,
|
|
leftpad_k,
|
|
block_table,
|
|
alibi_slopes,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
zero_tensors,
|
|
causal,
|
|
window_size_left,
|
|
window_size_right,
|
|
softcap,
|
|
return_softmax,
|
|
None,
|
|
)
|
|
# if out.isnan().any() or softmax_lse.isnan().any():
|
|
# breakpoint()
|
|
return out, softmax_lse, S_dmask, rng_state
|
|
|
|
|
|
@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_forward")
|
|
def _flash_attn_varlen_forward_fake(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
cu_seqlens_q: torch.Tensor,
|
|
cu_seqlens_k: torch.Tensor,
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int = -1,
|
|
window_size_right: int = -1,
|
|
softcap: float = 0.0,
|
|
alibi_slopes: Optional[torch.Tensor] = None,
|
|
return_softmax: bool = False,
|
|
block_table: Optional[torch.Tensor] = None,
|
|
leftpad_k: Optional[torch.Tensor] = None,
|
|
seqused_k: Optional[torch.Tensor] = None,
|
|
zero_tensors: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
paged_kv = block_table is not None
|
|
batch_size = cu_seqlens_q.numel() - 1
|
|
total_q, num_heads, _ = q.shape
|
|
|
|
out = torch.empty_like(q)
|
|
softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout)
|
|
p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
if return_softmax:
|
|
if torch.cuda.is_available() and torch.version.hip:
|
|
p = torch.empty((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
else:
|
|
p = torch.empty((batch_size, num_heads, round_multiple(max_seqlen_q, 128), round_multiple(max_seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout)
|
|
rng_state = torch.empty((2,), dtype=torch.int64, device=q.device)
|
|
return out, softmax_lse, p, rng_state
|
|
|
|
|
|
if torch.__version__ >= "2.4.0":
|
|
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
|
|
else:
|
|
_wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
|
|
|
|
|
|
@_torch_custom_op_wrapper("flash_attn::_flash_attn_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
|
|
def _flash_attn_backward(
|
|
dout: torch.Tensor,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
softmax_lse: torch.Tensor,
|
|
dq: Optional[torch.Tensor],
|
|
dk: Optional[torch.Tensor],
|
|
dv: Optional[torch.Tensor],
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int,
|
|
window_size_right: int,
|
|
softcap: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
deterministic: bool,
|
|
rng_state: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
(
|
|
dq,
|
|
dk,
|
|
dv,
|
|
softmax_d,
|
|
) = flash_attn_gpu.bwd(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
alibi_slopes,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size_left,
|
|
window_size_right,
|
|
softcap,
|
|
deterministic,
|
|
None,
|
|
rng_state,
|
|
)
|
|
return softmax_d
|
|
|
|
|
|
@_torch_register_fake_wrapper("flash_attn::_flash_attn_backward")
|
|
def _flash_attn_backward_fake(
|
|
dout: torch.Tensor,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
softmax_lse: torch.Tensor,
|
|
dq: Optional[torch.Tensor],
|
|
dk: Optional[torch.Tensor],
|
|
dv: Optional[torch.Tensor],
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int,
|
|
window_size_right: int,
|
|
softcap: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
deterministic: bool,
|
|
rng_state: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
if dq is None:
|
|
dq = torch.empty_like(q)
|
|
if dk is None:
|
|
dk = torch.empty_like(k)
|
|
if dv is None:
|
|
dv = torch.empty_like(v)
|
|
batch_size, seqlen_q, num_heads, _ = q.shape
|
|
if torch.cuda.is_available() and torch.version.hip:
|
|
softmax_d = torch.empty((batch_size, num_heads, seqlen_q), device=q.device, dtype=torch.float32)
|
|
else:
|
|
softmax_d = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128)), device=q.device, dtype=torch.float32)
|
|
|
|
return softmax_d
|
|
|
|
|
|
if torch.__version__ >= "2.4.0":
|
|
_wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
|
|
else:
|
|
_wrapped_flash_attn_backward = _flash_attn_backward
|
|
|
|
|
|
@_torch_custom_op_wrapper("flash_attn::_flash_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda")
|
|
def _flash_attn_varlen_backward(
|
|
dout: torch.Tensor,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
softmax_lse: torch.Tensor,
|
|
dq: Optional[torch.Tensor],
|
|
dk: Optional[torch.Tensor],
|
|
dv: Optional[torch.Tensor],
|
|
cu_seqlens_q: torch.Tensor,
|
|
cu_seqlens_k: torch.Tensor,
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int,
|
|
window_size_right: int,
|
|
softcap: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
deterministic: bool,
|
|
rng_state: Optional[torch.Tensor] = None,
|
|
zero_tensors: bool = False,
|
|
) -> torch.Tensor:
|
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
(
|
|
dq,
|
|
dk,
|
|
dv,
|
|
softmax_d,
|
|
) = flash_attn_gpu.varlen_bwd(
|
|
dout,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
alibi_slopes,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
zero_tensors,
|
|
causal,
|
|
window_size_left,
|
|
window_size_right,
|
|
softcap,
|
|
deterministic,
|
|
None,
|
|
rng_state,
|
|
)
|
|
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
|
# breakpoint()
|
|
return softmax_d
|
|
|
|
|
|
@_torch_register_fake_wrapper("flash_attn::_flash_attn_varlen_backward")
|
|
def _flash_attn_varlen_backward_fake(
|
|
dout: torch.Tensor,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
out: torch.Tensor,
|
|
softmax_lse: torch.Tensor,
|
|
dq: Optional[torch.Tensor],
|
|
dk: Optional[torch.Tensor],
|
|
dv: Optional[torch.Tensor],
|
|
cu_seqlens_q: torch.Tensor,
|
|
cu_seqlens_k: torch.Tensor,
|
|
max_seqlen_q: int,
|
|
max_seqlen_k: int,
|
|
dropout_p: float,
|
|
softmax_scale: float,
|
|
causal: bool,
|
|
window_size_left: int,
|
|
window_size_right: int,
|
|
softcap: float,
|
|
alibi_slopes: Optional[torch.Tensor],
|
|
deterministic: bool,
|
|
rng_state: Optional[torch.Tensor] = None,
|
|
zero_tensors: bool = False,
|
|
) -> torch.Tensor:
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
|
batch_size = cu_seqlens_q.numel() - 1
|
|
total_q, num_heads, _ = q.shape
|
|
|
|
if dq is None:
|
|
dq = torch.empty_like(q)
|
|
if dk is None:
|
|
dk = torch.empty_like(k)
|
|
if dv is None:
|
|
dv = torch.empty_like(v)
|
|
if torch.cuda.is_available() and torch.version.hip:
|
|
softmax_d = torch.empty((num_heads, total_q), device=q.device, dtype=torch.float32)
|
|
else:
|
|
softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32)
|
|
|
|
return softmax_d
|
|
|
|
|
|
if torch.__version__ >= "2.4.0":
|
|
_wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
|
|
else:
|
|
_wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
|
|
|
|
|
|
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
qkv,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_softmax,
|
|
is_grad_enabled,
|
|
):
|
|
is_grad = is_grad_enabled and qkv.requires_grad
|
|
if softmax_scale is None:
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
|
q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach()
|
|
head_size_og = q.size(3)
|
|
if head_size_og % 8 != 0:
|
|
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=causal,
|
|
window_size_left=window_size[0],
|
|
window_size_right=window_size[1],
|
|
softcap=softcap,
|
|
alibi_slopes=alibi_slopes,
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
)
|
|
if is_grad:
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
|
ctx.dropout_p = dropout_p
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.causal = causal
|
|
ctx.window_size = window_size
|
|
ctx.softcap = softcap
|
|
ctx.alibi_slopes = alibi_slopes
|
|
ctx.deterministic = deterministic
|
|
out = out_padded[..., :head_size_og]
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
|
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
|
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
|
head_size_og = dout.size(3)
|
|
dout_padded = dout
|
|
if head_size_og % 8 != 0:
|
|
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
_wrapped_flash_attn_backward(
|
|
dout_padded,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dqkv[:, :, 0],
|
|
dqkv[:, :, 1],
|
|
dqkv[:, :, 2],
|
|
ctx.dropout_p,
|
|
ctx.softmax_scale,
|
|
ctx.causal,
|
|
ctx.window_size[0],
|
|
ctx.window_size[1],
|
|
ctx.softcap,
|
|
ctx.alibi_slopes,
|
|
ctx.deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
return dqkv, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
qkv,
|
|
cu_seqlens,
|
|
max_seqlen,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_softmax,
|
|
is_grad_enabled,
|
|
):
|
|
is_grad = is_grad_enabled and qkv.requires_grad
|
|
if softmax_scale is None:
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
|
q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach()
|
|
head_size_og = q.size(2)
|
|
if head_size_og % 8 != 0:
|
|
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
max_seqlen,
|
|
max_seqlen,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=causal,
|
|
window_size_left=window_size[0],
|
|
window_size_right=window_size[1],
|
|
softcap=softcap,
|
|
alibi_slopes=alibi_slopes,
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
block_table=None,
|
|
)
|
|
if is_grad:
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
|
ctx.dropout_p = dropout_p
|
|
ctx.max_seqlen = max_seqlen
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.causal = causal
|
|
ctx.window_size = window_size
|
|
ctx.softcap = softcap
|
|
ctx.alibi_slopes = alibi_slopes
|
|
ctx.deterministic = deterministic
|
|
out = out_padded[..., :head_size_og]
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
|
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
|
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
|
head_size_og = dout.size(2)
|
|
dout_padded = dout
|
|
if head_size_og % 8 != 0:
|
|
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
_wrapped_flash_attn_varlen_backward(
|
|
dout_padded,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dqkv[:, 0],
|
|
dqkv[:, 1],
|
|
dqkv[:, 2],
|
|
cu_seqlens,
|
|
cu_seqlens,
|
|
ctx.max_seqlen,
|
|
ctx.max_seqlen,
|
|
ctx.dropout_p,
|
|
ctx.softmax_scale,
|
|
ctx.causal,
|
|
ctx.window_size[0],
|
|
ctx.window_size[1],
|
|
ctx.softcap,
|
|
ctx.alibi_slopes,
|
|
ctx.deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
return dqkv, None, None, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q,
|
|
kv,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_softmax,
|
|
is_grad_enabled,
|
|
):
|
|
is_grad = is_grad_enabled and any(
|
|
x.requires_grad for x in [q, kv]
|
|
)
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
k, v = kv[:, :, 0].detach(), kv[:, :, 1].detach()
|
|
head_size_og = q.size(3)
|
|
if head_size_og % 8 != 0:
|
|
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=causal,
|
|
window_size_left=window_size[0],
|
|
window_size_right=window_size[1],
|
|
softcap=softcap,
|
|
alibi_slopes=alibi_slopes,
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
)
|
|
if is_grad:
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
|
ctx.dropout_p = dropout_p
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.causal = causal
|
|
ctx.window_size = window_size
|
|
ctx.softcap = softcap
|
|
ctx.alibi_slopes = alibi_slopes
|
|
ctx.deterministic = deterministic
|
|
out = out_padded[..., :head_size_og]
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
|
dq = torch.empty_like(q)
|
|
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
|
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
|
head_size_og = dout.size(3)
|
|
dout_padded = dout
|
|
if head_size_og % 8 != 0:
|
|
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
_wrapped_flash_attn_backward(
|
|
dout_padded,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dkv[:, :, 0],
|
|
dkv[:, :, 1],
|
|
ctx.dropout_p,
|
|
ctx.softmax_scale,
|
|
ctx.causal,
|
|
ctx.window_size[0],
|
|
ctx.window_size[1],
|
|
ctx.softcap,
|
|
ctx.alibi_slopes,
|
|
ctx.deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
dkv = dkv[..., : dout.shape[-1]]
|
|
return dq, dkv, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q,
|
|
kv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_softmax,
|
|
is_grad_enabled,
|
|
):
|
|
is_grad = is_grad_enabled and any(
|
|
x.requires_grad for x in [q, kv]
|
|
)
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
k, v = kv[:, 0].detach(), kv[:, 1].detach()
|
|
head_size_og = q.size(2)
|
|
if head_size_og % 8 != 0:
|
|
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=causal,
|
|
window_size_left=window_size[0],
|
|
window_size_right=window_size[1],
|
|
softcap=softcap,
|
|
alibi_slopes=alibi_slopes,
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
block_table=None,
|
|
)
|
|
if is_grad:
|
|
ctx.save_for_backward(
|
|
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
|
)
|
|
ctx.dropout_p = dropout_p
|
|
ctx.max_seqlen_q = max_seqlen_q
|
|
ctx.max_seqlen_k = max_seqlen_k
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.causal = causal
|
|
ctx.window_size = window_size
|
|
ctx.softcap = softcap
|
|
ctx.alibi_slopes = alibi_slopes
|
|
ctx.deterministic = deterministic
|
|
out = out_padded[..., :head_size_og]
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
|
dq = torch.empty_like(q)
|
|
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
|
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
|
head_size_og = dout.size(2)
|
|
dout_padded = dout
|
|
if head_size_og % 8 != 0:
|
|
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
_wrapped_flash_attn_varlen_backward(
|
|
dout_padded,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dkv[:, 0],
|
|
dkv[:, 1],
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
ctx.max_seqlen_q,
|
|
ctx.max_seqlen_k,
|
|
ctx.dropout_p,
|
|
ctx.softmax_scale,
|
|
ctx.causal,
|
|
ctx.window_size[0],
|
|
ctx.window_size[1],
|
|
ctx.softcap,
|
|
ctx.alibi_slopes,
|
|
ctx.deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
dkv = dkv[..., : dout.shape[-1]]
|
|
return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class FlashAttnFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_softmax,
|
|
is_grad_enabled,
|
|
):
|
|
is_grad = is_grad_enabled and any(
|
|
x.requires_grad for x in [q, k, v]
|
|
)
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
head_size_og = q.size(3)
|
|
if head_size_og % 8 != 0:
|
|
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=causal,
|
|
window_size_left=window_size[0],
|
|
window_size_right=window_size[1],
|
|
softcap=softcap,
|
|
alibi_slopes=alibi_slopes,
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
)
|
|
if is_grad:
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
|
ctx.dropout_p = dropout_p
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.causal = causal
|
|
ctx.window_size = window_size
|
|
ctx.softcap = softcap
|
|
ctx.alibi_slopes = alibi_slopes
|
|
ctx.deterministic = deterministic
|
|
out = out_padded[..., :head_size_og]
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
|
head_size_og = dout.size(3)
|
|
dout_padded = dout
|
|
if head_size_og % 8 != 0:
|
|
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
_wrapped_flash_attn_backward(
|
|
dout_padded,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
ctx.dropout_p,
|
|
ctx.softmax_scale,
|
|
ctx.causal,
|
|
ctx.window_size[0],
|
|
ctx.window_size[1],
|
|
ctx.softcap,
|
|
ctx.alibi_slopes,
|
|
ctx.deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
dk = dk[..., : dout.shape[-1]]
|
|
dv = dv[..., : dout.shape[-1]]
|
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
class FlashAttnVarlenFunc(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_softmax,
|
|
block_table,
|
|
is_grad_enabled,
|
|
):
|
|
is_grad = is_grad_enabled and any(
|
|
x.requires_grad for x in [q, k, v]
|
|
)
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
head_size_og = q.size(2)
|
|
if head_size_og % 8 != 0:
|
|
q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8])
|
|
k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8])
|
|
v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8])
|
|
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_varlen_forward(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal=causal,
|
|
window_size_left=window_size[0],
|
|
window_size_right=window_size[1],
|
|
softcap=softcap,
|
|
alibi_slopes=alibi_slopes,
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
block_table=block_table,
|
|
)
|
|
if is_grad:
|
|
ctx.save_for_backward(
|
|
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
|
)
|
|
ctx.dropout_p = dropout_p
|
|
ctx.max_seqlen_q = max_seqlen_q
|
|
ctx.max_seqlen_k = max_seqlen_k
|
|
ctx.softmax_scale = softmax_scale
|
|
ctx.causal = causal
|
|
ctx.window_size = window_size
|
|
ctx.softcap = softcap
|
|
ctx.alibi_slopes = alibi_slopes
|
|
ctx.deterministic = deterministic
|
|
|
|
out = out_padded[..., :head_size_og]
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
|
head_size_og = dout.size(2)
|
|
dout_padded = dout
|
|
if head_size_og % 8 != 0:
|
|
dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8])
|
|
_wrapped_flash_attn_varlen_backward(
|
|
dout_padded,
|
|
q,
|
|
k,
|
|
v,
|
|
out,
|
|
softmax_lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
ctx.max_seqlen_q,
|
|
ctx.max_seqlen_k,
|
|
ctx.dropout_p,
|
|
ctx.softmax_scale,
|
|
ctx.causal,
|
|
ctx.window_size[0],
|
|
ctx.window_size[1],
|
|
ctx.softcap,
|
|
ctx.alibi_slopes,
|
|
ctx.deterministic,
|
|
rng_state=rng_state,
|
|
)
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
dk = dk[..., : dout.shape[-1]]
|
|
dv = dv[..., : dout.shape[-1]]
|
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
|
|
|
|
|
def flash_attn_qkvpacked_func(
|
|
qkv,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # <=0.0 means deactivate
|
|
alibi_slopes=None,
|
|
deterministic=False,
|
|
return_attn_probs=False,
|
|
):
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
of the gradients of Q, K, V.
|
|
For multi-query and grouped-query attention (MQA/GQA), please see
|
|
flash_attn_kvpacked_func and flash_attn_func.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
|
|
|
Arguments:
|
|
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
|
|
the attention score of query i and key j.
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
(they might not have the right scaling).
|
|
Return:
|
|
out: (batch_size, seqlen, nheads, headdim).
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
"""
|
|
return FlashAttnQKVPackedFunc.apply(
|
|
qkv,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_attn_probs,
|
|
torch.is_grad_enabled(),
|
|
)
|
|
|
|
|
|
def flash_attn_kvpacked_func(
|
|
q,
|
|
kv,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # 0.0 means deactivated
|
|
alibi_slopes=None,
|
|
deterministic=False,
|
|
return_attn_probs=False,
|
|
):
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
If K, V are already stacked into 1 tensor, this function will be faster than
|
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
of the gradients of K, V.
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
1 1 1 1 0
|
|
1 1 1 1 1
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
0 0
|
|
0 0
|
|
0 0
|
|
1 0
|
|
1 1
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between
|
|
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
|
|
|
Arguments:
|
|
q: (batch_size, seqlen, nheads, headdim)
|
|
kv: (batch_size, seqlen, 2, nheads_k, headdim)
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
is added to the attention score of query i and key j.
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
(they might not have the right scaling).
|
|
Return:
|
|
out: (batch_size, seqlen, nheads, headdim).
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
"""
|
|
return FlashAttnKVPackedFunc.apply(
|
|
q,
|
|
kv,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_attn_probs,
|
|
torch.is_grad_enabled(),
|
|
)
|
|
|
|
|
|
def flash_attn_func(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # 0.0 means deactivated
|
|
alibi_slopes=None,
|
|
deterministic=False,
|
|
return_attn_probs=False,
|
|
):
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
1 1 1 1 0
|
|
1 1 1 1 1
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
0 0
|
|
0 0
|
|
0 0
|
|
1 0
|
|
1 1
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between
|
|
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
|
|
|
Arguments:
|
|
q: (batch_size, seqlen, nheads, headdim)
|
|
k: (batch_size, seqlen, nheads_k, headdim)
|
|
v: (batch_size, seqlen, nheads_k, headdim)
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
is added to the attention score of query i and key j.
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
(they might not have the right scaling).
|
|
Return:
|
|
out: (batch_size, seqlen, nheads, headdim).
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
"""
|
|
return FlashAttnFunc.apply(
|
|
q,
|
|
k,
|
|
v,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_attn_probs,
|
|
torch.is_grad_enabled(),
|
|
)
|
|
|
|
|
|
def flash_attn_varlen_qkvpacked_func(
|
|
qkv,
|
|
cu_seqlens,
|
|
max_seqlen,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # 0.0 means deactivated
|
|
alibi_slopes=None,
|
|
deterministic=False,
|
|
return_attn_probs=False,
|
|
):
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
|
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
of the gradients of Q, K, V.
|
|
For multi-query and grouped-query attention (MQA/GQA), please see
|
|
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
|
|
|
|
Arguments:
|
|
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
|
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into qkv.
|
|
max_seqlen: int. Maximum sequence length in the batch.
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|)
|
|
is added to the attention score of query i and key j.
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
(they might not have the right scaling).
|
|
Return:
|
|
out: (total, nheads, headdim).
|
|
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
"""
|
|
return FlashAttnVarlenQKVPackedFunc.apply(
|
|
qkv,
|
|
cu_seqlens,
|
|
max_seqlen,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_attn_probs,
|
|
torch.is_grad_enabled(),
|
|
)
|
|
|
|
|
|
def flash_attn_varlen_kvpacked_func(
|
|
q,
|
|
kv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # 0.0 means deactivated
|
|
alibi_slopes=None,
|
|
deterministic=False,
|
|
return_attn_probs=False,
|
|
):
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
If K, V are already stacked into 1 tensor, this function will be faster than
|
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
of the gradients of K, V.
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
1 1 1 1 0
|
|
1 1 1 1 1
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
0 0
|
|
0 0
|
|
0 0
|
|
1 0
|
|
1 1
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between
|
|
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
|
|
|
Arguments:
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
|
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
|
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into kv.
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
is added to the attention score of query i and key j.
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
(they might not have the right scaling).
|
|
Return:
|
|
out: (total, nheads, headdim).
|
|
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
"""
|
|
return FlashAttnVarlenKVPackedFunc.apply(
|
|
q,
|
|
kv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_attn_probs,
|
|
torch.is_grad_enabled(),
|
|
)
|
|
|
|
|
|
def flash_attn_varlen_func(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p=0.0,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # 0.0 means deactivated
|
|
alibi_slopes=None,
|
|
deterministic=False,
|
|
return_attn_probs=False,
|
|
block_table=None,
|
|
):
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
1 1 1 1 0
|
|
1 1 1 1 1
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
0 0
|
|
0 0
|
|
0 0
|
|
1 0
|
|
1 1
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between
|
|
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
|
|
|
Arguments:
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
|
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
|
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
|
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into q.
|
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
of the sequences in the batch, used to index into kv.
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
|
dropout_p: float. Dropout probability.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
is added to the attention score of query i and key j.
|
|
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
|
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
(they might not have the right scaling).
|
|
Return:
|
|
out: (total, nheads, headdim).
|
|
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
"""
|
|
return FlashAttnVarlenFunc.apply(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
dropout_p,
|
|
softmax_scale,
|
|
causal,
|
|
window_size,
|
|
softcap,
|
|
alibi_slopes,
|
|
deterministic,
|
|
return_attn_probs,
|
|
block_table,
|
|
torch.is_grad_enabled(),
|
|
)
|
|
|
|
|
|
def flash_attn_with_kvcache(
|
|
q,
|
|
k_cache,
|
|
v_cache,
|
|
k=None,
|
|
v=None,
|
|
rotary_cos=None,
|
|
rotary_sin=None,
|
|
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
|
|
cache_batch_idx: Optional[torch.Tensor] = None,
|
|
cache_leftpad: Optional[torch.Tensor] = None,
|
|
block_table: Optional[torch.Tensor] = None,
|
|
softmax_scale=None,
|
|
causal=False,
|
|
window_size=(-1, -1), # -1 means infinite context window
|
|
softcap=0.0, # 0.0 means deactivated
|
|
rotary_interleaved=True,
|
|
alibi_slopes=None,
|
|
num_splits=0,
|
|
return_softmax_lse=False,
|
|
):
|
|
"""
|
|
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
|
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
|
|
the previous step, and update them with the new keys/values from the current step, and do
|
|
attention with the updated cache, all in 1 kernel.
|
|
|
|
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
|
|
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
|
|
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
|
|
|
|
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
|
|
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
|
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
|
|
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
|
|
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
|
|
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
|
|
|
|
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
|
|
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
1 1 1 1 0
|
|
1 1 1 1 1
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
0 0
|
|
0 0
|
|
0 0
|
|
1 0
|
|
1 1
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
If window_size != (-1, -1), implements sliding window local attention. Query at position i
|
|
will only attend to keys between
|
|
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
|
|
|
|
Note: Does not support backward pass.
|
|
|
|
Arguments:
|
|
q: (batch_size, seqlen, nheads, headdim)
|
|
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
|
|
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
|
|
page_block_size must be a multiple of 256.
|
|
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
|
|
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
|
|
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
|
|
k with k_cache, starting at the indices specified by cache_seqlens.
|
|
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
|
|
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
|
|
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
|
|
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
|
|
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
|
|
KV cache.
|
|
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
|
|
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
|
|
If the indices are not distinct, and k and v are provided, the values updated in the cache
|
|
might come from any of the duplicate indices.
|
|
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
|
|
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
Default to 1 / sqrt(headdim).
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
|
|
softcap: float. Anything > 0 activates softcapping attention.
|
|
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
|
|
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
|
|
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
|
|
(i.e. GPT-NeoX style).
|
|
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
|
|
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
|
|
is added to the attention score of query i and key j.
|
|
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
|
|
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
|
to automatically determine the number of splits.
|
|
Don't change this unless you know what you are doing.
|
|
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
|
|
|
Return:
|
|
out: (batch_size, seqlen, nheads, headdim).
|
|
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
normalization factor).
|
|
"""
|
|
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
|
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
|
if softmax_scale is None:
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
|
if cache_seqlens is not None and isinstance(cache_seqlens, int):
|
|
cache_seqlens = torch.full(
|
|
(q.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
|
|
)
|
|
cache_seqlens = maybe_contiguous(cache_seqlens)
|
|
cache_batch_idx = maybe_contiguous(cache_batch_idx)
|
|
block_table = maybe_contiguous(block_table)
|
|
out, softmax_lse = flash_attn_gpu.fwd_kvcache(
|
|
q,
|
|
k_cache,
|
|
v_cache,
|
|
k,
|
|
v,
|
|
cache_seqlens,
|
|
rotary_cos,
|
|
rotary_sin,
|
|
cache_batch_idx,
|
|
cache_leftpad,
|
|
block_table,
|
|
alibi_slopes,
|
|
None,
|
|
softmax_scale,
|
|
causal,
|
|
window_size[0],
|
|
window_size[1],
|
|
softcap,
|
|
rotary_interleaved,
|
|
num_splits,
|
|
)
|
|
return (out, softmax_lse) if return_softmax_lse else out
|