mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-07-01 21:04:02 -04:00
remove benchmark result, undo changes to benchmark
This commit is contained in:
@@ -1,23 +0,0 @@
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(512,512), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.283ms, 941.6 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(1024,1024), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.428ms, 1204.3 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(2048,2048), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.711ms, 1354.1 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(4096,4096), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 1.133ms, 1455.6 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(512,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.208ms, 642.5 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(1024,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.277ms, 932.9 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(2048,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.403ms, 1195.5 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(4096,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.621ms, 1327.8 TFLOPS
|
||||
@@ -1,23 +0,0 @@
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(512,512), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.304ms, 876.9 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(1024,1024), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.442ms, 1166.3 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(2048,2048), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.723ms, 1330.6 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=symmetric(4096,4096), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 1.135ms, 1453.5 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(512,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.232ms, 574.9 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(1024,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.297ms, 869.6 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(2048,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.417ms, 1155.2 TFLOPS
|
||||
|
||||
### headdim = 128, causal = False, seqlen = 8192, batch_size = 4, nheads = 16, nheads_kv = 16, window=left(4096,0), varlen = False, deterministic = False ###
|
||||
FA Python fwd: 0.635ms, 1298.7 TFLOPS
|
||||
+94
-112
@@ -232,7 +232,7 @@ dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
device = 'cuda'
|
||||
verbose = True
|
||||
varlen = False
|
||||
has_backward = False
|
||||
has_backward = True
|
||||
page_size = None
|
||||
# page_size = 128
|
||||
softcap = 0.0
|
||||
@@ -263,11 +263,6 @@ time_b = {}
|
||||
# for headdim in [64, 96, 128]:
|
||||
# for headdim in [64, 128, 256]:
|
||||
# for headdim in [64, 96, 128, 192, 256]:
|
||||
# Local attention window sizes to test
|
||||
window_sizes_to_test = [512, 1024, 2048, 4096]
|
||||
# Window types: 'symmetric' for (w, w), 'left' for (w, 0)
|
||||
window_types_to_test = ['symmetric', 'left']
|
||||
|
||||
for headdim in [128]:
|
||||
# nheads = dim // headdim
|
||||
nheads = 32 if headdim <= 64 else 16 if headdim <= 192 else 8
|
||||
@@ -290,6 +285,10 @@ for headdim in [128]:
|
||||
|
||||
for batch_size, seqlen in bs_seqlen_vals:
|
||||
num_splits = 0
|
||||
# window_size = (-1, -1)
|
||||
window_size = (None, None)
|
||||
window_size_fa = (-1, -1)
|
||||
# window_size = (seqlen // 2 - 1, 0)
|
||||
pack_gqa = None
|
||||
# seqlen_q = 64
|
||||
seqlen_q = seqlen
|
||||
@@ -326,113 +325,96 @@ for headdim in [128]:
|
||||
else:
|
||||
page_table = None
|
||||
|
||||
# Only test causal=False for local attention
|
||||
for causal in [False]:
|
||||
for causal in [False, True]:
|
||||
# for causal in [True]:
|
||||
for window_type in window_types_to_test:
|
||||
for window_w in window_sizes_to_test:
|
||||
# Skip window sizes larger than sequence length
|
||||
if window_w >= seqlen:
|
||||
continue
|
||||
|
||||
# Set window size based on type
|
||||
if window_type == 'symmetric':
|
||||
window_size = (window_w, window_w)
|
||||
window_size_fa = (window_w, window_w)
|
||||
window_desc = f"symmetric({window_w},{window_w})"
|
||||
else: # left
|
||||
window_size = (window_w, 0)
|
||||
window_size_fa = (window_w, 0)
|
||||
window_desc = f"left({window_w},0)"
|
||||
|
||||
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, window={window_desc}, {varlen = }, {deterministic = } ###")
|
||||
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
|
||||
if cudnn is not None:
|
||||
# if False:
|
||||
if headdim <= 256 and dtype != torch.float8_e4m3fn:
|
||||
cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
|
||||
if has_backward and headdim == headdim_v:
|
||||
cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
|
||||
# if False:
|
||||
if not varlen:
|
||||
m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
|
||||
else:
|
||||
m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
|
||||
time_f[(causal, headdim, batch_size, seqlen, window_desc), "Flash2"] = m0.mean
|
||||
if has_backward:
|
||||
time.sleep(1)
|
||||
if not varlen:
|
||||
_, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
|
||||
repeats=repeats, verbose=False, desc='Fav2')
|
||||
else:
|
||||
_, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
|
||||
repeats=repeats, verbose=False, desc='Fav2')
|
||||
time_b[(causal, headdim, batch_size, seqlen, window_desc), "Flash2"] = m0b.mean
|
||||
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
|
||||
|
||||
if cudnn is not None:
|
||||
# if False:
|
||||
if headdim <= 256 and dtype != torch.float8_e4m3fn:
|
||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||
m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||
time_f[(causal, headdim, batch_size, seqlen, window_desc), "cuDNN"] = m2.mean
|
||||
if has_backward:
|
||||
time.sleep(1)
|
||||
m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||
time_b[(causal, headdim, batch_size, seqlen, window_desc), "cuDNN"] = m2b.mean
|
||||
# pytorch_profiler(cudnn_spda, backward=False)
|
||||
# pytorch_profiler(cudnn_spda_bwd, backward=False)
|
||||
print(f"\n### {headdim = }, {causal = }, {seqlen = }, {batch_size = }, {nheads = }, {nheads_kv = }, {varlen = }, {deterministic = } ###")
|
||||
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
|
||||
if cudnn is not None:
|
||||
# if False:
|
||||
if headdim <= 256 and dtype != torch.float8_e4m3fn:
|
||||
cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
|
||||
if has_backward and headdim == headdim_v:
|
||||
cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
|
||||
# if False:
|
||||
if not varlen:
|
||||
m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
|
||||
else:
|
||||
m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
|
||||
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean
|
||||
if has_backward:
|
||||
time.sleep(1)
|
||||
if flash_attn_func_v3 is not None:
|
||||
if not varlen:
|
||||
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
|
||||
else:
|
||||
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
|
||||
time_f[(causal, headdim, batch_size, seqlen, window_desc), "Flash3"] = m1.mean
|
||||
if flash_attn_func_python is not None:
|
||||
if not varlen:
|
||||
m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
|
||||
else:
|
||||
m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward:
|
||||
time.sleep(1)
|
||||
if not varlen:
|
||||
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3')
|
||||
else:
|
||||
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
|
||||
repeats=repeats, verbose=False, desc='Fav3')
|
||||
time_b[(causal, headdim, batch_size, seqlen, window_desc), "Flash3"] = m1b.mean
|
||||
time.sleep(1)
|
||||
# if not varlen:
|
||||
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)
|
||||
# else:
|
||||
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
|
||||
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
|
||||
if not varlen:
|
||||
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
|
||||
else:
|
||||
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
|
||||
if not varlen:
|
||||
_, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
|
||||
repeats=repeats, verbose=False, desc='Fav2')
|
||||
else:
|
||||
_, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
|
||||
repeats=repeats, verbose=False, desc='Fav2')
|
||||
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
|
||||
# pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True)
|
||||
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
|
||||
# if False:
|
||||
print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
|
||||
if has_backward:
|
||||
print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
|
||||
if cudnn is not None:
|
||||
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
|
||||
if has_backward:
|
||||
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
|
||||
if flash_attn_func_v3 is not None:
|
||||
print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
|
||||
print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
|
||||
if cudnn is not None:
|
||||
# if False:
|
||||
if headdim <= 256 and dtype != torch.float8_e4m3fn:
|
||||
time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark
|
||||
m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||
time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
|
||||
if has_backward:
|
||||
time.sleep(1)
|
||||
m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
|
||||
time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean
|
||||
# pytorch_profiler(cudnn_spda, backward=False)
|
||||
# pytorch_profiler(cudnn_spda_bwd, backward=False)
|
||||
time.sleep(1)
|
||||
if flash_attn_func_v3 is not None:
|
||||
if not varlen:
|
||||
# m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, cache_leftpad = leftpad_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||
m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||
# pytorch_profiler(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa)
|
||||
else:
|
||||
m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size_fa, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
|
||||
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits)
|
||||
time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
|
||||
if flash_attn_func_python is not None:
|
||||
if not varlen:
|
||||
m1_py = time_fwd(flash_attn_func_python, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, learnable_sink=sinks, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
|
||||
else:
|
||||
m1_py = time_fwd(flash_attn_varlen_func_python, q_unpad, k_unpad if page_size is None else k_paged, v_unpad if page_size is None else v_paged, cu_seqlens_q, cu_seqlens_k, page_table=page_table, causal=causal, window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3 python')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_v3 is not None and has_backward:
|
||||
time.sleep(1)
|
||||
if not varlen:
|
||||
_, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, softcap=softcap, repeats=repeats, verbose=False, desc='Fav3')
|
||||
else:
|
||||
_, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
|
||||
repeats=repeats, verbose=False, desc='Fav3')
|
||||
time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
|
||||
time.sleep(1)
|
||||
# if not varlen:
|
||||
# pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, deterministic=deterministic, backward=True)
|
||||
# else:
|
||||
# pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True)
|
||||
# benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func_python is not None and has_backward:
|
||||
if not varlen:
|
||||
_, m1b_py = benchmark_backward(flash_attn_func_python, q, k, v, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
|
||||
else:
|
||||
_, m1b_py = benchmark_backward(flash_attn_varlen_func_python, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, causal=causal, softcap=softcap, deterministic=deterministic, repeats=repeats, verbose=False, desc='Fav4 python')
|
||||
|
||||
if flash_attn_func_python is not None:
|
||||
print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
|
||||
print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and flash_attn_func is not None:
|
||||
# if False:
|
||||
print(f'FAv2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
|
||||
if has_backward:
|
||||
print(f'FAv2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
|
||||
if cudnn is not None:
|
||||
print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
|
||||
if has_backward:
|
||||
print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
|
||||
if flash_attn_func_v3 is not None:
|
||||
print(f'FAv3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
|
||||
print(f'FAv3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
|
||||
|
||||
if flash_attn_func_python is not None:
|
||||
print(f'FA Python fwd: {m1_py.mean * 1e3:.3f}ms, {(nFLOPS / m1_py.mean * 1e-12):.1f} TFLOPS')
|
||||
if dtype != torch.float8_e4m3fn and headdim == headdim_v and has_backward:
|
||||
print(f'FA Python bwd: {m1b_py.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b_py.mean * 1e-12):.1f} TFLOPS')
|
||||
|
||||
Reference in New Issue
Block a user