first attempt

This commit is contained in:
John Doe
2026-02-10 20:49:18 -05:00
parent 5dbee4edb4
commit 6ac2a69fa9
5 changed files with 569 additions and 0 deletions
+174
View File
@@ -0,0 +1,174 @@
# #########################################################
# Build wheels for ROCm using Docker
# #########################################################
name: "[ROCm] Build wheels and upload to GitHub Releases"
on:
workflow_call:
inputs:
flash-attn-version:
description: "Flash-Attention version"
required: true
type: string
python-version:
description: "Python version"
required: true
type: string
torch-version:
description: "PyTorch version"
required: true
type: string
rocm-version:
description: "ROCm version"
required: true
type: string
runner:
description: "Runner type"
required: false
type: string
default: "ubuntu-22.04"
is-upload:
description: "Whether to upload the release asset"
required: false
type: boolean
default: true
build-triton-backend:
description: "Whether to also build Triton backend"
required: false
type: boolean
default: false
jobs:
build_wheels:
name: Build wheels and Upload (ROCm)
runs-on: ${{ inputs.runner }}
container:
image: rocm/pytorch:rocm${{ inputs.rocm-version }}_ubuntu22.04_py3.10_pytorch_2.1.2
options: --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host
env:
DEBIAN_FRONTEND: noninteractive
TERM: xterm-256color
BUILD_TRITON_BACKEND: ${{ inputs.build-triton-backend }}
steps:
- uses: actions/checkout@v4
- name: Install system dependencies
run: |
apt-get update
apt-get install -y git ninja-build time patchelf wget curl
- name: Setup Python
run: |
# Install deadsnakes PPA for multiple Python versions
apt-get update
apt-get install -y software-properties-common
add-apt-repository -y ppa:deadsnakes/ppa
apt-get update
# Install requested Python version
PYTHON_PKG="python${{ inputs.python-version }}"
apt-get install -y ${PYTHON_PKG} ${PYTHON_PKG}-dev ${PYTHON_PKG}-venv
# Set as default python
update-alternatives --install /usr/bin/python python /usr/bin/${PYTHON_PKG} 1
update-alternatives --install /usr/bin/python3 python3 /usr/bin/${PYTHON_PKG} 1
# Install pip
curl -sS https://bootstrap.pypa.io/get-pip.py | python
python -m pip install --upgrade pip setuptools wheel
- name: Verify ROCm installation
run: |
echo "ROCm version:"
if command -v rocm-smi &> /dev/null; then
rocm-smi --showproductname || echo "ROCm SMI not available in container"
fi
if command -v rocminfo &> /dev/null; then
rocminfo | grep -E "Name:|Marketing Name:" | head -5 || echo "ROCminfo not fully available"
fi
- name: Build wheels
id: build_wheels
run: |
chmod +x build_rocm.sh
./build_rocm.sh ${{ inputs.flash-attn-version }} ${{ inputs.python-version }} ${{ inputs.torch-version }} ${{ inputs.rocm-version }}
wheel_path=$(ls flash-attention/dist/*.whl | head -n 1)
echo "WHEEL_PATH=$wheel_path" >> $GITHUB_OUTPUT
# Check for Triton backend wheel if built
if [ -d "flash-attention/dist_triton" ]; then
wheel_path_triton=$(ls flash-attention/dist_triton/*.whl | head -n 1)
echo "WHEEL_PATH_TRITON=$wheel_path_triton" >> $GITHUB_OUTPUT
fi
- name: Test wheel installation
run: |
pip uninstall -y flash-attn > /dev/null 2>&1 || true
pip install --no-cache-dir ${{ steps.build_wheels.outputs.WHEEL_PATH }}
python -c "import flash_attn; print('Flash Attention version:', flash_attn.__version__)"
python -c "import flash_attn_2_cuda as flash_attn_cuda; print('Flash Attention CUDA module loaded successfully')" || echo "CUDA module test skipped"
- name: Upload Release Asset (CK Backend)
if: ${{ inputs.is-upload }}
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
tag_name=${{ github.ref_name }}
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH }}"
# Check if the file exists
if [ ! -f "$wheel_path" ]; then
echo "Error: Wheel file not found at $wheel_path"
exit 1
fi
# Upload the release asset using GitHub CLI
gh release upload "$tag_name" "$wheel_path" --clobber
- name: Upload Release Asset (Triton Backend)
if: ${{ inputs.is-upload && inputs.build-triton-backend && steps.build_wheels.outputs.WHEEL_PATH_TRITON }}
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
tag_name=${{ github.ref_name }}
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH_TRITON }}"
# Check if the file exists
if [ ! -f "$wheel_path" ]; then
echo "Warning: Triton wheel file not found at $wheel_path"
exit 0
fi
# Upload the release asset using GitHub CLI
gh release upload "$tag_name" "$wheel_path" --clobber
- name: Apply auditwheel repair
continue-on-error: true
id: auditwheel_repair
run: |
pip install auditwheel
auditwheel show ${{ steps.build_wheels.outputs.WHEEL_PATH }} || echo "Auditwheel show failed"
auditwheel repair \
--exclude libc10* --exclude libtorch* --exclude libamdhip* --exclude librocm* --exclude libhsa* \
--exclude libMIOpen* --exclude librccl* --exclude librocblas* --exclude libhipblas* \
${{ steps.build_wheels.outputs.WHEEL_PATH }} -w flash-attention/dist_manylinux || echo "Auditwheel repair failed"
wheel_path_manylinux=$(ls flash-attention/dist_manylinux/*manylinux*.whl 2>/dev/null | head -n 1 || echo "")
if [ -n "$wheel_path_manylinux" ]; then
echo "WHEEL_PATH_MANYLINUX=$wheel_path_manylinux" >> $GITHUB_OUTPUT
fi
- name: Upload manylinux wheel
if: ${{ inputs.is-upload && steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}
continue-on-error: true
env:
GITHUB_TOKEN: ${{ github.token }}
run: |
tag_name=${{ github.ref_name }}
wheel_path="${{ steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}"
if [ -f "$wheel_path" ]; then
gh release upload "$tag_name" "$wheel_path" --clobber
fi
+25
View File
@@ -123,6 +123,29 @@ jobs:
container-image: "quay.io/pypa/manylinux_2_28_aarch64"
secrets: inherit
# #########################################################
# ROCm
# #########################################################
build_wheels_rocm:
name: Build ROCm
needs: [create_releases, create_matrix]
if: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm }}
strategy:
fail-fast: false
matrix:
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.flash-attn-version }}
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.python-version }}
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.torch-version }}
rocm-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.rocm-version }}
uses: ./.github/workflows/_build_rocm.yml
with:
flash-attn-version: ${{ matrix.flash-attn-version }}
python-version: ${{ matrix.python-version }}
torch-version: ${{ matrix.torch-version }}
rocm-version: ${{ matrix.rocm-version }}
secrets: inherit
# #########################################################
# Windows
# #########################################################
@@ -199,6 +222,7 @@ jobs:
- build_wheels_linux_arm64
- build_wheels_linux_self_hosted
- build_wheels_linux_arm64_self_hosted
- build_wheels_rocm
- build_wheels_windows
- build_wheels_windows_code_build
- build_wheels_windows_self_hosted
@@ -228,6 +252,7 @@ jobs:
- build_wheels_linux
- build_wheels_linux_arm64
- build_wheels_linux_self_hosted
- build_wheels_rocm
- build_wheels_linux_arm64_self_hosted
- build_wheels_windows
- build_wheels_windows_code_build
+134
View File
@@ -0,0 +1,134 @@
#!/bin/bash
set -e
# Parameters with defaults
FLASH_ATTN_VERSION=$1
PYTHON_VERSION=$2
TORCH_VERSION=$3
ROCM_VERSION=$4
echo "Building Flash Attention for ROCm with parameters:"
echo " Flash-Attention: $FLASH_ATTN_VERSION"
echo " Python: $PYTHON_VERSION"
echo " PyTorch: $TORCH_VERSION"
echo " ROCm: $ROCM_VERSION"
# Set ROCm and PyTorch versions
MATRIX_ROCM_VERSION=$(echo $ROCM_VERSION | awk -F \. {'print $1 "." $2'})
MATRIX_TORCH_VERSION=$(echo $TORCH_VERSION | awk -F \. {'print $1 "." $2'})
echo "Derived versions:"
echo " ROCm Matrix: $MATRIX_ROCM_VERSION"
echo " Torch Matrix: $MATRIX_TORCH_VERSION"
# Install PyTorch for ROCm
echo "Installing PyTorch $TORCH_VERSION for ROCm $MATRIX_ROCM_VERSION..."
if [[ $TORCH_VERSION == *"dev"* ]]; then
pip install --force-reinstall --no-cache-dir --pre torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/rocm${MATRIX_ROCM_VERSION}
else
pip install --force-reinstall --no-cache-dir torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/rocm${MATRIX_ROCM_VERSION}
fi
# Install additional dependencies
echo "Installing build dependencies..."
pip install ninja packaging
# Verify installation
echo "Verifying installations..."
python -V
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('ROCm:', torch.version.hip if hasattr(torch.version, 'hip') else 'Not found')"
# Display ROCm information
if command -v rocm-smi &> /dev/null; then
echo "ROCm SMI:"
rocm-smi --showproductname || true
fi
if command -v rocminfo &> /dev/null; then
echo "Available ROCm devices:"
rocminfo | grep -E "Name:|Marketing Name:" || true
fi
# Checkout flash-attn
echo "Checking out flash-attention v$FLASH_ATTN_VERSION..."
git clone https://github.com/Dao-AILab/flash-attention.git -b "v$FLASH_ATTN_VERSION"
# Determine MAX_JOBS based on system resources
NUM_THREADS=$(nproc)
RAM_GB=$(free -g | awk '/^Mem:/{print $2}')
echo "System resources:"
echo " CPU threads: $NUM_THREADS"
echo " RAM: ${RAM_GB}GB"
# Calculate MAX_JOBS based on available resources
# ROCm builds are memory intensive, so we use conservative estimates
if [[ -z "${MAX_JOBS:-}" ]]; then
MAX_PRODUCT_CPU=$NUM_THREADS
MAX_PRODUCT_RAM=$(awk -v ram="$RAM_GB" 'BEGIN {print int(ram / 4)}')
MAX_JOBS=$((MAX_PRODUCT_CPU < MAX_PRODUCT_RAM ? MAX_PRODUCT_CPU : MAX_PRODUCT_RAM))
# Ensure minimum values
MAX_JOBS=$((MAX_JOBS < 1 ? 1 : MAX_JOBS))
# Cap at 8 to avoid overwhelming the system
MAX_JOBS=$((MAX_JOBS > 8 ? 8 : MAX_JOBS))
fi
echo "Build parallelism settings:"
echo " MAX_JOBS: $MAX_JOBS"
# Detect GPU architecture if available
if command -v rocminfo &> /dev/null; then
GPU_ARCH=$(rocminfo | grep -o -m 1 'gfx[0-9a-z]*' | head -n 1)
if [ -n "$GPU_ARCH" ]; then
echo "Detected GPU architecture: $GPU_ARCH"
export PYTORCH_ROCM_ARCH=$GPU_ARCH
fi
fi
# Set default architectures if not detected
# Common ROCm architectures: gfx90a (MI200), gfx942 (MI300), gfx1030, gfx1100
if [ -z "${PYTORCH_ROCM_ARCH:-}" ]; then
# Build for multiple common architectures
export PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx1030;gfx1100"
echo "No GPU detected, building for multiple architectures: $PYTORCH_ROCM_ARCH"
fi
# Build wheels with Composable Kernel (CK) backend (default for AMD)
echo "Building wheels with Composable Kernel backend..."
cd flash-attention
LOCAL_VERSION_LABEL="rocm${MATRIX_ROCM_VERSION//./}torch${MATRIX_TORCH_VERSION}"
# Disable Triton backend to use Composable Kernel (CK) backend
export FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
export FLASH_ATTENTION_FORCE_BUILD=TRUE
export FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL}
MAX_JOBS=$MAX_JOBS time python setup.py bdist_wheel --dist-dir=dist
wheel_name=$(basename $(ls dist/*.whl | head -n 1))
echo "Built wheel: $wheel_name"
# Optional: Also build Triton backend if requested
if [ "${BUILD_TRITON_BACKEND:-false}" = "true" ]; then
echo "Building additional wheel with Triton backend..."
# Clean previous build
python setup.py clean
rm -rf build dist
LOCAL_VERSION_LABEL_TRITON="rocm${MATRIX_ROCM_VERSION//./}torch${MATRIX_TORCH_VERSION}triton"
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL_TRITON}
MAX_JOBS=$MAX_JOBS time python setup.py bdist_wheel --dist-dir=dist_triton
wheel_name_triton=$(basename $(ls dist_triton/*.whl | head -n 1))
echo "Built Triton wheel: $wheel_name_triton"
fi
echo "Build complete!"
+28
View File
@@ -245,6 +245,31 @@ WINDOWS_SELF_HOSTED_MATRIX = {
],
}
ROCM_MATRIX = {
"flash-attn-version": [
#"2.6.3",
#"2.7.4",
"2.8.3",
],
"python-version": [
#"3.10",
# "3.11",
#"3.12",
"3.13",
"3.14",
],
"torch-version": [
# "2.5.1",
# "2.6.0",
# "2.7.1",
"2.9.1",
],
"rocm-version": [
"7.1",
"7.2",
],
}
def main():
print(
@@ -271,6 +296,9 @@ def main():
"windows_code_build": False,
# "windows_code_build": WINDOWS_CODEBUILD_MATRIX,
#
#"rocm": False,
"rocm": ROCM_MATRIX,
#
"exclude": EXCLUDE,
}
)
+208
View File
@@ -0,0 +1,208 @@
# ROCm Build Setup for Flash Attention Wheels
This document describes the ROCm build infrastructure added to the flash-attention-prebuild-wheels repository.
## Overview
ROCm (Radeon Open Compute) support has been added to enable building Flash Attention wheels for AMD GPUs. The implementation follows the same pattern as existing CUDA builds but uses ROCm Docker containers and AMD-specific build configurations.
## Created Files
### 1. `build_rocm.sh`
Build script for creating Flash Attention wheels with ROCm support.
**Features:**
- Installs PyTorch for ROCm from the official PyTorch ROCm index
- Supports both Composable Kernel (CK) and Triton backends
- Auto-detects AMD GPU architecture (`gfx90a`, `gfx942`, etc.)
- Falls back to building for multiple common architectures when no GPU is detected
- Configures build parallelism based on available system resources
- Creates wheels with version labels like `rocm61torch2.5`
**Usage:**
```bash
./build_rocm.sh <flash-attn-version> <python-version> <torch-version> <rocm-version>
# Example:
./build_rocm.sh 2.8.3 3.11 2.5.1 6.2
```
**Environment Variables:**
- `BUILD_TRITON_BACKEND=true` - Also build Triton backend variant (optional)
- `PYTORCH_ROCM_ARCH` - Specify target GPU architectures (auto-detected if not set)
- `MAX_JOBS` - Override parallel build jobs
### 2. `.github/workflows/_build_rocm.yml`
Reusable GitHub Actions workflow for ROCm builds.
**Features:**
- Uses official ROCm PyTorch Docker containers (`rocm/pytorch:rocm*`)
- Supports multiple Python versions via deadsnakes PPA
- Includes GPU device passthrough for container (`/dev/kfd`, `/dev/dri`)
- Tests wheel installation before upload
- Optionally builds both CK and Triton backend wheels
- Applies `auditwheel repair` for distribution compatibility
**Inputs:**
- `flash-attn-version` - Flash Attention version to build (required)
- `python-version` - Python version (required)
- `torch-version` - PyTorch version (required)
- `rocm-version` - ROCm version (required, e.g., "6.1" or "6.2")
- `runner` - Runner type (default: "ubuntu-22.04")
- `is-upload` - Whether to upload to GitHub Releases (default: true)
- `build-triton-backend` - Also build Triton backend (default: false)
### 3. Updated `create_matrix.py`
Added ROCm build matrix configuration.
**ROCm Matrix:**
```python
ROCM_MATRIX = {
"flash-attn-version": ["2.6.3", "2.7.4", "2.8.3"],
"python-version": ["3.10", "3.11", "3.12"],
"torch-version": ["2.5.1", "2.6.0"],
"rocm-version": ["6.1", "6.2"],
}
```
**Note:** Set `"rocm": ROCM_MATRIX` in the `main()` function to enable ROCm builds.
### 4. Updated `.github/workflows/build.yml`
Integrated ROCm builds into the main build workflow.
**Changes:**
- Added `build_wheels_rocm` job that uses the `_build_rocm.yml` workflow
- Updated `update_release_notes` and `update_docs` jobs to include ROCm builds in dependencies
## ROCm Architecture Support
The build system supports the following AMD GPU architectures:
- **gfx90a** - AMD Instinct MI200 series (MI210, MI250, MI250X)
- **gfx942** - AMD Instinct MI300 series (MI300A, MI300X)
- **gfx1030** - AMD Radeon RX 6000 series
- **gfx1100** - AMD Radeon RX 7000 series
When building in a GitHub Actions runner without AMD GPUs, the build script compiles for all common architectures to ensure broad compatibility.
## Backend Options
Flash Attention 2 on ROCm supports two backends:
### Composable Kernel (CK) - Default
- Native AMD backend optimized for ROCm
- Better integration with AMD GPU features
- Default choice for production builds
- Controlled by `FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"`
### Triton
- OpenAI Triton backend
- Python-friendly kernel implementation
- Useful for experimental features
- Controlled by `FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"`
Wheels are labeled to distinguish backends (e.g., `rocm62torch2.5` vs `rocm62torch2.5triton`).
## Wheel Naming Convention
ROCm wheels follow this naming pattern:
```
flash_attn-{version}+rocm{rocm_version}torch{torch_version}-cp{python_ver}-cp{python_ver}-linux_x86_64.whl
Example:
flash_attn-2.8.3+rocm62torch2.5-cp311-cp311-linux_x86_64.whl
```
For Triton backend:
```
flash_attn-2.8.3+rocm62torch2.5triton-cp311-cp311-linux_x86_64.whl
```
## Enabling ROCm Builds
To enable ROCm builds in the GitHub Actions workflow:
1. Open `create_matrix.py`
2. Find the `main()` function
3. Change `"rocm": False` to `"rocm": ROCM_MATRIX`
4. Adjust the matrix values as needed for your requirements
5. Commit and create a release tag (e.g., `v0.1.0`)
The workflow will automatically:
- Create a GitHub release
- Build wheels for all matrix combinations
- Upload wheels to the release
- Update documentation
## Testing ROCm Wheels
After building, test the wheel:
```bash
pip install flash_attn-2.8.3+rocm62torch2.5-cp311-cp311-linux_x86_64.whl
python -c "import flash_attn; print(flash_attn.__version__)"
python -c "import flash_attn_2_cuda; print('Flash Attention ROCm loaded successfully')"
```
## Requirements
### For GitHub Actions:
- Ubuntu 22.04 runners (GitHub-hosted or self-hosted)
- Docker support
- Sufficient disk space (~20GB per build)
- ROCm-capable GPU (optional, builds work without GPU)
### For Local Builds:
- ROCm installation (6.1+ recommended)
- ROCm-enabled PyTorch
- Build dependencies: `git`, `ninja-build`, `cmake`, `python3-dev`
- AMD GPU (optional, for architecture detection)
## Docker Container Details
The workflow uses official ROCm PyTorch containers:
```
rocm/pytorch:rocm{version}_ubuntu22.04_py3.10_pytorch_2.1.2
```
Device passthrough options:
```yaml
--device=/dev/kfd
--device=/dev/dri
--group-add video
--cap-add=SYS_PTRACE
--security-opt seccomp=unconfined
--ipc=host
```
These options enable GPU access from within the container, though the builds can complete without physical AMD GPUs present.
## References
- [AMD ROCm Documentation](https://rocm.docs.amd.com/)
- [Flash Attention on ROCm](https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/model-acceleration-libraries.html)
- [ROCm PyTorch Docker Images](https://hub.docker.com/r/rocm/pytorch)
- [Flash Attention GitHub](https://github.com/Dao-AILab/flash-attention)
## Troubleshooting
### Build Failures
1. **Out of memory**: Reduce `MAX_JOBS` in the build script
2. **Architecture mismatch**: Set `PYTORCH_ROCM_ARCH` explicitly
3. **PyTorch version issues**: Verify ROCm version compatibility with PyTorch
### Common Issues
- **Container GPU access**: Ensure Docker has proper device permissions
- **Python version not found**: The deadsnakes PPA may not have all versions immediately
- **Wheel incompatibility**: Use `auditwheel repair` output for broader compatibility
## Future Enhancements
Potential improvements:
- Self-hosted runners with AMD GPUs for faster builds
- ROCm version auto-detection
- Multi-architecture wheel bundling
- Performance benchmarking integration
- Support for newer ROCm versions as they're released