remove benchmark result, undo changes to benchmark

This commit is contained in:
Henry Tsang
2026-01-15 14:55:59 -08:00
parent a512bd8c7c
commit 2020964fc8
3 changed files with 94 additions and 158 deletions
-23
View File
@@ -1,23 +0,0 @@
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(512,512), varlen = False, deterministic = False ###
FA Python fwd: 0.283ms, 941.6 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(1024,1024), varlen = False, deterministic = False ###
FA Python fwd: 0.428ms, 1204.3 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(2048,2048), varlen = False, deterministic = False ###
FA Python fwd: 0.711ms, 1354.1 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(4096,4096), varlen = False, deterministic = False ###
FA Python fwd: 1.133ms, 1455.6 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(512,0), varlen = False, deterministic = False ###
FA Python fwd: 0.208ms, 642.5 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(1024,0), varlen = False, deterministic = False ###
FA Python fwd: 0.277ms, 932.9 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(2048,0), varlen = False, deterministic = False ###
FA Python fwd: 0.403ms, 1195.5 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(4096,0), varlen = False, deterministic = False ###
FA Python fwd: 0.621ms, 1327.8 TFLOPS
-23
View File
@@ -1,23 +0,0 @@
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(512,512), varlen = False, deterministic = False ###
FA Python fwd: 0.304ms, 876.9 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(1024,1024), varlen = False, deterministic = False ###
FA Python fwd: 0.442ms, 1166.3 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(2048,2048), varlen = False, deterministic = False ###
FA Python fwd: 0.723ms, 1330.6 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(4096,4096), varlen = False, deterministic = False ###
FA Python fwd: 1.135ms, 1453.5 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(512,0), varlen = False, deterministic = False ###
FA Python fwd: 0.232ms, 574.9 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(1024,0), varlen = False, deterministic = False ###
FA Python fwd: 0.297ms, 869.6 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(2048,0), varlen = False, deterministic = False ###
FA Python fwd: 0.417ms, 1155.2 TFLOPS
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(4096,0), varlen = False, deterministic = False ###
FA Python fwd: 0.635ms, 1298.7 TFLOPS
+94 -112
View File
@@ -232,7 +232,7 @@ dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
device = 'cuda'
verbose = True
varlen = False
has_backward = False
has_backward = True
page_size = None
# page_size = 128
softcap = 0.0
@@ -263,11 +263,6 @@ time_b = {}
# for headdim in [64, 96, 128]:
# for headdim in [64, 128, 256]:
# for headdim in [64, 96, 128, 192, 256]:
# Local attention window sizes to test
window_sizes_to_test = [512, 1024, 2048, 4096]
# Window types: 'symmetric' for (w, w), 'left' for (w, 0)
window_types_to_test = ['symmetric', 'left']
for headdim in [128]:
# nheads = dim // headdim
nheads = 32 if headdim <= 64 else 16 if headdim <= 192 else 8
@@ -290,6 +285,10 @@ for headdim in [128]:
for batch_size, seqlen in bs_seqlen_vals:
num_splits = 0
# window_size = (-1, -1)
window_size = (None, None)
window_size_fa = (-1, -1)
# window_size = (seqlen // 2 - 1, 0)
pack_gqa = None
# seqlen_q = 64
seqlen_q = seqlen
@@ -326,113 +325,96 @@ for headdim in [128]:
else:
page_table = None
# Only test causal=False for local attention
for causal in [False]:
for causal in [False, True]:
# for causal in [True]:
for window_type in window_types_to_test:
for window_w in window_sizes_to_test:
# Skip window sizes larger than sequence length
if window_w >= seqlen:
continue
# Set window size based on type
if window_type == 'symmetric':
window_size = (window_w, window_w)
window_size_fa = (window_w, window_w)
window_desc = f"symmetric({window_w},{window_w})"
else: # left
window_size = (window_w, 0)
window_size_fa = (window_w, 0)
window_desc = f"left({window_w},0)"
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, window={window_desc}, {varlen = }, {deterministic = } ###")
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn:
cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
if has_backward and headdim == headdim_v:
cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
# if False:
if not varlen:
m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
else:
m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
time_f[(causal, headdim, batch_size, seqlen, window_desc), "Flash2"] = m0.mean
if has_backward:
time.sleep(1)
if not varlen:
_, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav2')
else:
_, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav2')
time_b[(causal, headdim, batch_size, seqlen, window_desc), "Flash2"] = m0b.mean
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn:
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
time_f[(causal, headdim, batch_size, seqlen, window_desc), "cuDNN"] = m2.mean
if has_backward:
time.sleep(1)
m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
time_b[(causal, headdim, batch_size, seqlen, window_desc), "cuDNN"] = m2b.mean
# pytorch_profiler(cudnn_spda, backward=False)
# pytorch_profiler(cudnn_spda_bwd, backward=False)
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###")
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn:
cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
if has_backward and headdim == headdim_v:
cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
# if False:
if not varlen:
m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
else:
m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean
if has_backward:
time.sleep(1)
if flash_attn_func_v3 is not None:
if not varlen:
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
else:
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
time_f[(causal, headdim, batch_size, seqlen, window_desc), "Flash3"] = m1.mean
if flash_attn_func_python is not None:
if not varlen:
m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
else:
m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward:
time.sleep(1)
if not varlen:
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3')
else:
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav3')
time_b[(causal, headdim, batch_size, seqlen, window_desc), "Flash3"] = m1b.mean
time.sleep(1)
# if not varlen:
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)
# else:
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
if not varlen:
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
else:
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
if not varlen:
_, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav2')
else:
_, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav2')
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
# if False:
print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
if has_backward:
print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
if cudnn is not None:
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
if has_backward:
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
if flash_attn_func_v3 is not None:
print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
if cudnn is not None:
# if False:
if headdim <= 256 and dtype != torch.float8_e4m3fn:
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
if has_backward:
time.sleep(1)
m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean
# pytorch_profiler(cudnn_spda, backward=False)
# pytorch_profiler(cudnn_spda_bwd, backward=False)
time.sleep(1)
if flash_attn_func_v3 is not None:
if not varlen:
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
else:
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
if flash_attn_func_python is not None:
if not varlen:
m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
else:
m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward:
time.sleep(1)
if not varlen:
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3')
else:
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
repeats=repeats, verbose=False, desc='Fav3')
time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
time.sleep(1)
# if not varlen:
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)
# else:
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
if not varlen:
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
else:
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
if flash_attn_func_python is not None:
print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
# if False:
print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
if has_backward:
print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
if cudnn is not None:
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
if has_backward:
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
if flash_attn_func_v3 is not None:
print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
if flash_attn_func_python is not None:
print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS')
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS')