mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-06-30 21:07:55 -04:00
FlashAttention-2 release
This commit is contained in:
+2
-2
@@ -1,3 +1,3 @@
|
||||
[submodule "csrc/flash_attn/cutlass"]
|
||||
path = csrc/flash_attn/cutlass
|
||||
[submodule "csrc/cutlass"]
|
||||
path = csrc/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
Tri Dao, trid@stanford.edu
|
||||
Dan Fu, danfu@cs.stanford.edu
|
||||
Tri Dao, trid@cs.stanford.edu
|
||||
@@ -2,8 +2,10 @@ recursive-include csrc *.cu
|
||||
recursive-include csrc *.h
|
||||
recursive-include csrc *.cuh
|
||||
recursive-include csrc *.cpp
|
||||
recursive-include csrc *.hpp
|
||||
|
||||
recursive-include flash_attn *.cu
|
||||
recursive-include flash_attn *.h
|
||||
recursive-include flash_attn *.cuh
|
||||
recursive-include flash_attn *.cpp
|
||||
recursive-include flash_attn *.hpp
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# FlashAttention
|
||||
This repository provides the official implementation of FlashAttention from the
|
||||
following paper.
|
||||
This repository provides the official implementation of FlashAttention and
|
||||
FlashAttention-2 from the
|
||||
following papers.
|
||||
|
||||
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
|
||||
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
|
||||
@@ -8,39 +9,22 @@ Paper: https://arxiv.org/abs/2205.14135
|
||||
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
|
||||

|
||||
|
||||
**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
|
||||
Tri Dao
|
||||
|
||||
Paper: https://tridao.me/publications/flash2/flash2.pdf
|
||||
|
||||

