mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:27:54 -04:00
chore: update scripts
This commit is contained in:
@@ -15,10 +15,10 @@ The built packages are available on the [release page](https://github.com/mjun08
|
||||
## Table of Contents
|
||||
|
||||
- [Install](#install)
|
||||
- [Self-build runner](#self-build)
|
||||
- [Packages](#packages)
|
||||
- [Linux x86_64](#linux-x86_64)
|
||||
- [Windows x86_64](#windows-x86_64)
|
||||
- [Self-build runner](#self-build)
|
||||
- [History](#history)
|
||||
- [Original Repository](#original-repository)
|
||||
|
||||
@@ -46,61 +46,6 @@ wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/downlo
|
||||
pip install ./flash_attn-2.6.3+cu124torch2.5-cp312-cp312-linux_x86_64.whl
|
||||
```
|
||||
|
||||
## Self build
|
||||
|
||||
If you cannot find the version you are looking for, you can fork this repository and create a wheel on GitHub Actions.
|
||||
|
||||
1. Fork this repository
|
||||
2. Edit workflow file [`.github/workflows/build.yml`](https://github.com/mjun0812/flash-attention-prebuild-wheels/blob/main/.github/workflows/build.yml) to set the version you want to build.
|
||||
3. Add tag `v*.*.*` to trigger the build workflow.
|
||||
|
||||
Please note that depending on the combination of versions, it may not be possible to build.
|
||||
|
||||
### Self-Hosted Runner Build
|
||||
|
||||
In some version combinations, you cannot build wheels on GitHub-hosted runners due to job time limitations.
|
||||
To build the wheels for these versions, you can use self-hosted runners.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/mjun0812/flash-attention-prebuild-wheels.git
|
||||
cd self-hosted-runner
|
||||
cp env.template env
|
||||
```
|
||||
|
||||
Edit `env` file to set the environment variables.
|
||||
|
||||
```bash
|
||||
# Edit env
|
||||
PERSONAL_ACCESS_TOKEN=[Github Personal Access Token]
|
||||
```
|
||||
|
||||
Edit compose.yml file if you use repository folked from this repository.
|
||||
|
||||
```yaml
|
||||
services:
|
||||
runner:
|
||||
privileged: true
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
REPOSITORY_URL: [Target Repository URL]
|
||||
PERSONAL_ACCESS_TOKEN: $PERSONAL_ACCESS_TOKEN
|
||||
GH_RUNNER_VERSION: 2.324.0
|
||||
RUNNER_NAME: self-hosted-runner
|
||||
RUNNER_GROUP: default
|
||||
RUNNER_LABELS: self-hosted
|
||||
TARGET_ARCH: x64
|
||||
```
|
||||
|
||||
Then, build and run the docker container.
|
||||
|
||||
```bash
|
||||
# Build and run
|
||||
docker compose build
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
## Packages
|
||||
|
||||
### Linux x86_64
|
||||
@@ -359,6 +304,7 @@ docker compose up -d
|
||||
|
||||
| Python | PyTorch | CUDA | package |
|
||||
| ------ | ------- | ---- | ------- |
|
||||
| 3.12 | 2.9 | 13.0 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.18) |
|
||||
| 3.12 | 2.7.0 | 12.6.3 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8) |
|
||||
| 3.12 | 2.7.0 | 12.4.1 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8) |
|
||||
| 3.12 | 2.7.0 | 11.8.0 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8) |
|
||||
@@ -399,6 +345,7 @@ docker compose up -d
|
||||
| 3.11 | 2.1.2 | 12.4.1 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.5) |
|
||||
| 3.11 | 2.0.1 | 12.6.3 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.5) |
|
||||
| 3.11 | 2.0.1 | 12.4.1 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.5) |
|
||||
| 3.10 | 2.9 | 13.0 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.18) |
|
||||
| 3.10 | 2.7.0 | 12.6.3 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8) |
|
||||
| 3.10 | 2.7.0 | 12.4.1 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8) |
|
||||
| 3.10 | 2.7.0 | 11.8.0 | [Release1](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.0.8) |
|
||||
@@ -1118,7 +1065,6 @@ docker compose up -d
|
||||
| --- | --- | --- | --- |
|
||||
| 2.6.3, 2.8.3 | 3.10, 3.11, 3.12, 3.13 | 2.9 | 13.0 |
|
||||
|
||||
|
||||
### v0.4.17
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.17)
|
||||
@@ -1129,7 +1075,6 @@ docker compose up -d
|
||||
| --- | --- | --- | --- |
|
||||
| 2.6.3, 2.8.3 | 3.10, 3.11, 3.12, 3.13 | 2.9 | 12.6, 12.8 |
|
||||
|
||||
|
||||
### v0.4.16
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.16)
|
||||
@@ -1140,7 +1085,6 @@ docker compose up -d
|
||||
| --- | --- | --- | --- |
|
||||
| 2.6.3, 2.8.3 | 3.9 | 2.5, 2.6, 2.7, 2.8 | 12.4, 12.6 |
|
||||
|
||||
|
||||
### v0.4.15
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.15)
|
||||
@@ -1157,7 +1101,6 @@ docker compose up -d
|
||||
| --- | --- | --- | --- |
|
||||
| 2.8.3 | 3.11, 3.12, 3.13 | 2.9 | 12.6 |
|
||||
|
||||
|
||||
### v0.4.12
|
||||
|
||||
[Release](https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/tag/v0.4.12)
|
||||
@@ -1390,6 +1333,61 @@ Skip for experimental reasons.
|
||||
| -------------------------- | ---------- | ---------------------------------------- | ---------------------- |
|
||||
| 2.4.3, 2.5.6, 2.5.9, 2.6.3 | 3.11, 3.12 | 2.0.1, 2.1.2, 2.2.2, 2.3.1, 2.4.1, 2.5.0 | 11.8.0, 12.1.1, 12.4.1 |
|
||||
|
||||
## Self build
|
||||
|
||||
If you cannot find the version you are looking for, you can fork this repository and create a wheel on GitHub Actions.
|
||||
|
||||
1. Fork this repository
|
||||
2. Edit workflow file [`.github/workflows/build.yml`](https://github.com/mjun0812/flash-attention-prebuild-wheels/blob/main/.github/workflows/build.yml) to set the version you want to build.
|
||||
3. Add tag `v*.*.*` to trigger the build workflow.
|
||||
|
||||
Please note that depending on the combination of versions, it may not be possible to build.
|
||||
|
||||
### Self-Hosted Runner Build
|
||||
|
||||
In some version combinations, you cannot build wheels on GitHub-hosted runners due to job time limitations.
|
||||
To build the wheels for these versions, you can use self-hosted runners.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/mjun0812/flash-attention-prebuild-wheels.git
|
||||
cd self-hosted-runner
|
||||
cp env.template env
|
||||
```
|
||||
|
||||
Edit `env` file to set the environment variables.
|
||||
|
||||
```bash
|
||||
# Edit env
|
||||
PERSONAL_ACCESS_TOKEN=[Github Personal Access Token]
|
||||
```
|
||||
|
||||
Edit compose.yml file if you use repository folked from this repository.
|
||||
|
||||
```yaml
|
||||
services:
|
||||
runner:
|
||||
privileged: true
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
REPOSITORY_URL: [Target Repository URL]
|
||||
PERSONAL_ACCESS_TOKEN: $PERSONAL_ACCESS_TOKEN
|
||||
GH_RUNNER_VERSION: 2.324.0
|
||||
RUNNER_NAME: self-hosted-runner
|
||||
RUNNER_GROUP: default
|
||||
RUNNER_LABELS: self-hosted
|
||||
TARGET_ARCH: x64
|
||||
```
|
||||
|
||||
Then, build and run the docker container.
|
||||
|
||||
```bash
|
||||
# Build and run
|
||||
docker compose build
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
## Original Repository
|
||||
|
||||
[repo](https://github.com/Dao-AILab/flash-attention)
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import re
|
||||
|
||||
|
||||
def parse_wheel_filename(filename: str) -> dict | None:
|
||||
"""
|
||||
Wheel filename から情報を抽出
|
||||
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
|
||||
"""
|
||||
# Flash Attention wheelのパターンに合わせて正規表現を調整
|
||||
# PyTorchバージョンはマイナーバージョン1桁の形式も対応 (例: torch2.9)
|
||||
# post1 のようなバージョンサフィックスにも対応 (例: 2.7.4.post1)
|
||||
pattern = (
|
||||
r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
|
||||
)
|
||||
match = re.match(pattern, filename)
|
||||
|
||||
if match:
|
||||
flash_version = match.group(1)
|
||||
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 130 -> 13.0
|
||||
torch_version = match.group(3)
|
||||
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 310 -> 3.10
|
||||
platform = match.group(5) # linux_x86_64, win32など
|
||||
|
||||
return {
|
||||
"flash_version": flash_version,
|
||||
"cuda_version": cuda_version,
|
||||
"torch_version": torch_version,
|
||||
"python_version": python_version,
|
||||
"platform": platform,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def normalize_platform_name(raw: str) -> str:
|
||||
"""Platform name normalization
|
||||
Examples:
|
||||
linux -> Linux
|
||||
linux_x86_64 -> Linux x86_64
|
||||
win32 -> Windows
|
||||
amd64 -> x86_64
|
||||
"""
|
||||
name = raw[:1].upper() + raw[1:] # linux -> Linux
|
||||
name = name.replace("_", " ", 1) # linux_x86_64 -> Linux x86_64
|
||||
if "Win" in name:
|
||||
name = name.replace("Win", "Windows")
|
||||
if "amd64" in name:
|
||||
name = name.replace("amd64", "x86_64")
|
||||
return name
|
||||
+2
-37
@@ -1,37 +1,8 @@
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_wheel_filename(filename):
|
||||
"""
|
||||
Wheel filename から情報を抽出
|
||||
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
||||
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
||||
"""
|
||||
# Flash Attention wheelのパターンに合わせて正規表現を調整
|
||||
# PyTorchバージョンはパッチバージョンなしの形式 (例: torch2.5)
|
||||
pattern = (
|
||||
r"flash_attn-(\d+\.\d+\.\d+)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
|
||||
)
|
||||
match = re.match(pattern, filename)
|
||||
|
||||
if match:
|
||||
flash_version = match.group(1)
|
||||
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 124 -> 12.4
|
||||
torch_version = match.group(3)
|
||||
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 311 -> 3.11
|
||||
platform = match.group(5) # linux, win32など
|
||||
|
||||
return {
|
||||
"flash_version": flash_version,
|
||||
"cuda_version": cuda_version,
|
||||
"torch_version": torch_version,
|
||||
"python_version": python_version,
|
||||
"platform": platform,
|
||||
}
|
||||
return None
|
||||
from common import normalize_platform_name, parse_wheel_filename
|
||||
|
||||
|
||||
def generate_release_notes_from_assets(assets_info: dict):
|
||||
@@ -74,13 +45,7 @@ def generate_release_notes_from_assets(assets_info: dict):
|
||||
if any(len(data[key]) == 0 for key in data):
|
||||
continue
|
||||
|
||||
platform_name = platform_name[:1].upper() + platform_name[1:]
|
||||
platform_name = platform_name.replace("_", " ", 1)
|
||||
|
||||
if "Win" in platform_name:
|
||||
platform_name = platform_name.replace("Win", "Windows")
|
||||
if "amd64" in platform_name:
|
||||
platform_name = platform_name.replace("amd64", "x86_64")
|
||||
platform_name = normalize_platform_name(platform_name)
|
||||
|
||||
notes.append(f"## {platform_name}")
|
||||
notes.append("")
|
||||
|
||||
@@ -1,52 +1,17 @@
|
||||
"""Update the History section in README.md from release notes or assets."""
|
||||
"""Update the History section in README.md from assets."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
WHEEL_PATTERN = re.compile(
|
||||
r"flash_attn-(\d+\.\d+\.\d+)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
|
||||
)
|
||||
from common import normalize_platform_name, parse_wheel_filename
|
||||
|
||||
|
||||
def parse_wheel_filename(filename: str) -> Dict[str, str] | None:
|
||||
match = WHEEL_PATTERN.match(filename)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
flash_version = match.group(1)
|
||||
cuda_digits = match.group(2)
|
||||
torch_version = match.group(3)
|
||||
python_digits = match.group(4)
|
||||
platform = match.group(5)
|
||||
|
||||
cuda_version = f"{cuda_digits[:2]}.{cuda_digits[2:]}"
|
||||
python_version = f"{python_digits[:1]}.{python_digits[1:]}"
|
||||
|
||||
return {
|
||||
"flash_version": flash_version,
|
||||
"cuda_version": cuda_version,
|
||||
"torch_version": torch_version,
|
||||
"python_version": python_version,
|
||||
"platform": platform,
|
||||
}
|
||||
|
||||
|
||||
def normalize_platform_name(raw: str) -> str:
|
||||
name = raw[:1].upper() + raw[1:]
|
||||
name = name.replace("_", " ", 1)
|
||||
if "Win" in name:
|
||||
name = name.replace("Win", "Windows")
|
||||
if "amd64" in name:
|
||||
name = name.replace("amd64", "x86_64")
|
||||
return name
|
||||
|
||||
|
||||
def collect_versions(assets: Iterable[Dict[str, str]]) -> Dict[str, Dict[str, set[str]]]:
|
||||
def collect_versions(
|
||||
assets: Iterable[Dict[str, str]],
|
||||
) -> Dict[str, Dict[str, set[str]]]:
|
||||
aggregated: Dict[str, Dict[str, set[str]]] = {}
|
||||
for asset in assets:
|
||||
info = parse_wheel_filename(asset.get("name", ""))
|
||||
@@ -109,11 +74,6 @@ def render_body_from_aggregated(aggregated: Dict[str, Dict[str, set[str]]]) -> s
|
||||
return "\n".join(body_lines).strip()
|
||||
|
||||
|
||||
def convert_release_notes_to_body(notes_text: str) -> str:
|
||||
converted = re.sub(r"^## ", "#### ", notes_text, flags=re.MULTILINE)
|
||||
return converted.strip()
|
||||
|
||||
|
||||
def build_history_section(tag: str, repo: str, body: str) -> str:
|
||||
release_url = f"https://github.com/{repo}/releases/tag/{tag}"
|
||||
lines = [f"### {tag}", "", f"[Release]({release_url})", "", body.strip()]
|
||||
@@ -121,7 +81,9 @@ def build_history_section(tag: str, repo: str, body: str) -> str:
|
||||
|
||||
|
||||
def remove_existing_section(content: str, tag: str) -> str:
|
||||
pattern = re.compile(rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL)
|
||||
pattern = re.compile(
|
||||
rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL
|
||||
)
|
||||
return re.sub(pattern, "", content)
|
||||
|
||||
|
||||
@@ -137,47 +99,21 @@ def insert_history_section(content: str, section: str) -> str:
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Update README.md History section")
|
||||
parser.add_argument("--assets", type=Path, help="JSON file from gh release view")
|
||||
parser.add_argument(
|
||||
"--notes",
|
||||
help="Release notes markdown file path or '-' to read from stdin",
|
||||
"--assets", type=Path, required=True, help="JSON file from gh release view"
|
||||
)
|
||||
parser.add_argument("--tag", required=True, help="Release tag name")
|
||||
parser.add_argument("--repo", required=True, help="Repository in owner/name format")
|
||||
parser.add_argument(
|
||||
"--readme",
|
||||
type=Path,
|
||||
default=Path("README.md"),
|
||||
help="Path to README.md",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
history_body: str
|
||||
|
||||
if args.notes:
|
||||
if args.notes == "-":
|
||||
notes_text = sys.stdin.read()
|
||||
else:
|
||||
notes_path = Path(args.notes)
|
||||
if not notes_path.exists():
|
||||
raise FileNotFoundError(f"Notes file not found: {notes_path}")
|
||||
notes_text = notes_path.read_text(encoding="utf-8")
|
||||
|
||||
history_body = convert_release_notes_to_body(notes_text)
|
||||
else:
|
||||
if not args.assets:
|
||||
raise ValueError("Either --notes or --assets must be provided")
|
||||
if not args.assets.exists():
|
||||
raise FileNotFoundError(f"Assets file not found: {args.assets}")
|
||||
|
||||
data = json.loads(args.assets.read_text(encoding="utf-8"))
|
||||
assets = data.get("assets", [])
|
||||
aggregated = collect_versions(assets)
|
||||
history_body = render_body_from_aggregated(aggregated)
|
||||
data = json.loads(args.assets.read_text(encoding="utf-8"))
|
||||
assets = data.get("assets", [])
|
||||
aggregated = collect_versions(assets)
|
||||
history_body = render_body_from_aggregated(aggregated)
|
||||
|
||||
section = build_history_section(args.tag, args.repo, history_body)
|
||||
|
||||
content = args.readme.read_text(encoding="utf-8")
|
||||
content = Path("README.md").read_text(encoding="utf-8")
|
||||
stripped = remove_existing_section(content, args.tag)
|
||||
updated = insert_history_section(stripped, section)
|
||||
|
||||
@@ -185,13 +121,9 @@ def main() -> None:
|
||||
print("No changes in README.md")
|
||||
return
|
||||
|
||||
args.readme.write_text(updated, encoding="utf-8")
|
||||
print(f"Inserted history for {args.tag} into {args.readme}")
|
||||
Path("README.md").write_text(updated, encoding="utf-8")
|
||||
print(f"Inserted history for {args.tag} into README.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
main()
|
||||
@@ -1,15 +1,17 @@
|
||||
"""
|
||||
python generate_packages_table.py --update-readme
|
||||
python insert_packages_to_readme.py --assets assets.json --update
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from common import normalize_platform_name, parse_wheel_filename
|
||||
|
||||
|
||||
def parse_numeric_version(text: str) -> tuple:
|
||||
"""Extract numeric version tuple for sorting."""
|
||||
@@ -17,86 +19,180 @@ def parse_numeric_version(text: str) -> tuple:
|
||||
return tuple(int(n) for n in nums)
|
||||
|
||||
|
||||
def extract_packages_from_history(text: str) -> list[dict]:
|
||||
"""Extract package information from History section."""
|
||||
lines = text.splitlines()
|
||||
def extract_release_url_from_download_url(download_url: str) -> str | None:
|
||||
"""Extract release tag from download URL and construct release page URL."""
|
||||
# Pattern: /releases/download/{tag}/
|
||||
match = re.search(r"/releases/download/([^/]+)/", download_url)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
tag = match.group(1)
|
||||
# Construct release page URL
|
||||
# Extract repo path from download URL
|
||||
repo_match = re.search(r"(https://github\.com/[^/]+/[^/]+)", download_url)
|
||||
if not repo_match:
|
||||
return None
|
||||
|
||||
repo_path = repo_match.group(1)
|
||||
return f"{repo_path}/releases/tag/{tag}"
|
||||
|
||||
|
||||
def extract_packages_from_readme(readme_path: Path) -> list[dict]:
|
||||
"""Extract package information from existing Packages section in README.md."""
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find Packages section
|
||||
packages_start = content.find("## Packages")
|
||||
if packages_start == -1:
|
||||
return []
|
||||
|
||||
# Find the end of Packages section
|
||||
packages_end = content.find("## History", packages_start)
|
||||
if packages_end == -1:
|
||||
remaining_content = content[packages_start + len("## Packages") :]
|
||||
next_section = remaining_content.find("\n## ")
|
||||
if next_section != -1:
|
||||
packages_end = packages_start + len("## Packages") + next_section
|
||||
else:
|
||||
packages_end = len(content)
|
||||
|
||||
packages_section = content[packages_start:packages_end]
|
||||
lines = packages_section.splitlines()
|
||||
|
||||
packages = []
|
||||
current_os = None
|
||||
current_fa_version = None
|
||||
in_table = False
|
||||
|
||||
# Find start of History section
|
||||
in_history = False
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("## ") and "History" in line:
|
||||
in_history = True
|
||||
lines = lines[i:]
|
||||
break
|
||||
line_stripped = line.strip()
|
||||
|
||||
if not in_history:
|
||||
# Detect OS heading (### Linux x86_64)
|
||||
if line_stripped.startswith("### ") and not line_stripped.startswith("#### "):
|
||||
current_os = line_stripped[4:].strip()
|
||||
current_fa_version = None
|
||||
in_table = False
|
||||
continue
|
||||
|
||||
# Detect Flash-Attention version heading (#### Flash-Attention 2.8.3)
|
||||
if line_stripped.startswith("#### Flash-Attention "):
|
||||
# Extract version after "#### Flash-Attention "
|
||||
current_fa_version = line_stripped.replace(
|
||||
"#### Flash-Attention ", ""
|
||||
).strip()
|
||||
in_table = False
|
||||
continue
|
||||
|
||||
# Detect table start
|
||||
if "| Python | PyTorch | CUDA | package |" in line_stripped:
|
||||
in_table = True
|
||||
continue
|
||||
|
||||
# Skip table separator line
|
||||
if in_table and "| ------ |" in line_stripped:
|
||||
continue
|
||||
|
||||
# Process table rows
|
||||
if (
|
||||
in_table
|
||||
and line_stripped.startswith("|")
|
||||
and current_os
|
||||
and current_fa_version
|
||||
):
|
||||
# Parse table row: | Python | PyTorch | CUDA | package |
|
||||
cells = [
|
||||
c.strip() for c in line_stripped.split("|")[1:-1]
|
||||
] # Remove empty first/last cells
|
||||
cells = [c for c in cells if c] # Remove empty cells
|
||||
if len(cells) >= 4:
|
||||
python_version = cells[0]
|
||||
torch_version = cells[1]
|
||||
cuda_version = cells[2]
|
||||
package_cell = cells[3]
|
||||
|
||||
# Extract all release URLs from package cell
|
||||
# Pattern: [Release1](url1), [Release2](url2), ...
|
||||
release_urls = re.findall(r"\[Release\d+\]\(([^)]+)\)", package_cell)
|
||||
|
||||
if release_urls:
|
||||
# Create a package entry for each release URL
|
||||
for release_url in release_urls:
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": current_fa_version,
|
||||
"Python": python_version,
|
||||
"PyTorch": torch_version,
|
||||
"CUDA": cuda_version,
|
||||
"OS": current_os,
|
||||
"package": release_url,
|
||||
}
|
||||
)
|
||||
elif package_cell != "-":
|
||||
# Handle single release or other formats
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": current_fa_version,
|
||||
"Python": python_version,
|
||||
"PyTorch": torch_version,
|
||||
"CUDA": cuda_version,
|
||||
"OS": current_os,
|
||||
"package": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Detect end of table (empty line or closing </details>)
|
||||
if in_table and (not line_stripped or line_stripped == "</details>"):
|
||||
in_table = False
|
||||
|
||||
return packages
|
||||
|
||||
|
||||
def extract_packages_from_assets_json(assets_path: Path) -> list[dict]:
|
||||
"""Extract package information from assets.json file."""
|
||||
with assets_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if "assets" not in data:
|
||||
return []
|
||||
|
||||
packages = []
|
||||
current_release_url = None
|
||||
current_os = "Linux x86_64" # default
|
||||
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
for asset in data["assets"]:
|
||||
name = asset.get("name", "")
|
||||
url = asset.get("url", "")
|
||||
|
||||
# Reset on new version
|
||||
if line.startswith("### "):
|
||||
current_release_url = None
|
||||
current_os = "Linux x86_64"
|
||||
|
||||
# Capture Release link
|
||||
elif "[Release](" in line:
|
||||
match = re.search(r"\[Release\]\(([^)]+)\)", line)
|
||||
if match:
|
||||
current_release_url = match.group(1)
|
||||
|
||||
# Capture OS heading
|
||||
elif line.startswith("#### "):
|
||||
current_os = line[5:].strip() or "Linux x86_64"
|
||||
|
||||
# Process table
|
||||
elif line.startswith("| Flash-Attention") or line.startswith(
|
||||
"|Flash-Attention"
|
||||
):
|
||||
# Skip header and separator
|
||||
i += 2
|
||||
|
||||
# Process table rows
|
||||
while i < len(lines):
|
||||
row_line = lines[i].strip()
|
||||
if not row_line.startswith("|") or not row_line:
|
||||
break
|
||||
|
||||
# Parse table row
|
||||
cells = [c.strip() for c in row_line.split("|")]
|
||||
cells = [c for c in cells if c] # Remove empty cells
|
||||
|
||||
if len(cells) >= 4:
|
||||
fa_versions = [v.strip() for v in cells[0].split(",") if v.strip()]
|
||||
py_versions = [v.strip() for v in cells[1].split(",") if v.strip()]
|
||||
pt_versions = [v.strip() for v in cells[2].split(",") if v.strip()]
|
||||
cu_versions = [v.strip() for v in cells[3].split(",") if v.strip()]
|
||||
|
||||
# Generate all combinations
|
||||
for fa, py, pt, cu in itertools.product(
|
||||
fa_versions, py_versions, pt_versions, cu_versions
|
||||
):
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": fa,
|
||||
"Python": py,
|
||||
"PyTorch": pt,
|
||||
"CUDA": cu,
|
||||
"OS": current_os,
|
||||
"package": current_release_url,
|
||||
}
|
||||
)
|
||||
|
||||
i += 1
|
||||
# Only process .whl files
|
||||
if not name.endswith(".whl"):
|
||||
continue
|
||||
|
||||
i += 1
|
||||
# Parse wheel filename
|
||||
info = parse_wheel_filename(name)
|
||||
if not info:
|
||||
continue
|
||||
|
||||
# Extract release URL from download URL
|
||||
release_url = extract_release_url_from_download_url(url)
|
||||
|
||||
# Normalize platform name
|
||||
os_name = normalize_platform_name(info["platform"])
|
||||
|
||||
# Format versions for display
|
||||
flash_version = info["flash_version"]
|
||||
python_version = info["python_version"]
|
||||
torch_version = info["torch_version"] # Already in format like "2.9"
|
||||
cuda_version = info["cuda_version"]
|
||||
|
||||
packages.append(
|
||||
{
|
||||
"Flash-Attention": flash_version,
|
||||
"Python": python_version,
|
||||
"PyTorch": torch_version,
|
||||
"CUDA": cuda_version,
|
||||
"OS": os_name,
|
||||
"package": release_url,
|
||||
}
|
||||
)
|
||||
|
||||
return packages
|
||||
|
||||
@@ -153,14 +249,28 @@ def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
|
||||
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
|
||||
|
||||
def combine_packages(group):
|
||||
# Get unique non-null packages
|
||||
packages = [pkg for pkg in group["package"].dropna().unique() if pkg]
|
||||
# Get unique non-null packages (handle both list and scalar values)
|
||||
all_packages = []
|
||||
for pkg in group["package"]:
|
||||
if pd.notna(pkg):
|
||||
if isinstance(pkg, list):
|
||||
all_packages.extend(pkg)
|
||||
else:
|
||||
all_packages.append(pkg)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_packages = []
|
||||
for pkg in all_packages:
|
||||
if pkg and pkg not in seen:
|
||||
seen.add(pkg)
|
||||
unique_packages.append(pkg)
|
||||
|
||||
# Take the first row as base
|
||||
result = group.iloc[0].copy()
|
||||
|
||||
# Combine packages into a list
|
||||
result["package"] = packages if packages else [None]
|
||||
result["package"] = unique_packages if unique_packages else [None]
|
||||
|
||||
return result
|
||||
|
||||
@@ -351,47 +461,54 @@ def update_readme_packages_section(readme_path: Path, packages_markdown: str) ->
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a one-row-per-package Markdown table from the History section of a README.md"
|
||||
description="Generate a one-row-per-package Markdown table from assets.json file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"readme",
|
||||
nargs="?",
|
||||
type=Path,
|
||||
default=Path("README.md"),
|
||||
help="Path to README.md (default: README.md)",
|
||||
"--assets",
|
||||
type=str,
|
||||
default="assets.json",
|
||||
help="Path to assets.json file (default: assets.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update-readme",
|
||||
"--update",
|
||||
action="store_true",
|
||||
help="Update the Packages section in README.md instead of printing to stdout",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
with args.readme.open("r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
assets_path = Path(args.assets)
|
||||
if not assets_path.exists():
|
||||
print(f"Error: {assets_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
packages = extract_packages_from_history(text)
|
||||
readme_path = Path("README.md")
|
||||
|
||||
if not packages:
|
||||
print("No packages found in History section", file=sys.stderr)
|
||||
return
|
||||
# Extract packages from assets.json
|
||||
assets_packages = extract_packages_from_assets_json(assets_path)
|
||||
|
||||
df = pd.DataFrame(packages)
|
||||
df_sorted = sort_packages(df)
|
||||
df_merged = merge_duplicate_rows(df_sorted)
|
||||
markdown = generate_markdown_table_by_os(df_merged)
|
||||
# Extract packages from existing README.md
|
||||
readme_packages = extract_packages_from_readme(readme_path)
|
||||
|
||||
if args.update_readme:
|
||||
# Update the README.md file
|
||||
update_readme_packages_section(args.readme, markdown)
|
||||
print(f"Updated Packages section in {args.readme}")
|
||||
else:
|
||||
# Print to stdout (original behavior)
|
||||
print(markdown)
|
||||
# Combine both lists
|
||||
all_packages = assets_packages + readme_packages
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
if not all_packages:
|
||||
print(f"No packages found in {assets_path} or README.md", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Convert to DataFrame and process
|
||||
df = pd.DataFrame(all_packages)
|
||||
df_sorted = sort_packages(df)
|
||||
df_merged = merge_duplicate_rows(df_sorted)
|
||||
markdown = generate_markdown_table_by_os(df_merged)
|
||||
|
||||
if args.update:
|
||||
# Update the README.md file
|
||||
update_readme_packages_section(readme_path, markdown)
|
||||
print(f"Updated Packages section in {readme_path}")
|
||||
else:
|
||||
# Print to stdout (original behavior)
|
||||
print(markdown)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Reference in New Issue
Block a user