mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-06-30 23:57:53 -04:00
refactor: reorganize auditwheel repair and add manylinux platform support
- Move auditwheel repair after initial release upload with continue-on-error to allow pipeline continuation - Add manylinux platform normalization support in normalize_platform_name() - Expand self-hosted build matrix to include Python 3.14 and Flash Attention 2.8.3 - Improve wheel upload flow by separating regular and manylinux wheel handling
This commit is contained in:
@@ -65,7 +65,7 @@ jobs:
|
||||
- name: Install build dependencies
|
||||
run: |
|
||||
sudo apt install -y ninja-build clang time
|
||||
pip install -U pip setuptools==75.8.0 wheel packaging psutil auditwheel
|
||||
pip install -U pip setuptools==75.8.0 wheel packaging psutil
|
||||
|
||||
- name: Build wheels
|
||||
id: build_wheels
|
||||
@@ -75,23 +75,12 @@ jobs:
|
||||
wheel_name=$(basename $(ls flash-attention/dist/*.whl | head -n 1))
|
||||
echo "WHEEL_NAME=$wheel_name" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
run: |
|
||||
auditwheel show flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}
|
||||
auditwheel repair flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }} -w flash-attention/dist_manylinux
|
||||
wheel_name=$(basename $(ls flash-attention/dist_manylinux/*.whl | head -n 1))
|
||||
echo "WHEEL_NAME_MANYLINUX=$wheel_name" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Test
|
||||
run: |
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1
|
||||
pip install --no-cache-dir flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1
|
||||
pip install --no-cache-dir flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
|
||||
- name: Upload Release Asset
|
||||
if: ${{ inputs.is-upload }}
|
||||
env:
|
||||
@@ -99,18 +88,41 @@ jobs:
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}"
|
||||
wheel_path_manylinux="flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}"
|
||||
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$wheel_path" ]; then
|
||||
echo "Error: Wheel file not found at $wheel_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
continue-on-error: true
|
||||
run: |
|
||||
pip install auditwheel
|
||||
auditwheel show flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}
|
||||
auditwheel repair flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }} -w flash-attention/dist_manylinux
|
||||
wheel_name=$(basename $(ls flash-attention/dist_manylinux/*.whl | head -n 1))
|
||||
echo "WHEEL_NAME_MANYLINUX=$wheel_name" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
if: ${{ inputs.is-upload }}
|
||||
continue-on-error: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1
|
||||
pip install --no-cache-dir flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
|
||||
wheel_path_manylinux="flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}"
|
||||
if [ ! -f "$wheel_path_manylinux" ]; then
|
||||
echo "Error: Wheel file not found at $wheel_path_manylinux"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
gh release upload "$tag_name" "$wheel_path_manylinux" --clobber
|
||||
|
||||
@@ -230,7 +230,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt install -y ninja-build clang
|
||||
pip install -U pip setuptools==75.8.0 wheel packaging psutil auditwheel
|
||||
pip install -U pip setuptools==75.8.0 wheel packaging psutil
|
||||
|
||||
- name: Build wheels
|
||||
timeout-minutes: 2160
|
||||
@@ -242,13 +242,6 @@ jobs:
|
||||
wheel_name=$(basename $(ls flash-attention/dist/*.whl | head -n 1))
|
||||
echo "WHEEL_NAME=$wheel_name" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
run: |
|
||||
auditwheel show flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}
|
||||
auditwheel repair flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }} -w flash-attention/dist_manylinux
|
||||
wheel_name=$(basename $(ls flash-attention/dist_manylinux/*.whl | head -n 1))
|
||||
echo "WHEEL_NAME_MANYLINUX=$wheel_name" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Install Test
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -256,10 +249,6 @@ jobs:
|
||||
pip install --no-cache-dir flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1
|
||||
pip install --no-cache-dir flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
|
||||
- name: Upload Release Asset
|
||||
if: ${{ inputs.is-upload }}
|
||||
shell: bash
|
||||
@@ -268,20 +257,42 @@ jobs:
|
||||
run: |
|
||||
tag_name=${{ github.ref_name }}
|
||||
wheel_path="flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}"
|
||||
wheel_path_manylinux="flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}"
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$wheel_path" ]; then
|
||||
echo "Error: Wheel file not found at $wheel_path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
continue-on-error: true
|
||||
run: |
|
||||
pip install auditwheel
|
||||
auditwheel show flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }}
|
||||
auditwheel repair flash-attention/dist/${{ steps.build_wheels.outputs.WHEEL_NAME }} -w flash-attention/dist_manylinux
|
||||
wheel_name=$(basename $(ls flash-attention/dist_manylinux/*.whl | head -n 1))
|
||||
echo "WHEEL_NAME_MANYLINUX=$wheel_name" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Apply auditwheel repair
|
||||
if: ${{ inputs.is-upload }}
|
||||
continue-on-error: true
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
pip uninstall -y flash-attn > /dev/null 2>&1
|
||||
pip install --no-cache-dir flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}
|
||||
python -c "import flash_attn; print(flash_attn.__version__)"
|
||||
|
||||
wheel_path_manylinux="flash-attention/dist_manylinux/${{ steps.build_wheels.outputs.WHEEL_NAME_MANYLINUX }}"
|
||||
if [ ! -f "$wheel_path_manylinux" ]; then
|
||||
echo "Error: Wheel file not found at $wheel_path_manylinux"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Upload the release asset using GitHub CLI
|
||||
gh release upload "$tag_name" "$wheel_path" --clobber
|
||||
gh release upload "$tag_name" "$wheel_path_manylinux" --clobber
|
||||
|
||||
- name: Clean up
|
||||
|
||||
@@ -8,6 +8,7 @@ def parse_wheel_filename(filename: str) -> dict | None:
|
||||
flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
|
||||
flash_attn-2.8.3+cu128torch2.9-cp313-cp313-manylinux_2_34_x86_64.whl
|
||||
|
||||
---
|
||||
Wheel filename から情報を抽出
|
||||
@@ -43,9 +44,26 @@ def normalize_platform_name(raw: str) -> str:
|
||||
Examples:
|
||||
linux -> Linux
|
||||
linux_x86_64 -> Linux x86_64
|
||||
manylinux_2_34_x86_64 -> Manylinux 2_34 x86_64
|
||||
manylinux_2_17_aarch64 -> Manylinux 2_17 arm64
|
||||
win32 -> Windows
|
||||
amd64 -> x86_64
|
||||
"""
|
||||
# Handle manylinux format: manylinux_X_Y_ARCH -> Manylinux X_Y ARCH
|
||||
if raw.startswith("manylinux"):
|
||||
# Extract parts from manylinux_X_Y_ARCH format
|
||||
# Examples: manylinux_2_34_x86_64, manylinux_2_17_aarch64
|
||||
parts = raw.split("_")
|
||||
if len(parts) >= 4:
|
||||
# parts[0] = 'manylinux', parts[1] = X, parts[2] = Y, parts[3:] = ARCH parts
|
||||
# ARCH can contain underscores (e.g., x86_64)
|
||||
version = f"{parts[1]}_{parts[2]}"
|
||||
arch = "_".join(parts[3:]) # Join remaining parts for arch (e.g., x86_64)
|
||||
# Apply architecture normalization
|
||||
if arch == "aarch64":
|
||||
arch = "arm64"
|
||||
return f"Manylinux {version} {arch}"
|
||||
|
||||
name = raw[:1].upper() + raw[1:] # linux -> Linux
|
||||
name = name.replace("_", " ", 1) # linux_x86_64 -> Linux x86_64
|
||||
if "Win" in name:
|
||||
|
||||
+18
-6
@@ -89,10 +89,22 @@ LINUX_ARM64_MATRIX = {
|
||||
}
|
||||
|
||||
LINUX_SELF_HOSTED_MATRIX = {
|
||||
"flash-attn-version": ["2.7.4"],
|
||||
"python-version": ["3.10", "3.11", "3.12", "3.13"],
|
||||
"flash-attn-version": [
|
||||
"2.7.4",
|
||||
"2.8.3",
|
||||
],
|
||||
"python-version": [
|
||||
"3.10",
|
||||
"3.11",
|
||||
"3.12",
|
||||
"3.13",
|
||||
"3.14",
|
||||
],
|
||||
"torch-version": ["2.9.1"],
|
||||
"cuda-version": ["12.8", "13.0"],
|
||||
"cuda-version": [
|
||||
"12.8",
|
||||
"13.0",
|
||||
],
|
||||
}
|
||||
|
||||
LINUX_ARM64_SELF_HOSTED_MATRIX = {
|
||||
@@ -169,11 +181,11 @@ def main():
|
||||
"linux_arm64": False,
|
||||
# "linux_arm64": LINUX_ARM64_MATRIX,
|
||||
#
|
||||
"linux_self_hosted": False,
|
||||
# "linux_self_hosted": LINUX_SELF_HOSTED_MATRIX,
|
||||
# "linux_self_hosted": False,
|
||||
"linux_self_hosted": LINUX_SELF_HOSTED_MATRIX,
|
||||
#
|
||||
# "linux_arm64_self_hosted": False,
|
||||
"linux_arm64_self_hosted": LINUX_ARM64_SELF_HOSTED_MATRIX,
|
||||
# "linux_arm64_self_hosted": LINUX_ARM64_SELF_HOSTED_MATRIX,
|
||||
#
|
||||
"windows": False,
|
||||
# "windows": WINDOWS_MATRIX,
|
||||
|
||||
Reference in New Issue
Block a user