mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:37:53 -04:00
first attempt
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
# #########################################################
|
||||
# Build wheels for ROCm using Docker
|
||||
# #########################################################
|
||||
|
||||
name: "[ROCm] Build wheels and upload to GitHub Releases"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
flash-attn-version:
|
||||
description: "Flash-Attention version"
|
||||
required: true
|
||||
type: string
|
||||
python-version:
|
||||
description: "Python version"
|
||||
required: true
|
||||
type: string
|
||||
torch-version:
|
||||
description: "PyTorch version"
|
||||
required: true
|
||||
type: string
|
||||
rocm-version:
|
||||
description: "ROCm version"
|
||||
required: true
|
||||
type: string
|
||||
runner:
|
||||
description: "Runner type"
|
||||
required: false
|
||||
type: string
|
||||
default: "ubuntu-22.04"
|
||||
is-upload:
|
||||
description: "Whether to upload the release asset"
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
build-triton-backend:
|
||||
description: "Whether to also build Triton backend"
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
name: Build wheels and Upload (ROCm)
|
||||
runs-on: ${{ inputs.runner }}
|
||||
container:
|
||||
image: rocm/pytorch:rocm${{ inputs.rocm-version }}_ubuntu22.04_py3.10_pytorch_2.1.2
|
||||
options: --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
TERM: xterm-256color
|
||||
BUILD_TRITON_BACKEND: ${{ inputs.build-triton-backend }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y git ninja-build time patchelf wget curl
|
||||
|
||||
- name: Setup Python
|
||||
run: |
|
||||
# Install deadsnakes PPA for multiple Python versions
|
||||
apt-get update
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get update
|
||||
|
||||
# Install requested Python version
|
||||
PYTHON_PKG="python${{ inputs.python-version }}"
|
||||
apt-get install -y ${PYTHON_PKG} ${PYTHON_PKG}-dev ${PYTHON_PKG}-venv
|
||||
|
||||
# Set as default python
|
||||
update-alternatives --install /usr/bin/python python /usr/bin/${PYTHON_PKG} 1
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/${PYTHON_PKG} 1
|
||||
|
||||
# Install pip
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
- name: Verify ROCm installation
|
||||
run: |
|
||||
echo "ROCm version:"
|
||||
if command -v rocm-smi &> /dev/null; then
|
||||
rocm-smi --showproductname || echo "ROCm SMI not available in container"
|
||||
fi
|
||||
if command -v rocminfo &> /dev/null; then
|
||||
rocminfo | grep -E "Name:|Marketing Name:" | head -5 || echo "ROCminfo not fully available"
|
||||
fi
|
||||
|
||||
- name: Build wheels
|
||||
id: build_wheels
|
||||
run: |
|
||||
chmod +x build_rocm.sh
|
||||
./build_rocm.sh ${{ inputs.flash-attn-version }} ${{ inputs.python-version }} ${{ inputs.torch-version }} ${{ inputs.rocm-version }}
|
||||
wheel_path=$(ls flash-attention/dist/*.whl | head -n 1)
|
||||
echo "WHEEL_PATH=$wheel_path" >> $GITHUB_OUTPUT
|
||||
|
||||
# Check for Triton backend wheel if built
|
||||
if [ -d "flash-attention/dist_triton" ]; then
|
||||
wheel_path_triton=$(ls flash-attention/dist_triton/*.whl | head -n 1)
|
||||
echo "WHEEL_PATH_TRITON=$wheel_path_triton" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Test wheel installation
|
||||
run: |
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1 || true
|
||||
pip install --no-cache-dir ${{ steps.build_wheels.outputs.WHEEL_PATH }}
|
||||
python -c "import flash_attn; print('Flash Attention version:', flash_attn.__version__)"
|
||||
python -c "import flash_attn_2_cuda as flash_attn_cuda; print('Flash Attention CUDA module loaded successfully')" || echo "CUDA module test skipped"
|
||||
|
||||
- name: Upload Release Asset (CK Backend)
|
||||
if: ${{ inputs.is-upload }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH }}"
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$wheel_path" ]; then
|
||||
echo "Error: Wheel file not found at $wheel_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
|
||||
- name: Upload Release Asset (Triton Backend)
|
||||
if: ${{ inputs.is-upload && inputs.build-triton-backend && steps.build_wheels.outputs.WHEEL_PATH_TRITON }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH_TRITON }}"
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$wheel_path" ]; then
|
||||
echo "Warning: Triton wheel file not found at $wheel_path"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
continue-on-error: true
|
||||
id: auditwheel_repair
|
||||
run: |
|
||||
pip install auditwheel
|
||||
auditwheel show ${{ steps.build_wheels.outputs.WHEEL_PATH }} || echo "Auditwheel show failed"
|
||||
auditwheel repair \
|
||||
--exclude libc10* --exclude libtorch* --exclude libamdhip* --exclude librocm* --exclude libhsa* \
|
||||
--exclude libMIOpen* --exclude librccl* --exclude librocblas* --exclude libhipblas* \
|
||||
${{ steps.build_wheels.outputs.WHEEL_PATH }} -w flash-attention/dist_manylinux || echo "Auditwheel repair failed"
|
||||
|
||||
wheel_path_manylinux=$(ls flash-attention/dist_manylinux/*manylinux*.whl 2>/dev/null | head -n 1 || echo "")
|
||||
if [ -n "$wheel_path_manylinux" ]; then
|
||||
echo "WHEEL_PATH_MANYLINUX=$wheel_path_manylinux" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Upload manylinux wheel
|
||||
if: ${{ inputs.is-upload && steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}
|
||||
continue-on-error: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="${{ steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}"
|
||||
|
||||
if [ -f "$wheel_path" ]; then
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
fi
|
||||
@@ -123,6 +123,29 @@ jobs:
|
||||
container-image: "quay.io/pypa/manylinux_2_28_aarch64"
|
||||
secrets: inherit
|
||||
|
||||
# #########################################################
|
||||
# ROCm
|
||||
# #########################################################
|
||||
|
||||
build_wheels_rocm:
|
||||
name: Build ROCm
|
||||
needs: [create_releases, create_matrix]
|
||||
if: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.flash-attn-version }}
|
||||
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.python-version }}
|
||||
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.torch-version }}
|
||||
rocm-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.rocm-version }}
|
||||
uses: ./.github/workflows/_build_rocm.yml
|
||||
with:
|
||||
flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
torch-version: ${{ matrix.torch-version }}
|
||||
rocm-version: ${{ matrix.rocm-version }}
|
||||
secrets: inherit
|
||||
|
||||
# #########################################################
|
||||
# Windows
|
||||
# #########################################################
|
||||
@@ -199,6 +222,7 @@ jobs:
|
||||
- build_wheels_linux_arm64
|
||||
- build_wheels_linux_self_hosted
|
||||
- build_wheels_linux_arm64_self_hosted
|
||||
- build_wheels_rocm
|
||||
- build_wheels_windows
|
||||
- build_wheels_windows_code_build
|
||||
- build_wheels_windows_self_hosted
|
||||
@@ -228,6 +252,7 @@ jobs:
|
||||
- build_wheels_linux
|
||||
- build_wheels_linux_arm64
|
||||
- build_wheels_linux_self_hosted
|
||||
- build_wheels_rocm
|
||||
- build_wheels_linux_arm64_self_hosted
|
||||
- build_wheels_windows
|
||||
- build_wheels_windows_code_build
|
||||
|
||||
+134
@@ -0,0 +1,134 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Parameters with defaults
|
||||
FLASH_ATTN_VERSION=$1
|
||||
PYTHON_VERSION=$2
|
||||
TORCH_VERSION=$3
|
||||
ROCM_VERSION=$4
|
||||
|
||||
echo "Building Flash Attention for ROCm with parameters:"
|
||||
echo " Flash-Attention: $FLASH_ATTN_VERSION"
|
||||
echo " Python: $PYTHON_VERSION"
|
||||
echo " PyTorch: $TORCH_VERSION"
|
||||
echo " ROCm: $ROCM_VERSION"
|
||||
|
||||
# Set ROCm and PyTorch versions
|
||||
MATRIX_ROCM_VERSION=$(echo $ROCM_VERSION | awk -F \. {'print $1 "." $2'})
|
||||
MATRIX_TORCH_VERSION=$(echo $TORCH_VERSION | awk -F \. {'print $1 "." $2'})
|
||||
|
||||
echo "Derived versions:"
|
||||
echo " ROCm Matrix: $MATRIX_ROCM_VERSION"
|
||||
echo " Torch Matrix: $MATRIX_TORCH_VERSION"
|
||||
|
||||
# Install PyTorch for ROCm
|
||||
echo "Installing PyTorch $TORCH_VERSION for ROCm $MATRIX_ROCM_VERSION..."
|
||||
if [[ $TORCH_VERSION == *"dev"* ]]; then
|
||||
pip install --force-reinstall --no-cache-dir --pre torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/rocm${MATRIX_ROCM_VERSION}
|
||||
else
|
||||
pip install --force-reinstall --no-cache-dir torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/rocm${MATRIX_ROCM_VERSION}
|
||||
fi
|
||||
|
||||
# Install additional dependencies
|
||||
echo "Installing build dependencies..."
|
||||
pip install ninja packaging
|
||||
|
||||
# Verify installation
|
||||
echo "Verifying installations..."
|
||||
python -V
|
||||
python -c "import torch; print('PyTorch:', torch.__version__)"
|
||||
python -c "import torch; print('ROCm:', torch.version.hip if hasattr(torch.version, 'hip') else 'Not found')"
|
||||
|
||||
# Display ROCm information
|
||||
if command -v rocm-smi &> /dev/null; then
|
||||
echo "ROCm SMI:"
|
||||
rocm-smi --showproductname || true
|
||||
fi
|
||||
|
||||
if command -v rocminfo &> /dev/null; then
|
||||
echo "Available ROCm devices:"
|
||||
rocminfo | grep -E "Name:|Marketing Name:" || true
|
||||
fi
|
||||
|
||||
# Checkout flash-attn
|
||||
echo "Checking out flash-attention v$FLASH_ATTN_VERSION..."
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git -b "v$FLASH_ATTN_VERSION"
|
||||
|
||||
# Determine MAX_JOBS based on system resources
|
||||
NUM_THREADS=$(nproc)
|
||||
RAM_GB=$(free -g | awk '/^Mem:/{print $2}')
|
||||
echo "System resources:"
|
||||
echo " CPU threads: $NUM_THREADS"
|
||||
echo " RAM: ${RAM_GB}GB"
|
||||
|
||||
# Calculate MAX_JOBS based on available resources
|
||||
# ROCm builds are memory intensive, so we use conservative estimates
|
||||
if [[ -z "${MAX_JOBS:-}" ]]; then
|
||||
MAX_PRODUCT_CPU=$NUM_THREADS
|
||||
MAX_PRODUCT_RAM=$(awk -v ram="$RAM_GB" 'BEGIN {print int(ram / 4)}')
|
||||
MAX_JOBS=$((MAX_PRODUCT_CPU < MAX_PRODUCT_RAM ? MAX_PRODUCT_CPU : MAX_PRODUCT_RAM))
|
||||
|
||||
# Ensure minimum values
|
||||
MAX_JOBS=$((MAX_JOBS < 1 ? 1 : MAX_JOBS))
|
||||
|
||||
# Cap at 8 to avoid overwhelming the system
|
||||
MAX_JOBS=$((MAX_JOBS > 8 ? 8 : MAX_JOBS))
|
||||
fi
|
||||
|
||||
echo "Build parallelism settings:"
|
||||
echo " MAX_JOBS: $MAX_JOBS"
|
||||
|
||||
# Detect GPU architecture if available
|
||||
if command -v rocminfo &> /dev/null; then
|
||||
GPU_ARCH=$(rocminfo | grep -o -m 1 'gfx[0-9a-z]*' | head -n 1)
|
||||
if [ -n "$GPU_ARCH" ]; then
|
||||
echo "Detected GPU architecture: $GPU_ARCH"
|
||||
export PYTORCH_ROCM_ARCH=$GPU_ARCH
|
||||
fi
|
||||
fi
|
||||
|
||||
# Set default architectures if not detected
|
||||
# Common ROCm architectures: gfx90a (MI200), gfx942 (MI300), gfx1030, gfx1100
|
||||
if [ -z "${PYTORCH_ROCM_ARCH:-}" ]; then
|
||||
# Build for multiple common architectures
|
||||
export PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx1030;gfx1100"
|
||||
echo "No GPU detected, building for multiple architectures: $PYTORCH_ROCM_ARCH"
|
||||
fi
|
||||
|
||||
# Build wheels with Composable Kernel (CK) backend (default for AMD)
|
||||
echo "Building wheels with Composable Kernel backend..."
|
||||
cd flash-attention
|
||||
|
||||
LOCAL_VERSION_LABEL="rocm${MATRIX_ROCM_VERSION//./}torch${MATRIX_TORCH_VERSION}"
|
||||
|
||||
# Disable Triton backend to use Composable Kernel (CK) backend
|
||||
export FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
|
||||
export FLASH_ATTENTION_FORCE_BUILD=TRUE
|
||||
export FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL}
|
||||
|
||||
MAX_JOBS=$MAX_JOBS time python setup.py bdist_wheel --dist-dir=dist
|
||||
|
||||
wheel_name=$(basename $(ls dist/*.whl | head -n 1))
|
||||
echo "Built wheel: $wheel_name"
|
||||
|
||||
# Optional: Also build Triton backend if requested
|
||||
if [ "${BUILD_TRITON_BACKEND:-false}" = "true" ]; then
|
||||
echo "Building additional wheel with Triton backend..."
|
||||
|
||||
# Clean previous build
|
||||
python setup.py clean
|
||||
rm -rf build dist
|
||||
|
||||
LOCAL_VERSION_LABEL_TRITON="rocm${MATRIX_ROCM_VERSION//./}torch${MATRIX_TORCH_VERSION}triton"
|
||||
|
||||
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
|
||||
export FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL_TRITON}
|
||||
|
||||
MAX_JOBS=$MAX_JOBS time python setup.py bdist_wheel --dist-dir=dist_triton
|
||||
|
||||
wheel_name_triton=$(basename $(ls dist_triton/*.whl | head -n 1))
|
||||
echo "Built Triton wheel: $wheel_name_triton"
|
||||
fi
|
||||
|
||||
echo "Build complete!"
|
||||
@@ -245,6 +245,31 @@ WINDOWS_SELF_HOSTED_MATRIX = {
|
||||
],
|
||||
}
|
||||
|
||||
ROCM_MATRIX = {
|
||||
"flash-attn-version": [
|
||||
#"2.6.3",
|
||||
#"2.7.4",
|
||||
"2.8.3",
|
||||
],
|
||||
"python-version": [
|
||||
#"3.10",
|
||||
# "3.11",
|
||||
#"3.12",
|
||||
"3.13",
|
||||
"3.14",
|
||||
],
|
||||
"torch-version": [
|
||||
# "2.5.1",
|
||||
# "2.6.0",
|
||||
# "2.7.1",
|
||||
"2.9.1",
|
||||
],
|
||||
"rocm-version": [
|
||||
"7.1",
|
||||
"7.2",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
print(
|
||||
@@ -271,6 +296,9 @@ def main():
|
||||
"windows_code_build": False,
|
||||
# "windows_code_build": WINDOWS_CODEBUILD_MATRIX,
|
||||
#
|
||||
#"rocm": False,
|
||||
"rocm": ROCM_MATRIX,
|
||||
#
|
||||
"exclude": EXCLUDE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
# ROCm Build Setup for Flash Attention Wheels
|
||||
|
||||
This document describes the ROCm build infrastructure added to the flash-attention-prebuild-wheels repository.
|
||||
|
||||
## Overview
|
||||
|
||||
ROCm (Radeon Open Compute) support has been added to enable building Flash Attention wheels for AMD GPUs. The implementation follows the same pattern as existing CUDA builds but uses ROCm Docker containers and AMD-specific build configurations.
|
||||
|
||||
## Created Files
|
||||
|
||||
### 1. `build_rocm.sh`
|
||||
Build script for creating Flash Attention wheels with ROCm support.
|
||||
|
||||
**Features:**
|
||||
- Installs PyTorch for ROCm from the official PyTorch ROCm index
|
||||
- Supports both Composable Kernel (CK) and Triton backends
|
||||
- Auto-detects AMD GPU architecture (`gfx90a`, `gfx942`, etc.)
|
||||
- Falls back to building for multiple common architectures when no GPU is detected
|
||||
- Configures build parallelism based on available system resources
|
||||
- Creates wheels with version labels like `rocm61torch2.5`
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
./build_rocm.sh <flash-attn-version> <python-version> <torch-version> <rocm-version>
|
||||
|
||||
# Example:
|
||||
./build_rocm.sh 2.8.3 3.11 2.5.1 6.2
|
||||
```
|
||||
|
||||
**Environment Variables:**
|
||||
- `BUILD_TRITON_BACKEND=true` - Also build Triton backend variant (optional)
|
||||
- `PYTORCH_ROCM_ARCH` - Specify target GPU architectures (auto-detected if not set)
|
||||
- `MAX_JOBS` - Override parallel build jobs
|
||||
|
||||
### 2. `.github/workflows/_build_rocm.yml`
|
||||
Reusable GitHub Actions workflow for ROCm builds.
|
||||
|
||||
**Features:**
|
||||
- Uses official ROCm PyTorch Docker containers (`rocm/pytorch:rocm*`)
|
||||
- Supports multiple Python versions via deadsnakes PPA
|
||||
- Includes GPU device passthrough for container (`/dev/kfd`, `/dev/dri`)
|
||||
- Tests wheel installation before upload
|
||||
- Optionally builds both CK and Triton backend wheels
|
||||
- Applies `auditwheel repair` for distribution compatibility
|
||||
|
||||
**Inputs:**
|
||||
- `flash-attn-version` - Flash Attention version to build (required)
|
||||
- `python-version` - Python version (required)
|
||||
- `torch-version` - PyTorch version (required)
|
||||
- `rocm-version` - ROCm version (required, e.g., "6.1" or "6.2")
|
||||
- `runner` - Runner type (default: "ubuntu-22.04")
|
||||
- `is-upload` - Whether to upload to GitHub Releases (default: true)
|
||||
- `build-triton-backend` - Also build Triton backend (default: false)
|
||||
|
||||
### 3. Updated `create_matrix.py`
|
||||
Added ROCm build matrix configuration.
|
||||
|
||||
**ROCm Matrix:**
|
||||
```python
|
||||
ROCM_MATRIX = {
|
||||
"flash-attn-version": ["2.6.3", "2.7.4", "2.8.3"],
|
||||
"python-version": ["3.10", "3.11", "3.12"],
|
||||
"torch-version": ["2.5.1", "2.6.0"],
|
||||
"rocm-version": ["6.1", "6.2"],
|
||||
}
|
||||
```
|
||||
|
||||
**Note:** Set `"rocm": ROCM_MATRIX` in the `main()` function to enable ROCm builds.
|
||||
|
||||
### 4. Updated `.github/workflows/build.yml`
|
||||
Integrated ROCm builds into the main build workflow.
|
||||
|
||||
**Changes:**
|
||||
- Added `build_wheels_rocm` job that uses the `_build_rocm.yml` workflow
|
||||
- Updated `update_release_notes` and `update_docs` jobs to include ROCm builds in dependencies
|
||||
|
||||
## ROCm Architecture Support
|
||||
|
||||
The build system supports the following AMD GPU architectures:
|
||||
|
||||
- **gfx90a** - AMD Instinct MI200 series (MI210, MI250, MI250X)
|
||||
- **gfx942** - AMD Instinct MI300 series (MI300A, MI300X)
|
||||
- **gfx1030** - AMD Radeon RX 6000 series
|
||||
- **gfx1100** - AMD Radeon RX 7000 series
|
||||
|
||||
When building in a GitHub Actions runner without AMD GPUs, the build script compiles for all common architectures to ensure broad compatibility.
|
||||
|
||||
## Backend Options
|
||||
|
||||
Flash Attention 2 on ROCm supports two backends:
|
||||
|
||||
### Composable Kernel (CK) - Default
|
||||
- Native AMD backend optimized for ROCm
|
||||
- Better integration with AMD GPU features
|
||||
- Default choice for production builds
|
||||
- Controlled by `FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"`
|
||||
|
||||
### Triton
|
||||
- OpenAI Triton backend
|
||||
- Python-friendly kernel implementation
|
||||
- Useful for experimental features
|
||||
- Controlled by `FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"`
|
||||
|
||||
Wheels are labeled to distinguish backends (e.g., `rocm62torch2.5` vs `rocm62torch2.5triton`).
|
||||
|
||||
## Wheel Naming Convention
|
||||
|
||||
ROCm wheels follow this naming pattern:
|
||||
```
|
||||
flash_attn-{version}+rocm{rocm_version}torch{torch_version}-cp{python_ver}-cp{python_ver}-linux_x86_64.whl
|
||||
|
||||
Example:
|
||||
flash_attn-2.8.3+rocm62torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
```
|
||||
|
||||
For Triton backend:
|
||||
```
|
||||
flash_attn-2.8.3+rocm62torch2.5triton-cp311-cp311-linux_x86_64.whl
|
||||
```
|
||||
|
||||
## Enabling ROCm Builds
|
||||
|
||||
To enable ROCm builds in the GitHub Actions workflow:
|
||||
|
||||
1. Open `create_matrix.py`
|
||||
2. Find the `main()` function
|
||||
3. Change `"rocm": False` to `"rocm": ROCM_MATRIX`
|
||||
4. Adjust the matrix values as needed for your requirements
|
||||
5. Commit and create a release tag (e.g., `v0.1.0`)
|
||||
|
||||
The workflow will automatically:
|
||||
- Create a GitHub release
|
||||
- Build wheels for all matrix combinations
|
||||
- Upload wheels to the release
|
||||
- Update documentation
|
||||
|
||||
## Testing ROCm Wheels
|
||||
|
||||
After building, test the wheel:
|
||||
|
||||
```bash
|
||||
pip install flash_attn-2.8.3+rocm62torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
python -c "import flash_attn_2_cuda; print('Flash Attention ROCm loaded successfully')"
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
### For GitHub Actions:
|
||||
- Ubuntu 22.04 runners (GitHub-hosted or self-hosted)
|
||||
- Docker support
|
||||
- Sufficient disk space (~20GB per build)
|
||||
- ROCm-capable GPU (optional, builds work without GPU)
|
||||
|
||||
### For Local Builds:
|
||||
- ROCm installation (6.1+ recommended)
|
||||
- ROCm-enabled PyTorch
|
||||
- Build dependencies: `git`, `ninja-build`, `cmake`, `python3-dev`
|
||||
- AMD GPU (optional, for architecture detection)
|
||||
|
||||
## Docker Container Details
|
||||
|
||||
The workflow uses official ROCm PyTorch containers:
|
||||
```
|
||||
rocm/pytorch:rocm{version}_ubuntu22.04_py3.10_pytorch_2.1.2
|
||||
```
|
||||
|
||||
Device passthrough options:
|
||||
```yaml
|
||||
--device=/dev/kfd
|
||||
--device=/dev/dri
|
||||
--group-add video
|
||||
--cap-add=SYS_PTRACE
|
||||
--security-opt seccomp=unconfined
|
||||
--ipc=host
|
||||
```
|
||||
|
||||
These options enable GPU access from within the container, though the builds can complete without physical AMD GPUs present.
|
||||
|
||||
## References
|
||||
|
||||
- [AMD ROCm Documentation](https://rocm.docs.amd.com/)
|
||||
- [Flash Attention on ROCm](https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.html)
|
||||
- [ROCm PyTorch Docker Images](https://hub.docker.com/r/rocm/pytorch)
|
||||
- [Flash Attention GitHub](https://github.com/Dao-AILab/flash-attention)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Build Failures
|
||||
|
||||
1. **Out of memory**: Reduce `MAX_JOBS` in the build script
|
||||
2. **Architecture mismatch**: Set `PYTORCH_ROCM_ARCH` explicitly
|
||||
3. **PyTorch version issues**: Verify ROCm version compatibility with PyTorch
|
||||
|
||||
### Common Issues
|
||||
|
||||
- **Container GPU access**: Ensure Docker has proper device permissions
|
||||
- **Python version not found**: The deadsnakes PPA may not have all versions immediately
|
||||
- **Wheel incompatibility**: Use `auditwheel repair` output for broader compatibility
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements:
|
||||
- Self-hosted runners with AMD GPUs for faster builds
|
||||
- ROCm version auto-detection
|
||||
- Multi-architecture wheel bundling
|
||||
- Performance benchmarking integration
|
||||
- Support for newer ROCm versions as they're released
|
||||
Reference in New Issue
Block a user