mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:37:53 -04:00
first attempt
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
# #########################################################
|
||||
# Build wheels for ROCm using Docker
|
||||
# #########################################################
|
||||
|
||||
name: "[ROCm] Build wheels and upload to GitHub Releases"
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
flash-attn-version:
|
||||
description: "Flash-Attention version"
|
||||
required: true
|
||||
type: string
|
||||
python-version:
|
||||
description: "Python version"
|
||||
required: true
|
||||
type: string
|
||||
torch-version:
|
||||
description: "PyTorch version"
|
||||
required: true
|
||||
type: string
|
||||
rocm-version:
|
||||
description: "ROCm version"
|
||||
required: true
|
||||
type: string
|
||||
runner:
|
||||
description: "Runner type"
|
||||
required: false
|
||||
type: string
|
||||
default: "ubuntu-22.04"
|
||||
is-upload:
|
||||
description: "Whether to upload the release asset"
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
build-triton-backend:
|
||||
description: "Whether to also build Triton backend"
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
|
||||
jobs:
|
||||
build_wheels:
|
||||
name: Build wheels and Upload (ROCm)
|
||||
runs-on: ${{ inputs.runner }}
|
||||
container:
|
||||
image: rocm/pytorch:rocm${{ inputs.rocm-version }}_ubuntu22.04_py3.10_pytorch_2.1.2
|
||||
options: --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host
|
||||
env:
|
||||
DEBIAN_FRONTEND: noninteractive
|
||||
TERM: xterm-256color
|
||||
BUILD_TRITON_BACKEND: ${{ inputs.build-triton-backend }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y git ninja-build time patchelf wget curl
|
||||
|
||||
- name: Setup Python
|
||||
run: |
|
||||
# Install deadsnakes PPA for multiple Python versions
|
||||
apt-get update
|
||||
apt-get install -y software-properties-common
|
||||
add-apt-repository -y ppa:deadsnakes/ppa
|
||||
apt-get update
|
||||
|
||||
# Install requested Python version
|
||||
PYTHON_PKG="python${{ inputs.python-version }}"
|
||||
apt-get install -y ${PYTHON_PKG} ${PYTHON_PKG}-dev ${PYTHON_PKG}-venv
|
||||
|
||||
# Set as default python
|
||||
update-alternatives --install /usr/bin/python python /usr/bin/${PYTHON_PKG} 1
|
||||
update-alternatives --install /usr/bin/python3 python3 /usr/bin/${PYTHON_PKG} 1
|
||||
|
||||
# Install pip
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
- name: Verify ROCm installation
|
||||
run: |
|
||||
echo "ROCm version:"
|
||||
if command -v rocm-smi &> /dev/null; then
|
||||
rocm-smi --showproductname || echo "ROCm SMI not available in container"
|
||||
fi
|
||||
if command -v rocminfo &> /dev/null; then
|
||||
rocminfo | grep -E "Name:|Marketing Name:" | head -5 || echo "ROCminfo not fully available"
|
||||
fi
|
||||
|
||||
- name: Build wheels
|
||||
id: build_wheels
|
||||
run: |
|
||||
chmod +x build_rocm.sh
|
||||
./build_rocm.sh ${{ inputs.flash-attn-version }} ${{ inputs.python-version }} ${{ inputs.torch-version }} ${{ inputs.rocm-version }}
|
||||
wheel_path=$(ls flash-attention/dist/*.whl | head -n 1)
|
||||
echo "WHEEL_PATH=$wheel_path" >> $GITHUB_OUTPUT
|
||||
|
||||
# Check for Triton backend wheel if built
|
||||
if [ -d "flash-attention/dist_triton" ]; then
|
||||
wheel_path_triton=$(ls flash-attention/dist_triton/*.whl | head -n 1)
|
||||
echo "WHEEL_PATH_TRITON=$wheel_path_triton" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Test wheel installation
|
||||
run: |
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1 || true
|
||||
pip install --no-cache-dir ${{ steps.build_wheels.outputs.WHEEL_PATH }}
|
||||
python -c "import flash_attn; print('Flash Attention version:', flash_attn.__version__)"
|
||||
python -c "import flash_attn_2_cuda as flash_attn_cuda; print('Flash Attention CUDA module loaded successfully')" || echo "CUDA module test skipped"
|
||||
|
||||
- name: Upload Release Asset (CK Backend)
|
||||
if: ${{ inputs.is-upload }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH }}"
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$wheel_path" ]; then
|
||||
echo "Error: Wheel file not found at $wheel_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
|
||||
- name: Upload Release Asset (Triton Backend)
|
||||
if: ${{ inputs.is-upload && inputs.build-triton-backend && steps.build_wheels.outputs.WHEEL_PATH_TRITON }}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="${{ steps.build_wheels.outputs.WHEEL_PATH_TRITON }}"
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$wheel_path" ]; then
|
||||
echo "Warning: Triton wheel file not found at $wheel_path"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
continue-on-error: true
|
||||
id: auditwheel_repair
|
||||
run: |
|
||||
pip install auditwheel
|
||||
auditwheel show ${{ steps.build_wheels.outputs.WHEEL_PATH }} || echo "Auditwheel show failed"
|
||||
auditwheel repair \
|
||||
--exclude libc10* --exclude libtorch* --exclude libamdhip* --exclude librocm* --exclude libhsa* \
|
||||
--exclude libMIOpen* --exclude librccl* --exclude librocblas* --exclude libhipblas* \
|
||||
${{ steps.build_wheels.outputs.WHEEL_PATH }} -w flash-attention/dist_manylinux || echo "Auditwheel repair failed"
|
||||
|
||||
wheel_path_manylinux=$(ls flash-attention/dist_manylinux/*manylinux*.whl 2>/dev/null | head -n 1 || echo "")
|
||||
if [ -n "$wheel_path_manylinux" ]; then
|
||||
echo "WHEEL_PATH_MANYLINUX=$wheel_path_manylinux" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Upload manylinux wheel
|
||||
if: ${{ inputs.is-upload && steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}
|
||||
continue-on-error: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="${{ steps.auditwheel_repair.outputs.WHEEL_PATH_MANYLINUX }}"
|
||||
|
||||
if [ -f "$wheel_path" ]; then
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
fi
|
||||
@@ -123,6 +123,29 @@ jobs:
|
||||
container-image: "quay.io/pypa/manylinux_2_28_aarch64"
|
||||
secrets: inherit
|
||||
|
||||
# #########################################################
|
||||
# ROCm
|
||||
# #########################################################
|
||||
|
||||
build_wheels_rocm:
|
||||
name: Build ROCm
|
||||
needs: [create_releases, create_matrix]
|
||||
if: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.flash-attn-version }}
|
||||
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.python-version }}
|
||||
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.torch-version }}
|
||||
rocm-version: ${{ fromjson(needs.create_matrix.outputs.matrix).rocm.rocm-version }}
|
||||
uses: ./.github/workflows/_build_rocm.yml
|
||||
with:
|
||||
flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
torch-version: ${{ matrix.torch-version }}
|
||||
rocm-version: ${{ matrix.rocm-version }}
|
||||
secrets: inherit
|
||||
|
||||
# #########################################################
|
||||
# Windows
|
||||
# #########################################################
|
||||
@@ -199,6 +222,7 @@ jobs:
|
||||
- build_wheels_linux_arm64
|
||||
- build_wheels_linux_self_hosted
|
||||
- build_wheels_linux_arm64_self_hosted
|
||||
- build_wheels_rocm
|
||||
- build_wheels_windows
|
||||
- build_wheels_windows_code_build
|
||||
- build_wheels_windows_self_hosted
|
||||
@@ -228,6 +252,7 @@ jobs:
|
||||
- build_wheels_linux
|
||||
- build_wheels_linux_arm64
|
||||
- build_wheels_linux_self_hosted
|
||||
- build_wheels_rocm
|
||||
- build_wheels_linux_arm64_self_hosted
|
||||
- build_wheels_windows
|
||||
- build_wheels_windows_code_build
|
||||
|
||||
Reference in New Issue
Block a user