mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:27:54 -04:00
Merge pull request #53 from mjun0812/feat/dynamic-matrix
Feat/dynamic matrix
This commit is contained in:
+105
-210
@@ -19,71 +19,34 @@ jobs:
|
||||
--title "${{ github.ref_name }}" \
|
||||
--notes "TBD"
|
||||
|
||||
create_matrix:
|
||||
name: Create Matrix
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
matrix: ${{ steps.create_matrix.outputs.matrix }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Create Matrix
|
||||
id: create_matrix
|
||||
run: |
|
||||
python create_matrix.py | tee /tmp/matrix.json
|
||||
echo "matrix=$(cat /tmp/matrix.json)" >> $GITHUB_OUTPUT
|
||||
|
||||
# #########################################################
|
||||
# Linux
|
||||
# #########################################################
|
||||
build_wheels_linux:
|
||||
name: Build Linux
|
||||
needs: create_releases
|
||||
needs: [create_releases, create_matrix]
|
||||
if: ${{ fromjson(needs.create_matrix.outputs.matrix).linux }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flash-attn-version:
|
||||
- "2.6.3"
|
||||
- "2.7.4.post1"
|
||||
- "2.8.3"
|
||||
python-version:
|
||||
# - "3.9"
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
- "3.13"
|
||||
torch-version:
|
||||
# - "2.5.1"
|
||||
# - "2.6.0"
|
||||
# - "2.7.1"
|
||||
# - "2.8.0"
|
||||
- "2.9.0"
|
||||
# https://developer.nvidia.com/cuda-toolkit-archive
|
||||
cuda-version:
|
||||
# - "12.4.1"
|
||||
# - "12.6.3"
|
||||
# - "12.8.1"
|
||||
# - "12.9.1"
|
||||
- "13.0.1"
|
||||
exclude:
|
||||
# torch < 2.2 does not support Python 3.12
|
||||
- python-version: "3.12"
|
||||
torch-version: "2.0.1"
|
||||
- python-version: "3.12"
|
||||
torch-version: "2.1.2"
|
||||
# torch 2.0.1 does not support CUDA 12.x
|
||||
- torch-version: "2.0.1"
|
||||
cuda-version: "12.1.1"
|
||||
- torch-version: "2.0.1"
|
||||
cuda-version: "12.4.1"
|
||||
- torch-version: "2.0.1"
|
||||
cuda-version: "12.6.3"
|
||||
- torch-version: "2.0.1"
|
||||
cuda-version: "12.8.1"
|
||||
# torch 2.6.0 does not support CUDA 12.1
|
||||
- torch-version: "2.6.0"
|
||||
cuda-version: "12.1.1"
|
||||
# torch 2.7.0 does not support CUDA 12.4
|
||||
- torch-version: "2.7.0"
|
||||
cuda-version: "12.4.1"
|
||||
# torch < 2.8 does not support CUDA 12.9
|
||||
- torch-version: "2.5.1"
|
||||
cuda-version: "12.9.1"
|
||||
- torch-version: "2.6.3"
|
||||
cuda-version: "12.9.1"
|
||||
- torch-version: "2.7.1"
|
||||
cuda-version: "12.9.1"
|
||||
# flash-attn 2.7.4 does not build in GitHub Hosted Runner
|
||||
- flash-attn-version: "2.7.4"
|
||||
# torch >= 2.9 does not support Python 3.9
|
||||
- torch-version: "2.9.0"
|
||||
python-version: "3.9"
|
||||
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.flash-attn-version }}
|
||||
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.python-version }}
|
||||
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.torch-version }}
|
||||
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.cuda-version }}
|
||||
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
|
||||
uses: ./.github/workflows/build_linux.yml
|
||||
with:
|
||||
flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
@@ -92,159 +55,76 @@ jobs:
|
||||
cuda-version: ${{ matrix.cuda-version }}
|
||||
secrets: inherit
|
||||
|
||||
# build_wheels_linux_self_hosted:
|
||||
# name: Build Linux (self-hosted)
|
||||
# needs: create_releases
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# flash-attn-version: ["2.8.3"]
|
||||
# python-version: ["3.10", "3.11", "3.12"]
|
||||
# torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0"]
|
||||
# # https://developer.nvidia.com/cuda-toolkit-archive
|
||||
# cuda-version: ["12.4.1", "12.8.1", "12.9.1"]
|
||||
# exclude:
|
||||
# # torch < 2.2 does not support Python 3.12
|
||||
# - python-version: "3.12"
|
||||
# torch-version: "2.0.1"
|
||||
# - python-version: "3.12"
|
||||
# torch-version: "2.1.2"
|
||||
# # torch 2.0.1 does not support CUDA 12.x
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.1.1"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.4.1"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.6.3"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.8.1"
|
||||
# # torch 2.6.0 does not support CUDA 12.1
|
||||
# - torch-version: "2.6.0"
|
||||
# cuda-version: "12.1.1"
|
||||
# # torch 2.7.0 does not support CUDA 12.4
|
||||
# - torch-version: "2.7.0"
|
||||
# cuda-version: "12.4.1"
|
||||
# # torch < 2.8 does not support CUDA 12.9
|
||||
# - torch-version: "2.5.1"
|
||||
# cuda-version: "12.9.1"
|
||||
# - torch-version: "2.6.3"
|
||||
# cuda-version: "12.9.1"
|
||||
# - torch-version: "2.7.1"
|
||||
# cuda-version: "12.9.1"
|
||||
# uses: ./.github/workflows/build_linux_self_host.yml
|
||||
# with:
|
||||
# flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# torch-version: ${{ matrix.torch-version }}
|
||||
# cuda-version: ${{ matrix.cuda-version }}
|
||||
# secrets: inherit
|
||||
build_wheels_linux_self_hosted:
|
||||
name: Build Linux (self-hosted)
|
||||
needs: [create_releases, create_matrix]
|
||||
if: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.flash-attn-version }}
|
||||
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.python-version }}
|
||||
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.torch-version }}
|
||||
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.cuda-version }}
|
||||
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
|
||||
uses: ./.github/workflows/build_linux_self_host.yml
|
||||
with:
|
||||
flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
torch-version: ${{ matrix.torch-version }}
|
||||
cuda-version: ${{ matrix.cuda-version }}
|
||||
secrets: inherit
|
||||
|
||||
# #########################################################
|
||||
# Windows
|
||||
# #########################################################
|
||||
# build_wheels_windows:
|
||||
# name: Build Windows
|
||||
# needs: create_releases
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# flash-attn-version: ["2.8.3"]
|
||||
# python-version: ["3.11", "3.12", "3.13"]
|
||||
# torch-version: ["2.9.0.dev20250909"]
|
||||
# # https://developer.nvidia.com/cuda-toolkit-archive
|
||||
# cuda-version: ["12.6.3", "12.8.1"]
|
||||
# exclude:
|
||||
# # torch < 2.2 does not support Python 3.12
|
||||
# - python-version: "3.12"
|
||||
# torch-version: "2.0.1"
|
||||
# - python-version: "3.12"
|
||||
# torch-version: "2.1.2"
|
||||
# # torch 2.0.1 does not support CUDA 12.x
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.1.1"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.4.1"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.6.3"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.8.1"
|
||||
# # torch 2.6.0 does not support CUDA 12.1
|
||||
# - torch-version: "2.6.0"
|
||||
# cuda-version: "12.1.1"
|
||||
# # torch 2.7.0 does not support CUDA 12.4
|
||||
# - torch-version: "2.7.0"
|
||||
# cuda-version: "12.4.1"
|
||||
# # torch < 2.8 does not support CUDA 12.9
|
||||
# - torch-version: "2.5.1"
|
||||
# cuda-version: "12.9.1"
|
||||
# - torch-version: "2.6.3"
|
||||
# cuda-version: "12.9.1"
|
||||
# - torch-version: "2.7.1"
|
||||
# cuda-version: "12.9.1"
|
||||
# uses: ./.github/workflows/build_windows.yml
|
||||
# with:
|
||||
# flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# torch-version: ${{ matrix.torch-version }}
|
||||
# cuda-version: ${{ matrix.cuda-version }}
|
||||
# secrets: inherit
|
||||
build_wheels_windows:
|
||||
name: Build Windows
|
||||
needs: [create_releases, create_matrix]
|
||||
if: ${{ fromjson(needs.create_matrix.outputs.matrix).windows }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.flash-attn-version }}
|
||||
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.python-version }}
|
||||
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.torch-version }}
|
||||
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.cuda-version }}
|
||||
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
|
||||
uses: ./.github/workflows/build_windows.yml
|
||||
with:
|
||||
flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
torch-version: ${{ matrix.torch-version }}
|
||||
cuda-version: ${{ matrix.cuda-version }}
|
||||
secrets: inherit
|
||||
|
||||
# build_wheels_windows_code_build:
|
||||
# name: Build Windows (AWS CodeBuild)
|
||||
# needs: create_releases
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# flash-attn-version: ["2.7.4", "2.8.2"]
|
||||
# python-version: ["3.10", "3.11", "3.12"]
|
||||
# torch-version: ["2.7.1", "2.8.0"]
|
||||
# # https://developer.nvidia.com/cuda-toolkit-archive
|
||||
# cuda-version: ["12.8.1"]
|
||||
# exclude:
|
||||
# # torch < 2.2 does not support Python 3.12
|
||||
# - python-version: "3.12"
|
||||
# torch-version: "2.0.1"
|
||||
# - python-version: "3.12"
|
||||
# torch-version: "2.1.2"
|
||||
# # torch 2.0.1 does not support CUDA 12.x
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.1.1"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.4.1"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.6.3"
|
||||
# - torch-version: "2.0.1"
|
||||
# cuda-version: "12.8.1"
|
||||
# # torch 2.6.0 does not support CUDA 12.1
|
||||
# - torch-version: "2.6.0"
|
||||
# cuda-version: "12.1.1"
|
||||
# # torch 2.7.0 does not support CUDA 12.4
|
||||
# - torch-version: "2.7.0"
|
||||
# cuda-version: "12.4.1"
|
||||
# # torch < 2.8 does not support CUDA 12.9
|
||||
# - torch-version: "2.5.1"
|
||||
# cuda-version: "12.9.1"
|
||||
# - torch-version: "2.6.3"
|
||||
# cuda-version: "12.9.1"
|
||||
# - torch-version: "2.7.1"
|
||||
# cuda-version: "12.9.1"
|
||||
# uses: ./.github/workflows/build_windows_code_build.yml
|
||||
# with:
|
||||
# flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
# torch-version: ${{ matrix.torch-version }}
|
||||
# cuda-version: ${{ matrix.cuda-version }}
|
||||
# secrets: inherit
|
||||
build_wheels_windows_code_build:
|
||||
name: Build Windows (AWS CodeBuild)
|
||||
needs: [create_releases, create_matrix]
|
||||
if: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.flash-attn-version }}
|
||||
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.python-version }}
|
||||
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.torch-version }}
|
||||
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.cuda-version }}
|
||||
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
|
||||
uses: ./.github/workflows/build_windows_code_build.yml
|
||||
with:
|
||||
flash-attn-version: ${{ matrix.flash-attn-version }}
|
||||
python-version: ${{ matrix.python-version }}
|
||||
torch-version: ${{ matrix.torch-version }}
|
||||
cuda-version: ${{ matrix.cuda-version }}
|
||||
secrets: inherit
|
||||
|
||||
update_release_notes:
|
||||
name: Update Release Notes
|
||||
needs:
|
||||
- build_wheels_linux
|
||||
# - build_wheels_linux_self_hosted
|
||||
# - build_wheels_windows
|
||||
# - build_wheels_windows_code_build
|
||||
permissions:
|
||||
contents: write
|
||||
- build_wheels_linux_self_hosted
|
||||
- build_wheels_windows
|
||||
- build_wheels_windows_code_build
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -265,23 +145,38 @@ jobs:
|
||||
python create_release_note.py /tmp/assets.json > /tmp/release_notes.md
|
||||
gh release edit "${{ github.ref_name }}" --notes-file /tmp/release_notes.md
|
||||
|
||||
- name: Update README history and packages
|
||||
update_docs:
|
||||
name: Update Docs
|
||||
needs:
|
||||
- build_wheels_linux
|
||||
- build_wheels_linux_self_hosted
|
||||
- build_wheels_windows
|
||||
- build_wheels_windows_code_build
|
||||
permissions:
|
||||
contents: write
|
||||
if: always()
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Update release history and packages section in README.md
|
||||
run: |
|
||||
cat /tmp/release_notes.md | python insert_history.py \
|
||||
--notes - \
|
||||
gh release view "${{ github.ref_name }}" --json assets > /tmp/assets.json
|
||||
python create_release_history.py \
|
||||
--assets /tmp/assets.json \
|
||||
--tag "${{ github.ref_name }}" \
|
||||
--repo "${{ github.repository }}"
|
||||
python generate_packages_table.py --update-readme
|
||||
--repo "${{ github.repository }}" \
|
||||
--output docs/release_history.md
|
||||
python insert_packages_to_readme.py --assets /tmp/assets.json --update
|
||||
|
||||
- name: Commit and push README updates
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
- name: Commit and push docs updates
|
||||
run: |
|
||||
git config --global user.name "github-actions[bot]"
|
||||
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
|
||||
if git diff --quiet; then
|
||||
echo "No README updates to commit."
|
||||
echo "No docs updates to commit."
|
||||
exit 0
|
||||
fi
|
||||
git commit -am "docs: update README for ${{ github.ref_name }}"
|
||||
git commit -am "docs: update docs for ${{ github.ref_name }}"
|
||||
git push origin HEAD:"${DEFAULT_BRANCH}"
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
.DS_Store
|
||||
.env
|
||||
__pycache__/
|
||||
.ruff_cache/
|
||||
@@ -0,0 +1,57 @@
|
||||
import re
|
||||
|
||||
|
||||
def parse_wheel_filename(filename: str) -> dict | None:
|
||||
"""
|
||||
Extract information from a wheel filename.
|
||||
Examples:
|
||||
flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
---
|
||||
Wheel filename から情報を抽出
|
||||
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
|
||||
"""
|
||||
# Flash Attention wheelのパターンに合わせて正規表現を調整
|
||||
# PyTorchバージョンはマイナーバージョン1桁の形式も対応 (例: torch2.9)
|
||||
# post1 のようなバージョンサフィックスにも対応 (例: 2.7.4.post1)
|
||||
pattern = (
|
||||
r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
|
||||
)
|
||||
match = re.match(pattern, filename)
|
||||
|
||||
if match:
|
||||
flash_version = match.group(1)
|
||||
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 130 -> 13.0
|
||||
torch_version = match.group(3)
|
||||
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 310 -> 3.10
|
||||
platform = match.group(5) # linux_x86_64, win32など
|
||||
|
||||
return {
|
||||
"flash_version": flash_version,
|
||||
"cuda_version": cuda_version,
|
||||
"torch_version": torch_version,
|
||||
"python_version": python_version,
|
||||
"platform": platform,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def normalize_platform_name(raw: str) -> str:
|
||||
"""Platform name normalization
|
||||
Examples:
|
||||
linux -> Linux
|
||||
linux_x86_64 -> Linux x86_64
|
||||
win32 -> Windows
|
||||
amd64 -> x86_64
|
||||
"""
|
||||
name = raw[:1].upper() + raw[1:] # linux -> Linux
|
||||
name = name.replace("_", " ", 1) # linux_x86_64 -> Linux x86_64
|
||||
if "Win" in name:
|
||||
name = name.replace("Win", "Windows")
|
||||
if "amd64" in name:
|
||||
name = name.replace("amd64", "x86_64")
|
||||
return name
|
||||
@@ -1,52 +1,17 @@
|
||||
"""Update the History section in README.md from release notes or assets."""
|
||||
"""Update the History section in README.md from assets."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
WHEEL_PATTERN = re.compile(
|
||||
r"flash_attn-(\d+\.\d+\.\d+)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
|
||||
)
|
||||
from common import normalize_platform_name, parse_wheel_filename
|
||||
|
||||
|
||||
def parse_wheel_filename(filename: str) -> Dict[str, str] | None:
|
||||
match = WHEEL_PATTERN.match(filename)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
flash_version = match.group(1)
|
||||
cuda_digits = match.group(2)
|
||||
torch_version = match.group(3)
|
||||
python_digits = match.group(4)
|
||||
platform = match.group(5)
|
||||
|
||||
cuda_version = f"{cuda_digits[:2]}.{cuda_digits[2:]}"
|
||||
python_version = f"{python_digits[:1]}.{python_digits[1:]}"
|
||||
|
||||
return {
|
||||
"flash_version": flash_version,
|
||||
"cuda_version": cuda_version,
|
||||
"torch_version": torch_version,
|
||||
"python_version": python_version,
|
||||
"platform": platform,
|
||||
}
|
||||
|
||||
|
||||
def normalize_platform_name(raw: str) -> str:
|
||||
name = raw[:1].upper() + raw[1:]
|
||||
name = name.replace("_", " ", 1)
|
||||
if "Win" in name:
|
||||
name = name.replace("Win", "Windows")
|
||||
if "amd64" in name:
|
||||
name = name.replace("amd64", "x86_64")
|
||||
return name
|
||||
|
||||
|
||||
def collect_versions(assets: Iterable[Dict[str, str]]) -> Dict[str, Dict[str, set[str]]]:
|
||||
def collect_versions(
|
||||
assets: Iterable[Dict[str, str]],
|
||||
) -> Dict[str, Dict[str, set[str]]]:
|
||||
aggregated: Dict[str, Dict[str, set[str]]] = {}
|
||||
for asset in assets:
|
||||
info = parse_wheel_filename(asset.get("name", ""))
|
||||
@@ -109,11 +74,6 @@ def render_body_from_aggregated(aggregated: Dict[str, Dict[str, set[str]]]) -> s
|
||||
return "\n".join(body_lines).strip()
|
||||
|
||||
|
||||
def convert_release_notes_to_body(notes_text: str) -> str:
|
||||
converted = re.sub(r"^## ", "#### ", notes_text, flags=re.MULTILINE)
|
||||
return converted.strip()
|
||||
|
||||
|
||||
def build_history_section(tag: str, repo: str, body: str) -> str:
|
||||
release_url = f"https://github.com/{repo}/releases/tag/{tag}"
|
||||
lines = [f"### {tag}", "", f"[Release]({release_url})", "", body.strip()]
|
||||
@@ -121,7 +81,9 @@ def build_history_section(tag: str, repo: str, body: str) -> str:
|
||||
|
||||
|
||||
def remove_existing_section(content: str, tag: str) -> str:
|
||||
pattern = re.compile(rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL)
|
||||
pattern = re.compile(
|
||||
rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
return re.sub(pattern, "", content)
|
||||
|
||||
|
||||
@@ -137,47 +99,22 @@ def insert_history_section(content: str, section: str) -> str:
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Update README.md History section")
|
||||
parser.add_argument("--assets", type=Path, help="JSON file from gh release view")
|
||||
parser.add_argument(
|
||||
"--notes",
|
||||
help="Release notes markdown file path or '-' to read from stdin",
|
||||
"--assets", type=Path, required=True, help="JSON file from gh release view"
|
||||
)
|
||||
parser.add_argument("--tag", required=True, help="Release tag name")
|
||||
parser.add_argument("--repo", required=True, help="Repository in owner/name format")
|
||||
parser.add_argument(
|
||||
"--readme",
|
||||
type=Path,
|
||||
default=Path("README.md"),
|
||||
help="Path to README.md",
|
||||
)
|
||||
parser.add_argument("--output", type=Path, required=True, help="Output file path")
|
||||
args = parser.parse_args()
|
||||
|
||||
history_body: str
|
||||
|
||||
if args.notes:
|
||||
if args.notes == "-":
|
||||
notes_text = sys.stdin.read()
|
||||
else:
|
||||
notes_path = Path(args.notes)
|
||||
if not notes_path.exists():
|
||||
raise FileNotFoundError(f"Notes file not found: {notes_path}")
|
||||
notes_text = notes_path.read_text(encoding="utf-8")
|
||||
|
||||
history_body = convert_release_notes_to_body(notes_text)
|
||||
else:
|
||||
if not args.assets:
|
||||
raise ValueError("Either --notes or --assets must be provided")
|
||||
if not args.assets.exists():
|
||||
raise FileNotFoundError(f"Assets file not found: {args.assets}")
|
||||
|
||||
data = json.loads(args.assets.read_text(encoding="utf-8"))
|
||||
assets = data.get("assets", [])
|
||||
aggregated = collect_versions(assets)
|
||||
history_body = render_body_from_aggregated(aggregated)
|
||||
data = json.loads(args.assets.read_text(encoding="utf-8"))
|
||||
assets = data.get("assets", [])
|
||||
aggregated = collect_versions(assets)
|
||||
history_body = render_body_from_aggregated(aggregated)
|
||||
|
||||
section = build_history_section(args.tag, args.repo, history_body)
|
||||
|
||||
content = args.readme.read_text(encoding="utf-8")
|
||||
content = args.output.read_text(encoding="utf-8")
|
||||
stripped = remove_existing_section(content, args.tag)
|
||||
updated = insert_history_section(stripped, section)
|
||||
|
||||
@@ -185,13 +122,9 @@ def main() -> None:
|
||||
print("No changes in README.md")
|
||||
return
|
||||
|
||||
args.readme.write_text(updated, encoding="utf-8")
|
||||
print(f"Inserted history for {args.tag} into {args.readme}")
|
||||
args.output.write_text(updated, encoding="utf-8")
|
||||
print(f"Inserted history for {args.tag} into README.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
main()
|
||||
+2
-37
@@ -1,37 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_wheel_filename(filename):
|
||||
"""
|
||||
Wheel filename から情報を抽出
|
||||
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
||||
"""
|
||||
# Flash Attention wheelのパターンに合わせて正規表現を調整
|
||||
# PyTorchバージョンはパッチバージョンなしの形式 (例: torch2.5)
|
||||
pattern = (
|
||||
r"flash_attn-(\d+\.\d+\.\d+)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
|
||||
)
|
||||
match = re.match(pattern, filename)
|
||||
|
||||
if match:
|
||||
flash_version = match.group(1)
|
||||
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 124 -> 12.4
|
||||
torch_version = match.group(3)
|
||||
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 311 -> 3.11
|
||||
platform = match.group(5) # linux, win32など
|
||||
|
||||
return {
|
||||
"flash_version": flash_version,
|
||||
"cuda_version": cuda_version,
|
||||
"torch_version": torch_version,
|
||||
"python_version": python_version,
|
||||
"platform": platform,
|
||||
}
|
||||
return None
|
||||
from common import normalize_platform_name, parse_wheel_filename
|
||||
|
||||
|
||||
def generate_release_notes_from_assets(assets_info: dict):
|
||||
@@ -74,13 +45,7 @@ def generate_release_notes_from_assets(assets_info: dict):
|
||||
if any(len(data[key]) == 0 for key in data):
|
||||
continue
|
||||
|
||||
platform_name = platform_name[:1].upper() + platform_name[1:]
|
||||
platform_name = platform_name.replace("_", " ", 1)
|
||||
|
||||
if "Win" in platform_name:
|
||||
platform_name = platform_name.replace("Win", "Windows")
|
||||
if "amd64" in platform_name:
|
||||
platform_name = platform_name.replace("amd64", "x86_64")
|
||||
platform_name = normalize_platform_name(platform_name)
|
||||
|
||||
notes.append(f"## {platform_name}")
|
||||
notes.append("")
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
## History
|
||||
|
||||
### v0.4.18
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.18)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.6.3, 2.8.3 | 3.10, 3.11, 3.12, 3.13 | 2.9 | 13.0 |
|
||||
|
||||
### v0.4.17
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.17)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.6.3, 2.8.3 | 3.10, 3.11, 3.12, 3.13 | 2.9 | 12.6, 12.8 |
|
||||
|
||||
### v0.4.16
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.16)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.6.3, 2.8.3 | 3.9 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.6 |
|
||||
|
||||
### v0.4.15
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.15)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.8.3 | 3.11, 3.12, 3.13 | 2.9 | 12.6, 12.8 |
|
||||
|
||||
#### Windows x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.8.3 | 3.11, 3.12, 3.13 | 2.9 | 12.6 |
|
||||
|
||||
### v0.4.12
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.12)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.8.3 | 3.13 | 2.6, 2.7, 2.8 | 12.4, 12.6, 12.8, 12.9 |
|
||||
|
||||
#### Windows x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.8.2 | 3.13 | 2.6, 2.7, 2.8 | 12.4, 12.6 |
|
||||
|
||||
### v0.4.11
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.11)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.8.3 | 3.10, 3.11, 3.12 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.6, 12.8, 12.9 |
|
||||
|
||||
### v0.4.10
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.10)
|
||||
|
||||
#### Windows x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.7.4, 2.8.2 | 3.10, 3.11, 3.12 | 2.7, 2.8 | 12.8 |
|
||||
|
||||
### v0.4.9
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.9)
|
||||
|
||||
#### Windows x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.7.4 | 3.11 | 2.7 | 12.8 |
|
||||
|
||||
### v0.3.18
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.18)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --- | --- | --- | --- |
|
||||
| 2.7.4 | 3.10, 3.11, 3.12 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.8, 12.9 |
|
||||
|
||||
### v0.3.14
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.14)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ---------------- | -------------------------- | ---------------------- |
|
||||
| 2.6.3, 2.8.2 | 3.10, 3.11, 3.12 | 2.5.1, 2.6.0, 2.7.1, 2.8.0 | 12.4.1, 12.8.1, 12.9.1 |
|
||||
|
||||
### v0.3.13
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.13)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ---------------- | -------------------------- | ------ |
|
||||
| 2.8.1 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0, 2.7.1 | 12.8.1 |
|
||||
|
||||
### v0.3.12
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.12)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ---------------- | -------------------------- | -------------- |
|
||||
| 2.8.0 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0, 2.7.1 | 12.4.1, 12.8.1 |
|
||||
|
||||
### v0.3.10
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.10)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ---------------- | ------- | ------ |
|
||||
| 2.7.4 | 3.10, 3.11, 3.12 | 2.7.1 | 12.8.1 |
|
||||
|
||||
### v0.3.9
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.9)
|
||||
|
||||
#### Linux x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| ------------------- | ---------------- | ------- | ------ |
|
||||
| 2.4.3, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.7.1 | 12.8.1 |
|
||||
|
||||
#### Windows x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| ------------------- | ---------------- | ------------------- | ------ |
|
||||
| 2.5.9, 2.6.3, 2.7.4 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0 | 12.4.1 |
|
||||
|
||||
> [!IMPORTANT]
|
||||
> ⚠️ Building flash-attn v2.7.4 with CUDA 12.8 on Windows cannot be completed because of GitHub Actions’ processing-time limits. In the future, I plan to add a self-hosted Windows runner to resolve this issue.
|
||||
|
||||
### v0.3.1
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.1)
|
||||
|
||||
#### Windows x86_64
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ------ | ------- | ------ |
|
||||
| 2.6.3 | 3.11 | 2.6.0 | 12.6.3 |
|
||||
|
||||
From this version, Wheels for Windows are released.
|
||||
However, we are waiting for a report on how it works because we have not tested it enough.
|
||||
|
||||
### v0.2.1
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.2.1)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| -------------------------- | ---------------- | ----------------- | ------ |
|
||||
| 2.4.3, 2.5.9, 2.6.3, 2.7.4 | 3.10, 3.11, 3.12 | 2.8.0.dev20250523 | 12.8.1 |
|
||||
|
||||
### v0.2.0
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.2.0)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| ------------------- | ---------------- | ----------------- | ------ |
|
||||
| 2.4.3, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.8.0.dev20250523 | 12.8.1 |
|
||||
|
||||
### v0.1.0
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.1.0)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| -------------------------- | ---------------- | ------- | ------ |
|
||||
| 2.4.3, 2.5.9, 2.6.3, 2.7.4 | 3.10, 3.11, 3.12 | 2.7.0 | 12.8.1 |
|
||||
|
||||
v2.7.4 and v2.7.4.post1 are the same version.
|
||||
|
||||
From this release, self-hosted runners are used for building some wheels.
|
||||
|
||||
### v0.0.9
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.9)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| ------------------- | ---------------- | ------- | ------ |
|
||||
| 2.4.3, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.7.0 | 12.8.1 |
|
||||
|
||||
### v0.0.8
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| -------------------------------- | ---------------- | -------------------------- | ---------------------- |
|
||||
| 2.4.3, 2.5.9, 2.6.3, 2.7.4.post1 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0, 2.7.0 | 11.8.0, 12.4.1, 12.6.3 |
|
||||
|
||||
### v0.0.7
|
||||
|
||||
Skip for experimental reasons.
|
||||
|
||||
### v0.0.6
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.6)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| -------------------------------- | ---------------- | --------------------------------- | -------------- |
|
||||
| 2.4.3, 2.5.9, 2.6.3, 2.7.4.post1 | 3.10, 3.11, 3.12 | 2.2.2, 2.3.1, 2.4.1, 2.5.1, 2.6.0 | 12.4.1, 12.6.3 |
|
||||
|
||||
### v0.0.5
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.5)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| ------------------ | ---------------- | ----------------------------------------------- | -------------- |
|
||||
| 2.6.3, 2.7.4.post1 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1, 2.6.0 | 12.4.1, 12.6.3 |
|
||||
|
||||
### v0.0.4
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.4)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ---------------- | ---------------------------------------- | ---------------------- |
|
||||
| 2.7.3 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 |
|
||||
|
||||
### v0.0.3
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.3)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| --------------- | ---------------- | ---------------------------------------- | ---------------------- |
|
||||
| 2.7.2.post1 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 |
|
||||
|
||||
### v0.0.2
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.2)
|
||||
|
||||
| Flash-Attention | Python | PyTorch | CUDA |
|
||||
| -------------------------------- | ---------------- | ---------------------------------------- | ---------------------- |
|
||||
| 2.4.3, 2.5.6, 2.6.3, 2.7.0.post2 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 |
|
||||
|
||||
### v0.0.1
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.1)
|
||||
|
||||
| flash-attention | Python | PyTorch | CUDA |
|
||||
| --------------------------------- | ---------------- | ---------------------------------------- | ---------------------- |
|
||||
| 1.0.9, 2.4.3, 2.5.6, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.0 | 11.8.0, 12.1.1, 12.4.1 |
|
||||
|
||||
### v0.0.0
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.0)
|
||||
|
||||
| flash-attention | Python | PyTorch | CUDA |
|
||||
| -------------------------- | ---------- | ---------------------------------------- | ---------------------- |
|
||||
| 2.4.3, 2.5.6, 2.5.9, 2.6.3 | 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.0 | 11.8.0, 12.1.1, 12.4.1 |
|
||||
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Fetch all assets from all GitHub releases and save to assets.json
|
||||
|
||||
Usage:
|
||||
python fetch_all_assets.py
|
||||
python fetch_all_assets.py --output all_assets.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def get_github_token():
|
||||
"""Get GitHub token from environment variable."""
|
||||
token = os.environ.get("GITHUB_TOKEN")
|
||||
if not token:
|
||||
print(
|
||||
"Warning: GITHUB_TOKEN not set. API rate limit will be restricted.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def fetch_all_releases(repo: str, token: str | None = None) -> list[dict]:
|
||||
"""Fetch all releases from a GitHub repository."""
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"token {token}"
|
||||
headers["Accept"] = "application/vnd.github.v3+json"
|
||||
|
||||
all_releases = []
|
||||
page = 1
|
||||
per_page = 100
|
||||
|
||||
while True:
|
||||
url = f"https://api.github.com/repos/{repo}/releases"
|
||||
params = {"page": page, "per_page": per_page}
|
||||
|
||||
print(f"Fetching releases page {page}...", file=sys.stderr)
|
||||
response = requests.get(url, headers=headers, params=params, timeout=30)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(
|
||||
f"Error fetching releases: {response.status_code} - {response.text}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
break
|
||||
|
||||
releases = response.json()
|
||||
if not releases:
|
||||
break
|
||||
|
||||
all_releases.extend(releases)
|
||||
print(f" Found {len(releases)} releases on page {page}", file=sys.stderr)
|
||||
|
||||
# Check if there are more pages
|
||||
if len(releases) < per_page:
|
||||
break
|
||||
|
||||
page += 1
|
||||
time.sleep(0.5) # Rate limiting
|
||||
|
||||
return all_releases
|
||||
|
||||
|
||||
def extract_assets_from_releases(releases: list[dict]) -> list[dict]:
|
||||
"""Extract all wheel assets from releases."""
|
||||
all_assets = []
|
||||
|
||||
for release in releases:
|
||||
tag = release.get("tag_name", "")
|
||||
print(f"Processing release {tag}...", file=sys.stderr)
|
||||
|
||||
for asset in release.get("assets", []):
|
||||
name = asset.get("name", "")
|
||||
|
||||
# Only include .whl files
|
||||
if not name.endswith(".whl"):
|
||||
continue
|
||||
|
||||
# Extract relevant asset information
|
||||
asset_info = {
|
||||
"name": name,
|
||||
"url": asset.get("browser_download_url", ""),
|
||||
"size": asset.get("size", 0),
|
||||
"downloadCount": asset.get("download_count", 0),
|
||||
"createdAt": asset.get("created_at", ""),
|
||||
"updatedAt": asset.get("updated_at", ""),
|
||||
"id": asset.get("node_id", ""),
|
||||
"apiUrl": asset.get("url", ""),
|
||||
"contentType": asset.get("content_type", ""),
|
||||
"state": asset.get("state", ""),
|
||||
"label": asset.get("label", ""),
|
||||
}
|
||||
|
||||
all_assets.append(asset_info)
|
||||
|
||||
print(f"\nTotal assets found: {len(all_assets)}", file=sys.stderr)
|
||||
return all_assets
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fetch all assets from all GitHub releases"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo",
|
||||
type=str,
|
||||
default="mjun0812/flash-attention-prebuild-wheels",
|
||||
help="GitHub repository (default: mjun0812/flash-attention-prebuild-wheels)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="assets.json",
|
||||
help="Output file path (default: assets.json)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
token = get_github_token()
|
||||
|
||||
# Fetch all releases
|
||||
print(f"Fetching all releases from {args.repo}...", file=sys.stderr)
|
||||
releases = fetch_all_releases(args.repo, token)
|
||||
print(f"Total releases found: {len(releases)}\n", file=sys.stderr)
|
||||
|
||||
# Extract assets
|
||||
assets = extract_assets_from_releases(releases)
|
||||
|
||||
# Save to file
|
||||
output_path = Path(args.output)
|
||||
output_data = {"assets": assets}
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(output_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nSaved {len(assets)} assets to {output_path}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,398 +0,0 @@
|
||||
"""
|
||||
python generate_packages_table.py --update-readme
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def parse_numeric_version(text: str) -> tuple:
|
||||
"""Extract numeric version tuple for sorting."""
|
||||
nums = re.findall(r"\d+", text)
|
||||
return tuple(int(n) for n in nums)
|
||||
|
||||
|
||||
def extract_packages_from_history(text: str) -> list[dict]:
|
||||
"""Extract package information from History section."""
|
||||
lines = text.splitlines()
|
||||
|
||||
# Find start of History section
|
||||
in_history = False
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("## ") and "History" in line:
|
||||
in_history = True
|
||||
lines = lines[i:]
|
||||
break
|
||||
|
||||
if not in_history:
|
||||
return []
|
||||
|
||||
packages = []
|
||||
current_release_url = None
|
||||
current_os = "Linux x86_64" # default
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
# Reset on new version
|
||||
if line.startswith("### "):
|
||||
current_release_url = None
|
||||
current_os = "Linux x86_64"
|
||||
|
||||
# Capture Release link
|
||||
elif "[Release](" in line:
|
||||
match = re.search(r"\[Release\]\(([^)]+)\)", line)
|
||||
if match:
|
||||
current_release_url = match.group(1)
|
||||
|
||||
# Capture OS heading
|
||||
elif line.startswith("#### "):
|
||||
current_os = line[5:].strip() or "Linux x86_64"
|
||||
|
||||
# Process table
|
||||
elif line.startswith("| Flash-Attention") or line.startswith(
|
||||
"|Flash-Attention"
|
||||
):
|
||||
# Skip header and separator
|
||||
i += 2
|
||||
|
||||
# Process table rows
|
||||
while i < len(lines):
|
||||
row_line = lines[i].strip()
|
||||
if not row_line.startswith("|") or not row_line:
|
||||
break
|
||||
|
||||
# Parse table row
|
||||
cells = [c.strip() for c in row_line.split("|")]
|
||||
cells = [c for c in cells if c] # Remove empty cells
|
||||
|
||||
if len(cells) >= 4:
|
||||
fa_versions = [v.strip() for v in cells[0].split(",") if v.strip()]
|
||||
py_versions = [v.strip() for v in cells[1].split(",") if v.strip()]
|
||||
pt_versions = [v.strip() for v in cells[2].split(",") if v.strip()]
|
||||
cu_versions = [v.strip() for v in cells[3].split(",") if v.strip()]
|
||||
|
||||
# Generate all combinations
|
||||
for fa, py, pt, cu in itertools.product(
|
||||
fa_versions, py_versions, pt_versions, cu_versions
|
||||
):
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": fa,
|
||||
"Python": py,
|
||||
"PyTorch": pt,
|
||||
"CUDA": cu,
|
||||
"OS": current_os,
|
||||
"package": current_release_url,
|
||||
}
|
||||
)
|
||||
|
||||
i += 1
|
||||
continue
|
||||
|
||||
i += 1
|
||||
|
||||
return packages
|
||||
|
||||
|
||||
def sort_packages(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Sort packages with custom priority."""
|
||||
|
||||
# Add sorting keys
|
||||
# Flash-Attention: descending order (newer versions first)
|
||||
df["fa_sort"] = df["Flash-Attention"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
df["os_sort"] = df["OS"].str.lower()
|
||||
# Python, PyTorch, CUDA: descending order (newer versions first)
|
||||
df["py_sort"] = df["Python"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
df["pt_sort"] = df["PyTorch"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
df["cu_sort"] = df["CUDA"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
|
||||
# Package sort: extract version from URL, newer first
|
||||
def package_sort_key(url):
|
||||
if pd.isna(url) or not url:
|
||||
return (1, tuple()) # No URL comes last
|
||||
|
||||
tag_match = re.search(r"/tag/([^/]+)$", str(url))
|
||||
if not tag_match:
|
||||
return (1, tuple())
|
||||
|
||||
tag = tag_match.group(1)
|
||||
version_tuple = parse_numeric_version(tag)
|
||||
return (0, tuple(-v for v in version_tuple)) # Negate for descending
|
||||
|
||||
df["pkg_sort"] = df["package"].apply(package_sort_key)
|
||||
|
||||
# Sort by priority: Flash-Attention > OS > Python > PyTorch > CUDA > package
|
||||
df_sorted = df.sort_values(
|
||||
["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
|
||||
)
|
||||
|
||||
# Drop sorting columns
|
||||
return df_sorted.drop(
|
||||
columns=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
|
||||
)
|
||||
|
||||
|
||||
def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Merge rows with duplicate Flash-Attention, Python, PyTorch, CUDA, OS values."""
|
||||
# Group by all columns except 'package'
|
||||
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
|
||||
|
||||
def combine_packages(group):
|
||||
# Get unique non-null packages
|
||||
packages = [pkg for pkg in group["package"].dropna().unique() if pkg]
|
||||
|
||||
# Take the first row as base
|
||||
result = group.iloc[0].copy()
|
||||
|
||||
# Combine packages into a list
|
||||
result["package"] = packages if packages else [None]
|
||||
|
||||
return result
|
||||
|
||||
# Group and combine
|
||||
merged_df = df.groupby(group_cols, as_index=False).apply(
|
||||
combine_packages, include_groups=False
|
||||
)
|
||||
|
||||
# Reset index to clean up
|
||||
merged_df = merged_df.reset_index(drop=True)
|
||||
|
||||
return merged_df
|
||||
|
||||
|
||||
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
||||
"""Generate markdown tables grouped by OS and Flash-Attention version."""
|
||||
if df.empty:
|
||||
return ""
|
||||
|
||||
all_sections = []
|
||||
|
||||
# Group by OS and sort each group
|
||||
for os_name in sorted(df["OS"].unique()):
|
||||
os_df = df[df["OS"] == os_name].copy()
|
||||
|
||||
# Re-sort within each OS group to ensure Flash-Attention is in descending order
|
||||
os_df["fa_sort"] = os_df["Flash-Attention"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
os_df["py_sort"] = os_df["Python"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
os_df["pt_sort"] = os_df["PyTorch"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
os_df["cu_sort"] = os_df["CUDA"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
|
||||
# Sort by Flash-Attention > Python > PyTorch > CUDA
|
||||
os_df = os_df.sort_values(["fa_sort", "py_sort", "pt_sort", "cu_sort"])
|
||||
os_df = os_df.drop(columns=["fa_sort", "py_sort", "pt_sort", "cu_sort"])
|
||||
|
||||
# Create OS section header
|
||||
os_lines = [f"### {os_name}", ""]
|
||||
|
||||
# Group by Flash-Attention version within each OS
|
||||
fa_versions = []
|
||||
for fa_version in os_df["Flash-Attention"].unique():
|
||||
fa_df = os_df[os_df["Flash-Attention"] == fa_version].copy()
|
||||
|
||||
# Re-sort by Python > PyTorch > CUDA within each Flash-Attention version
|
||||
fa_df["py_sort"] = fa_df["Python"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
fa_df["pt_sort"] = fa_df["PyTorch"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
fa_df["cu_sort"] = fa_df["CUDA"].apply(
|
||||
lambda x: tuple(-v for v in parse_numeric_version(x))
|
||||
)
|
||||
fa_df = fa_df.sort_values(["py_sort", "pt_sort", "cu_sort"])
|
||||
fa_df = fa_df.drop(columns=["py_sort", "pt_sort", "cu_sort"])
|
||||
|
||||
# Create collapsible table for this Flash-Attention version
|
||||
table_lines = [
|
||||
"| Python | PyTorch | CUDA | package |",
|
||||
"| ------ | ------- | ---- | ------- |",
|
||||
]
|
||||
|
||||
for _, row in fa_df.iterrows():
|
||||
packages = row["package"]
|
||||
|
||||
# Handle case where packages is a list
|
||||
if isinstance(packages, list):
|
||||
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
|
||||
# Create numbered release links
|
||||
package_links = []
|
||||
for i, pkg in enumerate(packages, 1):
|
||||
if pd.notna(pkg) and pkg:
|
||||
package_links.append(f"[Release{i}]({pkg})")
|
||||
package_cell = ", ".join(package_links)
|
||||
else:
|
||||
package_cell = "-"
|
||||
else:
|
||||
# Handle single package (backward compatibility)
|
||||
package_cell = (
|
||||
f"[Release]({packages})"
|
||||
if pd.notna(packages) and packages
|
||||
else "-"
|
||||
)
|
||||
|
||||
line = f"| {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {package_cell} |"
|
||||
table_lines.append(line)
|
||||
|
||||
# Create collapsible section for this Flash-Attention version
|
||||
fa_section = [
|
||||
f"#### Flash-Attention {fa_version}",
|
||||
"",
|
||||
"<details>",
|
||||
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
|
||||
"",
|
||||
"\n".join(table_lines),
|
||||
"",
|
||||
"</details>",
|
||||
"",
|
||||
]
|
||||
|
||||
fa_versions.extend(fa_section)
|
||||
|
||||
os_lines.extend(fa_versions)
|
||||
all_sections.extend(os_lines)
|
||||
|
||||
return "\n".join(all_sections)
|
||||
|
||||
|
||||
def generate_markdown_table(df: pd.DataFrame) -> str:
|
||||
"""Generate markdown table from DataFrame (legacy function for backward compatibility)."""
|
||||
lines = [
|
||||
"| Flash-Attention | Python | PyTorch | CUDA | OS | package |",
|
||||
"| --------------- | ------ | ------- | ------ | ---- | ------- |",
|
||||
]
|
||||
|
||||
for _, row in df.iterrows():
|
||||
packages = row["package"]
|
||||
|
||||
# Handle case where packages is a list
|
||||
if isinstance(packages, list):
|
||||
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
|
||||
# Create numbered release links
|
||||
package_links = []
|
||||
for i, pkg in enumerate(packages, 1):
|
||||
if pd.notna(pkg) and pkg:
|
||||
package_links.append(f"[Release{i}]({pkg})")
|
||||
package_cell = ", ".join(package_links)
|
||||
else:
|
||||
package_cell = "-"
|
||||
else:
|
||||
# Handle single package (backward compatibility)
|
||||
package_cell = (
|
||||
f"[Release]({packages})" if pd.notna(packages) and packages else "-"
|
||||
)
|
||||
|
||||
line = f"| {row['Flash-Attention']} | {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {row['OS']} | {package_cell} |"
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
|
||||
"""Update the Packages section in README.md with new content."""
|
||||
try:
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the Packages section
|
||||
packages_start = content.find("## Packages")
|
||||
if packages_start == -1:
|
||||
raise ValueError("Packages section not found in README.md")
|
||||
|
||||
# Find the end of Packages section (next ## section or History section)
|
||||
packages_end = content.find("## History", packages_start)
|
||||
if packages_end == -1:
|
||||
# If no History section found, look for any other ## section
|
||||
remaining_content = content[packages_start + len("## Packages") :]
|
||||
next_section = remaining_content.find("\n## ")
|
||||
if next_section != -1:
|
||||
packages_end = packages_start + len("## Packages") + next_section
|
||||
else:
|
||||
packages_end = len(content)
|
||||
|
||||
# Replace the Packages section
|
||||
new_content = (
|
||||
content[:packages_start]
|
||||
+ "## Packages\n\n"
|
||||
+ packages_markdown
|
||||
+ "\n\n"
|
||||
+ content[packages_end:]
|
||||
)
|
||||
|
||||
# Write back to file
|
||||
with readme_path.open("w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to update README.md: {e}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a one-row-per-package Markdown table from the History section of a README.md"
|
||||
)
|
||||
parser.add_argument(
|
||||
"readme",
|
||||
nargs="?",
|
||||
type=Path,
|
||||
default=Path("README.md"),
|
||||
help="Path to README.md (default: README.md)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-readme",
|
||||
action="store_true",
|
||||
help="Update the Packages section in README.md instead of printing to stdout",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
with args.readme.open("r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
packages = extract_packages_from_history(text)
|
||||
|
||||
if not packages:
|
||||
print("No packages found in History section", file=sys.stderr)
|
||||
return
|
||||
|
||||
df = pd.DataFrame(packages)
|
||||
df_sorted = sort_packages(df)
|
||||
df_merged = merge_duplicate_rows(df_sorted)
|
||||
markdown = generate_markdown_table_by_os(df_merged)
|
||||
|
||||
if args.update_readme:
|
||||
# Update the README.md file
|
||||
update_readme_packages_section(args.readme, markdown)
|
||||
print(f"Updated Packages section in {args.readme}")
|
||||
else:
|
||||
# Print to stdout (original behavior)
|
||||
print(markdown)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,504 @@
|
||||
"""
|
||||
python insert_packages_to_readme.py --assets assets.json --update
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from common import normalize_platform_name, parse_wheel_filename
|
||||
|
||||
|
||||
def parse_numeric_version(text: str) -> tuple:
|
||||
"""Extract numeric version tuple for sorting."""
|
||||
nums = re.findall(r"\d+", text)
|
||||
return tuple(int(n) for n in nums)
|
||||
|
||||
|
||||
def normalize_semantic_version(version: str) -> str:
|
||||
"""Normalize semantic version by removing patch version.
|
||||
|
||||
Examples:
|
||||
2.9.1 -> 2.9
|
||||
2.8.1 -> 2.8
|
||||
2.6.3 -> 2.6
|
||||
2.9 -> 2.9 (no change if no patch version)
|
||||
"""
|
||||
if pd.isna(version) or not version:
|
||||
return version
|
||||
|
||||
# Split by '.' and take only major.minor
|
||||
parts = str(version).split(".")
|
||||
if len(parts) >= 2:
|
||||
return ".".join(parts[:2])
|
||||
return version
|
||||
|
||||
|
||||
def extract_packages_from_readme(readme_path: Path) -> list[dict]:
|
||||
"""Extract package information from existing Packages section in README.md."""
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find Packages section
|
||||
packages_start = content.find("## Packages")
|
||||
if packages_start == -1:
|
||||
return []
|
||||
|
||||
# Find the end of Packages section
|
||||
packages_end = content.find("## History", packages_start)
|
||||
if packages_end == -1:
|
||||
remaining_content = content[packages_start + len("## Packages") :]
|
||||
next_section = remaining_content.find("\n## ")
|
||||
if next_section != -1:
|
||||
packages_end = packages_start + len("## Packages") + next_section
|
||||
else:
|
||||
packages_end = len(content)
|
||||
|
||||
packages_section = content[packages_start:packages_end]
|
||||
lines = packages_section.splitlines()
|
||||
|
||||
packages = []
|
||||
current_os = None
|
||||
current_fa_version = None
|
||||
in_table = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
line_stripped = line.strip()
|
||||
|
||||
# Detect OS heading (### Linux x86_64)
|
||||
if line_stripped.startswith("### ") and not line_stripped.startswith("#### "):
|
||||
current_os = line_stripped[4:].strip()
|
||||
current_fa_version = None
|
||||
in_table = False
|
||||
continue
|
||||
|
||||
# Detect Flash-Attention version heading (#### Flash-Attention 2.8.3)
|
||||
if line_stripped.startswith("#### Flash-Attention "):
|
||||
# Extract version after "#### Flash-Attention "
|
||||
current_fa_version = line_stripped.replace(
|
||||
"#### Flash-Attention ", ""
|
||||
).strip()
|
||||
in_table = False
|
||||
continue
|
||||
|
||||
# Detect table start
|
||||
if "| Python | PyTorch | CUDA | package |" in line_stripped:
|
||||
in_table = True
|
||||
continue
|
||||
|
||||
# Skip table separator line
|
||||
if in_table and "| ------ |" in line_stripped:
|
||||
continue
|
||||
|
||||
# Process table rows
|
||||
if (
|
||||
in_table
|
||||
and line_stripped.startswith("|")
|
||||
and current_os
|
||||
and current_fa_version
|
||||
):
|
||||
# Parse table row: | Python | PyTorch | CUDA | package |
|
||||
cells = [
|
||||
c.strip() for c in line_stripped.split("|")[1:-1]
|
||||
] # Remove empty first/last cells
|
||||
cells = [c for c in cells if c] # Remove empty cells
|
||||
if len(cells) >= 4:
|
||||
python_version = cells[0]
|
||||
torch_version = cells[1]
|
||||
cuda_version = cells[2]
|
||||
package_cell = cells[3]
|
||||
|
||||
# Extract all URLs from package cell
|
||||
# Pattern: [Release1](url1), [Download1](url1), [Release](url), [Download](url), ...
|
||||
# Support both Release and Download patterns for backward compatibility
|
||||
package_urls = re.findall(
|
||||
r"\[(?:Release|Download)\d*\]\(([^)]+)\)", package_cell
|
||||
)
|
||||
|
||||
if package_urls:
|
||||
# Create a package entry for each URL
|
||||
for package_url in package_urls:
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": current_fa_version,
|
||||
"Python": python_version,
|
||||
"PyTorch": torch_version,
|
||||
"CUDA": cuda_version,
|
||||
"OS": current_os,
|
||||
"package": package_url,
|
||||
}
|
||||
)
|
||||
elif package_cell != "-":
|
||||
# Handle other formats
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": current_fa_version,
|
||||
"Python": python_version,
|
||||
"PyTorch": torch_version,
|
||||
"CUDA": cuda_version,
|
||||
"OS": current_os,
|
||||
"package": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Detect end of table (empty line or closing </details>)
|
||||
if in_table and (not line_stripped or line_stripped == "</details>"):
|
||||
in_table = False
|
||||
|
||||
return packages
|
||||
|
||||
|
||||
def extract_packages_from_assets_json(assets_path: Path) -> list[dict]:
|
||||
"""Extract package information from assets.json file."""
|
||||
with assets_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "assets" not in data:
|
||||
return []
|
||||
|
||||
packages = []
|
||||
|
||||
for asset in data["assets"]:
|
||||
name = asset.get("name", "")
|
||||
url = asset.get("url", "")
|
||||
|
||||
# Only process .whl files
|
||||
if not name.endswith(".whl"):
|
||||
continue
|
||||
|
||||
# Parse wheel filename
|
||||
info = parse_wheel_filename(name)
|
||||
if not info:
|
||||
continue
|
||||
|
||||
# Normalize platform name
|
||||
os_name = normalize_platform_name(info["platform"])
|
||||
|
||||
# Format versions for display
|
||||
flash_version = info["flash_version"]
|
||||
python_version = info["python_version"]
|
||||
torch_version = info["torch_version"] # Already in format like "2.9"
|
||||
cuda_version = info["cuda_version"]
|
||||
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": flash_version,
|
||||
"Python": python_version,
|
||||
"PyTorch": torch_version,
|
||||
"CUDA": cuda_version,
|
||||
"OS": os_name,
|
||||
"package": url, # Use download URL directly
|
||||
}
|
||||
)
|
||||
|
||||
return packages
|
||||
|
||||
|
||||
def sort_packages(
|
||||
df: pd.DataFrame,
|
||||
flash_ascending: bool = False,
|
||||
python_ascending: bool = False,
|
||||
pytorch_ascending: bool = False,
|
||||
cuda_ascending: bool = False,
|
||||
os_ascending: bool = True,
|
||||
package_ascending: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Sort packages by columns from left to right.
|
||||
|
||||
Args:
|
||||
df: DataFrame to sort
|
||||
flash_ascending: Sort Flash-Attention in ascending order (default: False, newer first)
|
||||
python_ascending: Sort Python in ascending order (default: False, newer first)
|
||||
pytorch_ascending: Sort PyTorch in ascending order (default: False, newer first)
|
||||
cuda_ascending: Sort CUDA in ascending order (default: False, newer first)
|
||||
os_ascending: Sort OS in ascending order (default: True, alphabetical)
|
||||
package_ascending: Sort package in ascending order (default: False, newer first)
|
||||
|
||||
Returns:
|
||||
Sorted DataFrame
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# Add sorting keys for version columns
|
||||
df["fa_sort"] = df["Flash-Attention"].apply(parse_numeric_version)
|
||||
df["py_sort"] = df["Python"].apply(parse_numeric_version)
|
||||
df["pt_sort"] = df["PyTorch"].apply(parse_numeric_version)
|
||||
df["cu_sort"] = df["CUDA"].apply(parse_numeric_version)
|
||||
df["os_sort"] = df["OS"].str.lower()
|
||||
|
||||
# Package sort: extract version from download URL
|
||||
def package_sort_key(url):
|
||||
# Handle list of URLs (take the first one for sorting)
|
||||
if isinstance(url, list):
|
||||
if not url or all(pd.isna(u) or not u for u in url):
|
||||
return tuple()
|
||||
# Find first non-empty URL
|
||||
for u in url:
|
||||
if pd.notna(u) and u:
|
||||
url = u
|
||||
break
|
||||
else:
|
||||
return tuple()
|
||||
|
||||
if pd.isna(url) or not url:
|
||||
return tuple() # No URL
|
||||
|
||||
# Extract tag from download URL: /releases/download/{tag}/
|
||||
tag_match = re.search(r"/releases/download/([^/]+)/", str(url))
|
||||
if not tag_match:
|
||||
return tuple()
|
||||
|
||||
tag = tag_match.group(1)
|
||||
return parse_numeric_version(tag)
|
||||
|
||||
df["pkg_sort"] = df["package"].apply(package_sort_key)
|
||||
|
||||
# Sort by columns from left to right
|
||||
df_sorted = df.sort_values(
|
||||
by=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"],
|
||||
ascending=[
|
||||
flash_ascending,
|
||||
os_ascending,
|
||||
python_ascending,
|
||||
pytorch_ascending,
|
||||
cuda_ascending,
|
||||
package_ascending,
|
||||
],
|
||||
)
|
||||
|
||||
# Drop sorting columns
|
||||
return df_sorted.drop(
|
||||
columns=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
|
||||
)
|
||||
|
||||
|
||||
def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Merge rows with duplicate Flash-Attention, Python, PyTorch, CUDA, OS values."""
|
||||
# Group by all columns except 'package'
|
||||
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
|
||||
|
||||
def combine_packages(group):
|
||||
# Get unique non-null packages (handle both list and scalar values)
|
||||
all_packages = []
|
||||
for pkg in group["package"]:
|
||||
if pd.notna(pkg):
|
||||
if isinstance(pkg, list):
|
||||
all_packages.extend(pkg)
|
||||
else:
|
||||
all_packages.append(pkg)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_packages = []
|
||||
for pkg in all_packages:
|
||||
if pkg and pkg not in seen:
|
||||
seen.add(pkg)
|
||||
unique_packages.append(pkg)
|
||||
|
||||
# Take the first row as base
|
||||
result = group.iloc[0].copy()
|
||||
|
||||
# Combine packages into a list
|
||||
result["package"] = unique_packages if unique_packages else [None]
|
||||
|
||||
return result
|
||||
|
||||
# Group and combine
|
||||
merged_df = df.groupby(group_cols, as_index=False).apply(
|
||||
combine_packages, include_groups=False
|
||||
)
|
||||
|
||||
# Reset index to clean up
|
||||
merged_df = merged_df.reset_index(drop=True)
|
||||
|
||||
return merged_df
|
||||
|
||||
|
||||
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
||||
"""Generate markdown tables grouped by OS and Flash-Attention version."""
|
||||
if df.empty:
|
||||
return ""
|
||||
|
||||
all_sections = []
|
||||
|
||||
# Group by OS and sort each group
|
||||
for os_name in sorted(df["OS"].unique()):
|
||||
os_df = df[df["OS"] == os_name].copy()
|
||||
|
||||
# Sort within OS group: Flash-Attention > Python > PyTorch > CUDA
|
||||
os_df = sort_packages(
|
||||
os_df,
|
||||
flash_ascending=False,
|
||||
python_ascending=True,
|
||||
pytorch_ascending=True,
|
||||
cuda_ascending=True,
|
||||
)
|
||||
|
||||
# Create OS section header
|
||||
os_lines = [f"### {os_name}", ""]
|
||||
|
||||
# Group by Flash-Attention version within each OS
|
||||
fa_versions = []
|
||||
for fa_version in os_df["Flash-Attention"].unique():
|
||||
fa_df = os_df[os_df["Flash-Attention"] == fa_version].copy()
|
||||
|
||||
# Sort by Python > PyTorch > CUDA within each Flash-Attention version
|
||||
fa_df = sort_packages(
|
||||
fa_df,
|
||||
python_ascending=True,
|
||||
pytorch_ascending=True,
|
||||
cuda_ascending=True,
|
||||
)
|
||||
|
||||
# Create collapsible table for this Flash-Attention version
|
||||
table_lines = [
|
||||
"| Python | PyTorch | CUDA | package |",
|
||||
"| ------ | ------- | ---- | ------- |",
|
||||
]
|
||||
|
||||
for _, row in fa_df.iterrows():
|
||||
packages = row["package"]
|
||||
|
||||
# Handle case where packages is a list
|
||||
if isinstance(packages, list):
|
||||
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
|
||||
# Create numbered download links
|
||||
package_links = []
|
||||
for i, pkg in enumerate(packages, 1):
|
||||
if pd.notna(pkg) and pkg:
|
||||
package_links.append(f"[Download{i}]({pkg})")
|
||||
package_cell = ", ".join(package_links)
|
||||
else:
|
||||
package_cell = "-"
|
||||
else:
|
||||
# Handle single package (backward compatibility)
|
||||
package_cell = (
|
||||
f"[Download]({packages})"
|
||||
if pd.notna(packages) and packages
|
||||
else "-"
|
||||
)
|
||||
|
||||
line = f"| {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {package_cell} |"
|
||||
table_lines.append(line)
|
||||
|
||||
# Create collapsible section for this Flash-Attention version
|
||||
fa_section = [
|
||||
f"#### Flash-Attention {fa_version}",
|
||||
"",
|
||||
"<details>",
|
||||
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
|
||||
"",
|
||||
"\n".join(table_lines),
|
||||
"",
|
||||
"</details>",
|
||||
"",
|
||||
]
|
||||
|
||||
fa_versions.extend(fa_section)
|
||||
|
||||
os_lines.extend(fa_versions)
|
||||
all_sections.extend(os_lines)
|
||||
|
||||
return "\n".join(all_sections)
|
||||
|
||||
|
||||
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
|
||||
"""Update the Packages section in README.md with new content."""
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the Packages section
|
||||
packages_start = content.find("## Packages")
|
||||
if packages_start == -1:
|
||||
raise ValueError("Packages section not found in README.md")
|
||||
|
||||
# Find the end of Packages section (next ## section or History section)
|
||||
packages_end = content.find("## History", packages_start)
|
||||
if packages_end == -1:
|
||||
# If no History section found, look for any other ## section
|
||||
remaining_content = content[packages_start + len("## Packages") :]
|
||||
next_section = remaining_content.find("\n## ")
|
||||
if next_section != -1:
|
||||
packages_end = packages_start + len("## Packages") + next_section
|
||||
else:
|
||||
packages_end = len(content)
|
||||
|
||||
# Replace the Packages section
|
||||
new_content = (
|
||||
content[:packages_start]
|
||||
+ "## Packages\n\n"
|
||||
+ packages_markdown
|
||||
+ "\n\n"
|
||||
+ content[packages_end:]
|
||||
)
|
||||
|
||||
# Write back to file
|
||||
with readme_path.open("w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a one-row-per-package Markdown table from assets.json file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--assets",
|
||||
type=str,
|
||||
default="assets.json",
|
||||
help="Path to assets.json file (default: assets.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update",
|
||||
action="store_true",
|
||||
help="Update the Packages section in README.md instead of printing to stdout",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assets_path = Path(args.assets)
|
||||
if not assets_path.exists():
|
||||
print(f"Error: {assets_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
readme_path = Path("README.md")
|
||||
|
||||
# Extract packages from assets.json
|
||||
assets_packages = extract_packages_from_assets_json(assets_path)
|
||||
|
||||
# Extract packages from existing README.md
|
||||
readme_packages = extract_packages_from_readme(readme_path)
|
||||
|
||||
# Combine both lists
|
||||
all_packages = assets_packages + readme_packages
|
||||
|
||||
if not all_packages:
|
||||
print(f"No packages found in {assets_path} or README.md", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Convert to DataFrame and process
|
||||
df = pd.DataFrame(all_packages)
|
||||
# Normalize CUDA versions (remove patch version)
|
||||
df["CUDA"] = df["CUDA"].apply(normalize_semantic_version)
|
||||
# Normalize PyTorch versions (remove patch version)
|
||||
df["PyTorch"] = df["PyTorch"].apply(normalize_semantic_version)
|
||||
# Normalize Python versions (remove patch version)
|
||||
df["Python"] = df["Python"].apply(normalize_semantic_version)
|
||||
df_sorted = sort_packages(df)
|
||||
df_merged = merge_duplicate_rows(df_sorted)
|
||||
markdown = generate_markdown_table_by_os(df_merged)
|
||||
|
||||
if args.update:
|
||||
# Update the README.md file
|
||||
update_readme_packages_section(readme_path, markdown)
|
||||
print(f"Updated Packages section in {readme_path}")
|
||||
else:
|
||||
# Print to stdout (original behavior)
|
||||
print(markdown)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user