Merge pull request #53 from mjun0812/feat/dynamic-matrix

Feat/dynamic matrix
This commit is contained in:
Junya Morioka
2025-11-02 02:46:09 +09:00
committed by GitHub
10 changed files with 2140 additions and 2020 deletions
+105 -210
View File
@@ -19,71 +19,34 @@ jobs:
--title "${{ github.ref_name }}" \
--notes "TBD"
create_matrix:
name: Create Matrix
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.create_matrix.outputs.matrix }}
steps:
- uses: actions/checkout@v4
- name: Create Matrix
id: create_matrix
run: |
python create_matrix.py | tee /tmp/matrix.json
echo "matrix=$(cat /tmp/matrix.json)" >> $GITHUB_OUTPUT
# #########################################################
# Linux
# #########################################################
build_wheels_linux:
name: Build Linux
needs: create_releases
needs: [create_releases, create_matrix]
if: ${{ fromjson(needs.create_matrix.outputs.matrix).linux }}
strategy:
fail-fast: false
matrix:
flash-attn-version:
- "2.6.3"
- "2.7.4.post1"
- "2.8.3"
python-version:
# - "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"
torch-version:
# - "2.5.1"
# - "2.6.0"
# - "2.7.1"
# - "2.8.0"
- "2.9.0"
# https://developer.nvidia.com/cuda-toolkit-archive
cuda-version:
# - "12.4.1"
# - "12.6.3"
# - "12.8.1"
# - "12.9.1"
- "13.0.1"
exclude:
# torch < 2.2 does not support Python 3.12
- python-version: "3.12"
torch-version: "2.0.1"
- python-version: "3.12"
torch-version: "2.1.2"
# torch 2.0.1 does not support CUDA 12.x
- torch-version: "2.0.1"
cuda-version: "12.1.1"
- torch-version: "2.0.1"
cuda-version: "12.4.1"
- torch-version: "2.0.1"
cuda-version: "12.6.3"
- torch-version: "2.0.1"
cuda-version: "12.8.1"
# torch 2.6.0 does not support CUDA 12.1
- torch-version: "2.6.0"
cuda-version: "12.1.1"
# torch 2.7.0 does not support CUDA 12.4
- torch-version: "2.7.0"
cuda-version: "12.4.1"
# torch < 2.8 does not support CUDA 12.9
- torch-version: "2.5.1"
cuda-version: "12.9.1"
- torch-version: "2.6.3"
cuda-version: "12.9.1"
- torch-version: "2.7.1"
cuda-version: "12.9.1"
# flash-attn 2.7.4 does not build in GitHub Hosted Runner
- flash-attn-version: "2.7.4"
# torch >= 2.9 does not support Python 3.9
- torch-version: "2.9.0"
python-version: "3.9"
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.flash-attn-version }}
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.python-version }}
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.torch-version }}
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux.cuda-version }}
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
uses: ./.github/workflows/build_linux.yml
with:
flash-attn-version: ${{ matrix.flash-attn-version }}
@@ -92,159 +55,76 @@ jobs:
cuda-version: ${{ matrix.cuda-version }}
secrets: inherit
# build_wheels_linux_self_hosted:
# name: Build Linux (self-hosted)
# needs: create_releases
# strategy:
# fail-fast: false
# matrix:
# flash-attn-version: ["2.8.3"]
# python-version: ["3.10", "3.11", "3.12"]
# torch-version: ["2.5.1", "2.6.0", "2.7.1", "2.8.0"]
# # https://developer.nvidia.com/cuda-toolkit-archive
# cuda-version: ["12.4.1", "12.8.1", "12.9.1"]
# exclude:
# # torch < 2.2 does not support Python 3.12
# - python-version: "3.12"
# torch-version: "2.0.1"
# - python-version: "3.12"
# torch-version: "2.1.2"
# # torch 2.0.1 does not support CUDA 12.x
# - torch-version: "2.0.1"
# cuda-version: "12.1.1"
# - torch-version: "2.0.1"
# cuda-version: "12.4.1"
# - torch-version: "2.0.1"
# cuda-version: "12.6.3"
# - torch-version: "2.0.1"
# cuda-version: "12.8.1"
# # torch 2.6.0 does not support CUDA 12.1
# - torch-version: "2.6.0"
# cuda-version: "12.1.1"
# # torch 2.7.0 does not support CUDA 12.4
# - torch-version: "2.7.0"
# cuda-version: "12.4.1"
# # torch < 2.8 does not support CUDA 12.9
# - torch-version: "2.5.1"
# cuda-version: "12.9.1"
# - torch-version: "2.6.3"
# cuda-version: "12.9.1"
# - torch-version: "2.7.1"
# cuda-version: "12.9.1"
# uses: ./.github/workflows/build_linux_self_host.yml
# with:
# flash-attn-version: ${{ matrix.flash-attn-version }}
# python-version: ${{ matrix.python-version }}
# torch-version: ${{ matrix.torch-version }}
# cuda-version: ${{ matrix.cuda-version }}
# secrets: inherit
build_wheels_linux_self_hosted:
name: Build Linux (self-hosted)
needs: [create_releases, create_matrix]
if: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted }}
strategy:
fail-fast: false
matrix:
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.flash-attn-version }}
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.python-version }}
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.torch-version }}
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).linux_self_hosted.cuda-version }}
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
uses: ./.github/workflows/build_linux_self_host.yml
with:
flash-attn-version: ${{ matrix.flash-attn-version }}
python-version: ${{ matrix.python-version }}
torch-version: ${{ matrix.torch-version }}
cuda-version: ${{ matrix.cuda-version }}
secrets: inherit
# #########################################################
# Windows
# #########################################################
# build_wheels_windows:
# name: Build Windows
# needs: create_releases
# strategy:
# fail-fast: false
# matrix:
# flash-attn-version: ["2.8.3"]
# python-version: ["3.11", "3.12", "3.13"]
# torch-version: ["2.9.0.dev20250909"]
# # https://developer.nvidia.com/cuda-toolkit-archive
# cuda-version: ["12.6.3", "12.8.1"]
# exclude:
# # torch < 2.2 does not support Python 3.12
# - python-version: "3.12"
# torch-version: "2.0.1"
# - python-version: "3.12"
# torch-version: "2.1.2"
# # torch 2.0.1 does not support CUDA 12.x
# - torch-version: "2.0.1"
# cuda-version: "12.1.1"
# - torch-version: "2.0.1"
# cuda-version: "12.4.1"
# - torch-version: "2.0.1"
# cuda-version: "12.6.3"
# - torch-version: "2.0.1"
# cuda-version: "12.8.1"
# # torch 2.6.0 does not support CUDA 12.1
# - torch-version: "2.6.0"
# cuda-version: "12.1.1"
# # torch 2.7.0 does not support CUDA 12.4
# - torch-version: "2.7.0"
# cuda-version: "12.4.1"
# # torch < 2.8 does not support CUDA 12.9
# - torch-version: "2.5.1"
# cuda-version: "12.9.1"
# - torch-version: "2.6.3"
# cuda-version: "12.9.1"
# - torch-version: "2.7.1"
# cuda-version: "12.9.1"
# uses: ./.github/workflows/build_windows.yml
# with:
# flash-attn-version: ${{ matrix.flash-attn-version }}
# python-version: ${{ matrix.python-version }}
# torch-version: ${{ matrix.torch-version }}
# cuda-version: ${{ matrix.cuda-version }}
# secrets: inherit
build_wheels_windows:
name: Build Windows
needs: [create_releases, create_matrix]
if: ${{ fromjson(needs.create_matrix.outputs.matrix).windows }}
strategy:
fail-fast: false
matrix:
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.flash-attn-version }}
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.python-version }}
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.torch-version }}
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows.cuda-version }}
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
uses: ./.github/workflows/build_windows.yml
with:
flash-attn-version: ${{ matrix.flash-attn-version }}
python-version: ${{ matrix.python-version }}
torch-version: ${{ matrix.torch-version }}
cuda-version: ${{ matrix.cuda-version }}
secrets: inherit
# build_wheels_windows_code_build:
# name: Build Windows (AWS CodeBuild)
# needs: create_releases
# strategy:
# fail-fast: false
# matrix:
# flash-attn-version: ["2.7.4", "2.8.2"]
# python-version: ["3.10", "3.11", "3.12"]
# torch-version: ["2.7.1", "2.8.0"]
# # https://developer.nvidia.com/cuda-toolkit-archive
# cuda-version: ["12.8.1"]
# exclude:
# # torch < 2.2 does not support Python 3.12
# - python-version: "3.12"
# torch-version: "2.0.1"
# - python-version: "3.12"
# torch-version: "2.1.2"
# # torch 2.0.1 does not support CUDA 12.x
# - torch-version: "2.0.1"
# cuda-version: "12.1.1"
# - torch-version: "2.0.1"
# cuda-version: "12.4.1"
# - torch-version: "2.0.1"
# cuda-version: "12.6.3"
# - torch-version: "2.0.1"
# cuda-version: "12.8.1"
# # torch 2.6.0 does not support CUDA 12.1
# - torch-version: "2.6.0"
# cuda-version: "12.1.1"
# # torch 2.7.0 does not support CUDA 12.4
# - torch-version: "2.7.0"
# cuda-version: "12.4.1"
# # torch < 2.8 does not support CUDA 12.9
# - torch-version: "2.5.1"
# cuda-version: "12.9.1"
# - torch-version: "2.6.3"
# cuda-version: "12.9.1"
# - torch-version: "2.7.1"
# cuda-version: "12.9.1"
# uses: ./.github/workflows/build_windows_code_build.yml
# with:
# flash-attn-version: ${{ matrix.flash-attn-version }}
# python-version: ${{ matrix.python-version }}
# torch-version: ${{ matrix.torch-version }}
# cuda-version: ${{ matrix.cuda-version }}
# secrets: inherit
build_wheels_windows_code_build:
name: Build Windows (AWS CodeBuild)
needs: [create_releases, create_matrix]
if: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build }}
strategy:
fail-fast: false
matrix:
flash-attn-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.flash-attn-version }}
python-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.python-version }}
torch-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.torch-version }}
cuda-version: ${{ fromjson(needs.create_matrix.outputs.matrix).windows_code_build.cuda-version }}
exclude: ${{ fromjson(needs.create_matrix.outputs.matrix).exclude }}
uses: ./.github/workflows/build_windows_code_build.yml
with:
flash-attn-version: ${{ matrix.flash-attn-version }}
python-version: ${{ matrix.python-version }}
torch-version: ${{ matrix.torch-version }}
cuda-version: ${{ matrix.cuda-version }}
secrets: inherit
update_release_notes:
name: Update Release Notes
needs:
- build_wheels_linux
# - build_wheels_linux_self_hosted
# - build_wheels_windows
# - build_wheels_windows_code_build
permissions:
contents: write
- build_wheels_linux_self_hosted
- build_wheels_windows
- build_wheels_windows_code_build
if: always()
runs-on: ubuntu-latest
steps:
@@ -265,23 +145,38 @@ jobs:
python create_release_note.py /tmp/assets.json > /tmp/release_notes.md
gh release edit "${{ github.ref_name }}" --notes-file /tmp/release_notes.md
- name: Update README history and packages
update_docs:
name: Update Docs
needs:
- build_wheels_linux
- build_wheels_linux_self_hosted
- build_wheels_windows
- build_wheels_windows_code_build
permissions:
contents: write
if: always()
runs-on: ubuntu-latest
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
steps:
- uses: actions/checkout@v4
- name: Update release history and packages section in README.md
run: |
cat /tmp/release_notes.md | python insert_history.py \
--notes - \
gh release view "${{ github.ref_name }}" --json assets > /tmp/assets.json
python create_release_history.py \
--assets /tmp/assets.json \
--tag "${{ github.ref_name }}" \
--repo "${{ github.repository }}"
python generate_packages_table.py --update-readme
--repo "${{ github.repository }}" \
--output docs/release_history.md
python insert_packages_to_readme.py --assets /tmp/assets.json --update
- name: Commit and push README updates
env:
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
- name: Commit and push docs updates
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
if git diff --quiet; then
echo "No README updates to commit."
echo "No docs updates to commit."
exit 0
fi
git commit -am "docs: update README for ${{ github.ref_name }}"
git commit -am "docs: update docs for ${{ github.ref_name }}"
git push origin HEAD:"${DEFAULT_BRANCH}"
+2
View File
@@ -1,2 +1,4 @@
.DS_Store
.env
__pycache__/
.ruff_cache/
+1026 -1290
View File
File diff suppressed because it is too large Load Diff
+57
View File
@@ -0,0 +1,57 @@
import re
def parse_wheel_filename(filename: str) -> dict | None:
"""
Extract information from a wheel filename.
Examples:
flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
---
Wheel filename から情報を抽出
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
"""
# Flash Attention wheelのパターンに合わせて正規表現を調整
# PyTorchバージョンはマイナーバージョン1桁の形式も対応 (例: torch2.9)
# post1 のようなバージョンサフィックスにも対応 (例: 2.7.4.post1)
pattern = (
r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
)
match = re.match(pattern, filename)
if match:
flash_version = match.group(1)
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 130 -> 13.0
torch_version = match.group(3)
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 310 -> 3.10
platform = match.group(5) # linux_x86_64, win32など
return {
"flash_version": flash_version,
"cuda_version": cuda_version,
"torch_version": torch_version,
"python_version": python_version,
"platform": platform,
}
return None
def normalize_platform_name(raw: str) -> str:
"""Platform name normalization
Examples:
linux -> Linux
linux_x86_64 -> Linux x86_64
win32 -> Windows
amd64 -> x86_64
"""
name = raw[:1].upper() + raw[1:] # linux -> Linux
name = name.replace("_", " ", 1) # linux_x86_64 -> Linux x86_64
if "Win" in name:
name = name.replace("Win", "Windows")
if "amd64" in name:
name = name.replace("amd64", "x86_64")
return name
+18 -85
View File
@@ -1,52 +1,17 @@
"""Update the History section in README.md from release notes or assets."""
"""Update the History section in README.md from assets."""
import argparse
import json
import re
import sys
from pathlib import Path
from typing import Dict, Iterable
WHEEL_PATTERN = re.compile(
r"flash_attn-(\d+\.\d+\.\d+)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
)
from common import normalize_platform_name, parse_wheel_filename
def parse_wheel_filename(filename: str) -> Dict[str, str] | None:
match = WHEEL_PATTERN.match(filename)
if not match:
return None
flash_version = match.group(1)
cuda_digits = match.group(2)
torch_version = match.group(3)
python_digits = match.group(4)
platform = match.group(5)
cuda_version = f"{cuda_digits[:2]}.{cuda_digits[2:]}"
python_version = f"{python_digits[:1]}.{python_digits[1:]}"
return {
"flash_version": flash_version,
"cuda_version": cuda_version,
"torch_version": torch_version,
"python_version": python_version,
"platform": platform,
}
def normalize_platform_name(raw: str) -> str:
name = raw[:1].upper() + raw[1:]
name = name.replace("_", " ", 1)
if "Win" in name:
name = name.replace("Win", "Windows")
if "amd64" in name:
name = name.replace("amd64", "x86_64")
return name
def collect_versions(assets: Iterable[Dict[str, str]]) -> Dict[str, Dict[str, set[str]]]:
def collect_versions(
assets: Iterable[Dict[str, str]],
) -> Dict[str, Dict[str, set[str]]]:
aggregated: Dict[str, Dict[str, set[str]]] = {}
for asset in assets:
info = parse_wheel_filename(asset.get("name", ""))
@@ -109,11 +74,6 @@ def render_body_from_aggregated(aggregated: Dict[str, Dict[str, set[str]]]) -> s
return "\n".join(body_lines).strip()
def convert_release_notes_to_body(notes_text: str) -> str:
converted = re.sub(r"^## ", "#### ", notes_text, flags=re.MULTILINE)
return converted.strip()
def build_history_section(tag: str, repo: str, body: str) -> str:
release_url = f"https://github.com/{repo}/releases/tag/{tag}"
lines = [f"### {tag}", "", f"[Release]({release_url})", "", body.strip()]
@@ -121,7 +81,9 @@ def build_history_section(tag: str, repo: str, body: str) -> str:
def remove_existing_section(content: str, tag: str) -> str:
pattern = re.compile(rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL)
pattern = re.compile(
rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL
)
return re.sub(pattern, "", content)
@@ -137,47 +99,22 @@ def insert_history_section(content: str, section: str) -> str:
def main() -> None:
parser = argparse.ArgumentParser(description="Update README.md History section")
parser.add_argument("--assets", type=Path, help="JSON file from gh release view")
parser.add_argument(
"--notes",
help="Release notes markdown file path or '-' to read from stdin",
"--assets", type=Path, required=True, help="JSON file from gh release view"
)
parser.add_argument("--tag", required=True, help="Release tag name")
parser.add_argument("--repo", required=True, help="Repository in owner/name format")
parser.add_argument(
"--readme",
type=Path,
default=Path("README.md"),
help="Path to README.md",
)
parser.add_argument("--output", type=Path, required=True, help="Output file path")
args = parser.parse_args()
history_body: str
if args.notes:
if args.notes == "-":
notes_text = sys.stdin.read()
else:
notes_path = Path(args.notes)
if not notes_path.exists():
raise FileNotFoundError(f"Notes file not found: {notes_path}")
notes_text = notes_path.read_text(encoding="utf-8")
history_body = convert_release_notes_to_body(notes_text)
else:
if not args.assets:
raise ValueError("Either --notes or --assets must be provided")
if not args.assets.exists():
raise FileNotFoundError(f"Assets file not found: {args.assets}")
data = json.loads(args.assets.read_text(encoding="utf-8"))
assets = data.get("assets", [])
aggregated = collect_versions(assets)
history_body = render_body_from_aggregated(aggregated)
data = json.loads(args.assets.read_text(encoding="utf-8"))
assets = data.get("assets", [])
aggregated = collect_versions(assets)
history_body = render_body_from_aggregated(aggregated)
section = build_history_section(args.tag, args.repo, history_body)
content = args.readme.read_text(encoding="utf-8")
content = args.output.read_text(encoding="utf-8")
stripped = remove_existing_section(content, args.tag)
updated = insert_history_section(stripped, section)
@@ -185,13 +122,9 @@ def main() -> None:
print("No changes in README.md")
return
args.readme.write_text(updated, encoding="utf-8")
print(f"Inserted history for {args.tag} into {args.readme}")
args.output.write_text(updated, encoding="utf-8")
print(f"Inserted history for {args.tag} into README.md")
if __name__ == "__main__":
try:
main()
except Exception as exc:
print(f"Error: {exc}", file=sys.stderr)
sys.exit(1)
main()
+2 -37
View File
@@ -1,37 +1,8 @@
import json
import re
import sys
from pathlib import Path
def parse_wheel_filename(filename):
"""
Wheel filename から情報を抽出
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
"""
# Flash Attention wheelのパターンに合わせて正規表現を調整
# PyTorchバージョンはパッチバージョンなしの形式 (例: torch2.5)
pattern = (
r"flash_attn-(\d+\.\d+\.\d+)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
)
match = re.match(pattern, filename)
if match:
flash_version = match.group(1)
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 124 -> 12.4
torch_version = match.group(3)
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 311 -> 3.11
platform = match.group(5) # linux, win32など
return {
"flash_version": flash_version,
"cuda_version": cuda_version,
"torch_version": torch_version,
"python_version": python_version,
"platform": platform,
}
return None
from common import normalize_platform_name, parse_wheel_filename
def generate_release_notes_from_assets(assets_info: dict):
@@ -74,13 +45,7 @@ def generate_release_notes_from_assets(assets_info: dict):
if any(len(data[key]) == 0 for key in data):
continue
platform_name = platform_name[:1].upper() + platform_name[1:]
platform_name = platform_name.replace("_", " ", 1)
if "Win" in platform_name:
platform_name = platform_name.replace("Win", "Windows")
if "amd64" in platform_name:
platform_name = platform_name.replace("amd64", "x86_64")
platform_name = normalize_platform_name(platform_name)
notes.append(f"## {platform_name}")
notes.append("")
+279
View File
@@ -0,0 +1,279 @@
## History
### v0.4.18
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.18)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.6.3, 2.8.3 | 3.10, 3.11, 3.12, 3.13 | 2.9 | 13.0 |
### v0.4.17
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.17)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.6.3, 2.8.3 | 3.10, 3.11, 3.12, 3.13 | 2.9 | 12.6, 12.8 |
### v0.4.16
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.16)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.6.3, 2.8.3 | 3.9 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.6 |
### v0.4.15
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.15)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.8.3 | 3.11, 3.12, 3.13 | 2.9 | 12.6, 12.8 |
#### Windows x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.8.3 | 3.11, 3.12, 3.13 | 2.9 | 12.6 |
### v0.4.12
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.12)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.8.3 | 3.13 | 2.6, 2.7, 2.8 | 12.4, 12.6, 12.8, 12.9 |
#### Windows x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.8.2 | 3.13 | 2.6, 2.7, 2.8 | 12.4, 12.6 |
### v0.4.11
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.11)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.8.3 | 3.10, 3.11, 3.12 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.6, 12.8, 12.9 |
### v0.4.10
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.10)
#### Windows x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.7.4, 2.8.2 | 3.10, 3.11, 3.12 | 2.7, 2.8 | 12.8 |
### v0.4.9
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.9)
#### Windows x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.7.4 | 3.11 | 2.7 | 12.8 |
### v0.3.18
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.18)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --- | --- | --- | --- |
| 2.7.4 | 3.10, 3.11, 3.12 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.8, 12.9 |
### v0.3.14
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.14)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ---------------- | -------------------------- | ---------------------- |
| 2.6.3, 2.8.2 | 3.10, 3.11, 3.12 | 2.5.1, 2.6.0, 2.7.1, 2.8.0 | 12.4.1, 12.8.1, 12.9.1 |
### v0.3.13
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.13)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ---------------- | -------------------------- | ------ |
| 2.8.1 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0, 2.7.1 | 12.8.1 |
### v0.3.12
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.12)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ---------------- | -------------------------- | -------------- |
| 2.8.0 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0, 2.7.1 | 12.4.1, 12.8.1 |
### v0.3.10
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.10)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ---------------- | ------- | ------ |
| 2.7.4 | 3.10, 3.11, 3.12 | 2.7.1 | 12.8.1 |
### v0.3.9
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.9)
#### Linux x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| ------------------- | ---------------- | ------- | ------ |
| 2.4.3, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.7.1 | 12.8.1 |
#### Windows x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| ------------------- | ---------------- | ------------------- | ------ |
| 2.5.9, 2.6.3, 2.7.4 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0 | 12.4.1 |
> [!IMPORTANT]
> ⚠️ Building flash-attn v2.7.4 with CUDA 12.8 on Windows cannot be completed because of GitHub Actions processing-time limits. In the future, I plan to add a self-hosted Windows runner to resolve this issue.
### v0.3.1
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.3.1)
#### Windows x86_64
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ------ | ------- | ------ |
| 2.6.3 | 3.11 | 2.6.0 | 12.6.3 |
From this version, Wheels for Windows are released.
However, we are waiting for a report on how it works because we have not tested it enough.
### v0.2.1
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.2.1)
| Flash-Attention | Python | PyTorch | CUDA |
| -------------------------- | ---------------- | ----------------- | ------ |
| 2.4.3, 2.5.9, 2.6.3, 2.7.4 | 3.10, 3.11, 3.12 | 2.8.0.dev20250523 | 12.8.1 |
### v0.2.0
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.2.0)
| Flash-Attention | Python | PyTorch | CUDA |
| ------------------- | ---------------- | ----------------- | ------ |
| 2.4.3, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.8.0.dev20250523 | 12.8.1 |
### v0.1.0
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.1.0)
| Flash-Attention | Python | PyTorch | CUDA |
| -------------------------- | ---------------- | ------- | ------ |
| 2.4.3, 2.5.9, 2.6.3, 2.7.4 | 3.10, 3.11, 3.12 | 2.7.0 | 12.8.1 |
v2.7.4 and v2.7.4.post1 are the same version.
From this release, self-hosted runners are used for building some wheels.
### v0.0.9
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.9)
| Flash-Attention | Python | PyTorch | CUDA |
| ------------------- | ---------------- | ------- | ------ |
| 2.4.3, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.7.0 | 12.8.1 |
### v0.0.8
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8)
| Flash-Attention | Python | PyTorch | CUDA |
| -------------------------------- | ---------------- | -------------------------- | ---------------------- |
| 2.4.3, 2.5.9, 2.6.3, 2.7.4.post1 | 3.10, 3.11, 3.12 | 2.4.1, 2.5.1, 2.6.0, 2.7.0 | 11.8.0, 12.4.1, 12.6.3 |
### v0.0.7
Skip for experimental reasons.
### v0.0.6
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.6)
| Flash-Attention | Python | PyTorch | CUDA |
| -------------------------------- | ---------------- | --------------------------------- | -------------- |
| 2.4.3, 2.5.9, 2.6.3, 2.7.4.post1 | 3.10, 3.11, 3.12 | 2.2.2, 2.3.1, 2.4.1, 2.5.1, 2.6.0 | 12.4.1, 12.6.3 |
### v0.0.5
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.5)
| Flash-Attention | Python | PyTorch | CUDA |
| ------------------ | ---------------- | ----------------------------------------------- | -------------- |
| 2.6.3, 2.7.4.post1 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1, 2.6.0 | 12.4.1, 12.6.3 |
### v0.0.4
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.4)
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ---------------- | ---------------------------------------- | ---------------------- |
| 2.7.3 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 |
### v0.0.3
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.3)
| Flash-Attention | Python | PyTorch | CUDA |
| --------------- | ---------------- | ---------------------------------------- | ---------------------- |
| 2.7.2.post1 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 |
### v0.0.2
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.2)
| Flash-Attention | Python | PyTorch | CUDA |
| -------------------------------- | ---------------- | ---------------------------------------- | ---------------------- |
| 2.4.3, 2.5.6, 2.6.3, 2.7.0.post2 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.1 | 11.8.0, 12.1.1, 12.4.1 |
### v0.0.1
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.1)
| flash-attention | Python | PyTorch | CUDA |
| --------------------------------- | ---------------- | ---------------------------------------- | ---------------------- |
| 1.0.9, 2.4.3, 2.5.6, 2.5.9, 2.6.3 | 3.10, 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.0 | 11.8.0, 12.1.1, 12.4.1 |
### v0.0.0
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.0)
| flash-attention | Python | PyTorch | CUDA |
| -------------------------- | ---------- | ---------------------------------------- | ---------------------- |
| 2.4.3, 2.5.6, 2.5.9, 2.6.3 | 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.0 | 11.8.0, 12.1.1, 12.4.1 |
+147
View File
@@ -0,0 +1,147 @@
"""
Fetch all assets from all GitHub releases and save to assets.json
Usage:
python fetch_all_assets.py
python fetch_all_assets.py --output all_assets.json
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
import requests
def get_github_token():
"""Get GitHub token from environment variable."""
token = os.environ.get("GITHUB_TOKEN")
if not token:
print(
"Warning: GITHUB_TOKEN not set. API rate limit will be restricted.",
file=sys.stderr,
)
return token
def fetch_all_releases(repo: str, token: str | None = None) -> list[dict]:
"""Fetch all releases from a GitHub repository."""
headers = {}
if token:
headers["Authorization"] = f"token {token}"
headers["Accept"] = "application/vnd.github.v3+json"
all_releases = []
page = 1
per_page = 100
while True:
url = f"https://api.github.com/repos/{repo}/releases"
params = {"page": page, "per_page": per_page}
print(f"Fetching releases page {page}...", file=sys.stderr)
response = requests.get(url, headers=headers, params=params, timeout=30)
if response.status_code != 200:
print(
f"Error fetching releases: {response.status_code} - {response.text}",
file=sys.stderr,
)
break
releases = response.json()
if not releases:
break
all_releases.extend(releases)
print(f" Found {len(releases)} releases on page {page}", file=sys.stderr)
# Check if there are more pages
if len(releases) < per_page:
break
page += 1
time.sleep(0.5) # Rate limiting
return all_releases
def extract_assets_from_releases(releases: list[dict]) -> list[dict]:
"""Extract all wheel assets from releases."""
all_assets = []
for release in releases:
tag = release.get("tag_name", "")
print(f"Processing release {tag}...", file=sys.stderr)
for asset in release.get("assets", []):
name = asset.get("name", "")
# Only include .whl files
if not name.endswith(".whl"):
continue
# Extract relevant asset information
asset_info = {
"name": name,
"url": asset.get("browser_download_url", ""),
"size": asset.get("size", 0),
"downloadCount": asset.get("download_count", 0),
"createdAt": asset.get("created_at", ""),
"updatedAt": asset.get("updated_at", ""),
"id": asset.get("node_id", ""),
"apiUrl": asset.get("url", ""),
"contentType": asset.get("content_type", ""),
"state": asset.get("state", ""),
"label": asset.get("label", ""),
}
all_assets.append(asset_info)
print(f"\nTotal assets found: {len(all_assets)}", file=sys.stderr)
return all_assets
def main() -> None:
parser = argparse.ArgumentParser(
description="Fetch all assets from all GitHub releases"
)
parser.add_argument(
"--repo",
type=str,
default="mjun0812/flash-attention-prebuild-wheels",
help="GitHub repository (default: mjun0812/flash-attention-prebuild-wheels)",
)
parser.add_argument(
"--output",
type=str,
default="assets.json",
help="Output file path (default: assets.json)",
)
args = parser.parse_args()
token = get_github_token()
# Fetch all releases
print(f"Fetching all releases from {args.repo}...", file=sys.stderr)
releases = fetch_all_releases(args.repo, token)
print(f"Total releases found: {len(releases)}\n", file=sys.stderr)
# Extract assets
assets = extract_assets_from_releases(releases)
# Save to file
output_path = Path(args.output)
output_data = {"assets": assets}
with output_path.open("w", encoding="utf-8") as f:
json.dump(output_data, f, indent=2, ensure_ascii=False)
print(f"\nSaved {len(assets)} assets to {output_path}", file=sys.stderr)
if __name__ == "__main__":
main()
-398
View File
@@ -1,398 +0,0 @@
"""
python generate_packages_table.py --update-readme
"""
import argparse
import itertools
import re
import sys
from pathlib import Path
import pandas as pd
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting."""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def extract_packages_from_history(text: str) -> list[dict]:
"""Extract package information from History section."""
lines = text.splitlines()
# Find start of History section
in_history = False
for i, line in enumerate(lines):
if line.strip().startswith("## ") and "History" in line:
in_history = True
lines = lines[i:]
break
if not in_history:
return []
packages = []
current_release_url = None
current_os = "Linux x86_64" # default
i = 0
while i < len(lines):
line = lines[i].strip()
# Reset on new version
if line.startswith("### "):
current_release_url = None
current_os = "Linux x86_64"
# Capture Release link
elif "[Release](" in line:
match = re.search(r"\[Release\]\(([^)]+)\)", line)
if match:
current_release_url = match.group(1)
# Capture OS heading
elif line.startswith("#### "):
current_os = line[5:].strip() or "Linux x86_64"
# Process table
elif line.startswith("| Flash-Attention") or line.startswith(
"|Flash-Attention"
):
# Skip header and separator
i += 2
# Process table rows
while i < len(lines):
row_line = lines[i].strip()
if not row_line.startswith("|") or not row_line:
break
# Parse table row
cells = [c.strip() for c in row_line.split("|")]
cells = [c for c in cells if c] # Remove empty cells
if len(cells) >= 4:
fa_versions = [v.strip() for v in cells[0].split(",") if v.strip()]
py_versions = [v.strip() for v in cells[1].split(",") if v.strip()]
pt_versions = [v.strip() for v in cells[2].split(",") if v.strip()]
cu_versions = [v.strip() for v in cells[3].split(",") if v.strip()]
# Generate all combinations
for fa, py, pt, cu in itertools.product(
fa_versions, py_versions, pt_versions, cu_versions
):
packages.append(
{
"Flash-Attention": fa,
"Python": py,
"PyTorch": pt,
"CUDA": cu,
"OS": current_os,
"package": current_release_url,
}
)
i += 1
continue
i += 1
return packages
def sort_packages(df: pd.DataFrame) -> pd.DataFrame:
"""Sort packages with custom priority."""
# Add sorting keys
# Flash-Attention: descending order (newer versions first)
df["fa_sort"] = df["Flash-Attention"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
df["os_sort"] = df["OS"].str.lower()
# Python, PyTorch, CUDA: descending order (newer versions first)
df["py_sort"] = df["Python"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
df["pt_sort"] = df["PyTorch"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
df["cu_sort"] = df["CUDA"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
# Package sort: extract version from URL, newer first
def package_sort_key(url):
if pd.isna(url) or not url:
return (1, tuple()) # No URL comes last
tag_match = re.search(r"/tag/([^/]+)$", str(url))
if not tag_match:
return (1, tuple())
tag = tag_match.group(1)
version_tuple = parse_numeric_version(tag)
return (0, tuple(-v for v in version_tuple)) # Negate for descending
df["pkg_sort"] = df["package"].apply(package_sort_key)
# Sort by priority: Flash-Attention > OS > Python > PyTorch > CUDA > package
df_sorted = df.sort_values(
["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
)
# Drop sorting columns
return df_sorted.drop(
columns=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
)
def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
"""Merge rows with duplicate Flash-Attention, Python, PyTorch, CUDA, OS values."""
# Group by all columns except 'package'
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
def combine_packages(group):
# Get unique non-null packages
packages = [pkg for pkg in group["package"].dropna().unique() if pkg]
# Take the first row as base
result = group.iloc[0].copy()
# Combine packages into a list
result["package"] = packages if packages else [None]
return result
# Group and combine
merged_df = df.groupby(group_cols, as_index=False).apply(
combine_packages, include_groups=False
)
# Reset index to clean up
merged_df = merged_df.reset_index(drop=True)
return merged_df
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
"""Generate markdown tables grouped by OS and Flash-Attention version."""
if df.empty:
return ""
all_sections = []
# Group by OS and sort each group
for os_name in sorted(df["OS"].unique()):
os_df = df[df["OS"] == os_name].copy()
# Re-sort within each OS group to ensure Flash-Attention is in descending order
os_df["fa_sort"] = os_df["Flash-Attention"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
os_df["py_sort"] = os_df["Python"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
os_df["pt_sort"] = os_df["PyTorch"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
os_df["cu_sort"] = os_df["CUDA"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
# Sort by Flash-Attention > Python > PyTorch > CUDA
os_df = os_df.sort_values(["fa_sort", "py_sort", "pt_sort", "cu_sort"])
os_df = os_df.drop(columns=["fa_sort", "py_sort", "pt_sort", "cu_sort"])
# Create OS section header
os_lines = [f"### {os_name}", ""]
# Group by Flash-Attention version within each OS
fa_versions = []
for fa_version in os_df["Flash-Attention"].unique():
fa_df = os_df[os_df["Flash-Attention"] == fa_version].copy()
# Re-sort by Python > PyTorch > CUDA within each Flash-Attention version
fa_df["py_sort"] = fa_df["Python"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
fa_df["pt_sort"] = fa_df["PyTorch"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
fa_df["cu_sort"] = fa_df["CUDA"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
fa_df = fa_df.sort_values(["py_sort", "pt_sort", "cu_sort"])
fa_df = fa_df.drop(columns=["py_sort", "pt_sort", "cu_sort"])
# Create collapsible table for this Flash-Attention version
table_lines = [
"| Python | PyTorch | CUDA | package |",
"| ------ | ------- | ---- | ------- |",
]
for _, row in fa_df.iterrows():
packages = row["package"]
# Handle case where packages is a list
if isinstance(packages, list):
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
# Create numbered release links
package_links = []
for i, pkg in enumerate(packages, 1):
if pd.notna(pkg) and pkg:
package_links.append(f"[Release{i}]({pkg})")
package_cell = ", ".join(package_links)
else:
package_cell = "-"
else:
# Handle single package (backward compatibility)
package_cell = (
f"[Release]({packages})"
if pd.notna(packages) and packages
else "-"
)
line = f"| {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {package_cell} |"
table_lines.append(line)
# Create collapsible section for this Flash-Attention version
fa_section = [
f"#### Flash-Attention {fa_version}",
"",
"<details>",
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
"",
"\n".join(table_lines),
"",
"</details>",
"",
]
fa_versions.extend(fa_section)
os_lines.extend(fa_versions)
all_sections.extend(os_lines)
return "\n".join(all_sections)
def generate_markdown_table(df: pd.DataFrame) -> str:
"""Generate markdown table from DataFrame (legacy function for backward compatibility)."""
lines = [
"| Flash-Attention | Python | PyTorch | CUDA | OS | package |",
"| --------------- | ------ | ------- | ------ | ---- | ------- |",
]
for _, row in df.iterrows():
packages = row["package"]
# Handle case where packages is a list
if isinstance(packages, list):
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
# Create numbered release links
package_links = []
for i, pkg in enumerate(packages, 1):
if pd.notna(pkg) and pkg:
package_links.append(f"[Release{i}]({pkg})")
package_cell = ", ".join(package_links)
else:
package_cell = "-"
else:
# Handle single package (backward compatibility)
package_cell = (
f"[Release]({packages})" if pd.notna(packages) and packages else "-"
)
line = f"| {row['Flash-Attention']} | {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {row['OS']} | {package_cell} |"
lines.append(line)
return "\n".join(lines)
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
"""Update the Packages section in README.md with new content."""
try:
with readme_path.open("r", encoding="utf-8") as f:
content = f.read()
# Find the Packages section
packages_start = content.find("## Packages")
if packages_start == -1:
raise ValueError("Packages section not found in README.md")
# Find the end of Packages section (next ## section or History section)
packages_end = content.find("## History", packages_start)
if packages_end == -1:
# If no History section found, look for any other ## section
remaining_content = content[packages_start + len("## Packages") :]
next_section = remaining_content.find("\n## ")
if next_section != -1:
packages_end = packages_start + len("## Packages") + next_section
else:
packages_end = len(content)
# Replace the Packages section
new_content = (
content[:packages_start]
+ "## Packages\n\n"
+ packages_markdown
+ "\n\n"
+ content[packages_end:]
)
# Write back to file
with readme_path.open("w", encoding="utf-8") as f:
f.write(new_content)
except Exception as e:
raise RuntimeError(f"Failed to update README.md: {e}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate a one-row-per-package Markdown table from the History section of a README.md"
)
parser.add_argument(
"readme",
nargs="?",
type=Path,
default=Path("README.md"),
help="Path to README.md (default: README.md)",
)
parser.add_argument(
"--update-readme",
action="store_true",
help="Update the Packages section in README.md instead of printing to stdout",
)
args = parser.parse_args()
try:
with args.readme.open("r", encoding="utf-8") as f:
text = f.read()
packages = extract_packages_from_history(text)
if not packages:
print("No packages found in History section", file=sys.stderr)
return
df = pd.DataFrame(packages)
df_sorted = sort_packages(df)
df_merged = merge_duplicate_rows(df_sorted)
markdown = generate_markdown_table_by_os(df_merged)
if args.update_readme:
# Update the README.md file
update_readme_packages_section(args.readme, markdown)
print(f"Updated Packages section in {args.readme}")
else:
# Print to stdout (original behavior)
print(markdown)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
if __name__ == "__main__":
main()
+504
View File
@@ -0,0 +1,504 @@
"""
python insert_packages_to_readme.py --assets assets.json --update
"""
import argparse
import json
import re
import sys
from pathlib import Path
import pandas as pd
from common import normalize_platform_name, parse_wheel_filename
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting."""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def normalize_semantic_version(version: str) -> str:
"""Normalize semantic version by removing patch version.
Examples:
2.9.1 -> 2.9
2.8.1 -> 2.8
2.6.3 -> 2.6
2.9 -> 2.9 (no change if no patch version)
"""
if pd.isna(version) or not version:
return version
# Split by '.' and take only major.minor
parts = str(version).split(".")
if len(parts) >= 2:
return ".".join(parts[:2])
return version
def extract_packages_from_readme(readme_path: Path) -> list[dict]:
"""Extract package information from existing Packages section in README.md."""
with readme_path.open("r", encoding="utf-8") as f:
content = f.read()
# Find Packages section
packages_start = content.find("## Packages")
if packages_start == -1:
return []
# Find the end of Packages section
packages_end = content.find("## History", packages_start)
if packages_end == -1:
remaining_content = content[packages_start + len("## Packages") :]
next_section = remaining_content.find("\n## ")
if next_section != -1:
packages_end = packages_start + len("## Packages") + next_section
else:
packages_end = len(content)
packages_section = content[packages_start:packages_end]
lines = packages_section.splitlines()
packages = []
current_os = None
current_fa_version = None
in_table = False
for i, line in enumerate(lines):
line_stripped = line.strip()
# Detect OS heading (### Linux x86_64)
if line_stripped.startswith("### ") and not line_stripped.startswith("#### "):
current_os = line_stripped[4:].strip()
current_fa_version = None
in_table = False
continue
# Detect Flash-Attention version heading (#### Flash-Attention 2.8.3)
if line_stripped.startswith("#### Flash-Attention "):
# Extract version after "#### Flash-Attention "
current_fa_version = line_stripped.replace(
"#### Flash-Attention ", ""
).strip()
in_table = False
continue
# Detect table start
if "| Python | PyTorch | CUDA | package |" in line_stripped:
in_table = True
continue
# Skip table separator line
if in_table and "| ------ |" in line_stripped:
continue
# Process table rows
if (
in_table
and line_stripped.startswith("|")
and current_os
and current_fa_version
):
# Parse table row: | Python | PyTorch | CUDA | package |
cells = [
c.strip() for c in line_stripped.split("|")[1:-1]
] # Remove empty first/last cells
cells = [c for c in cells if c] # Remove empty cells
if len(cells) >= 4:
python_version = cells[0]
torch_version = cells[1]
cuda_version = cells[2]
package_cell = cells[3]
# Extract all URLs from package cell
# Pattern: [Release1](url1), [Download1](url1), [Release](url), [Download](url), ...
# Support both Release and Download patterns for backward compatibility
package_urls = re.findall(
r"\[(?:Release|Download)\d*\]\(([^)]+)\)", package_cell
)
if package_urls:
# Create a package entry for each URL
for package_url in package_urls:
packages.append(
{
"Flash-Attention": current_fa_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": current_os,
"package": package_url,
}
)
elif package_cell != "-":
# Handle other formats
packages.append(
{
"Flash-Attention": current_fa_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": current_os,
"package": None,
}
)
# Detect end of table (empty line or closing </details>)
if in_table and (not line_stripped or line_stripped == "</details>"):
in_table = False
return packages
def extract_packages_from_assets_json(assets_path: Path) -> list[dict]:
"""Extract package information from assets.json file."""
with assets_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if "assets" not in data:
return []
packages = []
for asset in data["assets"]:
name = asset.get("name", "")
url = asset.get("url", "")
# Only process .whl files
if not name.endswith(".whl"):
continue
# Parse wheel filename
info = parse_wheel_filename(name)
if not info:
continue
# Normalize platform name
os_name = normalize_platform_name(info["platform"])
# Format versions for display
flash_version = info["flash_version"]
python_version = info["python_version"]
torch_version = info["torch_version"] # Already in format like "2.9"
cuda_version = info["cuda_version"]
packages.append(
{
"Flash-Attention": flash_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": os_name,
"package": url, # Use download URL directly
}
)
return packages
def sort_packages(
df: pd.DataFrame,
flash_ascending: bool = False,
python_ascending: bool = False,
pytorch_ascending: bool = False,
cuda_ascending: bool = False,
os_ascending: bool = True,
package_ascending: bool = False,
) -> pd.DataFrame:
"""
Sort packages by columns from left to right.
Args:
df: DataFrame to sort
flash_ascending: Sort Flash-Attention in ascending order (default: False, newer first)
python_ascending: Sort Python in ascending order (default: False, newer first)
pytorch_ascending: Sort PyTorch in ascending order (default: False, newer first)
cuda_ascending: Sort CUDA in ascending order (default: False, newer first)
os_ascending: Sort OS in ascending order (default: True, alphabetical)
package_ascending: Sort package in ascending order (default: False, newer first)
Returns:
Sorted DataFrame
"""
df = df.copy()
# Add sorting keys for version columns
df["fa_sort"] = df["Flash-Attention"].apply(parse_numeric_version)
df["py_sort"] = df["Python"].apply(parse_numeric_version)
df["pt_sort"] = df["PyTorch"].apply(parse_numeric_version)
df["cu_sort"] = df["CUDA"].apply(parse_numeric_version)
df["os_sort"] = df["OS"].str.lower()
# Package sort: extract version from download URL
def package_sort_key(url):
# Handle list of URLs (take the first one for sorting)
if isinstance(url, list):
if not url or all(pd.isna(u) or not u for u in url):
return tuple()
# Find first non-empty URL
for u in url:
if pd.notna(u) and u:
url = u
break
else:
return tuple()
if pd.isna(url) or not url:
return tuple() # No URL
# Extract tag from download URL: /releases/download/{tag}/
tag_match = re.search(r"/releases/download/([^/]+)/", str(url))
if not tag_match:
return tuple()
tag = tag_match.group(1)
return parse_numeric_version(tag)
df["pkg_sort"] = df["package"].apply(package_sort_key)
# Sort by columns from left to right
df_sorted = df.sort_values(
by=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"],
ascending=[
flash_ascending,
os_ascending,
python_ascending,
pytorch_ascending,
cuda_ascending,
package_ascending,
],
)
# Drop sorting columns
return df_sorted.drop(
columns=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
)
def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
"""Merge rows with duplicate Flash-Attention, Python, PyTorch, CUDA, OS values."""
# Group by all columns except 'package'
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
def combine_packages(group):
# Get unique non-null packages (handle both list and scalar values)
all_packages = []
for pkg in group["package"]:
if pd.notna(pkg):
if isinstance(pkg, list):
all_packages.extend(pkg)
else:
all_packages.append(pkg)
# Remove duplicates while preserving order
seen = set()
unique_packages = []
for pkg in all_packages:
if pkg and pkg not in seen:
seen.add(pkg)
unique_packages.append(pkg)
# Take the first row as base
result = group.iloc[0].copy()
# Combine packages into a list
result["package"] = unique_packages if unique_packages else [None]
return result
# Group and combine
merged_df = df.groupby(group_cols, as_index=False).apply(
combine_packages, include_groups=False
)
# Reset index to clean up
merged_df = merged_df.reset_index(drop=True)
return merged_df
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
"""Generate markdown tables grouped by OS and Flash-Attention version."""
if df.empty:
return ""
all_sections = []
# Group by OS and sort each group
for os_name in sorted(df["OS"].unique()):
os_df = df[df["OS"] == os_name].copy()
# Sort within OS group: Flash-Attention > Python > PyTorch > CUDA
os_df = sort_packages(
os_df,
flash_ascending=False,
python_ascending=True,
pytorch_ascending=True,
cuda_ascending=True,
)
# Create OS section header
os_lines = [f"### {os_name}", ""]
# Group by Flash-Attention version within each OS
fa_versions = []
for fa_version in os_df["Flash-Attention"].unique():
fa_df = os_df[os_df["Flash-Attention"] == fa_version].copy()
# Sort by Python > PyTorch > CUDA within each Flash-Attention version
fa_df = sort_packages(
fa_df,
python_ascending=True,
pytorch_ascending=True,
cuda_ascending=True,
)
# Create collapsible table for this Flash-Attention version
table_lines = [
"| Python | PyTorch | CUDA | package |",
"| ------ | ------- | ---- | ------- |",
]
for _, row in fa_df.iterrows():
packages = row["package"]
# Handle case where packages is a list
if isinstance(packages, list):
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
# Create numbered download links
package_links = []
for i, pkg in enumerate(packages, 1):
if pd.notna(pkg) and pkg:
package_links.append(f"[Download{i}]({pkg})")
package_cell = ", ".join(package_links)
else:
package_cell = "-"
else:
# Handle single package (backward compatibility)
package_cell = (
f"[Download]({packages})"
if pd.notna(packages) and packages
else "-"
)
line = f"| {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {package_cell} |"
table_lines.append(line)
# Create collapsible section for this Flash-Attention version
fa_section = [
f"#### Flash-Attention {fa_version}",
"",
"<details>",
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
"",
"\n".join(table_lines),
"",
"</details>",
"",
]
fa_versions.extend(fa_section)
os_lines.extend(fa_versions)
all_sections.extend(os_lines)
return "\n".join(all_sections)
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
"""Update the Packages section in README.md with new content."""
with readme_path.open("r", encoding="utf-8") as f:
content = f.read()
# Find the Packages section
packages_start = content.find("## Packages")
if packages_start == -1:
raise ValueError("Packages section not found in README.md")
# Find the end of Packages section (next ## section or History section)
packages_end = content.find("## History", packages_start)
if packages_end == -1:
# If no History section found, look for any other ## section
remaining_content = content[packages_start + len("## Packages") :]
next_section = remaining_content.find("\n## ")
if next_section != -1:
packages_end = packages_start + len("## Packages") + next_section
else:
packages_end = len(content)
# Replace the Packages section
new_content = (
content[:packages_start]
+ "## Packages\n\n"
+ packages_markdown
+ "\n\n"
+ content[packages_end:]
)
# Write back to file
with readme_path.open("w", encoding="utf-8") as f:
f.write(new_content)
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate a one-row-per-package Markdown table from assets.json file"
)
parser.add_argument(
"--assets",
type=str,
default="assets.json",
help="Path to assets.json file (default: assets.json)",
)
parser.add_argument(
"--update",
action="store_true",
help="Update the Packages section in README.md instead of printing to stdout",
)
args = parser.parse_args()
assets_path = Path(args.assets)
if not assets_path.exists():
print(f"Error: {assets_path} not found", file=sys.stderr)
sys.exit(1)
readme_path = Path("README.md")
# Extract packages from assets.json
assets_packages = extract_packages_from_assets_json(assets_path)
# Extract packages from existing README.md
readme_packages = extract_packages_from_readme(readme_path)
# Combine both lists
all_packages = assets_packages + readme_packages
if not all_packages:
print(f"No packages found in {assets_path} or README.md", file=sys.stderr)
return
# Convert to DataFrame and process
df = pd.DataFrame(all_packages)
# Normalize CUDA versions (remove patch version)
df["CUDA"] = df["CUDA"].apply(normalize_semantic_version)
# Normalize PyTorch versions (remove patch version)
df["PyTorch"] = df["PyTorch"].apply(normalize_semantic_version)
# Normalize Python versions (remove patch version)
df["Python"] = df["Python"].apply(normalize_semantic_version)
df_sorted = sort_packages(df)
df_merged = merge_duplicate_rows(df_sorted)
markdown = generate_markdown_table_by_os(df_merged)
if args.update:
# Update the README.md file
update_readme_packages_section(readme_path, markdown)
print(f"Updated Packages section in {readme_path}")
else:
# Print to stdout (original behavior)
print(markdown)
if __name__ == "__main__":
main()