mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:27:54 -04:00
175 lines
6.7 KiB
YAML
175 lines
6.7 KiB
YAML
# #########################################################
|
|
# Build wheels for ROCm using Docker
|
|
# #########################################################
|
|
|
|
name: "[ROCm] Build wheels and upload to GitHub Releases"
|
|
|
|
on:
|
|
workflow_call:
|
|
inputs:
|
|
flash-attn-version:
|
|
description: "Flash-Attention version"
|
|
required: true
|
|
type: string
|
|
python-version:
|
|
description: "Python version"
|
|
required: true
|
|
type: string
|
|
torch-version:
|
|
description: "PyTorch version"
|
|
required: true
|
|
type: string
|
|
rocm-version:
|
|
description: "ROCm version"
|
|
required: true
|
|
type: string
|
|
runner:
|
|
description: "Runner type"
|
|
required: false
|
|
type: string
|
|
default: "ubuntu-22.04"
|
|
is-upload:
|
|
description: "Whether to upload the release asset"
|
|
required: false
|
|
type: boolean
|
|
default: true
|
|
build-triton-backend:
|
|
description: "Whether to also build Triton backend"
|
|
required: false
|
|
type: boolean
|
|
default: false
|
|
|
|
jobs:
|
|
build_wheels:
|
|
name: Build wheels and Upload (ROCm)
|
|
runs-on: ${{ inputs.runner }}
|
|
container:
|
|
image: ${{ inputs.rocm-version == '7.1.1' && 'rocm/pytorch:rocm7.1.1_ubuntu24.04_py3.12_pytorch_release_2.9.1' || 'rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1' }}
|
|
options: --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host
|
|
env:
|
|
DEBIAN_FRONTEND: noninteractive
|
|
TERM: xterm-256color
|
|
BUILD_TRITON_BACKEND: ${{ inputs.build-triton-backend }}
|
|
|
|
steps:
|
|
- uses: actions/checkout@v4
|
|
|
|
- name: Install system dependencies
|
|
run: |
|
|
apt-get update
|
|
apt-get install -y git ninja-build time patchelf wget curl
|
|
|
|
- name: Setup Python
|
|
run: |
|
|
# Install deadsnakes PPA for multiple Python versions
|
|
apt-get update
|
|
apt-get install -y software-properties-common
|
|
add-apt-repository -y ppa:deadsnakes/ppa
|
|
apt-get update
|
|
|
|
# Install requested Python version
|
|
PYTHON_PKG="python${{ inputs.python-version }}"
|
|
apt-get install -y ${PYTHON_PKG} ${PYTHON_PKG}-dev ${PYTHON_PKG}-venv
|
|
|
|
# Set as default python
|
|
update-alternatives --install /usr/bin/python python /usr/bin/${PYTHON_PKG} 1
|
|
update-alternatives --install /usr/bin/python3 python3 /usr/bin/${PYTHON_PKG} 1
|
|
|
|
# Install pip
|
|
curl -sS https://bootstrap.pypa.io/get-pip.py | python
|
|
python -m pip install --upgrade pip setuptools wheel
|
|
|
|
- name: Verify ROCm installation
|
|
run: |
|
|
echo "ROCm version:"
|
|
if command -v rocm-smi &> /dev/null; then
|
|
rocm-smi --showproductname || echo "ROCm SMI not available in container"
|
|
fi
|
|
if command -v rocminfo &> /dev/null; then
|
|
rocminfo | grep -E "Name:|Marketing Name:" | head -5 || echo "ROCminfo not fully available"
|
|
fi
|
|
|
|
- name: Build wheels
|
|
id: build_wheels
|
|
run: |
|
|
chmod +x build_rocm.sh
|
|
./build_rocm.sh ${{ inputs.flash-attn-version }} ${{ inputs.python-version }} ${{ inputs.torch-version }} ${{ inputs.rocm-version }}
|
|
wheel_path=$(ls flash-attention/dist/*.whl | head -n 1)
|
|
echo "WHEEL_PATH=$wheel_path" >> $GITHUB_OUTPUT
|
|
|
|
# Check for Triton backend wheel if built
|
|
if [ -d "flash-attention/dist_triton" ]; then
|
|
wheel_path_triton=$(ls flash-attention/dist_triton/*.whl | head -n 1)
|
|
echo "WHEEL_PATH_TRITON=$wheel_path_triton" >> $GITHUB_OUTPUT
|
|
fi
|
|
|
|
- name: Test wheel installation
|
|
run: |
|
|
pip uninstall -y flash-attn > /dev/null 2>&1 || true
|
|
pip install --no-cache-dir ${{ steps.build_wheels.outputs.WHEEL_PATH }}
|
|
python -c "import flash_attn; print('Flash Attention version:', flash_attn.__version__)"
|
|
python -c "import flash_attn_2_cuda as flash_attn_cuda; print('Flash Attention CUDA module loaded successfully')" || echo "CUDA module test skipped"
|
|
|
|
- name: Upload Release Asset (CK Backend)
|
|
if: ${{ inputs.is-upload }}
|
|
env:
|
|
GITHUB_TOKEN: ${{ github.token }}
|
|
run: |
|
|
tag_name=${{ github.ref_name }}
|
|
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH }}"
|
|
|
|
# Check if the file exists
|
|
if [ ! -f "$wheel_path" ]; then
|
|
echo "Error: Wheel file not found at $wheel_path"
|
|
exit 1
|
|
fi
|
|
|
|
# Upload the release asset using GitHub CLI
|
|
gh release upload "$tag_name" "$wheel_path" --clobber
|
|
|
|
- name: Upload Release Asset (Triton Backend)
|
|
if: ${{ inputs.is-upload && inputs.build-triton-backend && steps.build_wheels.outputs.WHEEL_PATH_TRITON }}
|
|
env:
|
|
GITHUB_TOKEN: ${{ github.token }}
|
|
run: |
|
|
tag_name=${{ github.ref_name }}
|
|
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH_TRITON }}"
|
|
|
|
# Check if the file exists
|
|
if [ ! -f "$wheel_path" ]; then
|
|
echo "Warning: Triton wheel file not found at $wheel_path"
|
|
exit 0
|
|
fi
|
|
|
|
# Upload the release asset using GitHub CLI
|
|
gh release upload "$tag_name" "$wheel_path" --clobber
|
|
|
|
- name: Apply auditwheel repair
|
|
continue-on-error: true
|
|
id: auditwheel_repair
|
|
run: |
|
|
pip install auditwheel
|
|
auditwheel show ${{ steps.build_wheels.outputs.WHEEL_PATH }} || echo "Auditwheel show failed"
|
|
auditwheel repair \
|
|
--exclude libc10* --exclude libtorch* --exclude libamdhip* --exclude librocm* --exclude libhsa* \
|
|
--exclude libMIOpen* --exclude librccl* --exclude librocblas* --exclude libhipblas* \
|
|
${{ steps.build_wheels.outputs.WHEEL_PATH }} -w flash-attention/dist_manylinux || echo "Auditwheel repair failed"
|
|
|
|
wheel_path_manylinux=$(ls flash-attention/dist_manylinux/*manylinux*.whl 2>/dev/null | head -n 1 || echo "")
|
|
if [ -n "$wheel_path_manylinux" ]; then
|
|
echo "WHEEL_PATH_MANYLINUX=$wheel_path_manylinux" >> $GITHUB_OUTPUT
|
|
fi
|
|
|
|
- name: Upload manylinux wheel
|
|
if: ${{ inputs.is-upload && steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}
|
|
continue-on-error: true
|
|
env:
|
|
GITHUB_TOKEN: ${{ github.token }}
|
|
run: |
|
|
tag_name=${{ github.ref_name }}
|
|
wheel_path="${{ steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}"
|
|
|
|
if [ -f "$wheel_path" ]; then
|
|
gh release upload "$tag_name" "$wheel_path" --clobber
|
|
fi
|