1261 Commits

Author SHA1 Message Date
Driss Guessous 72c7ba484d Fix Hopper tests (#2242) 2026-02-08 09:25:01 -08:00
Tri Dao 2a8d39c540 [DSL] warpgroup_reg_alloc -> setmaxregister_increase 2026-02-08 22:17:36 +07:00
Tri Dao 17d29436b8 [Bwd,Sm90] Have score_mod and score_mod_bwd as partial functions 2026-02-08 22:10:05 +07:00
Tri Dao deb183092b [Bwd,Sm100] Shorten PipelineTmaUmma create 2026-02-08 21:11:17 +07:00
Tri Dao c912a37d52 [Bwd,Sm90] Use quack.copy_utils 2026-02-08 18:57:20 +07:00
Tri Dao b9148cec6f [Layout] Use layout_utils.transpose_view and select from quack 2026-02-08 18:24:28 +07:00
Tri Dao 90f10faafd [DSL] Use cute.math.{exp2,log2,log} 2026-02-08 18:02:19 +07:00
Tri Dao 8dd8019cef [Layout] Use quack.layout_utils.mma_partition_C_vec 2026-02-08 17:05:48 +07:00
Tri Dao b735ef24c2 [Layout] Use reshape_acc_to_mn and reshape_acc_to_frgA from quack 2026-02-08 16:48:16 +07:00
Tri Dao 7edcf59c9e [DSL] Use cute.arch.warp_reduction_{max,sum} 2026-02-08 16:18:56 +07:00
Tri Dao 81f2c2dcdc [Sm90] Use functions from quack.sm90_utils 2026-02-08 15:52:49 +07:00
Tri Dao d39b6292bb [DSL] Remove coord_offset_i64, domain_offset_i64, elem_pointer_i64
Cute-dsl now supports i64 strides by default
2026-02-08 11:11:43 +07:00
Tri Dao 5a66f2cca3 [DSL]Replace utils.{fma,mul,add}_packed_f32x2 with cute.arch version 2026-02-08 11:07:39 +07:00
Tri Dao a804a5a3ef [DSL] Replace old fence with cute.arch.fence_view_async_shared() 2026-02-08 10:52:54 +07:00
Driss Guessous 48af662c53 pytest-dist round robin to gpus (#2241) 2026-02-07 19:50:48 -08:00
Driss Guessous abaa87875d [CUTE]Bump to Cutedsl (#2216)
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-07 18:32:44 -08:00
Alex Butler 912c6c4518 Add FLASH_ATTENTION_TRITON_AMD_CONFIG_JSON env var support (#2239)
* Add FLASH_ATTENTION_TRITON_AMD_CONFIG_JSON env var support

Allows users to override triton config when not autotuning.

* Add FLASH_ATTENTION_TRITON_AMD_CONFIG_JSON to readme

* Rename to FLASH_ATTENTION_FWD_TRITON_AMD_CONFIG_JSON
2026-02-07 08:13:18 -08:00
jayhshah f1284cff5d hdim 192 smem fix (#2235) 2026-02-05 08:22:36 -08:00
Luca Wehrstedt e2743ab5b3 [FA3] Mark current main version as v3.0.0 stable (#2223)
A collaboration between Flash-Attention, PyTorch and xFormers is trying to provide pre-built wheels for FA3 across as many platforms/environments as possible (e.g., ARM, Windows, CUDA 13, ...). To simplify the installation workflow, it would help to tag these packages as stable, but the current main version is tagged as beta.

FA3 hasn't received substantial updates in a while (the latest was a bugfix almost two months ago), and most new development is happening in FA4. Thus, in this PR, I propose we just claim that the current main version _is_ stable.

I have heard concerns that the feature set of FA3 doesn't currently match FA2 (e.g., dropout is missing). I think this concern is partly addressed by the fact that the new wheels will have a different name than the FA2 ones (`flash_attn_3` and `flash_attn` respectively), hence the former does _not_ claim to be a replacement for the latter, and the two can coexist (and they provide different modules).
2026-02-04 19:45:41 -08:00
Markus Hoehnerbach 24445c0c17 short readme for flex flash (#2231) 2026-02-04 18:30:56 -08:00
Jane (Yuan) Xu ef9e6a6441 Use TORCH_TARGET_VERSION over TORCH_STABLE_ONLY (#2155) 2026-02-04 13:58:49 -08:00
Driss Guessous 188643b82d Fix shared-memory race (#2229) 2026-02-04 12:33:20 -08:00
zhuochen 514e63cc26 fix compute_block_sparsity usage in benchmark_mask_mod (#2221) 2026-02-02 13:41:42 -08:00
Michael Melesse 701ebe0578 [AMD] Triton Backend for ROCm #3 (#2178)
* Fused Bwd (#137)

* Fused with Good perf and stride fixed

Fix fused bugs

isolate failing case

fix bug

bring back test cases

rm split impl in fused

use exp2 is global variable now

try oom fix

save

make fused the default

limit to reproduce failure

return default to split

fix head size bug

use exp2 back to true

* new grid

* BLK_SLICE_FACTOR = 1

* add tflops

* new commit

* test in parrallel

* strides added by jusson

* disable alibi

* fix bugs again

* default to fused

* add bwd options for varlen

* backend filter

* default to jingning and batch 4

* best fwd config

* fix TRITON_PRINT_AUTOTUNING flag bug

* tune

* Tuning fwd prefill

* add if else

* use flag

* Minor mask fix

* FLIP GRID

* use best config for default

* print when autotuning

* test bfloat16

* fix k and v stride bugs

* skip bfloat16

* test kvpacked

* disable internal tests

* pick default config based on arch

* Add alibi in the new bwd kernel (#139)

* enable alibi for jinging kernel

enable alibi for jinging kernel

match

* save bad configs

* fix alibi and causal bug

* disable autotune by default

* auto tune when benching is good

* set best config

* remove env var

* Update amd_tests.yml

* upgrad to triton==3.3.0

* increase shm

* use 64 x 64 for now

* save

* handle 1d alibi

* Add fp8 to fused kernel (#140)

* fp8 stuff

find test case

compute delta fp8

basic fp8 config passing

non causal path works

* isolate bad case

* fix fp8 bug

* didnot fix fp8 bug

* back to failing test

* fp8 tests passing

* skip

* skip ref tests

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>

* head, seq, batch (#141)

* Fix keys (#144)

* save

* rm keys

* fix keys

* use GHA_RENDER_DEVICES

* normal docker

* Pad LSE (#148)

* add round multiple

* fix fwd

* backward fix

* use rounded lse flag

* passing ROUNDED_LSE

* default is new rounded mode

* rename to fused_atmoics and fused_no_atomics

* add test for torch_compile

* add varlen torch compile test

* add old one kernel for ref

* fix varlen mismatch bug

* fix shape issue in varlen but mismatch

* sync torch compile kernel launch

* simple varlen test

* add debug code

* rm old

* ignore old impls

* DEBUG flag works in interface only

* ref uses the righ shape for lse

* rm oldest bwd kernel

* fix typo

* fix varlen bug

* fix bug. Get info from q for now

* simple shape and stride checkout

* add more tests

* test kvcache

* kvcache safe

* match case

* fix segfault due to bad return_softmax

* run bench

* run seperate for the main functions

* just output benchmark

* default csv format and time stamp files

* non verbsoe bench

* Sliding Window Forward (#151)

* Compress SWA work

test case

set up debug inputs

add fwd ref

one mask ref

fwd first pass

save

ref doesnot work for bigger seqlens

save new version

some causal cases failing

found bad cases

working new attn

new atten works

new attn_fwd works

reorg n_extra_tokens

use seqlen_delta_qk

ref fwd works

add sliding window to bwd ref

test kvcache

decode ref work with everything except sliding window

add debug code for 12 failing sliding window cases for decode

attention_decode_forward_ref_impl mostly works except for alibi

fix alibi in attention_decode_forward_ref_impl

ref works with normal, varlen & kvcache

move stuff around

figure out masking

old attn inner

two inner functions

remove load_fn

do Lk - Lq like ref

unify IS_CAUSAL code in epilogue

clean up

add args

rm inference stuff

simplify compute_masking

simpler compute mask

stub out returning front masking variables

remove pointer pass

compute ptrs inloop

compute block min and max

window stub inside inner mask loop

trying to use attn_fwd_mask causes issues

fix compiler bug when front masking

gen specifc types

add sliding window and debug statements

use identity for v

add more taste cases

add comments

save

use k_max_token for clarity

disable debug configs

basic NON-CAUSAL SLIDING WINDOW

non causal sliding window works on the all the shapes

non sliding window working in fwd

clean up fused bwd

seperate old fwd_prefill

move configs to utils.py

* fix bwd ref bug

* skip local cases so that fa output

* no sliding window causal green

* add backward test skip for sliding window

* clean reduce in fwd_kvcache. no is_CASUAL branching

* add kvcache masking

* kvcache working

* fix some bugs in test.py

* clean up

* Fix Device Segfault (#152)

* Compress segfault work

fix backward segfault

rework offset

ignore .profile

ignore .analysis

save

* assert the kernel launch device and tensor devices are the same

* fix failing asserts

* add asserts to fwd

* Fix SDMASK bug

* Log triton, torch and fa version

* Fix fp8 import issues

* fix docs (#154)

* Sliding Window block classification logic (#155)

* add aiter code

* remove aiter stuff

* sliding window non causal masking works

* causal and sliding window block masking

* extract common

* clean up typo

* helper for swa

* ignore .amd

* fix last block bug

* Enable FA V3 (#157)

* Compress PA work

narrow pa test

ref works on most cases

inplace ref with new_kv

inplace paged attention

add pa ref

save pa

basic  paged works

save

fix swa + causal in pa. Also new_kv only on pa path

passing

build fa v3

import interface from fa v3

copy fa tests

use v3 api

clean up

rename to match old test

support different head sizes

remove fp8

basisc passing v3 cases

test_flash_attn_varlen_output v3 working

isolate bad case for kvcache

case passing

save

use decode is seqused/ cacheseql is given

use decode if not varlen

basci kvcache v3 working

kvcache enable more cases

detect kvcache case if seqused_q is non and sequese_k is not None

skip failing test

find fp8 failing case

mha fp8 works

fix fp8 MQA/GQA bug

clean up

more clean up

clean up more

don't need fp8 dead code

remove train code with fp8 stuff

fp8 working in kvcache

paged + fp8 seems to be working

new_kv allowed

* clean up

* skip hopper race test

* clean up more

* fix paged + alibi

* similar inner paged api

* unify _attn_fwd_inner

* AITER integration (#159)

* clean up v2 interface

* assert fp8 scale shapes

* rotary working

* move rotary to impl layers

* remove einops

* enable rotarry in v3

* create interface

* fix descale assert

* unify bwd

* lint from aiter

* clean fp8 api

* add api change

* assert shapes for v2

* remove ref and bench.py

* remove metadata class and clean up

* bwd_prefill

* one bwd.py

* rename

* lint

* add bwd_change (#156)

* Tune FP8 Perf (#160)

* check cu count for gfx942

* create get_cu_count

* update repo root

* update forward tune

* clean up load

* use float8_e4m3fnuz

* save

* show bwd mode

* recommend fp8

* use torch.float32 for fp8 kernel

* add both best fp16 and fp8 config

* tune fp8 backward

* descale factors should be b, hk

* fp8 bwd working on all primus configs

* tune bwd configs

* fa v3 tests passing

* better warning

* clean up bwd launcher

* v3 passing

* tune more

* improve perf

* clean up

* lint

* clean

* start tuning gfx950

* tune non causal path

* fix bug

* save

* Skip configs where BLOCK_M2 % BLOCK_N2 != 0

* skip more

* stop tuning

* fix varlen bug

* fix dropout & causal/swa segfault

* update the to machine new changes

* save

* fix more bugs

* remove random seed

* clean up

* update readme

* print tensor stats for debug

* disable sliding window tests

* add rdna configs

* fix k partial bug

* fix block_size_n bug

* fix type check bug

---------

Co-authored-by: Aliasger Zaidy <aliasger.zaidy@amd.com>
Co-authored-by: Tianxing Wu <tianxing.wu@amd.com>
2026-01-28 07:49:08 -08:00
Tri Dao 99589e5a66 [DSL] Optionally patch cute-dsl to use system's ptxas 2026-01-27 21:09:36 +07:00
Driss Guessous 4f892461bb [Flex][SM100] Replay expand fix on sm100 (#2209)
stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2209, branch: drisspg/stack/6
2026-01-26 11:22:14 -08:00
Wang Lecheng 438325c2c3 Update README to include 'psutil' package as build requirement (#2210)
Added 'psutil' as a build requirement in the README.
2026-01-25 01:41:59 -08:00
oliver könig 57cef6c2e7 ci: Use 1 ninja job for cu13 (#2195)
Signed-off-by: oliver könig <okoenig@nvidia.com>
2026-01-23 19:04:16 -08:00
Driss Guessous 2580b5a488 [Cute][Flex] Allow q_offset 1 and add block-sizes to disambiguate edge cases (#2187) 2026-01-21 17:25:32 -08:00
Qubitium-ModelCloud f15ccf5ff2 reduce chance of build oom (#2079) 2026-01-21 02:36:22 -08:00
Kareem 04e6ee1fb5 [Cute, SM90] fix fwd varlen Cute implementation bug for H100 (#2194)
* fix

* same fix for bwd and SM80
2026-01-20 09:42:46 -08:00
Driss Guessous a0f9f418fd [Cute][Flex] Fix expanded tensor bug (#2189) 2026-01-16 20:47:10 -08:00
Markus Hoehnerbach 2d6b146893 Merge pull request #2185 from henrylhtsang/test_local_r2p
[Cute,Fwd,Sm100] Add r2p for local mask
2026-01-16 16:06:04 -08:00
Henry Tsang 137ad8e6e0 add comment 2026-01-16 09:39:15 -08:00
Henry Tsang 2e6ae05b4b added back clamp to avoid "OverflowError: Python int too large to convert to C long" 2026-01-16 09:26:19 -08:00
Henry Tsang e94012ac16 lint 2026-01-15 20:40:24 -08:00
Henry Tsang 94f034800e doc 2026-01-15 20:38:42 -08:00
Henry Tsang 08e65188b5 remove 24 clamp 2026-01-15 20:27:15 -08:00
henrylhtsang e34d84057d remove zero logic for right_s and left_s 2026-01-15 19:43:44 -08:00
Henry Tsang ac8885812e flip in_bound to out_bound 2026-01-15 18:19:25 -08:00
henrylhtsang e4ec1ad333 switch from xor to mask_right & ~ mask_left 2026-01-15 17:10:03 -08:00
henrylhtsang 7108d1c854 Add R2P dual bound masking for local attention
Add mask_r2p_dual_bound function using XOR of two bitmasks
to efficiently mask elements outside [col_limit_left, col_limit_right)
range for SM100 local attention.
2026-01-15 15:10:19 -08:00
Henry Tsang 2020964fc8 remove benchmark result, undo changes to benchmark 2026-01-15 14:55:59 -08:00
henrylhtsang a512bd8c7c Add R2P dual bound masking for local attention
Add mask_r2p_dual_bound function using XOR of two bitmasks
to efficiently mask elements outside [col_limit_left, col_limit_right)
range for SM100 local attention.
2026-01-15 14:51:33 -08:00
timmy-feng fffabc3de1 [Cute,Fwd,Sm100] distributed offset calculation for paged KV (#2104)
* fully shard paged KV address calculation across threads

* use t0 indices for static bound checking

* increase tiled copy to full KV row

* shrink predicate tensor

* clarify paged KV divisibility constraints

* increase load register allocation
2026-01-15 12:11:01 -08:00
henrylhtsang 88067b00de baseline local flops 2026-01-15 11:46:15 -08:00
Driss Guessous 68649fb784 [Cute][Flex]Add pack-gqa divmod (#2180) 2026-01-15 10:09:54 -08:00
Driss Guessous 506441a3fc [Cute][Flex] add back in contig (#2177) 2026-01-14 17:04:02 -08:00
jayhshah 13696f2e5e [Cute] update row_max before safe overwrite for online_softmax (#2174)
* update row_max before safe overwrite

* move up row_max_prev
2026-01-13 13:09:43 -08:00
Driss Guessous 4894657ee0 [Cute][Flex] Remove no longer needed contig (#2172) 2026-01-12 14:08:33 -08:00