|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
We've been very happy to see FlashAttention being widely adopted in such a short
|
||||
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
|
||||
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
|
||||
contains a partial list of places where FlashAttention is being used.
|
||||
|
||||
## Full model code and training script
|
||||
|
||||
We have released the full GPT model
|
||||
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
|
||||
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
|
||||
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
|
||||
compared to the baseline implementation from Huggingface, reaching up to 189
|
||||
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
|
||||
any activation checkpointing).
|
||||
|
||||
We also include a training
|
||||
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
|
||||
train GPT2 on Openwebtext and GPT3 on The Pile.
|
||||
|
||||
## Triton implementation of FlashAttention
|
||||
|
||||
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
|
||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
|
||||
As Triton is a higher-level language than CUDA, it might be easier to understand
|
||||
and experiment with. The notations in the Triton implementation are also closer
|
||||
to what's used in our paper.
|
||||
|
||||
We also have an experimental implementation in Triton that support attention
|
||||
bias (e.g. ALiBi):
|
||||
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py
|
||||
|
||||
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
|
||||
Please cite and credit FlashAttention if you use it.
|
||||
|
||||
## Installation and features
|
||||
|
||||
@@ -53,125 +37,116 @@ We recommend the
|
||||
container from Nvidia, which has all the required tools to install FlashAttention.
|
||||
|
||||
To install:
|
||||
1. Make sure that PyTorch is installed.
|
||||
2. Make sure that `packaging` is installed (`pip install packaging`)
|
||||
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
||||
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
||||
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
||||
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
|
||||
compiling can take a very long time (2h) since it does not use multiple CPU
|
||||
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
||||
4. Then:
|
||||
```sh
|
||||
pip install flash-attn
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Alternatively you can compile from source:
|
||||
```
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Interface: `src/flash_attention.py`
|
||||
Interface: `src/flash_attention_interface.py`
|
||||
|
||||
To run the benchmark against PyTorch standard attention:
|
||||
```
|
||||
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
|
||||
```
|
||||
|
||||
FlashAttention currently supports:
|
||||
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
|
||||
2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
||||
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ...,
|
||||
128). Head dim > 64 backward requires A100 or H100.
|
||||
|
||||
Our tentative roadmap:
|
||||
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
|
||||
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
|
||||
3. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
|
||||
4. ~~[Jun 2022] Support bf16~~[Done].
|
||||
5. ~~[Jul 2022] Implement cross-attention~~[Done].
|
||||
6. ~~[Jul 2022] Support head dimension 128~~[Done].
|
||||
7. ~~[Aug 2022] Fuse rotary embedding~~[Done].
|
||||
8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
|
||||
FlashAttention-2 currently supports:
|
||||
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
|
||||
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
|
||||
GPUs for now.
|
||||
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
||||
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
|
||||
|
||||
|
||||
## How to use FlashAttention
|
||||
|
||||
Here's a simple example:
|
||||
```python
|
||||
import torch
|
||||
from flash_attn.flash_attention import FlashMHA
|
||||
|
||||
# Replace this with your correct GPU device
|
||||
device = "cuda:0"
|
||||
|
||||
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
|
||||
# and it includes the input and output linear layers
|
||||
flash_mha = FlashMHA(
|
||||
embed_dim=128, # total channels (= num_heads * head_dim)
|
||||
num_heads=8, # number of heads
|
||||
device=device,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
# Run forward pass with dummy data
|
||||
x = torch.randn(
|
||||
(64, 256, 128), # (batch, seqlen, embed_dim)
|
||||
device=device,
|
||||
dtype=torch.float16
|
||||
)
|
||||
|
||||
output = flash_mha(x)[0]
|
||||
The main functions implement scaled dot product attention (softmax(Q @ K^T *
|
||||
softmax_scale) @ V):
|
||||
```
|
||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
||||
```
|
||||
|
||||
Alternatively, you can import the inner attention layer only (so that the input
|
||||
and output linear layers are not included):
|
||||
```python
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
|
||||
# Create the nn.Module
|
||||
flash_attention = FlashAttention()
|
||||
```
|
||||
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of Q, K, V.
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
```
|
||||
|
||||
Or, if you need more fine-grained control, you can import one of the lower-level
|
||||
functions (this is more similar to the `torch.nn.functional` style):
|
||||
```python
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
```
|
||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
# or
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
|
||||
|
||||
# etc.
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
```
|
||||
|
||||
There are also separate Python files with various FlashAttention extensions:
|
||||
```python
|
||||
# Import the triton implementation (torch.nn.functional version only)
|
||||
from flash_attn.flash_attn_triton import flash_attn_func
|
||||
To see how these functions are used in a multi-head attention layer (which
|
||||
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).
|
||||
|
||||
# Import block sparse attention (nn.Module version)
|
||||
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
|
||||
## Upgrading from FlashAttention (1.x) to FlashAttention-2
|
||||
|
||||
# Import block sparse attention (torch.nn.functional version)
|
||||
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
|
||||
These functions have been renamed:
|
||||
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
|
||||
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
|
||||
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`
|
||||
|
||||
If the inputs have the same sequence lengths in the same batch, it is simpler
|
||||
and faster to use these functions:
|
||||
```
|
||||
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
|
||||
```
|
||||
```
|
||||
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
|
||||
```
|
||||
|
||||
## Speedup and Memory Savings
|
||||
## Performance
|
||||
|
||||
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
|
||||
|
||||
We currently have benchmarks for these GPUs:
|
||||
* [A100](#a100)
|
||||
* [RTX 3090](#rtx-3090)
|
||||
* [T4](#t4)
|
||||
* [H100](#h100)
|
||||
<!-- * [RTX 3090](#rtx-3090) -->
|
||||
<!-- * [T4](#t4) -->
|
||||
|
||||
### A100
|
||||
|
||||
We display FlashAttention speedup using these parameters (similar to BERT-base):
|
||||
* Batch size 8
|
||||
* Head dimension 64
|
||||
* 12 attention heads
|
||||
|
||||
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
|
||||
We display FlashAttention speedup using these parameters:
|
||||
* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
|
||||
* Sequence length 512, 1k, 2k, 4k, 8k, 16k.
|
||||
* Batch size set to 16k / seqlen.
|
||||
|
||||
#### Speedup
|
||||
|
||||

|
||||
|
||||
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
|
||||
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
|
||||

|
||||
|
||||
#### Memory
|
||||
|
||||
@@ -182,38 +157,37 @@ Memory savings are proportional to sequence length -- since standard attention h
|
||||
We see 10X memory savings at sequence length 2K, and 20X at 4K.
|
||||
As a result, FlashAttention can scale to much longer sequence lengths.
|
||||
|
||||
#### Head Dimension 128
|
||||
### H100
|
||||
|
||||

|
||||

|
||||
|
||||
We show speedup with head dimension 128.
|
||||
Here we show batch size 16 with 12 heads.
|
||||
Speedup is less than with the smaller head sizes, since we have to make the block size smaller in the tiling.
|
||||
But speedup is still significant, especially with a causal mask.
|
||||
## Full model code and training script
|
||||
|
||||
### RTX 3090
|
||||
We have released the full GPT model
|
||||
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
|
||||
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
|
||||
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
|
||||
compared to the baseline implementation from Huggingface, reaching up to 225
|
||||
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
|
||||
any activation checkpointing).
|
||||
|
||||
For the RTX 3090, we use batch size 12 with 12 attention heads.
|
||||
Memory savings are the same as on an A100, so we'll only show speedup here.
|
||||
We also include a training
|
||||
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
|
||||
train GPT2 on Openwebtext and GPT3 on The Pile.
|
||||
|
||||

|
||||
## Triton implementation of FlashAttention
|
||||
|
||||
We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s).
|
||||
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
|
||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
|
||||
### T4
|
||||
As Triton is a higher-level language than CUDA, it might be easier to understand
|
||||
and experiment with. The notations in the Triton implementation are also closer
|
||||
to what's used in our paper.
|
||||
|
||||
We again use batch size 12 with 12 attention heads.
|
||||
We also have an experimental implementation in Triton that support attention
|
||||
bias (e.g. ALiBi):
|
||||
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
|
||||
|
||||

|
||||
|
||||
T4 SRAM is smaller than the newer GPUs (64 KB), so we see less speedup (we need to make the block sizes smaller, so we end up doing more R/W).
|
||||
This matches the IO complexity analysis from section 3.2 of [our paper](https://arxiv.org/abs/2205.14135).
|
||||
|
||||
T4 GPUs are commonly used for inference, so we also measure speedup on the forward pass only (note that these are not directly comparable to the graphs above):
|
||||
|
||||

|
||||
|
||||
We see speedups between 2.5x-4.5x on the forward pass.
|
||||
|
||||
## Tests
|
||||
We test that FlashAttention produces the same output and gradient as a reference
|
||||
@@ -228,21 +202,10 @@ pytest -q -s tests/test_flash_attn.py
|
||||
```
|
||||
## When you encounter issues
|
||||
|
||||
This alpha release of FlashAttention contains code written for a research
|
||||
project to validate ideas on speeding up attention.
|
||||
We have tested it on several models (BERT, GPT2, ViT).
|
||||
However, there might still be bugs in the implementation that we hope to iron
|
||||
out in the next few months.
|
||||
This new release of FlashAttention-2 have been tested on several GPT-style
|
||||
models, mostly on A100 GPUs.
|
||||
|
||||
If you encounter any of these bugs, please open a respective GitHub Issue!
|
||||
|
||||
## Acknowledgments
|
||||
Our implementation uses Apex's
|
||||
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
|
||||
as a starting point.
|
||||
|
||||
We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
|
||||
and for his thoughtful answers to our questions about CUDA.
|
||||
If you encounter any of bugs, please open a respective GitHub Issue!
|
||||
|
||||
## Citation
|
||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
||||
@@ -253,4 +216,9 @@ If you use this codebase, or otherwise found our work valuable, please cite:
|
||||
booktitle={Advances in Neural Information Processing Systems},
|
||||
year={2022}
|
||||
}
|
||||
@article{dao2023flashattention2,
|
||||
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
|
||||
author={Dao, Tri},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 369 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 308 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 2.6 MiB |
+145
-21
@@ -6,11 +6,21 @@ import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
# from flash_attn.triton.fused_attention import attention as attention
|
||||
from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
|
||||
from flash_attn.flash_attn_triton_og import attention as attention_og
|
||||
# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
||||
from src.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
# # from flash_attn.triton.fused_attention import attention as attention
|
||||
# from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
|
||||
# from flash_attn.flash_attn_triton_og import attention as attention_og
|
||||
|
||||
# from triton.ops.flash_attention import attention as attention_triton
|
||||
|
||||
try:
|
||||
from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func
|
||||
from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func
|
||||
except ImportError:
|
||||
fav2_qkvpacked_func = None
|
||||
fav2_kvpacked_func = None
|
||||
|
||||
try:
|
||||
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
|
||||
@@ -71,16 +81,18 @@ def attention_megatron(qkv):
|
||||
torch.manual_seed(0)
|
||||
repeats = 30
|
||||
batch_size = 2
|
||||
seqlen = 4096
|
||||
seqlen = 8192
|
||||
nheads = 12
|
||||
headdim = 128
|
||||
# nheads = 24
|
||||
# headdim = 64
|
||||
# batch_size = 64
|
||||
# seqlen = 512
|
||||
# nheads = 8
|
||||
# headdim = 128
|
||||
dropout_p = 0.0
|
||||
causal = True
|
||||
dtype = torch.bfloat16
|
||||
dropout_p = 0.1
|
||||
causal = False
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
|
||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
||||
@@ -88,18 +100,130 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
|
||||
benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'),
|
||||
cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
|
||||
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
||||
repeats=repeats, desc='PyTorch Attention')
|
||||
# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
||||
# benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
||||
# cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
|
||||
# pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad,
|
||||
# cu_seqlens, seqlen, dropout_p, causal=causal, backward=True)
|
||||
# if fav2_qkvpacked_func is not None:
|
||||
# benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
||||
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
|
||||
|
||||
benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
|
||||
pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
|
||||
# for dropout_p in [0.1, 0.0]:
|
||||
# for causal in [False, True]:
|
||||
# print(f"### {dropout_p = }, {causal = } ###")
|
||||
# pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True)
|
||||
|
||||
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
|
||||
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
|
||||
# nheads_k = 2
|
||||
# q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True)
|
||||
# kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype,
|
||||
# requires_grad=True)
|
||||
# if fav2_kvpacked_func is not None:
|
||||
# benchmark_all(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, repeats=repeats, desc='Fav2')
|
||||
# pytorch_profiler(fav2_kvpacked_func, q, kv, dropout_p, causal=causal, backward=True)
|
||||
|
||||
if scaled_upper_triang_masked_softmax is not None:
|
||||
benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
|
||||
# dropout_p = 0.0
|
||||
# causal = False
|
||||
# benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
||||
# repeats=repeats, desc='PyTorch Attention')
|
||||
|
||||
# benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
|
||||
# pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
|
||||
|
||||
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||
# requires_grad=True) for _ in range(3)]
|
||||
# benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
|
||||
# # pytorch_profiler(attention, q, k, v, 1.0, backward=True)
|
||||
|
||||
# if scaled_upper_triang_masked_softmax is not None:
|
||||
# benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
|
||||
|
||||
# from src.ops.fftconv import fftconv_func
|
||||
|
||||
# dim = nheads * headdim
|
||||
# u = torch.randn(batch_size, dim, seqlen, device=device, dtype=dtype, requires_grad=True)
|
||||
# k = torch.randn(dim, seqlen, device=device, requires_grad=True)
|
||||
# D = torch.randn(dim, device=device, requires_grad=True)
|
||||
# benchmark_all(fftconv_func, u, k, D, repeats=repeats, desc='FFTConv')
|
||||
# pytorch_profiler(fftconv_func, u, k, D, backward=True)
|
||||
# pytorch_profiler(torch.fft.rfft, u.float())
|
||||
|
||||
flops = 4 * batch_size * seqlen ** 2 * nheads * headdim
|
||||
ideal_a100_time = flops / 312 / 1e9
|
||||
print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms")
|
||||
|
||||
|
||||
def time_fwd_bwd(func, *args, **kwargs):
|
||||
time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
|
||||
return time_f[1].mean, time_b[1].mean
|
||||
|
||||
bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
|
||||
causal_vals = [False, True]
|
||||
headdim_vals = [64, 128]
|
||||
dim = 2048
|
||||
dropout_p = 0.0
|
||||
|
||||
time_f = {}
|
||||
time_b = {}
|
||||
for causal in causal_vals:
|
||||
for headdim in headdim_vals:
|
||||
for batch_size, seqlen in bs_seqlen_vals:
|
||||
nheads = dim // headdim
|
||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True)
|
||||
f, b = time_fwd_bwd(
|
||||
flash_attn_varlen_qkvpacked_func, qkv_unpad, cu_seqlens, seqlen, dropout_p,
|
||||
causal=causal, repeats=repeats, verbose=False
|
||||
)
|
||||
time_f[(causal, headdim, batch_size, seqlen), "Flash"] = f
|
||||
time_b[(causal, headdim, batch_size, seqlen), "Flash"] = b
|
||||
|
||||
qkv = qkv.detach().requires_grad_(True)
|
||||
f, b = time_fwd_bwd(
|
||||
fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
||||
)
|
||||
time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = f
|
||||
time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = b
|
||||
|
||||
# q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||
# requires_grad=True) for _ in range(3)]
|
||||
# # Try both values of sequence_parallel and pick the faster one
|
||||
# f, b = time_fwd_bwd(
|
||||
# attention_triton, q, k, v, causal, headdim**(-0.5),
|
||||
# False, repeats=repeats, verbose=False
|
||||
# )
|
||||
# _, b0 = time_fwd_bwd(
|
||||
# attention_triton, q, k, v, causal, headdim**(-0.5),
|
||||
# True, repeats=repeats, verbose=False
|
||||
# )
|
||||
# time_f[(causal, headdim, batch_size, seqlen), "Triton"] = f
|
||||
# time_b[(causal, headdim, batch_size, seqlen), "Triton"] = min(b, b0)
|
||||
|
||||
if seqlen <= 8 * 1024:
|
||||
qkv = qkv.detach().requires_grad_(True)
|
||||
f, b = time_fwd_bwd(
|
||||
attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, verbose=False
|
||||
)
|
||||
else:
|
||||
f, b = float('nan'), float('nan')
|
||||
time_f[(causal, headdim, batch_size, seqlen), "Pytorch"] = f
|
||||
time_b[(causal, headdim, batch_size, seqlen), "Pytorch"] = b
|
||||
|
||||
# q, k, v = [torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype,
|
||||
# requires_grad=True) for _ in range(3)]
|
||||
# import xformers.ops as xops
|
||||
# f, b = time_fwd_bwd(
|
||||
# xops.memory_efficient_attention, q, k, v,
|
||||
# attn_bias=xops.LowerTriangularMask() if causal else None,
|
||||
# op=(xops.fmha.cutlass.FwOp, xops.fmha.cutlass.BwOp)
|
||||
# )
|
||||
# time_f[(causal, headdim, batch_size, seqlen), "xformers"] = f
|
||||
# time_b[(causal, headdim, batch_size, seqlen), "xformers"] = b
|
||||
|
||||
|
||||
import pickle
|
||||
with open('flash2_attn_time_h100.plk', 'wb') as fp:
|
||||
pickle.dump((time_f, time_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@@ -8,7 +8,7 @@ from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
|
||||
|
||||
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
|
||||
@@ -62,7 +62,7 @@ qkv_unpad = rearrange(Wqkv(x_unpad), 'nnz (t h d) -> nnz t h d', t=3,
|
||||
h=nheads).detach().requires_grad_()
|
||||
qkv = rearrange(Wqkv(x), 'b s (t h d) -> b s t h d', t=3, h=nheads).detach().requires_grad_()
|
||||
|
||||
fn = lambda qkv_unpad: flash_attn_unpadded_qkvpacked_func(
|
||||
fn = lambda qkv_unpad: flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad, cu_seqlens, max_seqlen_in_batch, dropout_p, causal=causal
|
||||
)
|
||||
benchmark_all(fn, qkv_unpad, repeats=repeats, desc='FlashAttention')
|
||||
|
||||
Submodule
+1
Submodule csrc/cutlass added at c4f6b8c6bc
Submodule csrc/flash_attn/cutlass deleted from 319a389f42
@@ -0,0 +1,912 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include "flash.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
|
||||
void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t seqlen_q_rounded,
|
||||
const size_t seqlen_k_rounded,
|
||||
const size_t h,
|
||||
const size_t h_k,
|
||||
const size_t d,
|
||||
const size_t d_rounded,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
at::Tensor out,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *p_d,
|
||||
void *softmax_lse_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = k.data_ptr();
|
||||
params.v_ptr = v.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = k.stride(-3);
|
||||
params.v_row_stride = v.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = k.stride(-2);
|
||||
params.v_head_stride = v.stride(-2);
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
if (cu_seqlens_q_d == nullptr) {
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = k.stride(0);
|
||||
params.v_batch_stride = v.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
}
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
|
||||
// P = softmax(QK^T)
|
||||
params.p_ptr = p_d;
|
||||
|
||||
// Softmax sum
|
||||
params.softmax_lse_ptr = softmax_lse_d;
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.h_k = h_k;
|
||||
params.h_h_k_ratio = h / h_k;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.seqlen_k = seqlen_k;
|
||||
params.seqlen_q_rounded = seqlen_q_rounded;
|
||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||
params.d = d;
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
||||
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
||||
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
|
||||
params.is_causal = is_causal;
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t seqlen_q_rounded,
|
||||
const size_t seqlen_k_rounded,
|
||||
const size_t h,
|
||||
const size_t h_k,
|
||||
const size_t d,
|
||||
const size_t d_rounded,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor out,
|
||||
const at::Tensor dout,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *dq_accum_d,
|
||||
void *dk_accum_d,
|
||||
void *dv_accum_d,
|
||||
void *softmax_lse_d,
|
||||
void *dsoftmax_sum_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k_d,
|
||||
nullptr,
|
||||
softmax_lse_d,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.do_ptr = dout.data_ptr();
|
||||
params.do_row_stride = dout.stride(-3);
|
||||
params.do_head_stride = dout.stride(-2);
|
||||
params.dq_ptr = dq.data_ptr();
|
||||
params.dk_ptr = dk.data_ptr();
|
||||
params.dv_ptr = dv.data_ptr();
|
||||
params.dq_row_stride = dq.stride(-3);
|
||||
params.dk_row_stride = dk.stride(-3);
|
||||
params.dv_row_stride = dv.stride(-3);
|
||||
params.dq_head_stride = dq.stride(-2);
|
||||
params.dk_head_stride = dk.stride(-2);
|
||||
params.dv_head_stride = dv.stride(-2);
|
||||
|
||||
if (cu_seqlens_q_d == nullptr) {
|
||||
params.do_batch_stride = dout.stride(0);
|
||||
params.dq_batch_stride = dq.stride(0);
|
||||
params.dk_batch_stride = dk.stride(0);
|
||||
params.dv_batch_stride = dv.stride(0);
|
||||
}
|
||||
|
||||
params.dq_accum_ptr = dq_accum_d;
|
||||
params.dk_accum_ptr = dk_accum_d;
|
||||
params.dv_accum_ptr = dv_accum_d;
|
||||
|
||||
// Softmax sum
|
||||
params.dsoftmax_sum = dsoftmax_sum_d;
|
||||
}
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
}
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
} else {
|
||||
out = torch::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
||||
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
|
||||
}
|
||||
|
||||
Flash_fwd_params params;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
seqlen_q, seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q_padded, k_padded, v_padded, out,
|
||||
/*cu_seqlens_q_d=*/nullptr,
|
||||
/*cu_seqlens_k_d=*/nullptr,
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
|
||||
at::Tensor out_padded = out;
|
||||
if (head_size_og % 8 != 0) {
|
||||
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int num_heads = sizes[1];
|
||||
const int head_size_og = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
}
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
|
||||
} else {
|
||||
out = torch::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
||||
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
|
||||
}
|
||||
|
||||
if (zero_tensors) {
|
||||
out.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_softmax) {p.zero_();}
|
||||
}
|
||||
|
||||
Flash_fwd_params params;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q_padded, k_padded, v_padded, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
|
||||
at::Tensor out_padded = out;
|
||||
if (head_size_og % 8 != 0) {
|
||||
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
|
||||
}
|
||||
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
if (params.d <= 32) {
|
||||
run_mha_bwd_<elem_type, 32>(params, stream, configure);
|
||||
} else if (params.d <= 64) {
|
||||
run_mha_bwd_<elem_type, 64>(params, stream, configure);
|
||||
} else if (params.d <= 96) {
|
||||
run_mha_bwd_<elem_type, 96>(params, stream, configure);
|
||||
} else if (params.d <= 128) {
|
||||
run_mha_bwd_<elem_type, 128>(params, stream, configure);
|
||||
} else if (params.d <= 160) {
|
||||
run_mha_bwd_<elem_type, 160>(params, stream, configure);
|
||||
} else if (params.d <= 192) {
|
||||
run_mha_bwd_<elem_type, 192>(params, stream, configure);
|
||||
} else if (params.d <= 224) {
|
||||
run_mha_bwd_<elem_type, 224>(params, stream, configure);
|
||||
} else if (params.d <= 256) {
|
||||
run_mha_bwd_<elem_type, 256>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
||||
TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
||||
TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
const int head_size_og = dout.size(3);
|
||||
const int head_size = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||
if (head_size > 192) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
|
||||
}
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
|
||||
at::Tensor dq, dk, dv;
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
||||
} else {
|
||||
dq = torch::empty_like(q);
|
||||
}
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dk = torch::empty_like(k);
|
||||
}
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dv = torch::empty_like(k);
|
||||
}
|
||||
|
||||
at::Tensor dout_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
dout_padded = dout;
|
||||
}
|
||||
|
||||
// bool loop = seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
at::Tensor dk_accum, dv_accum;
|
||||
if (loop) {
|
||||
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
||||
dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
||||
} else {
|
||||
dk_expanded = dk;
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
Flash_bwd_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
seqlen_q, seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q, k, v, out,
|
||||
dout_padded, dq, dk_expanded, dv_expanded,
|
||||
nullptr,
|
||||
nullptr,
|
||||
loop ? dq_accum.data_ptr() : nullptr,
|
||||
// loop ? dk_accum.data_ptr() : nullptr,
|
||||
// loop ? dv_accum.data_ptr() : nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
|
||||
if (is_dropout) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
||||
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
||||
}
|
||||
if (head_size_og % 8 != 0) {
|
||||
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
c10::optional<at::Generator> gen_
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == torch::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
||||
TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
||||
TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int num_heads = sizes[1];
|
||||
const int head_size_og = dout.size(2);
|
||||
const int head_size = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||
if (head_size > 192) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
|
||||
}
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor dq, dk, dv;
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||
} else {
|
||||
dq = torch::empty_like(q);
|
||||
}
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dk = torch::empty_like(k);
|
||||
}
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dv = torch::empty_like(k);
|
||||
}
|
||||
|
||||
at::Tensor dout_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
|
||||
} else {
|
||||
dout_padded = dout;
|
||||
}
|
||||
|
||||
// bool loop = max_seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
if (loop) {
|
||||
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
|
||||
dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
|
||||
} else {
|
||||
dk_expanded = dk;
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
if( zero_tensors ) {
|
||||
dq.zero_();
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
Flash_bwd_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q, k, v, out,
|
||||
dout_padded, dq, dk_expanded, dv_expanded,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? dq_accum.data_ptr() : nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
|
||||
if (is_dropout) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
|
||||
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
|
||||
}
|
||||
if (head_size_og % 8 != 0) {
|
||||
dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "FlashAttention";
|
||||
m.def("fwd", &mha_fwd, "Forward pass");
|
||||
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
|
||||
m.def("bwd", &mha_bwd, "Backward pass");
|
||||
m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
|
||||
}
|
||||
@@ -1,796 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "fmha.h"
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
|
||||
void set_params_fprop(FMHA_fprop_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t h,
|
||||
const size_t d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
at::Tensor out,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *o_tmp_d,
|
||||
void *s_d,
|
||||
void *softmax_lse_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal,
|
||||
int num_splits) {
|
||||
|
||||
Data_type acc_type = DATA_TYPE_FP32;
|
||||
Data_type data_type = !(q.dtype() == torch::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = k.data_ptr();
|
||||
params.v_ptr = v.data_ptr();
|
||||
params.q_row_stride_in_elts = q.stride(0);
|
||||
params.k_row_stride_in_elts = k.stride(0);
|
||||
params.v_row_stride_in_elts = v.stride(0);
|
||||
params.q_head_stride_in_elts = q.stride(1);
|
||||
params.k_head_stride_in_elts = k.stride(1);
|
||||
params.v_head_stride_in_elts = v.stride(1);
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.o_row_stride_in_elts = out.stride(0);
|
||||
params.o_head_stride_in_elts = out.stride(1);
|
||||
params.o_tmp_ptr = o_tmp_d;
|
||||
params.o_tmp_row_stride_in_elts = h * d;
|
||||
params.o_tmp_head_stride_in_elts = d;
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
|
||||
// S = softmax(P)
|
||||
params.s_ptr = s_d;
|
||||
params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type);
|
||||
|
||||
// Softmax sum
|
||||
params.softmax_lse_ptr = softmax_lse_d;
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.seqlen_k = seqlen_k;
|
||||
params.d = d;
|
||||
|
||||
// Set the different scale values.
|
||||
// const float scale_bmm1 = 1.f / sqrtf(d);
|
||||
const float scale_bmm1 = softmax_scale;
|
||||
|
||||
params.scale_bmm1f = scale_bmm1;
|
||||
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
||||
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
||||
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
|
||||
|
||||
params.is_causal = is_causal;
|
||||
params.num_splits = num_splits;
|
||||
}
|
||||
|
||||
void set_params_dgrad(FMHA_dgrad_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t h,
|
||||
const size_t d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor out,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *dq_tmp_d,
|
||||
void *do_packed_d,
|
||||
void *softmax_lse_d,
|
||||
void *dsoftmax_sum_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal,
|
||||
int num_splits) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, h, d,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k_d,
|
||||
dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp
|
||||
nullptr,
|
||||
softmax_lse_d,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.dq_ptr = dq.data_ptr();
|
||||
params.dk_ptr = dk.data_ptr();
|
||||
params.dv_ptr = dv.data_ptr();
|
||||
params.dq_row_stride_in_elts = dq.stride(0);
|
||||
params.dk_row_stride_in_elts = dk.stride(0);
|
||||
params.dv_row_stride_in_elts = dv.stride(0);
|
||||
params.dq_head_stride_in_elts = dq.stride(1);
|
||||
params.dk_head_stride_in_elts = dk.stride(1);
|
||||
params.dv_head_stride_in_elts = dv.stride(1);
|
||||
params.do_ptr = do_packed_d;
|
||||
|
||||
// Softmax sum
|
||||
params.dsoftmax_sum = dsoftmax_sum_d;
|
||||
}
|
||||
|
||||
void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
if (launch_params.params.d <= 32) {
|
||||
run_fmha_fwd_hdim32(launch_params);
|
||||
} else if (launch_params.params.d <= 64) {
|
||||
run_fmha_fwd_hdim64(launch_params);
|
||||
} else if (launch_params.params.d <= 128) {
|
||||
run_fmha_fwd_hdim128(launch_params);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
const int num_splits,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(out.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int blocksize_c = head_size > 64 ? 128 : 256;
|
||||
// Need to round max_seqlen_k to multiples of blocksize_c
|
||||
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
|
||||
if( max_seqlen_k_ <= 128 ) {
|
||||
max_seqlen_k = 128;
|
||||
} else if( max_seqlen_k_ <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
|
||||
|
||||
at::Tensor o_tmp;
|
||||
if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
|
||||
|
||||
auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
// auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
|
||||
|
||||
at::Tensor s;
|
||||
if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
|
||||
|
||||
if( zero_tensors ) {
|
||||
out.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_softmax) {s.zero_();}
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
set_params_fprop(launch_params.params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? o_tmp.data_ptr() : nullptr,
|
||||
return_softmax ? s.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
// Forward kernel will populate memory with the seed and offset.
|
||||
launch_params.params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
run_fmha_fwd(launch_params);
|
||||
|
||||
std::vector<at::Tensor> result = {softmax_lse};
|
||||
result.push_back(rng_state);
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
return result;
|
||||
}
|
||||
|
||||
void run_fmha_bwd(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (params.d <= 32) {
|
||||
run_fmha_bwd_hdim32(params, stream, configure);
|
||||
} else if (params.d <= 64) {
|
||||
run_fmha_bwd_hdim64(params, stream, configure);
|
||||
} else if (params.d <= 128) {
|
||||
run_fmha_bwd_hdim128(params, stream, configure);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp
|
||||
at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int num_splits,
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||
auto launch = &run_fmha_bwd;
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
TORCH_CHECK(dout.dtype() == q_dtype);
|
||||
TORCH_CHECK(dq.dtype() == q_dtype);
|
||||
TORCH_CHECK(dk.dtype() == q_dtype);
|
||||
TORCH_CHECK(dv.dtype() == q_dtype);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(dout.is_cuda());
|
||||
TORCH_CHECK(softmax_lse_.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(dout.is_contiguous());
|
||||
TORCH_CHECK(dq.stride(-1) == 1);
|
||||
TORCH_CHECK(dk.stride(-1) == 1);
|
||||
TORCH_CHECK(dv.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||
if (head_size > 64) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 64 requires A100 or H100 GPUs as the implementation needs a large amount of shared memory.");
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dk, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(dv, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int blocksize_c = (head_size > 64 || (is_sm75 && head_size > 32)) ? 128 : 256;
|
||||
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
|
||||
if( max_seqlen_k_ <= 128 ) {
|
||||
max_seqlen_k = 128;
|
||||
} else if( max_seqlen_k_ <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
|
||||
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_tmp;
|
||||
if (loop) { dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
|
||||
|
||||
if( zero_tensors ) {
|
||||
dq.zero_();
|
||||
dk.zero_();
|
||||
dv.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
FMHA_dgrad_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v, out,
|
||||
dq, dk, dv,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? dq_tmp.data_ptr() : nullptr,
|
||||
dout.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
launch(params, stream, /*configure=*/true);
|
||||
|
||||
if (params.num_splits > 1) {
|
||||
if (!dq_tmp.defined()) {
|
||||
dq_tmp = torch::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
|
||||
params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass
|
||||
} else {
|
||||
dq_tmp.zero_();
|
||||
}
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
if ( rng_state.has_value() ) {
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
} else if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
if (params.num_splits > 1) {
|
||||
dq.copy_(dq_tmp);
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, total := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const at::Tensor &blockmask, // (seqlen / 256, seqlen / 16)
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm90);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
|
||||
TORCH_CHECK(q.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(k.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(v.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(blockmask.dtype() == torch::kInt32);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
TORCH_CHECK(blockmask.is_cuda())
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(blockmask.is_contiguous())
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256;
|
||||
if( max_seqlen_k <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > 256;
|
||||
CHECK_SHAPE(blockmask, max_seqlen_k / 256, max_seqlen_q / 16);
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto o = torch::zeros({ total_q, num_heads, head_size }, opts);
|
||||
|
||||
at::Tensor o_tmp;
|
||||
if (loop) {
|
||||
// o_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
|
||||
o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
// auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
|
||||
auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
at::Tensor s;
|
||||
if (return_softmax) {
|
||||
s = torch::zeros({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts);
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
set_params_fprop(launch_params.params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v, o,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? o_tmp.data_ptr() : nullptr,
|
||||
return_softmax ? s.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
/*num_splits=*/1);
|
||||
launch_params.params.blockmask = static_cast<int *>(blockmask.data_ptr());
|
||||
|
||||
run_fmha_block_fp16_sm80(launch_params, /*configure=*/ true);
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
int64_t counter_offset = launch_params.elts_per_thread;
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
run_fmha_block_fp16_sm80(launch_params, /*configure=*/false);
|
||||
|
||||
std::vector<at::Tensor> result = {o, softmax_lse};
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp
|
||||
at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const at::Tensor &blockmask, // (seqlen / 256, seqlen / 16)
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
c10::optional<at::Generator> gen_
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm90);
|
||||
auto launch = &run_fmha_block_dgrad_fp16_sm80;
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK(q.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(k.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(v.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(dout.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(dq.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(dk.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(dv.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(blockmask.dtype() == torch::kInt32);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(dout.is_cuda());
|
||||
TORCH_CHECK(softmax_lse_.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
TORCH_CHECK(blockmask.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(dout.is_contiguous());
|
||||
TORCH_CHECK(dq.stride(-1) == 1);
|
||||
TORCH_CHECK(dk.stride(-1) == 1);
|
||||
TORCH_CHECK(dv.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(blockmask.is_contiguous());
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
|
||||
if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
|
||||
TORCH_CHECK(is_sm80 || is_sm90);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dk, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(dv, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256;
|
||||
if( max_seqlen_k <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > 256;
|
||||
CHECK_SHAPE(blockmask, max_seqlen_k / 256, max_seqlen_q / 16);
|
||||
|
||||
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
|
||||
auto softmax_lse = softmax_lse_.index({torch::indexing::Slice(), torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, max_seqlen_q)}).contiguous();
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_tmp;
|
||||
if (loop) {
|
||||
// dq_tmp = torch::zeros({total, num_heads, head_size}, opts.dtype(at::kFloat));
|
||||
dq_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
FMHA_dgrad_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v, out,
|
||||
dq, dk, dv,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? dq_tmp.data_ptr() : nullptr,
|
||||
dout.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
/*num_splits=*/1);
|
||||
params.blockmask = static_cast<int *>(blockmask.data_ptr());
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// We're gonna reset the rng state in Python after this kernel, so the counter offset
|
||||
// here doesn't matter at all. We just choose an arbitrary number;
|
||||
int64_t counter_offset = 4;
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
launch(params, stream);
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "Fused Multi-head Self-attention";
|
||||
m.def("fwd", &mha_fwd, "Forward pass");
|
||||
m.def("bwd", &mha_bwd, "Backward pass");
|
||||
m.def("fwd_block", &mha_fwd_block, "Forward pass (blocksparse)");
|
||||
m.def("bwd_block", &mha_bwd_block, "Backward pass (blocksparse)");
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Varlen=true>
|
||||
struct BlockInfo {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const uint32_t actual_seqlen_q;
|
||||
const uint32_t actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
@@ -0,0 +1,141 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = uint32_t;
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
// The pointer to the P matrix.
|
||||
void * __restrict__ p_ptr;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
// uint16_t p_dropout_in_uint16_t;
|
||||
uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_softmax_rp_dropout;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The dO and dQKV matrices.
|
||||
void *__restrict__ do_ptr;
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// To accumulate dQ
|
||||
void *__restrict__ dq_accum_ptr;
|
||||
void *__restrict__ dk_accum_ptr;
|
||||
void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
||||
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
||||
// dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dO, dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
index_t do_batch_stride;
|
||||
index_t do_row_stride;
|
||||
index_t do_head_stride;
|
||||
index_t dq_batch_stride;
|
||||
index_t dk_batch_stride;
|
||||
index_t dv_batch_stride;
|
||||
index_t dq_row_stride;
|
||||
index_t dk_row_stride;
|
||||
index_t dv_row_stride;
|
||||
index_t dq_head_stride;
|
||||
index_t dk_head_stride;
|
||||
index_t dv_head_stride;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.h == params.h_k) {
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 32, 128, 8, 2, 2, 2, false, false, elem_type>>(params, stream, configure);
|
||||
// // This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 2, 2, false, false, elem_type>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 64, 128, 8, 2, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<128, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<32, 128, 128, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<32, 128, 128, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 128, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// // Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 64, 128, 8, 2, 4, 2, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 64, 128, 8, 2, 4, 2, true, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 128, 8, 2, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// // This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 128, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 64, 8, 4, 2, 4, true, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 64, 64, 4, 2, 2, 2, true, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 32, 128, 4, 1, 4, 1, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 16, 128, 4, 1, 4, 1, false, false, elem_type>>(params, stream, configure);
|
||||
// // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 64, 8, 2, 2, 2, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 64, 8, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 16, 256, 8, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 128, 8, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 64, 64, 4, false, elem_type>>(params, stream, configure);
|
||||
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 256, 64, 8, 8, 4, 8, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 256, 64, 8, 8, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<64, 128, 64, 4, 4, 2, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<96, 64, 128, 8, 2, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<96, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_bwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.h == params.h_k) {
|
||||
// // run_flash_bwd<Flash_bwd_kernel_traits<96, 64, 128, 8, 2, 4, 4, true, false, elem_type>>(params, stream, configure);
|
||||
// // This is very slightly faster
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<96, 64, 128, 8, 2, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<96, 128, 64, 8, 4, 4, 4, false, false, elem_type>>(params, stream, configure);
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,355 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "flash.h"
|
||||
#include "flash_bwd_kernel.h"
|
||||
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
||||
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) {
|
||||
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
|
||||
flash::clear_dKVaccum<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
|
||||
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
|
||||
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
|
||||
flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
|
||||
flash::convert_dQ<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) {
|
||||
flash::convert_dKV<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
dim3 grid_n(num_n_block, params.b, params.h);
|
||||
|
||||
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_q as well.
|
||||
const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
}
|
||||
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
dim3 grid_n(num_n_block, params.b, params.h_k);
|
||||
flash_bwd_clear_dkvaccum_kernel<Kernel_traits><<<grid_n, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_k as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
|
||||
}
|
||||
kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
//
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (configure) return;
|
||||
// dim3 grid(params.b, params.h);
|
||||
// const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
// dim3 grid_m(num_m_block, params.b, params.h);
|
||||
|
||||
// if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA)
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
// }
|
||||
|
||||
// // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// // for cu_seqlens_q as well.
|
||||
// const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
// constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
// BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
|
||||
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
|
||||
// if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
// }
|
||||
// kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
|
||||
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
// if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
// }
|
||||
// kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
//
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
} else { // 96 KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// This has a lot of register spilling
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// }
|
||||
}
|
||||
});
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// if (params.h == params.h_k) {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 224;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
|
||||
// // 1st ones are good for H100, A100
|
||||
// // 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
|
||||
// // 1st one is good for H100, A100, A6000
|
||||
// }
|
||||
// }
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
|
||||
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
|
||||
// // For A100, H100, 1st is fastest.
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // This one is slightly faster for causal?
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
|
||||
// });
|
||||
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
|
||||
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// // For dropout there might be a lot of register spilling?
|
||||
// // These two are very slow due to register spilling
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
|
||||
// // This one is slightly slower
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// if (params.p_dropout == 1.f) {
|
||||
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// // Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
|
||||
// } else {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
|
||||
// }
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::bfloat16_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
// template<>
|
||||
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// using elem_type = cutlass::half_t;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
|
||||
// // This 3rd one is good for H100, and A100, A6000
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
|
||||
// // These two are always slower
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
|
||||
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
|
||||
// });
|
||||
// }
|
||||
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
@@ -0,0 +1,576 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
#include "block_info.h"
|
||||
#include "kernel_traits.h"
|
||||
#include "utils.h"
|
||||
#include "softmax.h"
|
||||
#include "philox.cuh"
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
// TODO: Shouldn't this be size<1>?
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
||||
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
flash::reduce_sum(scores, scores_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||
copy(scores_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? scores_max(mi)
|
||||
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scores_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
||||
flash::reduce_sum(scores, scores_sum_cur);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
||||
inline __device__ void write_softmax_to_gmem(
|
||||
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P
|
||||
) {
|
||||
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
||||
Layout l = tOrP.layout();
|
||||
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
||||
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
||||
|
||||
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
||||
if (Is_causal) {
|
||||
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
||||
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
||||
// }
|
||||
}
|
||||
|
||||
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
||||
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
|
||||
|
||||
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
// We move K and V to the last block.
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.v_row_stride, _1{}));
|
||||
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
|
||||
Shape<Int<kBlockM>, Int<kBlockN>>{},
|
||||
make_stride(params.seqlen_k_rounded, _1{}));
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQ{});
|
||||
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
|
||||
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
|
||||
typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||
|
||||
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx);
|
||||
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
||||
|
||||
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
||||
|
||||
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
|
||||
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
// // Allocate predicate tensors for m and n
|
||||
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
|
||||
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
|
||||
|
||||
// Construct identity layout for sQ and sK
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
|
||||
// if (cute::thread0()) {
|
||||
// print(tScQ.layout()); printf("\n");
|
||||
// for (int i = 0; i < size(tScQ); ++i) {
|
||||
// printf("%d ", get<0>(tScQ(i)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// for (int i = 0; i < size(tScQ); ++i) {
|
||||
// printf("%d ", get<1>(tScQ(i)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
// Set predicates for k bounds
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
|
||||
}
|
||||
|
||||
// Prologue
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
// // Copy rmem to smem
|
||||
// // copy(tQrQ, tQsQ);
|
||||
// flash::cp_async_wait<0>();
|
||||
// __syncthreads();
|
||||
// // if (cute::thread(1, 0)) { print(tQsQ); }
|
||||
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
|
||||
// // if (cute::thread0()) { print(sQNoSwizzle); }
|
||||
|
||||
if (Kernel_traits::Share_Q_K_smem) {
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
cute::cp_async_fence();
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||
// __syncthreads();
|
||||
|
||||
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
||||
flash::cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
}
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
unsigned long long seed = std::get<0>(seeds);
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
|
||||
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
|
||||
#pragma unroll
|
||||
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
// Advance gV
|
||||
if (masking_step > 0) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
if (!Is_causal) {
|
||||
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||
// static_assert(decltype(size<0>(taccScS))::value == 4);
|
||||
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
|
||||
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
|
||||
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM);
|
||||
// Idk why it's get<1> and not get<0> of the stride.
|
||||
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||
// I can't get the stride from idx_row
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > 0) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
copy(tOrP, tOrP_copy);
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
if (n_masking_steps > 1 && n_block <= 0) {
|
||||
--n_block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These are the iterations where we don't need masking on S
|
||||
for (; n_block >= 0; --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
// Advance gV
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > 0) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
copy(tOrP, tOrP_copy);
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
|
||||
// Convert acc_o from fp32 to fp16/bf16
|
||||
Tensor rO = flash::convert_type<Element>(acc_o);
|
||||
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// sO has the same size as sQ, so we don't need to sync here.
|
||||
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
||||
|
||||
copy(smem_thr_copy_O, taccOrO, taccOsO);
|
||||
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx);
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||
copy(gmem_thr_copy_O, tOsO, tOrO);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
||||
static_assert(decltype(size<0>(taccOcO))::value == 4);
|
||||
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
|
||||
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||
if (get<1>(taccOcO_row(0)) == 0) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccOcO_row(mi));
|
||||
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
|
||||
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
|
||||
// them to have the same number of threads or have to traverse the attention matrix
|
||||
// in the same order.
|
||||
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
|
||||
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
|
||||
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
@@ -0,0 +1,251 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "flash.h"
|
||||
#include "flash_fwd_kernel.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_q as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// These two are always slower
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 1st ones are good for H100, A100
|
||||
// 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
||||
// If we have N = 32, there are only 1024 elements to load at once, where each load
|
||||
// is 8 elements. This means we can only use 128 threads and not 256 threads.
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
|
||||
status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// 64 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 96 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/UnpackRaw.cuh>
|
||||
|
||||
#include <fmha_utils.h>
|
||||
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
// size_t qkv_stride_in_elts;
|
||||
// size_t qkv_stride_in_bytes;
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
uint32_t q_row_stride_in_elts;
|
||||
uint32_t k_row_stride_in_elts;
|
||||
uint32_t v_row_stride_in_elts;
|
||||
uint32_t q_head_stride_in_elts;
|
||||
uint32_t k_head_stride_in_elts;
|
||||
uint32_t v_head_stride_in_elts;
|
||||
|
||||
// The number of heads.
|
||||
int h;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FMHA_fprop_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
// size_t o_stride_in_elts;
|
||||
// size_t o_stride_in_bytes;
|
||||
uint32_t o_row_stride_in_elts;
|
||||
uint32_t o_head_stride_in_elts;
|
||||
uint32_t o_tmp_row_stride_in_elts;
|
||||
uint32_t o_tmp_head_stride_in_elts;
|
||||
|
||||
// The pointer to the O_tmp matrix, which holds O intermediate value during
|
||||
// the loop;
|
||||
void *__restrict__ o_tmp_ptr;
|
||||
|
||||
// The pointer to the S matrix.
|
||||
void * __restrict__ s_ptr;
|
||||
// The stride between rows of the S matrix.
|
||||
// int64_t s_stride_in_bytes;
|
||||
uint32_t s_stride_in_bytes;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_bmm1f;
|
||||
uint32_t scale_bmm1;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
uint32_t p_dropout_in_uint;
|
||||
uint16_t p_dropout_in_uint16_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_bmm1_rp_dropout;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout), in half2.
|
||||
uint32_t scale_dropout;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
||||
uint64_t * rng_state;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
int num_splits; // How many SMs per attention matrix.
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FMHA_dgrad_params : public FMHA_fprop_params {
|
||||
|
||||
// The dQKV matrices.
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
|
||||
// void *__restrict__ dk_accum_ptr;
|
||||
// void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
uint32_t dq_row_stride_in_elts;
|
||||
uint32_t dk_row_stride_in_elts;
|
||||
uint32_t dv_row_stride_in_elts;
|
||||
uint32_t dq_head_stride_in_elts;
|
||||
uint32_t dk_head_stride_in_elts;
|
||||
uint32_t dv_head_stride_in_elts;
|
||||
|
||||
// The dO matrix. We assume it is contiguous.
|
||||
void * __restrict__ do_ptr;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void * __restrict__ dsoftmax_sum;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_params>
|
||||
struct Launch_params{
|
||||
Launch_params(cudaDeviceProp * props_,
|
||||
cudaStream_t stream_,
|
||||
bool is_dropout_,
|
||||
bool return_softmax_)
|
||||
: elts_per_thread(0)
|
||||
, props(props_)
|
||||
, stream(stream_)
|
||||
, is_dropout(is_dropout_)
|
||||
, return_softmax(return_softmax_) {
|
||||
}
|
||||
|
||||
size_t elts_per_thread;
|
||||
|
||||
cudaDeviceProp * props;
|
||||
|
||||
cudaStream_t stream;
|
||||
|
||||
bool is_dropout;
|
||||
bool return_softmax;
|
||||
|
||||
Kernel_params params;
|
||||
int num_full_heads;
|
||||
int num_main_groups;
|
||||
int heads_last_wave;
|
||||
int main_steps;
|
||||
int rest_steps;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
|
||||
void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
|
||||
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
|
||||
|
||||
void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream);
|
||||
@@ -1,451 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmha/utils.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
|
||||
struct Fragment_base_ {
|
||||
|
||||
// The data type.
|
||||
using Data_type = Data_type_;
|
||||
// default input type
|
||||
using Input_type_ = Data_type_;
|
||||
// Does it store the array of elements.
|
||||
static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
|
||||
// The number of elements.
|
||||
static constexpr int NUM_ELTS = NUM_ELTS_;
|
||||
// The size of element in bits.
|
||||
static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
|
||||
// The size of byte of a single register.
|
||||
static constexpr int BYTES_PER_REG = 4;
|
||||
// The size in bits.
|
||||
static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
|
||||
// The number of registers needed to store the fragment.
|
||||
static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
|
||||
// The size in bytes (as returned by sizeof(Fragment_base<>).
|
||||
static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
|
||||
// The alignment.
|
||||
static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The type of the elements.
|
||||
typename Data_type_,
|
||||
// The number of elements.
|
||||
int NUM_ELTS_,
|
||||
// The alignment if you want to force a value -- use 0 otherwise.
|
||||
int ALIGNMENT_ = 0,
|
||||
// The base class.
|
||||
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
|
||||
>
|
||||
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
|
||||
|
||||
// The size of a load/store.
|
||||
static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);
|
||||
|
||||
// Clear the fragment. Using PTX in that code seems to produce better SASS...
|
||||
inline __device__ void clear() {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
|
||||
}
|
||||
}
|
||||
|
||||
// Immutable access to a register.
|
||||
inline __device__ const uint32_t& reg(int ii) const {
|
||||
return this->regs_[ii];
|
||||
}
|
||||
|
||||
// Mutable access to a register.
|
||||
inline __device__ uint32_t& reg(int ii) {
|
||||
return this->regs_[ii];
|
||||
}
|
||||
|
||||
uint32_t regs_[Base_::NUM_REGS];
|
||||
|
||||
// Immutable access to the elements.
|
||||
inline __device__ const Data_type_& elt(int ii) const {
|
||||
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Mutable access to the elements.
|
||||
inline __device__ Data_type_& elt(int ii) {
|
||||
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Immutable access to the elements with a cast.
|
||||
template< typename Cast_type >
|
||||
inline __device__ const Cast_type& elt_as(int ii) const {
|
||||
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Mutable access to the elements.
|
||||
template< typename Cast_type >
|
||||
inline __device__ Cast_type& elt_as(int ii) {
|
||||
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Add another fragment.
|
||||
inline __device__ void add(const Fragment &other) {
|
||||
// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
|
||||
// Also are we doing int addition or __half2 addition?
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
|
||||
this->elt(ii) += other.elt(ii);
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply by another fragment.
|
||||
inline __device__ void hmul(const Fragment &other) {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename elem_type>
|
||||
inline __device__ void hrelu_() {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Layout >
|
||||
struct Fragment_a : public Fragment<uint16_t, 8> {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Layout >
|
||||
struct Fragment_b : public Fragment<uint16_t, 8> {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Fragment_accumulator : public Fragment<float, 8> {
|
||||
|
||||
// The base class.
|
||||
using Base = Fragment<float, 8>;
|
||||
|
||||
// Add two fragments.
|
||||
template< typename Other_fragment_ >
|
||||
inline __device__ void add(const Other_fragment_ &other) {
|
||||
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
||||
this->elt(ii) = this->elt(ii) + other.elt(ii);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void mul_(const float other) {
|
||||
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
||||
this->elt(ii) *= other;
|
||||
}
|
||||
}
|
||||
|
||||
// Do the HMMA.
|
||||
template< typename Layout_a, typename Layout_b >
|
||||
inline __device__ void mma(const Fragment_a<Layout_a> &a,
|
||||
const Fragment_b<Layout_b> &b) {
|
||||
asm volatile( \
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
||||
" {%0, %1, %2, %3}, \n" \
|
||||
" {%4, %5, %6, %7}, \n" \
|
||||
" {%8, %9}, \n" \
|
||||
" {%0, %1, %2, %3}; \n" \
|
||||
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
|
||||
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
||||
, "r"(b.reg(0)), "r"(b.reg(1)));
|
||||
asm volatile( \
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
||||
" {%0, %1, %2, %3}, \n" \
|
||||
" {%4, %5, %6, %7}, \n" \
|
||||
" {%8, %9}, \n" \
|
||||
" {%0, %1, %2, %3}; \n" \
|
||||
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
|
||||
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
||||
, "r"(b.reg(2)), "r"(b.reg(3)));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Fragment, int M, int N >
|
||||
inline __device__ void clear(Fragment (&frag)[M][N]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ++ni ) {
|
||||
frag[mi][ni].clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Accumulator_type, int WARPS_K >
|
||||
struct Clear_accumulator {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int WARPS_K >
|
||||
struct Clear_accumulator<float, WARPS_K> {
|
||||
template< typename Acc, int M, int N >
|
||||
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
|
||||
fmha::clear(acc);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Acc, typename A, typename B, int M, int N>
|
||||
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ++ni ) {
|
||||
acc[mi][ni].mma(a[mi], b[ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Statically maps half types => cutlass data types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename Type_>
|
||||
struct HalfTypeToCutlassType { using Type = Type_; };
|
||||
|
||||
/// Statically maps __half => cutlass::half_t
|
||||
template <> struct HalfTypeToCutlassType<__half> {
|
||||
using Type = cutlass::half_t;
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
template <> struct HalfTypeToCutlassType<__nv_bfloat16> {
|
||||
using Type = cutlass::bfloat16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename elem_type, typename Acc, typename A, typename B, int M, int N>
|
||||
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
||||
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
#else
|
||||
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
// TD [2022-06-02] We don't support Volta (SM70) yet.
|
||||
assert(0);
|
||||
#endif
|
||||
using Element = typename HalfTypeToCutlassType<elem_type>::Type;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
|
||||
|
||||
constexpr int kIters = Shape::kK / InstructionShape::kK;
|
||||
// using FragmentA = typename WarpMma::FragmentA;
|
||||
// using FragmentB = typename WarpMma::FragmentB;
|
||||
using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
|
||||
using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
|
||||
using FragmentC = typename WarpMma::FragmentC;
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
|
||||
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
|
||||
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
|
||||
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
|
||||
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
|
||||
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
|
||||
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
|
||||
// }
|
||||
|
||||
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
|
||||
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
|
||||
static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
|
||||
static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
|
||||
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
|
||||
// const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
|
||||
// const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
|
||||
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
|
||||
FragmentA a_cl[kIters][M];
|
||||
FragmentA b_cl[kIters][N];
|
||||
constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < kIters; iter++) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; mi++) {
|
||||
uint32_t *a_ptr = a_cl[iter][mi].raw_data();
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < kRegs; ki++) {
|
||||
a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < kIters; iter++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ni++) {
|
||||
uint32_t *b_ptr = b_cl[iter][ni].raw_data();
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < kRegs; ki++) {
|
||||
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
|
||||
// TD [2022-06-02] For some reason the order for frag_b is different.
|
||||
b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WarpMma mma_op;
|
||||
// mma_op(c_cl, a_cl, b_cl, c_cl);
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < kIters; iter++) {
|
||||
mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
|
||||
reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
|
||||
}
|
||||
|
||||
// The modified c_cl is not copied back into acc, idk why
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; mi++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ni++) {
|
||||
#pragma unroll
|
||||
for (int i =0; i < 8; i++) {
|
||||
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The number of rows in the CTA tile.
|
||||
int M_,
|
||||
// The number of cols in the CTA tile.
|
||||
int N_,
|
||||
// The number of elements in the the K dimension of the GEMM loop.
|
||||
int K_,
|
||||
// The number of rows of warps.
|
||||
int WARPS_M_,
|
||||
// The number of cols of warps.
|
||||
int WARPS_N_,
|
||||
// The number of warps in the K dimension of the GEMM loop.
|
||||
int WARPS_K_>
|
||||
struct Cta_tile_ {
|
||||
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
// The number of warps.
|
||||
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
|
||||
// The number of warps per CTA.
|
||||
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
|
||||
// The number of threads per warp.
|
||||
static constexpr int THREADS_PER_WARP = 32;
|
||||
// The number of threads per CTA.
|
||||
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile>
|
||||
struct Hmma_tile {
|
||||
// The number of elements computed with a single warp-MMA.
|
||||
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
|
||||
|
||||
// The number of elements computed with a single CTA-MMA.
|
||||
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
|
||||
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
|
||||
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
|
||||
|
||||
// The number of MMAs needed to compute the GEMM.
|
||||
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
|
||||
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
|
||||
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
|
||||
|
||||
// // The number of elements computed per warp.
|
||||
// static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
|
||||
// N_PER_WARP = MMAS_N * N_PER_MMA,
|
||||
// K_PER_WARP = MMAS_K * K_PER_MMA;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using A_type = uint16_t;
|
||||
using B_type = uint16_t;
|
||||
using C_type = uint16_t;
|
||||
using Accumulator_type = float;
|
||||
using Epilogue_type = float;
|
||||
|
||||
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
|
||||
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
|
||||
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
|
||||
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile_>
|
||||
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
|
||||
Cta_tile_::N,
|
||||
Next_power_of_two<Cta_tile_::K>::VALUE,
|
||||
Cta_tile_::WARPS_M,
|
||||
Cta_tile_::WARPS_N,
|
||||
Cta_tile_::WARPS_K>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@@ -1,555 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <fmha/utils.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
template<
|
||||
// The dimensions of the tile computed by the CTA.
|
||||
typename Cta_tile_,
|
||||
// The number of bits per element.
|
||||
int BITS_PER_ELEMENT,
|
||||
// The number of rows of Q, K or V loaded by this tile.
|
||||
int ROWS_,
|
||||
// The number of columns.
|
||||
int COLS,
|
||||
int BYTES_PER_LDGS_ = 16
|
||||
>
|
||||
struct Gmem_tile_qkv {
|
||||
|
||||
using Cta_tile = Cta_tile_;
|
||||
|
||||
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
|
||||
// The size of each LDG.
|
||||
static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
|
||||
// The size of a row in bytes.
|
||||
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
|
||||
|
||||
// The number of threads to load a "row" of the matrix.
|
||||
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG;
|
||||
|
||||
static constexpr int ROWS = ROWS_;
|
||||
// The number of "rows" loaded per LDG.
|
||||
static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
|
||||
// The number of LDGs needed to load a chunk of the Q matrix.
|
||||
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
|
||||
|
||||
// Ctor.
|
||||
template< typename BInfo >
|
||||
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
|
||||
const uint32_t head_stride_in_elts, const int headdim,
|
||||
const BInfo &binfo, const int tidx, bool use_seqlen_q)
|
||||
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
|
||||
, actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
|
||||
, ptr(reinterpret_cast<char *>(ptr_))
|
||||
, tidx_(tidx)
|
||||
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) {
|
||||
|
||||
// Compute the position in the sequence (within the CTA for the moment).
|
||||
int row = tidx / THREADS_PER_ROW;
|
||||
// Compute the position of the thread in the row.
|
||||
int col = tidx % THREADS_PER_ROW;
|
||||
|
||||
// Store the row as we need it to disable the loads.
|
||||
// TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
|
||||
// row_ = row;
|
||||
|
||||
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
|
||||
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
|
||||
uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
|
||||
// Add the block index.
|
||||
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
|
||||
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
|
||||
|
||||
// Assemble the final pointer.
|
||||
ptr += row_offset + col * BYTES_PER_LDG;
|
||||
}
|
||||
|
||||
// Store data to shared memory.
|
||||
template< typename Smem_tile >
|
||||
inline __device__ void commit(Smem_tile &smem_tile) {
|
||||
smem_tile.store(fetch_);
|
||||
}
|
||||
|
||||
inline __device__ void load() {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
const void *ptrs[LDGS];
|
||||
uint32_t preds[LDGS];
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < LDGS; ++ii ) {
|
||||
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
|
||||
fetch_[ii] = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
|
||||
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < LDGS; ++ii ) {
|
||||
fct.load(ii, preds[ii]);
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to memory.
|
||||
inline __device__ void store(const uint4 (&data)[LDGS]) {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < LDGS; ++ii ) {
|
||||
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
|
||||
fmha::stg(ptr_, data[ii]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
|
||||
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
|
||||
actual_seqlen -= ROWS * steps;
|
||||
}
|
||||
|
||||
// The stride between rows for the QKV matrice.
|
||||
// int64_t row_stride_in_bytes;
|
||||
const uint32_t row_stride_in_bytes;
|
||||
// The pointer.
|
||||
char *ptr;
|
||||
// The fetch registers.
|
||||
uint4 fetch_[LDGS];
|
||||
// Keep track of the row the thread is processing as we move the tile.
|
||||
// int row_;
|
||||
const int tidx_;
|
||||
// The length of the sequence loaded by that memory tile.
|
||||
int actual_seqlen;
|
||||
const bool col_predicate;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
typename Cta_tile,
|
||||
int BYTES_PER_ELEMENT = 2
|
||||
>
|
||||
struct Gmem_tile_o {
|
||||
|
||||
static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4);
|
||||
|
||||
// The mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The size of each element.
|
||||
// static constexpr int BYTES_PER_ELEMENT = 2;
|
||||
// The size of each STG.
|
||||
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4;
|
||||
static constexpr int COLS = Cta_tile::N;
|
||||
// The size of a row in bytes.
|
||||
static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT;
|
||||
|
||||
// The number of threads to store a "row" of the matrix.
|
||||
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG;
|
||||
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
|
||||
static constexpr int ROWS = Cta_tile::M;
|
||||
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
|
||||
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
|
||||
// The number of outter loop for the stores.
|
||||
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
|
||||
|
||||
// The number of "rows" stored per STG.
|
||||
static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
|
||||
// Do we have to guard against partial writes/reads.
|
||||
static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0;
|
||||
// The number of STGs needed to store a chunk of the Q matrix.
|
||||
static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG);
|
||||
// The number of STGs needed to store a chunk of the Q matrix in total.
|
||||
static constexpr int STGS = STGS_PER_LOOP * LOOPS;
|
||||
|
||||
// Ctor.
|
||||
template<typename BInfo>
|
||||
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
|
||||
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
|
||||
const uint32_t head_stride_in_elts, const int headdim,
|
||||
const BInfo &binfo, const int tidx)
|
||||
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
|
||||
, actual_seqlen_q(binfo.actual_seqlen_q)
|
||||
, ptr_(reinterpret_cast<char *>(ptr))
|
||||
, tidx_(tidx)
|
||||
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) {
|
||||
|
||||
// Compute the position in the sequence (within the CTA for the moment).
|
||||
int row = tidx / THREADS_PER_ROW;
|
||||
// Compute the position of the thread in the row.
|
||||
int col = tidx % THREADS_PER_ROW;
|
||||
|
||||
// Store the row as we need it to disable loads.
|
||||
// row_ = row;
|
||||
|
||||
// The row offset in the batched GEMM.
|
||||
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
|
||||
uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes);
|
||||
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
|
||||
// Assemble the final pointer.
|
||||
ptr_ += row_offset + col * BYTES_PER_STG;
|
||||
|
||||
// Is that thread active on the last STG?
|
||||
if( HAS_INCOMPLETE_STG ) {
|
||||
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
template<typename elem_type=__half>
|
||||
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
||||
int jj = mi * STGS_PER_LOOP + ii;
|
||||
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (BYTES_PER_ELEMENT == 4) {
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
|
||||
}
|
||||
} else if (BYTES_PER_ELEMENT == 2) {
|
||||
float x = reinterpret_cast<const float &>(src[ii].x);
|
||||
float y = reinterpret_cast<const float &>(src[ii].y);
|
||||
float z = reinterpret_cast<const float &>(src[ii].z);
|
||||
float w = reinterpret_cast<const float &>(src[ii].w);
|
||||
uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory with atomicAdd.
|
||||
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
|
||||
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
||||
int jj = mi * STGS_PER_LOOP + ii;
|
||||
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load data from global memory.
|
||||
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
|
||||
static_assert(BYTES_PER_ELEMENT == 4);
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
||||
int jj = mi * STGS_PER_LOOP + ii;
|
||||
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
// row_ += ROWS * steps;
|
||||
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
|
||||
ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
|
||||
actual_seqlen_q -= ROWS * steps;
|
||||
}
|
||||
|
||||
// The stride between rows for the QKV matrice.
|
||||
// int64_t row_stride_in_bytes;
|
||||
const uint32_t row_stride_in_bytes;
|
||||
// The pointer.
|
||||
char *ptr_;
|
||||
// Is the thread active for the last STG?
|
||||
int is_active_for_last_stg_;
|
||||
// The length of the sequence loaded by that memory tile.
|
||||
int actual_seqlen_q;
|
||||
const int tidx_;
|
||||
const bool col_predicate;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Cta_tile, int BYTES_PER_ELEMENT >
|
||||
struct Gmem_tile_mma_sd {
|
||||
|
||||
// The mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// Each STG stores 8 elements.
|
||||
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
|
||||
// The number of MMAs in the M dimension.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
// The number of MMAs in the N dimension.
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
// The number of rows computed per MMA per thread block.
|
||||
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
|
||||
// The number of cols computed per MMA per thread block.
|
||||
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
|
||||
// The number of threads per block.
|
||||
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
|
||||
// The size of each row in bytes. I.e. how many bytes are stored per STG.
|
||||
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
|
||||
// The distance between elements stored per loop (in bytes).
|
||||
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
|
||||
|
||||
// The type of elements stored per STG.
|
||||
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx)
|
||||
: ptr_(static_cast<char *>(ptr)) {
|
||||
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
// The distance between two blocks (in bytes).
|
||||
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
|
||||
const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
|
||||
// Set store location for each thread at the beginning of the loop
|
||||
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
inline __device__ void store(const Type &data, const int mi, const int ni) {
|
||||
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
fmha::stg(ptr_ + offset, data);
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load(Type &data, const int mi, const int ni) {
|
||||
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
fmha::ldg(data, ptr_ + offset);
|
||||
}
|
||||
|
||||
// Move to the next tile.
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
ptr_ += LOOP_STRIDE_BYTES * steps;
|
||||
}
|
||||
|
||||
// The pointer in global memory.
|
||||
char *ptr_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
|
||||
struct Gmem_tile_mma_s : public Base {
|
||||
|
||||
// The number of mmas in the vertical dimension.
|
||||
static constexpr int M = Base::MMAS_M;
|
||||
// The number of mmas in the horizontal dimension.
|
||||
static constexpr int N = Base::MMAS_N;
|
||||
// The type of the vectors stored by each STG.
|
||||
using Type = typename Base::Type;
|
||||
|
||||
// Ctor.
|
||||
template< typename Params, typename Block_info >
|
||||
inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx)
|
||||
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
template<typename Mask, typename Fragment>
|
||||
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
uint4 dst;
|
||||
dst.x = frag[ni][mi].reg(0);
|
||||
dst.y = frag[ni][mi].reg(2);
|
||||
dst.z = frag[ni][mi].reg(1);
|
||||
dst.w = frag[ni][mi].reg(3);
|
||||
if( mask.any_valid(mi, ni) ) {
|
||||
Base::store(dst, mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
template<typename Mask>
|
||||
inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
regs[mi][ni] = make_uint4(0, 0, 0, 0);
|
||||
if( mask.any_valid(mi, ni) ) {
|
||||
Base::load(regs[mi][ni], mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The dimensions of the tile computed by the CTA.
|
||||
typename Cta_tile
|
||||
>
|
||||
struct Gmem_summary_stats {
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
|
||||
// The size of each element.
|
||||
static constexpr int BYTES_PER_ELEMENT = 4;
|
||||
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
|
||||
static constexpr int ROWS = Cta_tile::M;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx)
|
||||
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
// Extract the position in the warp.
|
||||
int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
// The distance between two blocks (in bytes).
|
||||
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
|
||||
uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
|
||||
|
||||
// Set store location for each thread at the beginning of the loop
|
||||
ptr_row_ = ptr_ + bidx * block_stride_bytes;
|
||||
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
|
||||
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
|
||||
if ((warp == 0) && (lane % 4 == 0)) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
|
||||
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
|
||||
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
|
||||
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
|
||||
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
template <int N>
|
||||
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ++ni) {
|
||||
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointer to the next location.
|
||||
inline __device__ void move() {
|
||||
ptr_ += ROWS * BYTES_PER_ELEMENT;
|
||||
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
|
||||
}
|
||||
|
||||
// Move the pointer to the next location.
|
||||
inline __device__ void move(const int steps) {
|
||||
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
|
||||
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
|
||||
}
|
||||
|
||||
// The pointer.
|
||||
char *ptr_;
|
||||
char *ptr_row_;
|
||||
const int tidx_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type_=__half>
|
||||
struct FMHA_kernel_traits {
|
||||
|
||||
// The CTA description for the 1st GEMM.
|
||||
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
|
||||
// The CTA description for the 2nd GEMM.
|
||||
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
|
||||
|
||||
// Do we use one buffer for K and V.
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u;
|
||||
// Do we keep K in registers.
|
||||
static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u;
|
||||
// Do we keep V in registers.
|
||||
static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
|
||||
|
||||
// The shared memory tile to swizzle Q.
|
||||
// using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
|
||||
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;
|
||||
// The shared memory tile to swizzle K.
|
||||
using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;
|
||||
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;
|
||||
|
||||
// The global memory tile to store O.
|
||||
using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;
|
||||
// The shared memory tile for O.
|
||||
using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;;
|
||||
|
||||
// The global memory tile to load/store S.
|
||||
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
|
||||
|
||||
// The shared memory tile to transpose S.
|
||||
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
|
||||
|
||||
using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
|
||||
|
||||
// // The global memory tile to store the accumulated dK and dV
|
||||
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
|
||||
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
|
||||
// // be issue any load or store of size 32 bytes.
|
||||
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
|
||||
|
||||
// The global memory tile to store the softmax sum.
|
||||
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
|
||||
|
||||
// The shared memory tile to store dp sum.
|
||||
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
|
||||
|
||||
using elem_type = elem_type_;
|
||||
|
||||
// Make sure the number of threads match.
|
||||
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA;
|
||||
// Make sure the number of threads matches both CTAs.
|
||||
static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, "");
|
||||
|
||||
// The amount of shared memory needed to load Q and K.
|
||||
static constexpr int BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE;
|
||||
// The extra amount of shared memory needed to load V.
|
||||
static constexpr int BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE;
|
||||
// The amount of shared memory needed for Q, K and V..
|
||||
static constexpr int BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V;
|
||||
// The amount of shared memory needed to load Q and store O.
|
||||
static constexpr int BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE;
|
||||
|
||||
// The amount of shared memory needed for Q, K, V and O.
|
||||
static constexpr int BYTES_PER_SMEM = fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO);
|
||||
// Make sure we have enough shared memory.
|
||||
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -1,90 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace fmha {
|
||||
|
||||
|
||||
template<typename Cta_tile, bool Is_causal=false>
|
||||
struct Mask {
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
template<typename BInfo>
|
||||
__device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
|
||||
: actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
|
||||
, loop_step_idx(loop_step_idx_) {
|
||||
|
||||
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
static_assert(Cta_tile::WARPS_K == 1, "");
|
||||
|
||||
// find the warp in the Cta tile
|
||||
const int warp_n = (warp / Cta_tile::WARPS_M);
|
||||
const int warp_m = (warp % Cta_tile::WARPS_M);
|
||||
// decompose warp into 8x4 tile
|
||||
const int quad = lane / 4;
|
||||
const int tid = (lane % 4) * 2;
|
||||
row = warp_m * 16 + quad;
|
||||
col = warp_n * 16 + tid;
|
||||
}
|
||||
|
||||
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
|
||||
|
||||
// ii and jj iterate over the 2x4 fragment
|
||||
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
|
||||
const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
|
||||
const int current_row = row_offset + ii * 8;
|
||||
const bool col_valid = current_col < actual_seqlen_k;
|
||||
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
|
||||
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
|
||||
// bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
|
||||
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
|
||||
// }
|
||||
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
|
||||
// return row_valid && col_valid;
|
||||
}
|
||||
|
||||
//BERT Mask: if upper left is invalid, none are valid
|
||||
inline __device__ bool any_valid(const int mi, const int ni) const {
|
||||
return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0);
|
||||
}
|
||||
|
||||
inline __device__ void load(const int it) {
|
||||
row_offset = it * Cta_tile::M + row;
|
||||
}
|
||||
int row_offset;
|
||||
|
||||
int row;
|
||||
int col;
|
||||
const int loop_step_idx;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
} // namespace fmha
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,607 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp_(float x, float max) {
|
||||
return __expf(x - max);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp2_(float x, float max) {
|
||||
return exp2f(x - max);
|
||||
// With fast-math, this produces the same PTX instruction as the assembly below
|
||||
// float diff = x - max;
|
||||
// float res;
|
||||
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
|
||||
// return res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int COLS> struct ReadType {};
|
||||
template<> struct ReadType<4> { using T = float;};
|
||||
template<> struct ReadType<8> { using T = float2;};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Cta_tile, typename Kernel_traits>
|
||||
struct Smem_tile_reduce {
|
||||
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
|
||||
static constexpr int WARPS_M = Cta_tile::WARPS_M;
|
||||
static constexpr int WARPS_N = Cta_tile::WARPS_N;
|
||||
|
||||
|
||||
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
|
||||
static constexpr int COLS = WARPS_N;
|
||||
static_assert(COLS == 4 || COLS == 8);
|
||||
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
|
||||
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
|
||||
static constexpr int ELTS_PER_TILE = ROWS * COLS;
|
||||
|
||||
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
|
||||
// TD [2022-05-02]: No longer true if head_dim != 64
|
||||
// static_assert(THREADS_PER_GROUP == 16); // DEBUG
|
||||
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
|
||||
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
|
||||
static_assert(LOOPS == 1);
|
||||
|
||||
using read_t = typename ReadType<COLS>::T;
|
||||
|
||||
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
|
||||
|
||||
int lane = tidx % 32;
|
||||
int warp = tidx / 32;
|
||||
|
||||
int warp_m = warp % WARPS_M;
|
||||
int warp_n = warp / WARPS_M;
|
||||
|
||||
qid_ = lane % 4;
|
||||
int qp = lane / 4;
|
||||
|
||||
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
|
||||
// This won't affect reading as we assume commutative reduction ops.
|
||||
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
|
||||
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
|
||||
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
|
||||
smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
|
||||
|
||||
}
|
||||
|
||||
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
|
||||
if( qid_ == 0 ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * WARPS_N;
|
||||
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
|
||||
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
|
||||
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
|
||||
}
|
||||
}
|
||||
|
||||
int qid_;
|
||||
float *smem_write_;
|
||||
read_t *smem_read_;
|
||||
read_t *smem_read_row_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
template<typename Cta_tile, typename Kernel_traits>
|
||||
struct Softmax_base {
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
|
||||
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
|
||||
static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
|
||||
// The number of elements that we are going to store per row.
|
||||
static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
|
||||
// The number of rows.
|
||||
static constexpr int ROWS = Cta_tile::M * GROUPS;
|
||||
// The total number of elements.
|
||||
static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax_base(const Params ¶ms, void *smem, int tidx)
|
||||
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
|
||||
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
|
||||
|
||||
// Move to the 1st mask loaded by the thread+ tidx;
|
||||
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
|
||||
|
||||
// Extract the position in the warp.
|
||||
int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
// Decompose the warp index into M and N.
|
||||
int warp_m = warp % Cta_tile::WARPS_M;
|
||||
int warp_n = warp / Cta_tile::WARPS_M;
|
||||
|
||||
// Decompose the warp-n index into group/position-inside-the-group.
|
||||
int warp_g = warp_n / ELEMENTS_PER_ROW;
|
||||
int warp_i = warp_n % ELEMENTS_PER_ROW;
|
||||
|
||||
// The location written by the threads.
|
||||
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
|
||||
int write_col = warp_i;
|
||||
|
||||
// Assemble the write pointer.
|
||||
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
|
||||
|
||||
// Assemble the read pointer.
|
||||
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
|
||||
}
|
||||
|
||||
template<bool zero=false, typename Mask>
|
||||
inline __device__ void apply_mask(const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < 2; ++ii ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
#pragma unroll
|
||||
for( int jj = 0; jj < 4; ++jj ) {
|
||||
if( !mask.is_valid(mi, ni, ii, jj) ) {
|
||||
elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool max_in_base2=false, bool elt_in_base2=false>
|
||||
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
// elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
|
||||
elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e,
|
||||
max_base2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool scale_max=true>
|
||||
inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
|
||||
const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
|
||||
const float scale = scale_ * M_LOG2E;
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
const float max_scaled = max[mi] * max_scale;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool max_in_base2=false>
|
||||
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e;
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
|
||||
}
|
||||
}
|
||||
}
|
||||
// inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
|
||||
// constexpr float kLog2e = M_LOG2E;
|
||||
// #pragma unroll
|
||||
// for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
// float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e;
|
||||
// max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8);
|
||||
// #pragma unroll
|
||||
// for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
// elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni++ ) {
|
||||
uint16_t tmp[8];
|
||||
// fmha::uint4_to_ushort8(ph(), tmp);
|
||||
uint4 tmp_32 = ph();
|
||||
fmha::uint4_to_ushort8(tmp_32, tmp);
|
||||
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
|
||||
unsigned long long philox_subsequence) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
static_assert(MMAS_M == 1); // We're assuming 16x16 blocks.
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni++ ) {
|
||||
uint16_t tmp[8];
|
||||
// fmha::uint4_to_ushort8(ph(), tmp);
|
||||
fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
|
||||
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
|
||||
// fmha::uint4_to_ushort8(tmp_32, tmp);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
static_assert(MMAS_N % 2 == 0);
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
|
||||
uint16_t tmp[8];
|
||||
fmha::uint4_to_ushort8(ph0(), tmp);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
}
|
||||
}
|
||||
fmha::uint4_to_ushort8(ph1(), tmp);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale all the elements.
|
||||
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
|
||||
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
|
||||
float inv_sum[MMAS_M * 2];
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
|
||||
}
|
||||
|
||||
// Update the values.
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] *= inv_sum[mi];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract all elements by dp_sum
|
||||
inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] -= dp_sum[mi];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The pointer to the mask.
|
||||
const char *packed_mask_ptr_;
|
||||
// Shared memory for the CTA-wide reduction.
|
||||
float *smem_, *smem_write_, *smem_read_;
|
||||
// The current thread index.
|
||||
int tidx_;
|
||||
// The elements.
|
||||
float elt_[MMAS_M * 2][MMAS_N * 4];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile, typename Kernel_traits>
|
||||
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
|
||||
// The base class.
|
||||
using Base = Softmax_base<Cta_tile, Kernel_traits>;
|
||||
// The fragment.
|
||||
using Fragment_a = fmha::Fragment_a<fmha::Row>;
|
||||
|
||||
static_assert(Fragment_a::NUM_REGS == 4);
|
||||
|
||||
static constexpr int WARPS_M = Cta_tile::WARPS_M;
|
||||
static constexpr int WARPS_N = Cta_tile::WARPS_N;
|
||||
// The MMAs.
|
||||
static constexpr int MMAS_M = Base::MMAS_M;
|
||||
static constexpr int MMAS_N = Base::MMAS_N;
|
||||
|
||||
// The accumulators.
|
||||
using Accumulator = fmha::Fragment_accumulator;
|
||||
using Accumulator_out = Fragment<uint16_t, 8>;
|
||||
static_assert(Accumulator_out::NUM_REGS == 4);
|
||||
|
||||
static_assert(std::is_same<Accumulator::Data_type, float>::value);
|
||||
|
||||
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
|
||||
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax(const Params ¶ms, void *smem, int tidx)
|
||||
: Base(params, smem, tidx)
|
||||
, params_scale_bmm1_(params.scale_bmm1)
|
||||
, smem_sum_(static_cast<float*>(smem), tidx)
|
||||
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
// Pack the data to a fragment for the next GEMM.
|
||||
template<typename elem_type=__half, int K, int M>
|
||||
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < K; ++ki ) {
|
||||
|
||||
// 1st row - 4 elements per row.
|
||||
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
|
||||
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
|
||||
float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
|
||||
float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];
|
||||
|
||||
// 2nd row - 4 elements per row.
|
||||
float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
|
||||
float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
|
||||
float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
|
||||
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
|
||||
|
||||
// Pack to 4 registers.
|
||||
dst[ki][mi].reg(0) = fmha::float2_pack<elem_type>(tmp_00, tmp_01);
|
||||
dst[ki][mi].reg(1) = fmha::float2_pack<elem_type>(tmp_10, tmp_11);
|
||||
dst[ki][mi].reg(2) = fmha::float2_pack<elem_type>(tmp_02, tmp_03);
|
||||
dst[ki][mi].reg(3) = fmha::float2_pack<elem_type>(tmp_12, tmp_13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale FP32 fragments
|
||||
inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
|
||||
const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
// 1st row - 4 elements per row.
|
||||
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
|
||||
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
|
||||
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
|
||||
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
|
||||
// 2nd row - 4 elements per row.
|
||||
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
|
||||
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
|
||||
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
|
||||
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale FP32 fragments
|
||||
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
// 1st row - 4 elements per row.
|
||||
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
|
||||
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
|
||||
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
|
||||
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
|
||||
// 2nd row - 4 elements per row.
|
||||
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
|
||||
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
|
||||
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
|
||||
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
|
||||
frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
|
||||
#pragma unroll
|
||||
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
|
||||
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
|
||||
thread_reduce_<zero_init>(frag, op);
|
||||
quad_reduce(frag, frag, op);
|
||||
smem_red.store(frag);
|
||||
__syncthreads();
|
||||
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
|
||||
smem_red.load(tmp);
|
||||
quad_allreduce(frag, tmp, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
|
||||
MaxOp<float> max;
|
||||
reduce_<zero_init>(frag, max, smem_max_);
|
||||
}
|
||||
|
||||
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
reduce_(frag, sum, smem_sum_);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
thread_reduce_<zero_init>(frag, sum);
|
||||
quad_reduce(frag, frag, sum);
|
||||
smem_sum_.store(frag);
|
||||
}
|
||||
|
||||
template<int NROWS, typename Operator>
|
||||
__device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS],
|
||||
Operator &op, Smem_tile_red & smem_red) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NROWS; ii++) {
|
||||
typename Smem_tile_red::read_t tmp[MMAS_M];
|
||||
smem_red.load_row(tmp, rows[ii]);
|
||||
quad_allreduce(frag[ii], tmp, op);
|
||||
}
|
||||
}
|
||||
|
||||
template<int NROWS>
|
||||
__device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS]){
|
||||
SumOp<float> sum;
|
||||
reduce_after_sync_(frag, rows, sum, smem_sum_);
|
||||
}
|
||||
|
||||
template<int NROWS>
|
||||
__device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS]){
|
||||
MaxOp<float> max;
|
||||
reduce_after_sync_(frag, rows, max, smem_max_);
|
||||
}
|
||||
|
||||
const uint32_t params_scale_bmm1_;
|
||||
Smem_tile_red smem_max_;
|
||||
Smem_tile_red smem_sum_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,64 +0,0 @@
|
||||
/* Copyright (c) 2022, Tri Dao.
|
||||
*/
|
||||
|
||||
#include "fmha.h"
|
||||
#include "fmha_block_dgrad_kernel_1xN_loop.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
|
||||
__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_block_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
|
||||
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
|
||||
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
|
||||
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
|
||||
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
|
||||
|
||||
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
|
||||
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
|
||||
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
|
||||
static_assert(smem_size_dp_sum == 16 * 4 * 2);
|
||||
|
||||
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
|
||||
|
||||
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
||||
bool is_causal = params.is_causal;
|
||||
auto kernel = is_dropout
|
||||
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false>)
|
||||
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false>);
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
kernel = is_dropout
|
||||
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/1>)
|
||||
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/1> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/1>);
|
||||
} else if (params.seqlen_k == blocksize_c * 2) {
|
||||
kernel = is_dropout
|
||||
? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, true, false, /*loop_steps=*/2>)
|
||||
: (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
|
||||
}
|
||||
|
||||
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
dim3 grid(params.b, params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) {
|
||||
if (params.d == 16) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>;
|
||||
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
} else if (params.d == 32) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>;
|
||||
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
} else if (params.d == 64) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
|
||||
run_fmha_block_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
}
|
||||
}
|
||||
@@ -1,772 +0,0 @@
|
||||
/* Copyright (c) 2022, Tri Dao.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fmha_fprop_kernel_1xN.h"
|
||||
#include "fmha_kernel.h"
|
||||
#include "fmha_blockmask.h"
|
||||
#include <fmha/kernel_traits.h>
|
||||
#include <fmha/gemm.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Smem_dp_sum, int M>
|
||||
inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M],
|
||||
Smem_dp_sum smem, const int buffer_idx) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; ++mi) {
|
||||
sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi]));
|
||||
}
|
||||
static_assert(M == 1);
|
||||
smem.store(sum[0], buffer_idx);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, typename Params, typename Prng>
|
||||
inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph,
|
||||
const int loop_step_idx) {
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_dq = typename Kernel_traits::Cta_tile_o;
|
||||
// The description of the CTA tile for the 3rd batched GEMM.
|
||||
using Cta_tile_dkv =
|
||||
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
|
||||
|
||||
static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128);
|
||||
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64);
|
||||
static_assert(Cta_tile_dkv::K == 16);
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_dq = fmha::Hmma_tile<Cta_tile_dq>;
|
||||
// The MMA tile for the 3rd GEMM.
|
||||
using Mma_tile_dkv = fmha::Hmma_tile<Cta_tile_dkv>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
|
||||
// The shared memory tile to reload Q transposed.
|
||||
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
|
||||
// The shared memory tile to swizzle K^T. Treat K^T as V
|
||||
using Smem_tile_kt = typename Kernel_traits::Smem_tile_v;
|
||||
|
||||
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_k;
|
||||
|
||||
// The global memory tile to load dO.
|
||||
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
|
||||
// The shared memory tile to load dO.
|
||||
// Treating dO as Q.
|
||||
using Smem_tile_do = typename Kernel_traits::Smem_tile_q;
|
||||
// The shared memory tile to reload dO transposed.
|
||||
using Smem_tile_dot = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load O.Loading O here is similar to loading dO.
|
||||
using Gmem_tile_o = Gmem_tile_do;
|
||||
|
||||
// The global memory tile to store dQ.
|
||||
using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
|
||||
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
|
||||
// The shared memory tile to swizzle dQ.
|
||||
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
|
||||
|
||||
// The global memory tile to store dV.
|
||||
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle dV.
|
||||
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
|
||||
|
||||
// The global memory tile to store dK.
|
||||
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle dK.
|
||||
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
|
||||
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
|
||||
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum;
|
||||
|
||||
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// Shared memory layout if we keep V in registers:
|
||||
// dO | Q | K / V | dQ | S | dP | dP_sum
|
||||
// dV | dK
|
||||
// Shared memory layout if we keep V shared memory:
|
||||
// dO | Q | K | V | dQ | S | dP | dP_sum
|
||||
// dV | dK
|
||||
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Blockmask blockmask(params, loop_step_idx);
|
||||
int block_row_idx = 0;
|
||||
int mask_val = blockmask.mask_val(0);
|
||||
if (mask_val == -1) return;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("mask_val = %d.\n", mask_val);
|
||||
// }
|
||||
|
||||
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
|
||||
// Allocate the global memory tile loader for Q.
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for dQ.
|
||||
Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// Allocate the global memory tile loader for V.
|
||||
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
Smem_tile_v smem_v(smem_v_, tidx);
|
||||
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
|
||||
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for dO.
|
||||
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the shared memory tile loader for dO.
|
||||
Smem_tile_do smem_do(&smem_[0], tidx);
|
||||
Smem_tile_dot smem_dot(&smem_[0], tidx);
|
||||
// Allocate the shared memory tile loader for Q^T.
|
||||
// TODO: assert that this points to the same memory as gemm_q_k.smem_q
|
||||
Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
|
||||
|
||||
Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx);
|
||||
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
|
||||
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M;
|
||||
|
||||
// Wind gmem tiles to the correct position.
|
||||
int block_row_idx_next = mask_val / 4;
|
||||
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
|
||||
block_row_idx = block_row_idx_next;
|
||||
gmem_q.move(block_row_idx_to_move);
|
||||
gmem_do.move(block_row_idx_to_move);
|
||||
gmem_o.move(block_row_idx_to_move);
|
||||
gmem_dq.move(block_row_idx_to_move);
|
||||
gmem_dq_tmp.move(block_row_idx_to_move);
|
||||
// TODO: need to move gmem_s if we want the intermediate result for debugging
|
||||
gmem_softmax_lse.move(block_row_idx_to_move);
|
||||
gmem_softmax_d.move(block_row_idx_to_move);
|
||||
block_row_idx = block_row_idx_next;
|
||||
|
||||
if (!Is_first) {
|
||||
gmem_k.move(loop_step_idx);
|
||||
gmem_v.move(loop_step_idx);
|
||||
}
|
||||
|
||||
// Trigger the loads for K.
|
||||
gmem_k.load();
|
||||
// Trigger the loads for Q.
|
||||
gmem_q.load();
|
||||
// Trigger the loads for V.
|
||||
gmem_v.load();
|
||||
// Trigger the loads for dO.
|
||||
gmem_do.load();
|
||||
// Trigger the loads for O.
|
||||
// if (Is_first) { gmem_o.load(); }
|
||||
// if (true) { gmem_o.load(); }
|
||||
if (Is_first || mask_val % 2 == 1) { gmem_o.load(); }
|
||||
|
||||
float p_lse[Mma_tile_p::MMAS_M * 2];
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
|
||||
float dp_sum[Mma_tile_p::MMAS_M * 2];
|
||||
// if (!Is_first) {
|
||||
// if (false) {
|
||||
if (!(Is_first || mask_val % 2 == 1)) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
}
|
||||
|
||||
float dp_sum_regs[Gmem_tile_do::LDGS];
|
||||
Smem_dp_sum smem_dp_sum(reinterpret_cast<float *>(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx);
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
// Commit the data for Q, dO, and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_do.commit(smem_do);
|
||||
// if (Is_first) {
|
||||
// if (true) {
|
||||
if (Is_first || mask_val % 2 == 1) {
|
||||
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0);
|
||||
const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW;
|
||||
if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row);
|
||||
}
|
||||
}
|
||||
|
||||
// Instead of scaling dP by rp_dropout, we scale V instead
|
||||
if (Is_dropout) {
|
||||
const uint32_t scale_dropout = params.scale_dropout;
|
||||
#pragma unroll
|
||||
for(int it=0; it < Gmem_tile_v::LDGS; it++){
|
||||
gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
|
||||
}
|
||||
}
|
||||
|
||||
gmem_v.commit(smem_v);
|
||||
|
||||
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
|
||||
// #pragma unroll
|
||||
// for(int it=0; it < Gmem_tile_k::LDGS; it++){
|
||||
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
|
||||
// }
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire kernel.
|
||||
typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N];
|
||||
if (Kernel_traits::V_IN_REGS) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
smem_v.load(frag_v[ki], ki);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for V.
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
// Load the fragments for K^T.
|
||||
// typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
// smem_kt.load(frag_kt[0], 0);
|
||||
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
|
||||
// #pragma unroll
|
||||
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
|
||||
// smem_kt.load(frag_kt[ki], ki);
|
||||
// }
|
||||
|
||||
// Create the object to do the softmax.
|
||||
// We won't be using the shared memory for this softmax at all
|
||||
Softmax softmax(params, smem_, tidx);
|
||||
|
||||
// Declare the accumulators for the 3rd gemm.
|
||||
fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dv);
|
||||
fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for( int l = 0; l < steps; l++ ) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("block_row_idx = %d\n", block_row_idx);
|
||||
// }
|
||||
if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
|
||||
|
||||
int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("mask_val = %d, mask_val_next = %d\n", mask_val, mask_val_next);
|
||||
// }
|
||||
|
||||
// Load the fragments for V.
|
||||
// typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N];
|
||||
if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); }
|
||||
|
||||
// Load the fragments for dO.
|
||||
typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M];
|
||||
smem_do.load(frag_do[0], 0);
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
|
||||
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
gemm_q_k(acc_p);
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(block_row_idx);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
// Scale by log-sum-exp of the softmax
|
||||
// softmax.apply_exp(p_lse);
|
||||
softmax.template scale_apply_exp</*scale_max=*/false>(p_lse, params.scale_bmm1f);
|
||||
if (Is_dropout) {
|
||||
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
|
||||
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
|
||||
}
|
||||
|
||||
using Frag_p = fmha::Fragment_a<fmha::Row>;
|
||||
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
|
||||
// Store s * dmask to smem for transpose
|
||||
smem_s.store(frag_p);
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
bool not_last_iter = (l < steps - 1) && (mask_val_next != -1);
|
||||
block_row_idx_next = mask_val_next / 4;
|
||||
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
|
||||
if (not_last_iter) {
|
||||
gemm_q_k.smem_q.move_to_next_write_buffer();
|
||||
gmem_q.move(block_row_idx_to_move);
|
||||
gmem_q.load();
|
||||
}
|
||||
|
||||
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
|
||||
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
// __syncthreads();
|
||||
// }
|
||||
|
||||
bool is_first_read = Is_first || mask_val % 2 == 1;
|
||||
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
|
||||
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
|
||||
// later. This is because loading dp_sum earlier uses more registers.
|
||||
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
// if (Is_first) {
|
||||
// if (true) {
|
||||
if (is_first_read) {
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 8; ++ii) {
|
||||
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do this part of dP^T = (dO * V^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of dO values.
|
||||
smem_do.load(frag_do[ki & 1], ki);
|
||||
if (!Kernel_traits::V_IN_REGS) {
|
||||
smem_v.load(frag_v[ki & 1], ki);
|
||||
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
|
||||
} else {
|
||||
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
|
||||
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
|
||||
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
|
||||
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1]));
|
||||
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
|
||||
// }
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
if (!Kernel_traits::V_IN_REGS) {
|
||||
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
|
||||
} else {
|
||||
fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the fragments for K^T.
|
||||
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
smem_kt.load(frag_kt[0], 0);
|
||||
|
||||
// if (Is_first) {
|
||||
// if (true) {
|
||||
if (is_first_read) {
|
||||
const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4;
|
||||
const int row[2] = {quad, quad + 8};
|
||||
smem_dp_sum.load(dp_sum, row, l % 2);
|
||||
}
|
||||
|
||||
// Trigger the load for the next dO values.
|
||||
if (not_last_iter) {
|
||||
smem_do.move_to_next_write_buffer();
|
||||
gmem_do.move(block_row_idx_to_move);
|
||||
gmem_do.load();
|
||||
gmem_o.move(block_row_idx_to_move);
|
||||
// if (Is_first) {
|
||||
// if (true) {
|
||||
if (Is_first || mask_val_next % 2 == 1) {
|
||||
gmem_o.load();
|
||||
}
|
||||
}
|
||||
|
||||
softmax.unpack_noscale(acc_dp);
|
||||
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
|
||||
// // will be zero.
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
|
||||
// if (Is_first) { softmax.subtract_dp_sum(dp_sum); }
|
||||
// if (true) { softmax.subtract_dp_sum(dp_sum); }
|
||||
if (is_first_read) { softmax.subtract_dp_sum(dp_sum); }
|
||||
|
||||
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
softmax.template pack<__half>(frag_dp);
|
||||
|
||||
if (!Is_dropout) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
|
||||
frag_p[mi][ni].hmul(frag_dp[mi][ni]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
__half2 dp_sum_half[Mma_tile_p::MMAS_M * 2];
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
|
||||
dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]);
|
||||
}
|
||||
const __half zero_h = __half(0.f);
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii) {
|
||||
const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii);
|
||||
const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii));
|
||||
// If this element is dropped, then frag_p stores -p instead of p.
|
||||
// So pd holds -p * dp_sum in that case.
|
||||
const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]);
|
||||
const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd);
|
||||
const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd);
|
||||
frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store dp to smem for transpose
|
||||
smem_dp.store(frag_p);
|
||||
|
||||
// gmem_s.store(frag_p, mask);
|
||||
// gmem_s.move();
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dq::WARPS_K>::apply(acc_dq);
|
||||
|
||||
// Do this part of O = P^T * V^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_kt.load(frag_kt[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
|
||||
// fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dq::MMAS_K;
|
||||
fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
|
||||
// fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
|
||||
}
|
||||
|
||||
static_assert(Gmem_tile_dq::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
smem_dq.store(acc_dq, 0);
|
||||
|
||||
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
|
||||
static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4);
|
||||
static_assert(Mma_tile_dkv::MMAS_K == 1);
|
||||
smem_dot.load(frag_dot[0], 0);
|
||||
|
||||
// Threads in a warp is communicating via shared memory (smem_s and smem_dp)
|
||||
__syncwarp();
|
||||
typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
smem_s.load(frag_s);
|
||||
|
||||
if (Is_dropout) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
frag_s[ki][mi].template hrelu_<__half>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_dot.load(frag_dot[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dkv::MMAS_K;
|
||||
fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// __syncthreads();
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (not_last_iter) {
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
}
|
||||
|
||||
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
|
||||
// if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); }
|
||||
if (!is_first_read) { gmem_dq_tmp.load(dq_out, 0); }
|
||||
|
||||
// __syncthreads();
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (not_last_iter) {
|
||||
gmem_do.commit(smem_do);
|
||||
// if (Is_first) {
|
||||
// if (true) {
|
||||
gmem_softmax_d.move(block_row_idx_to_move);
|
||||
if (Is_first || mask_val_next % 2 == 1) {
|
||||
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
|
||||
// smem_dp_sum.move_to_next_write_buffer();
|
||||
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2);
|
||||
const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW;
|
||||
if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row_1);
|
||||
}
|
||||
}
|
||||
gmem_softmax_lse.move(block_row_idx_to_move);
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
// if (!Is_first) {
|
||||
if (!(Is_first || mask_val_next % 2 == 1)) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
}
|
||||
}
|
||||
|
||||
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
smem_dp.load(frag_dpt);
|
||||
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N];
|
||||
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
|
||||
static_assert(Mma_tile_dkv::MMAS_K == 1);
|
||||
smem_qt.load(frag_qt[0], 0);
|
||||
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_qt.load(frag_qt[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dkv::MMAS_K;
|
||||
fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Make sure dQ is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
// Load from shared memory.
|
||||
is_first_read ? smem_dq.template load</*zero_init=*/true>(dq_out) : smem_dq.template load</*zero_init=*/false>(dq_out);
|
||||
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((mask_val & 0x2) != 0)
|
||||
|| ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
if (is_final_write) {
|
||||
// if (Is_dropout) {
|
||||
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
|
||||
// }
|
||||
dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f);
|
||||
// Output the values.
|
||||
gmem_dq.template store<__half>(dq_out, 0);
|
||||
} else {
|
||||
// Output the values.
|
||||
gmem_dq_tmp.store(dq_out, 0);
|
||||
}
|
||||
|
||||
// Move to the next part of the output.
|
||||
gmem_dq.move(block_row_idx_to_move);
|
||||
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(block_row_idx_to_move); }
|
||||
|
||||
// // Make sure the data is in shared memory.
|
||||
// __syncthreads();
|
||||
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (not_last_iter) {
|
||||
gemm_q_k.smem_q.move_to_next_read_buffer();
|
||||
gemm_q_k.reload_q();
|
||||
smem_qt.move_to_next_read_buffer();
|
||||
// smem_qt.load(frag_qt[0], 0);
|
||||
smem_do.move_to_next_read_buffer();
|
||||
smem_dot.move_to_next_read_buffer();
|
||||
// smem_dot.load(frag_dot[0], 0);
|
||||
}
|
||||
|
||||
if (mask_val_next == -1) break;
|
||||
mask_val = mask_val_next;
|
||||
block_row_idx += block_row_idx_to_move;
|
||||
|
||||
} // Outer loop over the sequence length.
|
||||
|
||||
if (Is_dropout) {
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
|
||||
acc_dv[mi][ni].mul_(params.rp_dropout);
|
||||
}
|
||||
}
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
|
||||
// }
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
|
||||
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
|
||||
acc_dk[mi][ni].mul_(params.scale_bmm1f);
|
||||
}
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
|
||||
// }
|
||||
|
||||
__syncthreads();
|
||||
// TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than
|
||||
// the total amount of shared mem?
|
||||
// Epilogue swizzle for dV
|
||||
Smem_tile_dv smem_dv(&smem_[0], tidx);
|
||||
smem_dv.template store<__half>(acc_dv);
|
||||
|
||||
// Epilogue swizzle for dK
|
||||
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
|
||||
smem_dk.template store<__half>(acc_dk);
|
||||
|
||||
__syncthreads();
|
||||
uint4 dv_out[Smem_tile_dv::NUM_LDS];
|
||||
smem_dv.load(dv_out);
|
||||
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dv.move(loop_step_idx);
|
||||
}
|
||||
gmem_dv.store(dv_out);
|
||||
|
||||
uint4 dk_out[Smem_tile_dk::NUM_LDS];
|
||||
smem_dk.load(dk_out);
|
||||
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
|
||||
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
|
||||
// }
|
||||
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dk.move(loop_step_idx);
|
||||
}
|
||||
gmem_dk.store(dk_out);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N.
|
||||
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
|
||||
inline __device__ void compute_block_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx;
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
|
||||
|
||||
if (loop_steps == 1) {
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
} else if (loop_steps == 2) {
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
|
||||
} else {
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
|
||||
}
|
||||
compute_block_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@@ -1,90 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include "fmha.h"
|
||||
#include "fmha_block_fprop_kernel_1xN.h"
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
|
||||
__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) {
|
||||
fmha::device_block_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_block_fp16_sm80_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
bool is_causal = launch_params.params.is_causal;
|
||||
// TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way?
|
||||
auto kernel = launch_params.is_dropout
|
||||
? (is_causal
|
||||
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, true, false>)
|
||||
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, true, false, false>))
|
||||
: (is_causal
|
||||
? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, true, false>)
|
||||
: (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, true> : &fmha_block_fprop_fp16_sm80_loop_kernel<Kernel_traits, false, false, false>));
|
||||
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
|
||||
// Don't need smem_size_softmax_lse if we're not looping
|
||||
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
|
||||
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
|
||||
|
||||
if( smem_size >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
|
||||
if (configure) {
|
||||
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
|
||||
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
|
||||
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
|
||||
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
|
||||
launch_params.elts_per_thread = elts_per_head;
|
||||
return;
|
||||
}
|
||||
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
if (launch_params.params.d == 16) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>;
|
||||
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else if (launch_params.params.d == 32) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>;
|
||||
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else if (launch_params.params.d == 64) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
|
||||
run_fmha_block_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
}
|
||||
}
|
||||
@@ -1,533 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fmha_fprop_kernel_1xN.h"
|
||||
#include "fmha_kernel.h"
|
||||
#include "fmha_blockmask.h"
|
||||
#include <fmha/kernel_traits.h>
|
||||
#include <fmha/gemm.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
|
||||
inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
|
||||
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
|
||||
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
|
||||
|
||||
// The global memory tile to store O.
|
||||
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
|
||||
using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
|
||||
// The shared memory tile to swizzle O.
|
||||
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
|
||||
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Blockmask blockmask(params, loop_step_idx);
|
||||
int block_row_idx = 0;
|
||||
int mask_val = blockmask.mask_val(0);
|
||||
if (mask_val == -1) return;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("mask_val = %d.\n", mask_val);
|
||||
// }
|
||||
|
||||
Gemm1 gemm_q_k(smem_, tidx);
|
||||
// Allocate the global memory tile loader for Q.
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
|
||||
// Wind gmem tiles to the correct position.
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
int block_row_idx_next = mask_val / 4;
|
||||
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
|
||||
gmem_q.move(block_row_idx_to_move);
|
||||
gmem_o.move(block_row_idx_to_move);
|
||||
gmem_o_tmp.move(block_row_idx_to_move);
|
||||
if (Return_softmax) { gmem_s.move(block_row_idx_to_move); }
|
||||
gmem_softmax_lse.move(block_row_idx_to_move);
|
||||
block_row_idx = block_row_idx_next;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("begin = %d, steps = %d\n", begin, steps);
|
||||
// }
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// Allocate the global memory tile loader for V.
|
||||
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
Smem_tile_v smem_v(smem_v_, tidx);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
if (!Is_first) {
|
||||
gmem_k.move(loop_step_idx);
|
||||
gmem_v.move(loop_step_idx);
|
||||
if (Return_softmax) { gmem_s.move(loop_step_idx * steps); }
|
||||
}
|
||||
|
||||
// Trigger the loads for K.
|
||||
gmem_k.load();
|
||||
// Trigger the loads for Q.
|
||||
gmem_q.load();
|
||||
// Trigger the loads for V.
|
||||
gmem_v.load();
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
|
||||
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
|
||||
if (!(Is_first || mask_val % 2 == 1)) {
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
|
||||
}
|
||||
|
||||
// Commit the data for Q and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_v.commit(smem_v);
|
||||
|
||||
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
|
||||
// #pragma unroll
|
||||
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
|
||||
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
|
||||
// }
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire kernel.
|
||||
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
smem_v.load(frag_v[ki], ki);
|
||||
}
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for V.
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
|
||||
// Create the object to do the softmax.
|
||||
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
|
||||
|
||||
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for( int l = 0; l < steps; l++ ) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("block_row_idx = %d\n", block_row_idx);
|
||||
// }
|
||||
if (block_row_idx * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
|
||||
|
||||
int mask_val_next = l < steps - 1 ? blockmask.mask_val(l + 1) : -1;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("mask_val = %d, mask_val_next = %d\n", mask_val, mask_val_next);
|
||||
// }
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
|
||||
|
||||
// Do this part of P = Q * K^T.
|
||||
gemm_q_k(acc_p);
|
||||
|
||||
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
|
||||
bool is_first_read = Is_first || mask_val % 2 == 1;
|
||||
// if (!Is_first) { gmem_o_tmp.load(out, 0); }
|
||||
if (!is_first_read) { gmem_o_tmp.load(out, 0); }
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
bool not_last_iter = (l < steps - 1) && (mask_val_next != -1);
|
||||
block_row_idx_next = mask_val_next / 4;
|
||||
int block_row_idx_to_move = block_row_idx_next - block_row_idx;
|
||||
if (not_last_iter) {
|
||||
gemm_q_k.smem_q.move_to_next_write_buffer();
|
||||
gmem_q.move(block_row_idx_to_move);
|
||||
gmem_q.load();
|
||||
}
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(block_row_idx);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
|
||||
// softmax.unpack_noscale_half_and_apply_mask(acc_p, mask);
|
||||
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
|
||||
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
__syncthreads();
|
||||
}
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
|
||||
// }
|
||||
// }
|
||||
// Compute the max.
|
||||
float p_max[Mma_tile_p::MMAS_M * 2];
|
||||
// if (!Is_first) {
|
||||
if (!is_first_read) {
|
||||
smem_softmax_lse.store_pair(p_prev_lse, l % 2);
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
|
||||
}
|
||||
|
||||
// Trigger the load for the next LSE values.
|
||||
if (not_last_iter) {
|
||||
// if (!Is_first) {
|
||||
if (!(Is_first || mask_val_next % 2 == 1)) {
|
||||
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
|
||||
block_row_idx_to_move);
|
||||
}
|
||||
}
|
||||
|
||||
// __half2 p_max[Mma_tile_p::MMAS_M];
|
||||
// softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
|
||||
is_first_read ? softmax.template reduce_max</*zero_init=*/true>(p_max) : softmax.template reduce_max</*zero_init=*/false>(p_max);
|
||||
|
||||
// if ((threadIdx.x == 0) && (l == 38)) {
|
||||
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
|
||||
// }
|
||||
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute the exponential value.
|
||||
// softmax.apply_exp(p_max);
|
||||
softmax.scale_apply_exp(p_max, params.scale_bmm1f);
|
||||
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute the sum.
|
||||
float p_sum[Mma_tile_p::MMAS_M * 2];
|
||||
// if (!Is_first) {
|
||||
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
|
||||
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
|
||||
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
|
||||
// }
|
||||
// }
|
||||
// softmax.reduce_sum(p_sum);
|
||||
softmax.reduce_sum_before_sync_(p_sum);
|
||||
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
|
||||
|
||||
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
|
||||
// float sum = p_sum[mi];
|
||||
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
|
||||
// constexpr float kLog2e = M_LOG2E;
|
||||
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
|
||||
// }
|
||||
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
|
||||
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
|
||||
// gmem_softmax_lse.move();
|
||||
|
||||
// // Finalize softmax on the accumulators of P^T.
|
||||
// softmax.scale(p_sum);
|
||||
|
||||
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
|
||||
if (Is_dropout) {
|
||||
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint);
|
||||
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
|
||||
}
|
||||
|
||||
using Frag_p = fmha::Fragment_a<fmha::Row>;
|
||||
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
if (Return_softmax) {
|
||||
gmem_s.store(frag_p, mask);
|
||||
if (not_last_iter) {
|
||||
gmem_s.move(block_row_idx_to_move);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit the values for Q into shared memory.
|
||||
if (not_last_iter) {
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
}
|
||||
|
||||
if (Is_dropout && encode_dropout_in_sign_bit) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
|
||||
frag_p[ki][mi].template hrelu_<__half>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
|
||||
|
||||
// Do this part of O = P^T * V^T.
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]);
|
||||
}
|
||||
|
||||
// The mapping from tidx to rows changes between the softmax and the O-reduction.
|
||||
// So we recalculate the max.
|
||||
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
int rows[Gmem_tile_o::STGS_PER_LOOP];
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
||||
}
|
||||
softmax.reduce_max_after_sync_(p_max_o, rows);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
p_max_o[jj][0] *= params.scale_bmm1f;
|
||||
}
|
||||
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
|
||||
// if (!Is_first) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); }
|
||||
if (!is_first_read) { smem_softmax_lse.load(p_prev_scale_o, rows, l % 2); }
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
|
||||
// }
|
||||
// }
|
||||
|
||||
static_assert(Gmem_tile_o::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
smem_o.store(acc_o, 0);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
||||
// if (!Is_first) {
|
||||
if (!is_first_read) {
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
|
||||
p_sum_o[jj][0] += p_prev_scale_o[jj];
|
||||
}
|
||||
}
|
||||
|
||||
float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
|
||||
// if (sum == 0.f || sum != sum) {
|
||||
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
|
||||
// }
|
||||
// if (Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
|
||||
// }
|
||||
// }
|
||||
if ((tidx % Gmem_tile_o::THREADS_PER_ROW == 0) && (tidx / Gmem_tile_o::THREADS_PER_ROW < Gmem_tile_o::ROWS)) {
|
||||
gmem_softmax_lse.store_row(
|
||||
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
|
||||
}
|
||||
}
|
||||
if (not_last_iter) {
|
||||
gmem_softmax_lse.move(block_row_idx_to_move);
|
||||
}
|
||||
|
||||
// Load from shared memory.
|
||||
// if (!Is_first) {
|
||||
if (!is_first_read) {
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
|
||||
}
|
||||
}
|
||||
// smem_o.template load</*zero_init=*/Is_first>(out);
|
||||
is_first_read ? smem_o.template load</*zero_init=*/true>(out) : smem_o.template load</*zero_init=*/false>(out);
|
||||
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((mask_val & 0x2) != 0)
|
||||
|| ((Is_causal) && (block_row_idx * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("is_final_write = %d\n", is_final_write);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
if (Is_dropout && is_final_write) {
|
||||
inv_sum *= params.rp_dropout;
|
||||
}
|
||||
out[jj] = fmha::fmul4(out[jj], inv_sum);
|
||||
}
|
||||
|
||||
// if (Is_dropout && Is_last) {
|
||||
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Output the values.
|
||||
if (is_final_write) {
|
||||
gmem_o.template store<__half>(out, 0);
|
||||
} else {
|
||||
gmem_o_tmp.store(out, 0);
|
||||
}
|
||||
|
||||
// Move to the next part of the output.
|
||||
gmem_o.move(block_row_idx_to_move);
|
||||
if (!(Is_first && Is_last)) { gmem_o_tmp.move(block_row_idx_to_move); }
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
// Make sure we are reading from the correct buffer.
|
||||
gemm_q_k.smem_q.move_to_next_read_buffer();
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
if (not_last_iter) {
|
||||
gemm_q_k.reload_q();
|
||||
}
|
||||
|
||||
if (mask_val_next == -1) break;
|
||||
mask_val = mask_val_next;
|
||||
block_row_idx += block_row_idx_to_move;
|
||||
|
||||
} // Outer loop over the sequence length.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
|
||||
inline __device__ void device_block_1xN_loop(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx;
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
|
||||
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
const int STEPS = (params.seqlen_q + M - 1) / M;
|
||||
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph0, ph1, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph0, ph1, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph0, ph1, loop_step_idx);
|
||||
}
|
||||
fmha::device_block_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph0, ph1, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fmha.h>
|
||||
#include <fmha/utils.h>
|
||||
#include <fmha/smem_tile.h>
|
||||
#include <fmha/gmem_tile.h>
|
||||
#include <fmha/mask.h>
|
||||
#include <fmha/softmax.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Blockmask {
|
||||
|
||||
template<typename Params>
|
||||
__device__ Blockmask(const Params ¶ms, int loop_step_idx) :
|
||||
blockmask_ptr(params.blockmask + loop_step_idx * params.seqlen_q / 16) {
|
||||
}
|
||||
|
||||
__device__ int mask_val(int block_row_idx) const {
|
||||
return blockmask_ptr[block_row_idx];
|
||||
}
|
||||
|
||||
const int *blockmask_ptr;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "fmha_bwd_launch_template.h"
|
||||
|
||||
void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(params.is_bf16, ([&] {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "fmha_bwd_launch_template.h"
|
||||
|
||||
void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(params.is_bf16, ([&] {
|
||||
if (params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (params.seqlen_k >= 256) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "fmha_bwd_launch_template.h"
|
||||
|
||||
void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(params.is_bf16, ([&] {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (params.seqlen_k >= 256) {
|
||||
if ((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0)) {
|
||||
// Don't share smem for K & V, and don't keep V in registers
|
||||
// This speeds things up by 2-3% by avoiding register spills, but it
|
||||
// uses more shared memory, which is fine on A100 and H100 but not other GPUs.
|
||||
// For other GPUs, we keep V in registers.
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (dprops->major == 8 && dprops->minor > 0) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (dprops->major == 7 && dprops->minor == 5) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "fmha.h"
|
||||
#include "fmha_dgrad_kernel_1xN_loop.h"
|
||||
|
||||
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
|
||||
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
|
||||
// dq_tmp and having to copy dq_tmp to dq.
|
||||
inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
|
||||
int blocksize, bool is_causal) {
|
||||
float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
|
||||
float eff_1 = n_waves_1 / ceil(n_waves_1);
|
||||
int num_splits_parallel = seqlen / blocksize;
|
||||
float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm);
|
||||
float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel);
|
||||
float discount_factor;
|
||||
if (!is_causal) {
|
||||
discount_factor = 1.f + float(blocksize) / seqlen;
|
||||
} else { // For causal, parallelizing seems to help with load-balancing as well
|
||||
// For example, if headdim=128, seqlen >= 1280 always prefers parallel
|
||||
if (seqlen / blocksize >= 10) return num_splits_parallel;
|
||||
discount_factor = 1.f + 0.5 * float(blocksize) / seqlen;
|
||||
}
|
||||
float eff_parallel = eff_parallel_raw / discount_factor;
|
||||
return eff_1 >= eff_parallel ? 1 : num_splits_parallel;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_dot_do_o<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
|
||||
__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
|
||||
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
|
||||
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
|
||||
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
|
||||
|
||||
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
|
||||
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
|
||||
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
|
||||
|
||||
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
|
||||
|
||||
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, ([&] {
|
||||
auto kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
|
||||
} else if (params.seqlen_k == blocksize_c * 2) {
|
||||
kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
|
||||
}
|
||||
auto kernel_seqparallel = params.is_causal
|
||||
? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
|
||||
: &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
|
||||
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
// Automatically set num_splits to maximize occupancy
|
||||
if (params.num_splits <= 0) {
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv);
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
// We don't want more than 10 splits due to numerical error.
|
||||
// Numerical error on dk/dv scales as sqrt(num_splits).
|
||||
params.num_splits = num_splits_heuristic_bwd(
|
||||
params.b * params.h, dprops->multiProcessorCount,
|
||||
ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal
|
||||
);
|
||||
}
|
||||
if (configure) return;
|
||||
if (params.num_splits == 1) {
|
||||
dim3 grid(params.b, params.h, params.num_splits);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
} else {
|
||||
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
|
||||
fmha_bwd_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
|
||||
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
|
||||
dim3 grid(params.b, params.h, num_splits);
|
||||
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
}
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}));
|
||||
}
|
||||
@@ -1,841 +0,0 @@
|
||||
/* Copyright (c) 2022, Tri Dao.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fmha_fprop_kernel_1xN.h"
|
||||
#include "fmha_kernel.h"
|
||||
#include <fmha/kernel_traits.h>
|
||||
#include <fmha/gemm.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int ROWS, int THREADS_PER_ROW, typename elem_type=__half, int M, typename Gmem_softmax_sum>
|
||||
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale,
|
||||
Gmem_softmax_sum gmem_softmax_d, int tidx) {
|
||||
float sum[M];
|
||||
fmha::SumOp<float> sum_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; ++mi) {
|
||||
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(
|
||||
fmha::hmulsum8<elem_type>(do_[mi], o[mi]), sum_op
|
||||
) * scale;
|
||||
}
|
||||
const int dp_sum_row = tidx / THREADS_PER_ROW;
|
||||
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<const uint32_t (&)[M]>(sum), dp_sum_row);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
#else
|
||||
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
|
||||
assert(is_fp16_type);
|
||||
using elem_type = __half;
|
||||
#endif
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 3rd batched GEMM.
|
||||
using Cta_tile_dkv =
|
||||
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
|
||||
|
||||
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
|
||||
static_assert(Cta_tile_dkv::K == 16);
|
||||
|
||||
// The global memory tile to load dO.
|
||||
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
|
||||
|
||||
// The global memory tile to load O.Loading O here is similar to loading dO.
|
||||
using Gmem_tile_o = Gmem_tile_do;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// How many steps to jump per iteration.
|
||||
const int step_stride = gridDim.z;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
if( binfo.stop_early() ) return;
|
||||
|
||||
// Allocate the global memory tile loader for dO.
|
||||
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
|
||||
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M;
|
||||
// Wind gmem tiles to the correct position.
|
||||
gmem_do.move(blockIdx.z);
|
||||
gmem_o.move(blockIdx.z);
|
||||
gmem_softmax_d.move(blockIdx.z);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for (int l = blockIdx.z; l < steps; l += step_stride) {
|
||||
if (l * Cta_tile_p::M >= binfo.actual_seqlen_q)
|
||||
break;
|
||||
|
||||
gmem_do.load();
|
||||
gmem_do.move(step_stride);
|
||||
gmem_o.load();
|
||||
gmem_o.move(step_stride);
|
||||
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
gmem_softmax_d.move(step_stride);
|
||||
} // Outer loop over the sequence length.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params, typename Prng>
|
||||
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph,
|
||||
const int loop_step_idx) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
#else
|
||||
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
|
||||
assert(is_fp16_type);
|
||||
using elem_type = __half;
|
||||
#endif
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_dq = typename Kernel_traits::Cta_tile_o;
|
||||
// The description of the CTA tile for the 3rd batched GEMM.
|
||||
using Cta_tile_dkv =
|
||||
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
|
||||
|
||||
static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128);
|
||||
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
|
||||
static_assert(Cta_tile_dkv::K == 16);
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_dq = fmha::Hmma_tile<Cta_tile_dq>;
|
||||
// The MMA tile for the 3rd GEMM.
|
||||
using Mma_tile_dkv = fmha::Hmma_tile<Cta_tile_dkv>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
|
||||
// The shared memory tile to reload Q transposed.
|
||||
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
|
||||
// The shared memory tile to swizzle K^T. Treat K^T as V
|
||||
using Smem_tile_kt = typename Kernel_traits::Smem_tile_v;
|
||||
|
||||
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_k;
|
||||
|
||||
// The global memory tile to load dO.
|
||||
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
|
||||
// The shared memory tile to load dO.
|
||||
// Treating dO as Q.
|
||||
using Smem_tile_do = typename Kernel_traits::Smem_tile_q;
|
||||
// The shared memory tile to reload dO transposed.
|
||||
using Smem_tile_dot = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load O.Loading O here is similar to loading dO.
|
||||
using Gmem_tile_o = Gmem_tile_do;
|
||||
|
||||
// The global memory tile to store dQ.
|
||||
using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
|
||||
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
|
||||
// The shared memory tile to swizzle dQ.
|
||||
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
|
||||
|
||||
// The global memory tile to store dV.
|
||||
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle dV.
|
||||
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
|
||||
|
||||
// The global memory tile to store dK.
|
||||
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle dK.
|
||||
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
|
||||
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
|
||||
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false, elem_type>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// Shared memory layout if we keep V in registers:
|
||||
// dO | Q | K / V | dQ | S | dP | dP_sum
|
||||
// dV | dK
|
||||
// Shared memory layout if we keep V shared memory:
|
||||
// dO | Q | K | V | dQ | S | dP | dP_sum
|
||||
// dV | dK
|
||||
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
|
||||
// Allocate the global memory tile loader for Q.
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for dQ.
|
||||
Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// Allocate the global memory tile loader for V.
|
||||
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
Smem_tile_v smem_v(smem_v_, tidx);
|
||||
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
|
||||
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for dO.
|
||||
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the shared memory tile loader for dO.
|
||||
Smem_tile_do smem_do(&smem_[0], tidx);
|
||||
Smem_tile_dot smem_dot(&smem_[0], tidx);
|
||||
// Allocate the shared memory tile loader for Q^T.
|
||||
// TODO: assert that this points to the same memory as gemm_q_k.smem_q
|
||||
Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
|
||||
|
||||
Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx);
|
||||
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
|
||||
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
|
||||
// Otherwise we'd be reading out-of-bound memory before the loop
|
||||
if (begin * Cta_tile_p::M >= binfo.actual_seqlen_q) {
|
||||
// Still need to zero out dk and dv before returning
|
||||
static_assert(Smem_tile_dk::NUM_LDS == Smem_tile_dv::NUM_LDS);
|
||||
uint4 dkv_out[Smem_tile_dk::NUM_LDS];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Smem_tile_dk::NUM_LDS; ++i) { dkv_out[i] = make_uint4(0u, 0u, 0u, 0u); }
|
||||
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) { gmem_dk.move(loop_step_idx); }
|
||||
gmem_dk.store(dkv_out);
|
||||
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) { gmem_dv.move(loop_step_idx); }
|
||||
gmem_dv.store(dkv_out);
|
||||
return;
|
||||
}
|
||||
|
||||
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
|
||||
// Wind gmem tiles to the correct position.
|
||||
gmem_q.move(begin);
|
||||
gmem_do.move(begin);
|
||||
gmem_o.move(begin);
|
||||
if (!Seq_parallel) { gmem_dq.move(begin); } // If Seq_parallel, we're not using gmem_dq at all
|
||||
gmem_dq_tmp.move(begin);
|
||||
// TODO: need to move gmem_s if we want the intermediate result for debugging
|
||||
gmem_softmax_lse.move(begin);
|
||||
gmem_softmax_d.move(begin);
|
||||
|
||||
if (!Is_first) {
|
||||
gmem_k.move(loop_step_idx);
|
||||
gmem_v.move(loop_step_idx);
|
||||
}
|
||||
|
||||
// Trigger the loads for K.
|
||||
gmem_k.load();
|
||||
// Trigger the loads for Q.
|
||||
gmem_q.load();
|
||||
// Trigger the loads for V.
|
||||
gmem_v.load();
|
||||
// Trigger the loads for dO.
|
||||
gmem_do.load();
|
||||
// Trigger the loads for O.
|
||||
if (Is_first) { gmem_o.load(); }
|
||||
|
||||
float p_lse[Mma_tile_p::MMAS_M * 2];
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
// Commit the data for Q, dO, and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_do.commit(smem_do);
|
||||
if (Is_first) {
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
|
||||
// // Instead of scaling dP by rp_dropout, we scale V instead
|
||||
// if (Is_dropout) {
|
||||
// const uint32_t scale_dropout = params.scale_dropout;
|
||||
// #pragma unroll
|
||||
// for(int it=0; it < Gmem_tile_v::LDGS; it++){
|
||||
// gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
|
||||
// }
|
||||
// }
|
||||
|
||||
gmem_v.commit(smem_v);
|
||||
|
||||
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
|
||||
// #pragma unroll
|
||||
// for(int it=0; it < Gmem_tile_k::LDGS; it++){
|
||||
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
|
||||
// }
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire kernel.
|
||||
typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N];
|
||||
if (Kernel_traits::V_IN_REGS) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
smem_v.load(frag_v[ki], ki);
|
||||
}
|
||||
}
|
||||
|
||||
float dp_sum[Mma_tile_p::MMAS_M * 2];
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for V.
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
// Load the fragments for K^T.
|
||||
// typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
// smem_kt.load(frag_kt[0], 0);
|
||||
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
|
||||
// #pragma unroll
|
||||
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
|
||||
// smem_kt.load(frag_kt[ki], ki);
|
||||
// }
|
||||
|
||||
// Create the object to do the softmax.
|
||||
// We won't be using the shared memory for this softmax at all
|
||||
Softmax softmax(params, smem_, tidx);
|
||||
|
||||
// Declare the accumulators for the 3rd gemm.
|
||||
fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dv);
|
||||
fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for (int l = 0; l < steps; l++) {
|
||||
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q)
|
||||
break;
|
||||
|
||||
// Load the fragments for V.
|
||||
// typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N];
|
||||
if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); }
|
||||
|
||||
// Load the fragments for dO.
|
||||
typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M];
|
||||
smem_do.load(frag_do[0], 0);
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
|
||||
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
gemm_q_k(acc_p);
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(begin + l);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
// Scale by log-sum-exp of the softmax
|
||||
// softmax.apply_exp(p_lse);
|
||||
softmax.template scale_apply_exp</*scale_max=*/false>(p_lse, params.scale_bmm1f);
|
||||
if (Is_dropout) {
|
||||
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
|
||||
unsigned int warp_idx = threadIdx.x / 32;
|
||||
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
|
||||
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
|
||||
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
|
||||
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
|
||||
}
|
||||
|
||||
using Frag_p = fmha::Fragment_a<fmha::Row>;
|
||||
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.template pack<elem_type>(frag_p);
|
||||
|
||||
// Store s * dmask to smem for transpose
|
||||
smem_s.store(frag_p);
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
if (l + 1 < steps) {
|
||||
gemm_q_k.smem_q.move_to_next_write_buffer();
|
||||
gmem_q.move();
|
||||
gmem_q.load();
|
||||
}
|
||||
|
||||
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
|
||||
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
// __syncthreads();
|
||||
// }
|
||||
|
||||
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 8; ++ii) {
|
||||
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do this part of dP^T = (dO * V^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of dO values.
|
||||
smem_do.load(frag_do[ki & 1], ki);
|
||||
if (!Kernel_traits::V_IN_REGS) {
|
||||
smem_v.load(frag_v[ki & 1], ki);
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
|
||||
} else {
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
|
||||
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
|
||||
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
|
||||
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1]));
|
||||
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
|
||||
// }
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
if (!Kernel_traits::V_IN_REGS) {
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
|
||||
} else {
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
|
||||
}
|
||||
}
|
||||
|
||||
auto pointwise_mult = [](float p, float dp, float d) {
|
||||
return p * ((!Is_dropout) || p >= 0.f ? dp : d);
|
||||
};
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the fragments for K^T.
|
||||
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
smem_kt.load(frag_kt[0], 0);
|
||||
|
||||
// Trigger the load for the next dO values.
|
||||
if (l + 1 < steps) {
|
||||
smem_do.move_to_next_write_buffer();
|
||||
gmem_do.move();
|
||||
gmem_do.load();
|
||||
if (Is_first) {
|
||||
gmem_o.move();
|
||||
gmem_o.load();
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template pack<elem_type>(frag_p);
|
||||
|
||||
// Store dp to smem for transpose
|
||||
smem_dp.store(frag_p);
|
||||
|
||||
// gmem_s.store(frag_p, mask);
|
||||
// gmem_s.move();
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dq::WARPS_K>::apply(acc_dq);
|
||||
|
||||
// Do this part of O = P^T * V^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_kt.load(frag_kt[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
|
||||
// fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dq::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
|
||||
// fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
|
||||
}
|
||||
|
||||
static_assert(Gmem_tile_dq::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
// Need to syncthreads here, otherwise the smem_dq reads from the previous iteration
|
||||
// might happen after the smem_dq writes in this iteration.
|
||||
__syncthreads();
|
||||
smem_dq.store(acc_dq, 0);
|
||||
|
||||
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
|
||||
static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4);
|
||||
static_assert(Mma_tile_dkv::MMAS_K == 1);
|
||||
smem_dot.load(frag_dot[0], 0);
|
||||
|
||||
// Threads in a warp is communicating via shared memory (smem_s and smem_dp)
|
||||
__syncwarp();
|
||||
typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
smem_s.load(frag_s);
|
||||
|
||||
if (Is_dropout) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
frag_s[ki][mi].template hrelu_<elem_type>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_dot.load(frag_dot[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dkv::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0]));
|
||||
// printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y);
|
||||
// float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1]));
|
||||
// printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y);
|
||||
// }
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
|
||||
// printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
|
||||
// }
|
||||
// __syncthreads();
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (l + 1 < steps) {
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
}
|
||||
|
||||
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
|
||||
if (!Is_first && !Seq_parallel) { gmem_dq_tmp.load(dq_out, 0); }
|
||||
|
||||
// __syncthreads();
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (l + 1 < steps) {
|
||||
gmem_do.commit(smem_do);
|
||||
gmem_softmax_d.move();
|
||||
if (Is_first) {
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
gmem_softmax_lse.move();
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
}
|
||||
|
||||
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
smem_dp.load(frag_dpt);
|
||||
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N];
|
||||
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
|
||||
static_assert(Mma_tile_dkv::MMAS_K == 1);
|
||||
smem_qt.load(frag_qt[0], 0);
|
||||
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_qt.load(frag_qt[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dkv::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Make sure dQ is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
if (l + 1 < steps) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
}
|
||||
|
||||
// Load from shared memory.
|
||||
smem_dq.template load</*zero_init=*/Is_first || Seq_parallel>(dq_out);
|
||||
|
||||
if (!Seq_parallel) {
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
if (is_final_write) {
|
||||
// if (Is_dropout) {
|
||||
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
|
||||
// }
|
||||
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
|
||||
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
|
||||
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
// Output the values.
|
||||
gmem_dq.template store<elem_type>(dq_out, 0);
|
||||
// Move to the next part of the output.
|
||||
gmem_dq.move();
|
||||
// TODO: for parallel, need to deal with the dropout scaling
|
||||
} else {
|
||||
// Output the values.
|
||||
gmem_dq_tmp.store(dq_out, 0);
|
||||
}
|
||||
} else {
|
||||
// We always scale dq_out before writing in this case, since we don't want to
|
||||
// have to scale at the end when copying from dq_tmp to dq.
|
||||
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
|
||||
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
|
||||
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
gmem_dq_tmp.atomic_add(dq_out, 0);
|
||||
}
|
||||
|
||||
// Move to the next part of the output.
|
||||
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); }
|
||||
|
||||
// // Make sure the data is in shared memory.
|
||||
// __syncthreads();
|
||||
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (l + 1 < steps) {
|
||||
gemm_q_k.smem_q.move_to_next_read_buffer();
|
||||
gemm_q_k.reload_q();
|
||||
smem_qt.move_to_next_read_buffer();
|
||||
// smem_qt.load(frag_qt[0], 0);
|
||||
smem_do.move_to_next_read_buffer();
|
||||
smem_dot.move_to_next_read_buffer();
|
||||
// smem_dot.load(frag_dot[0], 0);
|
||||
}
|
||||
|
||||
} // Outer loop over the sequence length.
|
||||
|
||||
if (Is_dropout) {
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
|
||||
acc_dv[mi][ni].mul_(params.rp_dropout);
|
||||
}
|
||||
}
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
|
||||
// printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
|
||||
// }
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
|
||||
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
|
||||
// acc_dk[mi][ni].mul_(params.scale_bmm1f);
|
||||
acc_dk[mi][ni].mul_(params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
|
||||
// }
|
||||
|
||||
__syncthreads();
|
||||
// TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than
|
||||
// the total amount of shared mem?
|
||||
// Epilogue swizzle for dV
|
||||
Smem_tile_dv smem_dv(&smem_[0], tidx);
|
||||
smem_dv.template store<elem_type>(acc_dv);
|
||||
|
||||
// Epilogue swizzle for dK
|
||||
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
|
||||
smem_dk.template store<elem_type>(acc_dk);
|
||||
|
||||
__syncthreads();
|
||||
uint4 dv_out[Smem_tile_dv::NUM_LDS];
|
||||
smem_dv.load(dv_out);
|
||||
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dv.move(loop_step_idx);
|
||||
}
|
||||
gmem_dv.store(dv_out);
|
||||
|
||||
uint4 dk_out[Smem_tile_dk::NUM_LDS];
|
||||
smem_dk.load(dk_out);
|
||||
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dk.move(loop_step_idx);
|
||||
}
|
||||
gmem_dk.store(dk_out);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N.
|
||||
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
auto seed = params.rng_state[0];
|
||||
auto offset = params.rng_state[1];
|
||||
Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
|
||||
if (loop_steps == 1) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
} else if (loop_steps == 2) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
|
||||
} else {
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
|
||||
}
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) {
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
auto seed = params.rng_state[0];
|
||||
auto offset = params.rng_state[1];
|
||||
Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
|
||||
int loop_step_idx = blockIdx.z;
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false, /*Seq_parallel=*/true>(params, ph, loop_step_idx);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@@ -1,707 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fmha_kernel.h"
|
||||
#include <fmha/kernel_traits.h>
|
||||
#include <fmha/gemm.h>
|
||||
#include <fmha/utils.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct Gemm_Q_K_base {
|
||||
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
|
||||
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
|
||||
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
|
||||
using Fragment_q = typename Smem_tile_q::Fragment;
|
||||
using Fragment_k = typename Smem_tile_k::Fragment;
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
|
||||
static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
|
||||
|
||||
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx)
|
||||
: smem_q(smem_ptr_q, tidx)
|
||||
, smem_k(smem_ptr_k, tidx) {
|
||||
|
||||
}
|
||||
|
||||
__device__ inline void load_q() {
|
||||
smem_q.load(frag_q[0], 0);
|
||||
}
|
||||
|
||||
__device__ inline void reload_q() {
|
||||
smem_q.load(frag_q[0], 0);
|
||||
}
|
||||
|
||||
Fragment_q frag_q[2][Mma_tile_p::MMAS_M];
|
||||
Smem_tile_q smem_q;
|
||||
Smem_tile_k smem_k;
|
||||
};
|
||||
|
||||
template<typename Kernel_traits, bool K_in_regs, typename elem_type_=__half>
|
||||
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
|
||||
|
||||
using Base = Gemm_Q_K_base<Kernel_traits>;
|
||||
using Smem_tile_o = typename Base::Smem_tile_o;
|
||||
using Smem_tile_q = typename Base::Smem_tile_q;
|
||||
using Smem_tile_k = typename Base::Smem_tile_k;
|
||||
using Fragment_k = typename Base::Fragment_k;
|
||||
using Mma_tile_p = typename Base::Mma_tile_p;
|
||||
using elem_type = elem_type_;
|
||||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
// If V is stored in shared memory, we can't load K using the same shared memory.
|
||||
static_assert(Kernel_traits::V_IN_REGS);
|
||||
|
||||
static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE;
|
||||
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
|
||||
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
|
||||
|
||||
// Q | K / V
|
||||
// | O | SOFTMAX
|
||||
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
|
||||
+ std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,
|
||||
Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);
|
||||
|
||||
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
|
||||
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
__device__ inline void load_k(){
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
Base::smem_k.load(frag_k[ki], ki);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Acc, int M, int N>
|
||||
__device__ inline void operator()(Acc (&acc_p)[M][N]){
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
Base::smem_q.load(Base::frag_q[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void reload_k(){
|
||||
// Noop.
|
||||
}
|
||||
|
||||
Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
|
||||
};
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename elem_type_>
|
||||
struct Gemm_Q_K<Kernel_traits, false, elem_type_> : public Gemm_Q_K_base<Kernel_traits> {
|
||||
using Base = Gemm_Q_K_base<Kernel_traits>;
|
||||
using Smem_tile_o = typename Base::Smem_tile_o;
|
||||
using Smem_tile_q = typename Base::Smem_tile_q;
|
||||
using Smem_tile_k = typename Base::Smem_tile_k;
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
|
||||
using Fragment_k = typename Base::Fragment_k;
|
||||
using Mma_tile_p = typename Base::Mma_tile_p;
|
||||
using elem_type = elem_type_;
|
||||
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
|
||||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
|
||||
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
|
||||
|
||||
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
|
||||
static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE);
|
||||
static constexpr int SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE;
|
||||
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
|
||||
|
||||
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
|
||||
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
|
||||
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
|
||||
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE
|
||||
+ Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;
|
||||
|
||||
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
|
||||
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
__device__ inline void load_k(){
|
||||
Base::smem_k.load(frag_k[0], 0);
|
||||
}
|
||||
|
||||
template<typename Acc, int M, int N>
|
||||
__device__ inline void operator()(Acc (&acc_p)[M][N]){
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
Base::smem_q.load(Base::frag_q[ki & 1], ki);
|
||||
Base::smem_k.load(frag_k[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void reload_k(){
|
||||
Base::smem_k.load(frag_k[0], 0);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Kernel_traits>
|
||||
constexpr size_t get_dynamic_smem_size(){
|
||||
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
|
||||
inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
#else
|
||||
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
|
||||
assert(is_fp16_type);
|
||||
using elem_type = __half;
|
||||
#endif
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
|
||||
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
|
||||
|
||||
// The global memory tile to store O.
|
||||
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
|
||||
using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
|
||||
// The shared memory tile to swizzle O.
|
||||
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
|
||||
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS, elem_type>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// How many steps to jump per iteration, which is the same as params.num_splits.
|
||||
const int step_stride = gridDim.z;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Gemm1 gemm_q_k(smem_, tidx);
|
||||
// Allocate the global memory tile loader for Q.
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts,
|
||||
params.o_tmp_head_stride_in_elts, params.d, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
|
||||
// Wind gmem tiles to the correct position.
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
|
||||
// We want begin to be a multiple of gridDim.z
|
||||
// This is because the row indices processed by each threadblock must align between the
|
||||
// loop steps, otherwise we have a dependency between the blocks.
|
||||
// For example, threadblock with blockIdx.z == 1 must process row indices that are
|
||||
// k * gridDim.z + 1 for integer k.
|
||||
const int begin_mod_z = begin % gridDim.z;
|
||||
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
|
||||
// Otherwise we'd be reading out-of-bound memory before the loop
|
||||
if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return;
|
||||
const int steps_og = steps;
|
||||
steps -= begin;
|
||||
gmem_q.move(begin + blockIdx.z);
|
||||
gmem_o.move(begin + blockIdx.z);
|
||||
gmem_o_tmp.move(begin + blockIdx.z);
|
||||
if (Return_softmax) {
|
||||
gmem_s.move(begin + blockIdx.z);
|
||||
}
|
||||
gmem_softmax_lse.move(begin + blockIdx.z);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("begin = %d, steps = %d\n", begin, steps);
|
||||
// }
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// Allocate the global memory tile loader for V.
|
||||
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
Smem_tile_v smem_v(smem_v_, tidx);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
if (!Is_first) {
|
||||
gmem_k.move(loop_step_idx);
|
||||
gmem_v.move(loop_step_idx);
|
||||
if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
|
||||
}
|
||||
|
||||
// Trigger the loads for K.
|
||||
gmem_k.load();
|
||||
// Trigger the loads for Q.
|
||||
gmem_q.load();
|
||||
// Trigger the loads for V.
|
||||
gmem_v.load();
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
|
||||
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
|
||||
}
|
||||
|
||||
// Commit the data for Q and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_v.commit(smem_v);
|
||||
|
||||
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
|
||||
// #pragma unroll
|
||||
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
|
||||
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
|
||||
// }
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire kernel.
|
||||
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
smem_v.load(frag_v[ki], ki);
|
||||
}
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for V.
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
|
||||
// Create the object to do the softmax.
|
||||
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
|
||||
|
||||
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for (int l = blockIdx.z; l < steps; l += step_stride) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
|
||||
// printf("l = %d\n", l);
|
||||
// }
|
||||
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
|
||||
|
||||
// Do this part of P = Q * K^T.
|
||||
gemm_q_k(acc_p);
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
|
||||
// }
|
||||
|
||||
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
|
||||
if (!Is_first) { gmem_o_tmp.load(out, 0); }
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
if (l + step_stride < steps) {
|
||||
gemm_q_k.smem_q.move_to_next_write_buffer();
|
||||
gmem_q.move(step_stride);
|
||||
gmem_q.load();
|
||||
}
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(begin + l);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) {
|
||||
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
__syncthreads();
|
||||
}
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) {
|
||||
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
|
||||
// }
|
||||
// }
|
||||
// Compute the max.
|
||||
float p_max[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
smem_softmax_lse.store_pair(p_prev_lse);
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
|
||||
}
|
||||
|
||||
// Trigger the load for the next LSE values.
|
||||
if (l + step_stride < steps) {
|
||||
if (!Is_first) {
|
||||
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
|
||||
step_stride);
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
|
||||
|
||||
// if ((threadIdx.x == 0) && (l == 38)) {
|
||||
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
|
||||
// }
|
||||
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute the exponential value.
|
||||
// softmax.apply_exp(p_max);
|
||||
softmax.scale_apply_exp(p_max, params.scale_bmm1f);
|
||||
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute the sum.
|
||||
float p_sum[Mma_tile_p::MMAS_M * 2];
|
||||
// if (!Is_first) {
|
||||
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
|
||||
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
|
||||
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
|
||||
// }
|
||||
// }
|
||||
// softmax.reduce_sum(p_sum);
|
||||
softmax.reduce_sum_before_sync_(p_sum);
|
||||
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
|
||||
|
||||
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
|
||||
// float sum = p_sum[mi];
|
||||
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
|
||||
// constexpr float kLog2e = M_LOG2E;
|
||||
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
|
||||
// }
|
||||
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
|
||||
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
|
||||
// gmem_softmax_lse.move();
|
||||
|
||||
// // Finalize softmax on the accumulators of P^T.
|
||||
// softmax.scale(p_sum);
|
||||
|
||||
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
|
||||
if (Is_dropout) {
|
||||
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint16_t);
|
||||
unsigned int warp_idx = threadIdx.x / 32;
|
||||
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
|
||||
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
|
||||
// We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded
|
||||
// differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k
|
||||
// to multiples of 256 while bwd rounds seqlen_k to multiples of 128.
|
||||
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
|
||||
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
|
||||
}
|
||||
|
||||
using Frag_p = fmha::Fragment_a<fmha::Row>;
|
||||
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.template pack<elem_type>(frag_p);
|
||||
if (Return_softmax) {
|
||||
gmem_s.store(frag_p, mask);
|
||||
gmem_s.move(step_stride);
|
||||
}
|
||||
|
||||
// Commit the values for Q into shared memory.
|
||||
if (l + step_stride < steps) {
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
}
|
||||
|
||||
if (Is_dropout && encode_dropout_in_sign_bit) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
|
||||
frag_p[ki][mi].template hrelu_<elem_type>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
|
||||
|
||||
// Do this part of O = P^T * V^T.
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
fmha::gemm_cl<elem_type>(acc_o, frag_p[ki], frag_v[ki]);
|
||||
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
|
||||
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
|
||||
// printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
|
||||
// }
|
||||
}
|
||||
|
||||
// if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
|
||||
// }
|
||||
|
||||
// The mapping from tidx to rows changes between the softmax and the
|
||||
// O-reduction. So we recalculate the max.
|
||||
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
int rows[Gmem_tile_o::STGS_PER_LOOP];
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
||||
}
|
||||
softmax.reduce_max_after_sync_(p_max_o, rows);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
p_max_o[jj][0] *= params.scale_bmm1f;
|
||||
}
|
||||
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
|
||||
if (!Is_first) {
|
||||
smem_softmax_lse.load(p_prev_scale_o, rows);
|
||||
}
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
|
||||
// }
|
||||
// }
|
||||
|
||||
static_assert(Gmem_tile_o::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
smem_o.store(acc_o, 0);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
||||
if (!Is_first) {
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
|
||||
p_sum_o[jj][0] += p_prev_scale_o[jj];
|
||||
}
|
||||
}
|
||||
|
||||
float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
|
||||
// if (sum == 0.f || sum != sum) {
|
||||
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
|
||||
// }
|
||||
// if (Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
|
||||
// }
|
||||
// }
|
||||
if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) {
|
||||
gmem_softmax_lse.store_row(
|
||||
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
|
||||
}
|
||||
}
|
||||
gmem_softmax_lse.move(step_stride);
|
||||
|
||||
// Load from shared memory.
|
||||
if (!Is_first) {
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
|
||||
}
|
||||
}
|
||||
smem_o.template load</*zero_init=*/Is_first>(out);
|
||||
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
if (Is_dropout && is_final_write) {
|
||||
inv_sum *= params.rp_dropout;
|
||||
}
|
||||
out[jj] = fmha::fmul4(out[jj], inv_sum);
|
||||
}
|
||||
|
||||
// if (Is_dropout && Is_last) {
|
||||
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Output the values.
|
||||
if (is_final_write) {
|
||||
gmem_o.template store<elem_type>(out, 0);
|
||||
gmem_o.move(step_stride);
|
||||
} else {
|
||||
gmem_o_tmp.store(out, 0);
|
||||
}
|
||||
|
||||
// Move to the next part of the output.
|
||||
if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); }
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
// Make sure we are reading from the correct buffer.
|
||||
gemm_q_k.smem_q.move_to_next_read_buffer();
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
if (l + step_stride < steps) {
|
||||
gemm_q_k.reload_q();
|
||||
}
|
||||
} // Outer loop over the sequence length.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
|
||||
inline __device__ void device_1xN_loop(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The block index.
|
||||
const int bidx = gridDim.x * bidh + bidb;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
|
||||
// them to have the same number of threads or have to traverse the attention matrix
|
||||
// in the same order.
|
||||
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
|
||||
// (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
if (bidx == 0 && tidx == 0) {
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
const int STEPS = (params.seqlen_q + M - 1) / M;
|
||||
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph, loop_step_idx);
|
||||
}
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "fmha_fwd_launch_template.h"
|
||||
|
||||
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ([&] {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "fmha_fwd_launch_template.h"
|
||||
|
||||
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ([&] {
|
||||
if (launch_params.params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
} else if (launch_params.params.seqlen_k >= 256) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include "fmha_fwd_launch_template.h"
|
||||
|
||||
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ([&] {
|
||||
if (launch_params.params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
} else if (launch_params.params.seqlen_k >= 256) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include "static_switch.h"
|
||||
#include "fmha.h"
|
||||
#include "fmha_fprop_kernel_1xN.h"
|
||||
|
||||
// Find the number of splits that maximizes the occupancy. For example, if we have
|
||||
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
||||
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
||||
// splits as that would incur more HBM reads/writes.
|
||||
// So we find the best efficiency, then find the smallest number of splits that gets 95%
|
||||
// of the best efficiency.
|
||||
// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error.
|
||||
inline int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
||||
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if (eff > max_efficiency) { max_efficiency = eff; }
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
||||
if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
return num_splits;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
|
||||
__global__ void fmha_fwd_loop_kernel(FMHA_fprop_params params) {
|
||||
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
|
||||
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
|
||||
// Don't need smem_size_softmax_lse if we're not looping
|
||||
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
|
||||
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] {
|
||||
auto kernel = launch_params.params.is_causal
|
||||
? (launch_params.return_softmax
|
||||
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
|
||||
: &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
|
||||
: (launch_params.return_softmax
|
||||
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
|
||||
: &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
|
||||
if( smem_size >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// Automatically set num_splits to maximize occupancy
|
||||
if (launch_params.params.num_splits <= 0) {
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
launch_params.params.num_splits = num_splits_heuristic_fwd(
|
||||
launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
|
||||
ctas_per_sm,
|
||||
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
|
||||
);
|
||||
}
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}));
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <philox.cuh>
|
||||
|
||||
#include <fmha.h>
|
||||
#include <fmha/utils.h>
|
||||
#include <fmha/smem_tile.h>
|
||||
#include <fmha/gmem_tile.h>
|
||||
#include <fmha/mask.h>
|
||||
#include <fmha/softmax.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS_PER_CTA>
|
||||
struct BlockInfoPadded {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfoPadded(const Params ¶ms,
|
||||
const int bidb,
|
||||
const int bidh,
|
||||
const int tidx)
|
||||
: bidb(bidb), bidh(bidh), h(params.h) {
|
||||
|
||||
// The block index.
|
||||
sum_s_k = params.cu_seqlens_k[bidb];
|
||||
actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
|
||||
sum_s_q = params.cu_seqlens_q[bidb];
|
||||
actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q;
|
||||
|
||||
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
|
||||
}
|
||||
|
||||
__device__ bool stop_early(const int start_col = 0) const {
|
||||
return actual_seqlen_k <= start_col;
|
||||
}
|
||||
|
||||
int actual_seqlen_q;
|
||||
int actual_seqlen_k;
|
||||
int sum_s_q;
|
||||
int sum_s_k;
|
||||
int bidh;
|
||||
int bidb;
|
||||
int tidx_global;
|
||||
int h;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@@ -1,100 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define FMHA_CHECK_CUDA( call ) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if( status_ != cudaSuccess ) { \
|
||||
fprintf( stderr, \
|
||||
"CUDA error (%s:%d): %s\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString( status_ ) ); \
|
||||
exit( 1 ); \
|
||||
} \
|
||||
} while( 0 )
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
|
||||
if( dtype == DATA_TYPE_FP16 ) {
|
||||
half x = __float2half_rn( norm );
|
||||
uint16_t h = reinterpret_cast<const uint16_t &>( x );
|
||||
ushort2 h2 = { h, h };
|
||||
alpha = reinterpret_cast<const uint32_t &>( h2 );
|
||||
} else if( dtype == DATA_TYPE_BF16 ) {
|
||||
__nv_bfloat16 x = __float2bfloat16( norm );
|
||||
uint16_t h = reinterpret_cast<const uint16_t &>( x );
|
||||
ushort2 h2 = { h, h };
|
||||
alpha = reinterpret_cast<const uint32_t &>( h2 );
|
||||
} else if( dtype == DATA_TYPE_FP32 ) {
|
||||
alpha = reinterpret_cast<const uint32_t &>( norm );
|
||||
} else if( dtype == DATA_TYPE_INT32 ) {
|
||||
int32_t inorm = static_cast<int32_t>( norm );
|
||||
alpha = reinterpret_cast<const uint32_t &>( inorm );
|
||||
} else {
|
||||
assert( false );
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
|
||||
switch( dtype ) {
|
||||
case DATA_TYPE_FP32:
|
||||
return n * 4;
|
||||
case DATA_TYPE_FP16:
|
||||
return n * 2;
|
||||
case DATA_TYPE_BF16:
|
||||
return n * 2;
|
||||
case DATA_TYPE_INT32:
|
||||
return n * 4;
|
||||
case DATA_TYPE_INT8:
|
||||
return n;
|
||||
default:
|
||||
assert( false );
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -0,0 +1,366 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||
// No_double_buffer is another option to reduce smem usage, but will slow things down.
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
|
||||
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
|
||||
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_bwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
|
||||
static constexpr bool No_double_buffer = No_double_buffer_;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
|
||||
static_assert(kNWarps % AtomLayoutMSdP == 0);
|
||||
static_assert(kNWarps % AtomLayoutNdKV == 0);
|
||||
static_assert(kNWarps % AtomLayoutMdQ == 0);
|
||||
|
||||
using TiledMmaSdP = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using TiledMmadKV = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using TiledMmadQ = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQdO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQdO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdO{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
// SmemLayoutAtomQdO{},
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
// static constexpr int kPBlockN = kBlockN;
|
||||
static_assert(kBlockN >= 64);
|
||||
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
|
||||
static constexpr int kPBlockN = 64;
|
||||
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
|
||||
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
|
||||
static constexpr int kSwizzlePdS = 3;
|
||||
using SmemLayoutAtomPdS = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
|
||||
Stride<Int<kPBlockN>, _1>>{}));
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutAtomPdStransposed = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kPBlockN>>>{}));
|
||||
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposed{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomQdOtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||
static constexpr int kSmemdPsumCount = kBlockM;
|
||||
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
|
||||
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
|
||||
+ kSmemdSSize + kSmemPSize;
|
||||
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
|
||||
// to affect speed in practice.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -0,0 +1,159 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/copy.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits_sm90 {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
+104
-96
@@ -1,8 +1,55 @@
|
||||
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh
|
||||
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
|
||||
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
|
||||
#pragma once
|
||||
// Philox CUDA.
|
||||
|
||||
namespace flash {
|
||||
|
||||
struct ull2 {
|
||||
unsigned long long x;
|
||||
unsigned long long y;
|
||||
};
|
||||
|
||||
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
uint2 *res;
|
||||
unsigned long long tmp;
|
||||
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||
: "=l"(tmp)
|
||||
: "r"(a), "r"(b));
|
||||
res = (uint2*)(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
|
||||
constexpr unsigned long kPhiloxSA = 0xD2511F53;
|
||||
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline __device__ uint4 philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset) {
|
||||
constexpr unsigned long kPhilox10A = 0x9E3779B9;
|
||||
constexpr unsigned long kPhilox10B = 0xBB67AE85;
|
||||
uint2 key = reinterpret_cast<uint2&>(seed);
|
||||
uint4 counter;
|
||||
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset;
|
||||
tmp->y = subsequence;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter = philox_single_round(counter, key);
|
||||
key.x += (kPhilox10A);
|
||||
key.y += (kPhilox10B);
|
||||
}
|
||||
uint4 output = philox_single_round(counter, key);
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
namespace {
|
||||
|
||||
class Philox {
|
||||
@@ -10,7 +57,10 @@ public:
|
||||
__device__ inline Philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset)
|
||||
: key(reinterpret_cast<const uint2&>(seed)) {
|
||||
: STATE(0)
|
||||
, seed_(seed)
|
||||
, offset_(offset)
|
||||
, key(reinterpret_cast<const uint2&>(seed)) {
|
||||
//key.x = (unsigned int)seed;
|
||||
//key.y = (unsigned int)(seed >> 32);
|
||||
//counter = make_uint4(0, 0, 0, 0);
|
||||
@@ -19,6 +69,7 @@ public:
|
||||
//STATE = 0;
|
||||
//incr_n(offset / 4);
|
||||
|
||||
// key = reinterpret_cast<const uint2&>(seed);
|
||||
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset / 4;
|
||||
tmp->y = subsequence;
|
||||
@@ -26,72 +77,64 @@ public:
|
||||
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
|
||||
__device__ inline uint4 operator()() {
|
||||
uint4 counter_ = counter;
|
||||
uint2 key_ = key;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter_ = single_round(counter_, key_);
|
||||
key_.x += (kPhilox10A);
|
||||
key_.y += (kPhilox10B);
|
||||
}
|
||||
uint4 output = single_round(counter_, key_);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// }
|
||||
incr();
|
||||
return output;
|
||||
}
|
||||
|
||||
__device__ inline uint4 operator()(const unsigned long long subsequence) {
|
||||
uint4 counter_ = counter;
|
||||
ull2 * tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w);
|
||||
// }
|
||||
uint2 key_ = key;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter_ = single_round(counter_, key_);
|
||||
key_.x += (kPhilox10A);
|
||||
key_.y += (kPhilox10B);
|
||||
}
|
||||
uint4 output = single_round(counter_, key_);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// }
|
||||
return output;
|
||||
// // if (STATE == 0) {
|
||||
// uint4 counter_ = counter;
|
||||
// uint2 key_ = key;
|
||||
// // 7-round philox
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < 6; i++) {
|
||||
// counter_ = flash::philox_single_round(counter_, key_);
|
||||
// key_.x += (kPhilox10A);
|
||||
// key_.y += (kPhilox10B);
|
||||
// }
|
||||
// // output = philox_single_round(counter_, key_);
|
||||
// uint4 output = flash::philox_single_round(counter_, key_);
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// // }
|
||||
// incr();
|
||||
// // }
|
||||
// // return a float4 directly
|
||||
// // unsigned long ret;
|
||||
// // switch(STATE) {
|
||||
// // case 0: ret = output.x; break;
|
||||
// // case 1: ret = output.y; break;
|
||||
// // case 2: ret = output.z; break;
|
||||
// // case 3: ret = output.w; break;
|
||||
// //}
|
||||
// // STATE = (STATE + 1) % 4;
|
||||
// return output;
|
||||
return flash::philox(seed_, offset_, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned long long offset_, seed_;
|
||||
struct ull2 {
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
// uint4 output;
|
||||
const uint2 key;
|
||||
unsigned int STATE;
|
||||
__device__ inline void incr_n(unsigned long long n) {
|
||||
unsigned int nlo = (unsigned int)(n);
|
||||
unsigned int nhi = (unsigned int)(n >> 32);
|
||||
counter.x += nlo;
|
||||
if (counter.x < nlo)
|
||||
nhi++;
|
||||
counter.y += nhi;
|
||||
if (nhi <= counter.y)
|
||||
return;
|
||||
if (++counter.z)
|
||||
return;
|
||||
++counter.w;
|
||||
}
|
||||
|
||||
// __device__ inline void incr_n(unsigned long long n) {
|
||||
// unsigned int nlo = (unsigned int)(n);
|
||||
// unsigned int nhi = (unsigned int)(n >> 32);
|
||||
// counter.x += nlo;
|
||||
// if (counter.x < nlo)
|
||||
// nhi++;
|
||||
// counter.y += nhi;
|
||||
// if (nhi <= counter.y)
|
||||
// return;
|
||||
// if (++counter.z)
|
||||
// return;
|
||||
// ++counter.w;
|
||||
// }
|
||||
|
||||
__device__ uint4 incr(uint4 ctr) {
|
||||
__device__ uint4 incr128 (uint4 ctr)
|
||||
{
|
||||
uint4 res;
|
||||
asm ("add.cc.u32 %0, %4, %8;\n\t"
|
||||
"addc.cc.u32 %1, %5, %9;\n\t"
|
||||
@@ -107,51 +150,16 @@ private:
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
counter = incr(counter);
|
||||
counter = incr128(counter);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
|
||||
// __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
|
||||
// unsigned int *result_high) {
|
||||
// *result_high = __umulhi(a, b);
|
||||
// return a * b;
|
||||
// }
|
||||
|
||||
__device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
uint2 *res;
|
||||
unsigned long long tmp;
|
||||
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||
: "=l"(tmp)
|
||||
: "r"(a), "r"(b));
|
||||
res = (uint2*)(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
|
||||
//unsigned int hi0;
|
||||
//unsigned int hi1;
|
||||
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
|
||||
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
|
||||
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
// static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
// Inverse of 2^32.
|
||||
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
|
||||
__device__ __inline__ float4 uniform4(const uint4 x) {
|
||||
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
|
||||
x.w * M_RAN_INVM32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
||||
#include "philox.cuh"
|
||||
#include "utils.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
reduce_(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
SumOp<float> sum_op;
|
||||
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
||||
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
|
||||
const uint32_t warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const uint32_t row_idx_offset = row_idx_offset_;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const uint32_t row_idx = row_idx_base + i * 8;
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
uint32_t block_row_start, uint32_t block_col_start,
|
||||
uint32_t block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
@@ -1,6 +1,5 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -10,31 +9,57 @@
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, ([&] {
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// }));
|
||||
/// });
|
||||
/// ```
|
||||
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, F) \
|
||||
{ \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
F(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// modified from BOOL_SWITCH
|
||||
// because MSVC cannot handle std::conditional with constexpr variable
|
||||
#define FP16_SWITCH(COND, F) \
|
||||
{ \
|
||||
if (COND) { \
|
||||
using elem_type = __nv_bfloat16; \
|
||||
F(); \
|
||||
} else { \
|
||||
using elem_type = __half; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using elem_type = cutlass::half_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using elem_type = cutlass::bfloat16_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
[&] { \
|
||||
if (HEADDIM <= 32) { \
|
||||
constexpr int kHeadDim = 32; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 64) { \
|
||||
constexpr int kHeadDim = 64; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 96) { \
|
||||
constexpr int kHeadDim = 96; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 128) { \
|
||||
constexpr int kHeadDim = 128; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 160) { \
|
||||
constexpr int kHeadDim = 160; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 192) { \
|
||||
constexpr int kHeadDim = 192; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 224) { \
|
||||
constexpr int kHeadDim = 224; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 256) { \
|
||||
constexpr int kHeadDim = 256; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
@@ -0,0 +1,388 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t relu2(const uint32_t x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#else
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"\t .reg .f16x2 sela;\n" \
|
||||
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
|
||||
"\t and.b32 %0, sela, %1;\n"
|
||||
"}\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template<>
|
||||
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t convert_relu2(const float2 x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
|
||||
return res;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ float2 half2_unpack(uint32_t a);
|
||||
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert two half2's or bf162's into float, then take their dot product.
|
||||
template <typename T>
|
||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
||||
float2 af = flash::half2_unpack<T>(a);
|
||||
float2 bf = flash::half2_unpack<T>(b);
|
||||
return af.x * bf.x + af.y * bf.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
||||
template<typename T>
|
||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||
float sum;
|
||||
sum = flash::hfma2_to_float<T>(a.x, b.x);
|
||||
sum += flash::hfma2_to_float<T>(a.y, b.y);
|
||||
sum += flash::hfma2_to_float<T>(a.z, b.z);
|
||||
sum += flash::hfma2_to_float<T>(a.w, b.w);
|
||||
return sum;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename TiledCopy0, typename TiledCopy1>
|
||||
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy>
|
||||
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
template<typename Layout>
|
||||
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
template<typename MMA_traits, typename Layout>
|
||||
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
|
||||
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
get<0, 1>(l),
|
||||
get<1, 1, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
static_assert(numel % 2 == 0);
|
||||
using value_t = typename Engine::value_type;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor_uint32); ++i) {
|
||||
tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
|
||||
static_assert(std::is_same_v<float, From_type>);
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
static_assert(numel % 2 == 0);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
Tensor tensor_float2 = recast<float2>(tensor);
|
||||
Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(out_uint32); ++i) {
|
||||
out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
|
||||
}
|
||||
Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
|
||||
#else
|
||||
Tensor out = flash::convert_type<To_type>(tensor);
|
||||
flash::relu_(out);
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
||||
// (which is equivalent to commit_group then wait_group 0).
|
||||
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
||||
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void cp_async_wait() {
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
copy(thr_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
// TD [2023-04-13]: Strange that the code below can cause race condition.
|
||||
// I think it's because the copies are under an if statement.
|
||||
// if (Is_even_K) {
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// copy(thr_copy, S(_, m, _), D(_, m, _));
|
||||
// } else if (Clear_OOB_MN) {
|
||||
// clear(D(_, m, _));
|
||||
// }
|
||||
// }
|
||||
// } else { // It's slightly faster in this case if iterate over K first
|
||||
// #pragma unroll
|
||||
// for (int k = 0; k < size<2>(S); ++k) {
|
||||
// if (predicate_K(k)) {
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// copy(thr_copy, S(_, m, k), D(_, m, k));
|
||||
// } else if (Clear_OOB_MN) {
|
||||
// clear(D(_, m, k));
|
||||
// }
|
||||
// }
|
||||
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
// if (Clear_OOB_MN || Is_even_MN) {
|
||||
// clear(D(_, _, k));
|
||||
// } else {
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
|
||||
// clear(D(_, m, k));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
@@ -1 +1,8 @@
|
||||
__version__ = "1.0.9"
|
||||
__version__ = "2.0.0.post1"
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
|
||||
|
||||
class FlashAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
||||
max_s=None, need_weights=False):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
if unpadded: (nnz, 3, h, d)
|
||||
key_padding_mask: a bool tensor of shape (B, S)
|
||||
"""
|
||||
assert not need_weights
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
|
||||
if cu_seqlens is None:
|
||||
batch_size = qkv.shape[0]
|
||||
seqlen = qkv.shape[1]
|
||||
if key_padding_mask is None:
|
||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
else:
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
|
||||
indices, batch_size, seqlen),
|
||||
'b s (h d) -> b s h d', h=nheads)
|
||||
else:
|
||||
assert max_s is not None
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
|
||||
return output, None
|
||||
|
||||
|
||||
class FlashMHA(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
|
||||
causal=False, device=None, dtype=None) -> None:
|
||||
assert batch_first
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
|
||||
self.num_heads = num_heads
|
||||
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
|
||||
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
|
||||
self.inner_attn = FlashAttention(attention_dropout=attention_dropout)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x, key_padding_mask=None, need_weights=False):
|
||||
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
key_padding_mask: bool tensor of shape (batch, seqlen)
|
||||
"""
|
||||
qkv = self.Wqkv(x)
|
||||
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
|
||||
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights, causal=self.causal)
|
||||
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
|
||||
+371
-212
@@ -1,48 +1,86 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import flash_attn_cuda
|
||||
import flash_attn_2_cuda as flash_attn_cuda
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def _get_block_size(device, head_dim, is_dropout):
|
||||
assert head_dim % 8 == 0 and head_dim <= 128
|
||||
return 256 if head_dim <= 64 else 128
|
||||
def _get_block_size(device, head_dim, is_dropout, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
assert head_dim <= 256
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
|
||||
is_sm80 = major == 8 and minor == 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
if head_dim <= 32:
|
||||
return 128, 128
|
||||
if head_dim <= 64:
|
||||
return (128, 128) if not is_dropout else (128, 64)
|
||||
elif head_dim <= 96:
|
||||
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
|
||||
elif head_dim <= 128:
|
||||
if is_sm8x:
|
||||
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
|
||||
else:
|
||||
return 128, (64 if not is_dropout else 32)
|
||||
elif head_dim <= 160:
|
||||
if is_sm8x:
|
||||
return (128, 64) if not is_causal else (64, 64)
|
||||
else:
|
||||
return 128, 32
|
||||
elif head_dim <= 192:
|
||||
return (128, 64) if not is_dropout else (64, 64)
|
||||
elif head_dim <= 224:
|
||||
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
|
||||
elif head_dim <= 256:
|
||||
return (128, 64) if is_sm80 else (64, 64)
|
||||
|
||||
|
||||
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax, num_splits=0,
|
||||
generator=None):
|
||||
"""
|
||||
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
|
||||
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
|
||||
Don't change it unless you know what you're doing.
|
||||
"""
|
||||
softmax_lse, rng_state, *rest = flash_attn_cuda.fwd(
|
||||
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, num_splits, generator
|
||||
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||
if q.stride(-1) != 1:
|
||||
q = q.contiguous()
|
||||
if k.stride(-1) != 1:
|
||||
k = k.contiguous()
|
||||
if v.stride(-1) != 1:
|
||||
v = v.contiguous()
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
|
||||
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
|
||||
|
||||
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax):
|
||||
if q.stride(-1) != 1:
|
||||
q = q.contiguous()
|
||||
if k.stride(-1) != 1:
|
||||
k = k.contiguous()
|
||||
if v.stride(-1) != 1:
|
||||
v = v.contiguous()
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, None
|
||||
)
|
||||
# if out.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
S_dmask = rest[0] if return_softmax else None
|
||||
return out, softmax_lse, rng_state, S_dmask
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||
rng_state=None, num_splits=0, generator=None):
|
||||
"""
|
||||
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
|
||||
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
|
||||
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
|
||||
as num_splits=3), so effectively the choices are 0, 1, and 2.
|
||||
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
|
||||
"""
|
||||
dout = dout.contiguous() # CUDA code assumes that dout is contiguous
|
||||
_, _, _, softmax_d = flash_attn_cuda.bwd(
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
dropout_p, softmax_scale, causal):
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
|
||||
)
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
|
||||
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal):
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
|
||||
num_splits, generator, rng_state)
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
|
||||
)
|
||||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
return dq, dk, dv, softmax_d
|
||||
@@ -51,186 +89,341 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
|
||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
|
||||
return_softmax, deterministic):
|
||||
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
|
||||
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen = max_seqlen
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
||||
dqkv = torch.empty_like(qkv)
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
return dqkv, None, None, None, None, None, None, None
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen = max_seqlen
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
|
||||
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
|
||||
q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
dkv = torch.empty_like(kv)
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
_flash_attn_backward(
|
||||
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
||||
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None, None
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
|
||||
q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
|
||||
cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
|
||||
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
if dropout_p > 0:
|
||||
rng_state0 = torch.cuda.get_rng_state()
|
||||
generator1 = torch.Generator(device='cuda')
|
||||
rng_state1 = generator1.get_state()
|
||||
else:
|
||||
rng_state0, generator1, rng_state1 = None, None, None
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out = torch.empty_like(qkv[:, 0])
|
||||
_, softmax_lse0, S_dmask0 = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1],
|
||||
cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
q, k, v, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
_, softmax_lse1, S_dmask1 = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:],
|
||||
cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax, generator=generator1
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens,
|
||||
rng_state0, rng_state1)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen0 = max_seqlen0
|
||||
ctx.max_seqlen1 = max_seqlen1
|
||||
ctx.batch_size0 = batch_size0
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.deterministic = deterministic
|
||||
if not return_softmax:
|
||||
return out
|
||||
else:
|
||||
max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2])
|
||||
max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3])
|
||||
softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])),
|
||||
F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))],
|
||||
dim=0)
|
||||
return out, softmax_lse, S_dmask0, S_dmask1
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors
|
||||
batch_size0 = ctx.batch_size0
|
||||
if rng_state0 is not None:
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state0)
|
||||
if rng_state1 is not None:
|
||||
generator1 = torch.Generator(device='cuda')
|
||||
generator1.set_state(rng_state1)
|
||||
else:
|
||||
generator1 = None
|
||||
dqkv = torch.empty_like(qkv)
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
|
||||
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
|
||||
ctx.softmax_scale, ctx.causal, num_splits=1 if ctx.deterministic else 0,
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
|
||||
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
|
||||
ctx.softmax_scale, ctx.causal, generator=generator1,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
if rng_state0 is not None:
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None, None, None
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False, deterministic=False):
|
||||
class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse,
|
||||
cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen_q = max_seqlen_q
|
||||
ctx.max_seqlen_k = max_seqlen_k
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of Q, K, V.
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of K, V.
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
kv: (batch_size, seqlen, 2, nheads_k, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
q: (batch_size, seqlen, nheads, headdim)
|
||||
k: (batch_size, seqlen, nheads_k, headdim)
|
||||
v: (batch_size, seqlen, nheads_k, headdim)
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of Q, K, V.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
@@ -243,7 +436,6 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
@@ -253,17 +445,26 @@ def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, s
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
|
||||
causal, return_attn_probs, deterministic)
|
||||
return FlashAttnVarlenQKVPackedFunc.apply(
|
||||
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False, deterministic=False):
|
||||
def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
If K, V are already stacked into 1 tensor, this function will be faster than
|
||||
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
||||
of the gradients of K, V.
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
kv: (total_k, 2, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
@@ -277,9 +478,8 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total_q, nheads, headdim).
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
@@ -287,19 +487,25 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnKVPackedFunc.apply(q, kv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||
return_attn_probs, deterministic)
|
||||
return FlashAttnVarlenKVPackedFunc.apply(
|
||||
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale=None, causal=False, return_attn_probs=False,
|
||||
deterministic=False):
|
||||
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||
than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q.
|
||||
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
||||
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
||||
|
||||
Arguments:
|
||||
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
||||
k: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads, headdim), where total_k = total number of key tokens in the batch.
|
||||
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
||||
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
@@ -313,45 +519,6 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total_q, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs, deterministic)
|
||||
|
||||
|
||||
def flash_attn_unpadded_qkvpacked_split_func(
|
||||
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False, deterministic=False):
|
||||
"""
|
||||
Split attention into 2 kernels running on 2 separate streams for performance reason:
|
||||
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
|
||||
have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128.
|
||||
|
||||
dropout_p should be set to 0.0 during evaluation.
|
||||
|
||||
Arguments:
|
||||
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into qkv.
|
||||
max_seqlen0: int. Maximum sequence length in 1st part of the batch.
|
||||
max_seqlen1: int. Maximum sequence length in 2nd part of the batch.
|
||||
batch_size0: int. Number of sequences in the 1st part of the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
deterministic: bool. Whether or not to ensure deterministic execution.
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
@@ -361,15 +528,7 @@ def flash_attn_unpadded_qkvpacked_split_func(
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs,
|
||||
deterministic)
|
||||
|
||||
|
||||
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""For backward-compatibility only, will remove soon.
|
||||
dropout_p should be set to 0.0 during evaluation
|
||||
"""
|
||||
return flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_s, dropout_p, softmax_scale,
|
||||
causal, return_attn_probs)
|
||||
return FlashAttnVarlenFunc.apply(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs
|
||||
)
|
||||
|
||||
+19
-17
@@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from torchvision.ops import StochasticDepth
|
||||
# from torchvision.ops import StochasticDepth
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp
|
||||
@@ -70,12 +70,12 @@ class Block(nn.Module):
|
||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.dropout1 = dropout_cls(resid_dropout1)
|
||||
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
# self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
self.norm1 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
self.dropout2 = dropout_cls(resid_dropout2)
|
||||
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
# self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
@@ -129,13 +129,14 @@ class Block(nn.Module):
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
rowscale1 = None
|
||||
# if self.drop_path1.p == 0 or not self.training:
|
||||
# rowscale1 = None
|
||||
# else:
|
||||
# rowscale1 = self.drop_path1(torch.ones(
|
||||
# hidden_states.shape[:-1], device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype)
|
||||
# )
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
@@ -156,13 +157,14 @@ class Block(nn.Module):
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
# if self.drop_path2.p == 0 or not self.training:
|
||||
# rowscale2 = None
|
||||
# else:
|
||||
# rowscale2 = self.drop_path2(torch.ones(
|
||||
# hidden_states.shape[:-1], device=hidden_states.device,
|
||||
# dtype=hidden_states.dtype)
|
||||
# )
|
||||
rowscale2 = None
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
|
||||
+17
-56
@@ -10,14 +10,10 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
|
||||
except ImportError:
|
||||
flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.flash_attn_triton import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
||||
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
||||
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
||||
|
||||
try:
|
||||
@@ -46,17 +42,13 @@ class FlashSelfAttention(nn.Module):
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
|
||||
triton=False):
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
if attention_dropout != 0.0 or not triton:
|
||||
assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
|
||||
if attention_dropout == 0.0 and triton:
|
||||
assert flash_attn_qkvpacked_func is not None, 'FlashAttention Triton is not installed'
|
||||
assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed'
|
||||
assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.triton = triton
|
||||
|
||||
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
||||
"""Implements the multihead softmax attention.
|
||||
@@ -83,26 +75,13 @@ class FlashSelfAttention(nn.Module):
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
assert max_seqlen is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_unpadded_qkvpacked_func(
|
||||
return flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
else:
|
||||
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
||||
# Triton version doesn't support dropout
|
||||
if self.triton and (self.drop.p == 0 or not self.training):
|
||||
output = flash_attn_qkvpacked_func(qkv, None, causal, self.softmax_scale)
|
||||
else:
|
||||
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
|
||||
max_seqlen = seqlen
|
||||
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
|
||||
device=qkv.device)
|
||||
output = flash_attn_unpadded_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
return output
|
||||
return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal)
|
||||
|
||||
|
||||
class FlashCrossAttention(nn.Module):
|
||||
@@ -115,17 +94,13 @@ class FlashCrossAttention(nn.Module):
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
|
||||
triton=False):
|
||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||
super().__init__()
|
||||
if attention_dropout != 0.0 or not triton:
|
||||
assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed'
|
||||
if attention_dropout == 0.0 and triton:
|
||||
assert flash_attn_kvpacked_func is not None, 'FlashAttention Triton is not installed'
|
||||
assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed'
|
||||
assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.drop = nn.Dropout(attention_dropout)
|
||||
self.triton = triton
|
||||
|
||||
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
||||
cu_seqlens_k=None, max_seqlen_k=None):
|
||||
@@ -133,7 +108,7 @@ class FlashCrossAttention(nn.Module):
|
||||
Arguments
|
||||
---------
|
||||
q: The tensor containing the query. (B, Sq, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
|
||||
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
||||
causal: if passed, will override self.causal
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into q.
|
||||
@@ -154,7 +129,7 @@ class FlashCrossAttention(nn.Module):
|
||||
assert cu_seqlens_k.dtype == torch.int32
|
||||
assert max_seqlen_k is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
return flash_attn_unpadded_kvpacked_func(
|
||||
return flash_attn_varlen_kvpacked_func(
|
||||
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
@@ -162,23 +137,9 @@ class FlashCrossAttention(nn.Module):
|
||||
else:
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = kv.shape[1]
|
||||
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
|
||||
if self.triton and (self.drop.p == 0.0 or not self.training): # Triton version doesn't support dropout
|
||||
output = flash_attn_kvpacked_func(q, kv, None, causal, self.softmax_scale)
|
||||
else:
|
||||
q = rearrange(q, 'b s ... -> (b s) ...')
|
||||
kv = rearrange(kv, 'b s ... -> (b s) ...')
|
||||
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q,
|
||||
dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k,
|
||||
dtype=torch.int32, device=kv.device)
|
||||
output = flash_attn_unpadded_kvpacked_func(
|
||||
q, kv, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
|
||||
self.drop.p if self.training else 0.0,
|
||||
softmax_scale=self.softmax_scale, causal=causal
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
return output
|
||||
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
||||
return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0,
|
||||
causal=causal, softmax_scale=self.softmax_scale)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
@@ -111,28 +111,52 @@ cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_75,code=sm_75")
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_attn_cuda",
|
||||
name="flash_attn_2_cuda",
|
||||
sources=[
|
||||
"csrc/flash_attn/fmha_api.cpp",
|
||||
"csrc/flash_attn/src/fmha_fwd_hdim32.cu",
|
||||
"csrc/flash_attn/src/fmha_fwd_hdim64.cu",
|
||||
"csrc/flash_attn/src/fmha_fwd_hdim128.cu",
|
||||
"csrc/flash_attn/src/fmha_bwd_hdim32.cu",
|
||||
"csrc/flash_attn/src/fmha_bwd_hdim64.cu",
|
||||
"csrc/flash_attn/src/fmha_bwd_hdim128.cu",
|
||||
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu",
|
||||
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu",
|
||||
"csrc/flash_attn/flash_api.cpp",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||
@@ -157,11 +181,12 @@ ext_modules.append(
|
||||
include_dirs=[
|
||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
|
||||
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_package_version():
|
||||
with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
|
||||
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
|
||||
@@ -172,6 +197,7 @@ def get_package_version():
|
||||
else:
|
||||
return str(public_version)
|
||||
|
||||
|
||||
setup(
|
||||
name="flash_attn",
|
||||
version=get_package_version(),
|
||||
@@ -179,11 +205,9 @@ setup(
|
||||
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
|
||||
),
|
||||
author="Tri Dao",
|
||||
author_email="trid@stanford.edu",
|
||||
author_email="trid@cs.stanford.edu",
|
||||
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/HazyResearch/flash-attention",
|
||||
url="https://github.com/Dao-AILab/flash-attention",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
|
||||
+464
-643
File diff suppressed because it is too large
Load Diff
+2
-2
@@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr
|
||||
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
||||
|
||||
# Install FlashAttention
|
||||
RUN pip install flash-attn==1.0.9
|
||||
RUN pip install flash-attn==2.0.0.post1
|
||||
|
||||
# Install CUDA extensions for cross-entropy, fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v1.0.9 \
|
||||
&& cd flash-attention && git checkout v2.0.0.post1 \
|
||||
&& cd csrc/fused_softmax && pip install . && cd ../../ \
|
||||
&& cd csrc/rotary && pip install . && cd ../../ \
|
||||
&& cd csrc/xentropy && pip install . && cd ../../ \
|
||||
|
||||
Reference in New Issue
Block a user