refactor: improve code structure and documentation with shared utilities

- Move reusable functions (parse_numeric_version, normalize_semantic_version, get_tag_from_url, get_os_emoji, collect_versions_from_assets, format_versions) to common.py for code reuse across scripts
- Add comprehensive docstrings to all modules and functions for better maintainability
- Enhance parse_wheel_filename to support multiple manylinux tags (e.g., manylinux_2_24_x86_64.manylinux_2_28_x86_64)
- Improve normalize_platform_name to handle manylinux formats with multiple tags
- Refactor create_packages.py to use shared utilities from common.py
- Refactor create_release_history.py with improved function naming and error handling
- Refactor create_release_note.py with better error messages and usage documentation
- Decode URL encoding for improved readability in package documentation
- Update docs/packages.md with formatted output from refactored scripts
This commit is contained in:
Junya Morioka
2025-12-21 03:20:11 +09:00
parent 879974db28
commit 86938251af
5 changed files with 1383 additions and 1141 deletions
+171 -1
View File
@@ -1,4 +1,165 @@
"""Common utility functions for processing Flash-Attention wheel packages.
This module provides shared functionality for parsing wheel filenames, extracting version
information, and processing GitHub release assets. It is used by scripts that generate
documentation and release notes for Flash-Attention prebuilt wheels.
Functions:
- load_assets_json: Load assets from GitHub release JSON file
- parse_wheel_filename: Extract version info from wheel filename
- normalize_platform_name: Standardize platform names for display
- parse_numeric_version: Convert version strings to tuples for sorting
- normalize_semantic_version: Remove patch version from semantic versions
- get_tag_from_url: Extract release tag from GitHub download URL
- get_os_emoji: Get emoji representation for OS names
- collect_versions_from_assets: Aggregate version info by platform
- format_versions: Format version sets as comma-separated strings
"""
import json
import re
from pathlib import Path
from typing import Iterable
import pandas as pd
def load_assets_json(path: Path) -> list[dict]:
"""Load assets from assets.json file.
Args:
path: Path to assets.json file
Returns:
List of asset dictionaries
"""
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
return data.get("assets", [])
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting.
Examples:
"2.9.1" -> (2, 9, 1)
"3.10" -> (3, 10)
"""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def normalize_semantic_version(version: str) -> str:
"""Normalize semantic version by removing patch version.
Examples:
2.9.1 -> 2.9
2.8.1 -> 2.8
2.6.3 -> 2.6
2.9 -> 2.9 (no change if no patch version)
"""
if pd.isna(version) or not version:
return version
# Split by '.' and take only major.minor
parts = str(version).split(".")
if len(parts) >= 2:
return ".".join(parts[:2])
return version
def get_tag_from_url(url: str) -> str:
"""Extract tag from GitHub release URL.
Examples:
"https://github.com/user/repo/releases/download/v1.0.0/file.whl" -> "v1.0.0"
"""
if pd.isna(url) or not url:
return ""
match = re.search(r"/releases/download/([^/]+)/", str(url))
return match.group(1) if match else ""
def get_os_emoji(os_name: str) -> str:
"""Get emoji for OS name.
Args:
os_name: OS name (e.g., "Linux x86_64", "Windows")
Returns:
Emoji string with trailing space, or empty string
"""
os_lower = os_name.lower()
if "linux" in os_lower:
return "🐧 "
elif "windows" in os_lower:
return "🪟 "
else:
return ""
def collect_versions_from_assets(
assets: Iterable[dict],
) -> dict[str, dict[str, set[str]]]:
"""Collect version information from assets, grouped by platform.
Args:
assets: Iterable of asset dictionaries with "name" key
Returns:
Dictionary mapping platform name to version sets:
{
"Linux x86_64": {
"flash_versions": {"2.6.3", "2.7.4"},
"python_versions": {"3.10", "3.11"},
"torch_versions": {"2.5", "2.6"},
"cuda_versions": {"12.4", "13.0"}
},
...
}
"""
aggregated: dict[str, dict[str, set[str]]] = {}
for asset in assets:
name = asset.get("name", "")
if not name.endswith(".whl"):
continue
info = parse_wheel_filename(name)
if not info:
continue
platform = normalize_platform_name(info["platform"])
platform_data = aggregated.setdefault(
platform,
{
"flash_versions": set(),
"python_versions": set(),
"torch_versions": set(),
"cuda_versions": set(),
},
)
platform_data["flash_versions"].add(info["flash_version"])
platform_data["python_versions"].add(info["python_version"])
platform_data["torch_versions"].add(info["torch_version"])
platform_data["cuda_versions"].add(info["cuda_version"])
return aggregated
def format_versions(values: set[str]) -> str:
"""Format a set of version strings as comma-separated sorted string.
Args:
values: Set of version strings
Returns:
Comma-separated sorted string, or "-" if empty
"""
if not values:
return "-"
return ", ".join(sorted(values))
def parse_wheel_filename(filename: str) -> dict | None:
@@ -9,17 +170,20 @@ def parse_wheel_filename(filename: str) -> dict | None:
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
flash_attn-2.8.3+cu128torch2.9-cp313-cp313-manylinux_2_34_x86_64.whl
flash_attn-2.6.3+cu128torch2.9-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
---
Wheel filename から情報を抽出
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
flash_attn-2.6.3+cu128torch2.9-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
"""
# Flash Attention wheelのパターンに合わせて正規表現を調整
# PyTorchバージョンはマイナーバージョン1桁の形式も対応 (例: torch2.9)
# post1 のようなバージョンサフィックスにも対応 (例: 2.7.4.post1)
pattern = r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(\w+)\.whl"
# manylinux の複数タグにも対応 (例: manylinux_2_24_x86_64.manylinux_2_28_x86_64)
pattern = r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(.+?)\.whl"
match = re.match(pattern, filename)
if match:
@@ -46,9 +210,15 @@ def normalize_platform_name(raw: str) -> str:
linux_x86_64 -> Linux x86_64
manylinux_2_34_x86_64 -> Manylinux 2_34 x86_64
manylinux_2_17_aarch64 -> Manylinux 2_17 arm64
manylinux_2_24_x86_64.manylinux_2_28_x86_64 -> Manylinux 2_24 x86_64
win32 -> Windows
amd64 -> x86_64
"""
# Handle manylinux format with multiple tags: use only the first tag
# Example: manylinux_2_24_x86_64.manylinux_2_28_x86_64 -> manylinux_2_24_x86_64
if "." in raw and raw.startswith("manylinux"):
raw = raw.split(".")[0]
# Handle manylinux format: manylinux_X_Y_ARCH -> Manylinux X_Y ARCH
if raw.startswith("manylinux"):
# Extract parts from manylinux_X_Y_ARCH format
+51 -67
View File
@@ -1,51 +1,60 @@
"""
Create and update docs/packages.md from assets.json
"""Create and update docs/packages.md from assets.json.
This script generates a comprehensive package documentation page (docs/packages.md) from
GitHub release assets. It combines information from both assets.json and any existing
packages.md file, creating organized tables grouped by OS and Flash-Attention version.
The script:
- Parses wheel filenames to extract version information
- Merges data from assets.json and existing packages.md
- Generates collapsible tables organized by OS and Flash-Attention version
- Creates a table of contents for easy navigation
- Handles multiple download links per package
Usage:
python create_packages.py [--assets <assets.json>] [--output <packages.md>]
Arguments:
--assets: Path to assets.json file (default: assets.json)
Can be obtained via `gh release view --json assets`
--output: Output file path (default: docs/packages.md)
Example:
# Basic usage
python create_packages.py --assets assets.json --output docs/packages.md
# Using defaults
python create_packages.py
# Generate from GitHub release
gh release view v0.7.0 --json assets > assets.json
python create_packages.py
"""
import argparse
import json
import re
import sys
from pathlib import Path
from urllib.parse import unquote
import pandas as pd
from common import normalize_platform_name, parse_wheel_filename
from common import (
get_os_emoji,
get_tag_from_url,
load_assets_json,
normalize_platform_name,
normalize_semantic_version,
parse_numeric_version,
parse_wheel_filename,
)
ADD_NOTE = """> [!NOTE]
> Since v0.5.0, wheels are built with a local version label indicating the CUDA and PyTorch versions.
> Since v0.5.0, wheels are built with a local version label indicating the CUDA and PyTorch versions.
> Example: `pip list` -> `flash_attn==2.8.3 -> flash_attn==2.8.3+cu130torch2.9`
"""
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting."""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def normalize_semantic_version(version: str) -> str:
"""Normalize semantic version by removing patch version.
Examples:
2.9.1 -> 2.9
2.8.1 -> 2.8
2.6.3 -> 2.6
2.9 -> 2.9 (no change if no patch version)
"""
if pd.isna(version) or not version:
return version
# Split by '.' and take only major.minor
parts = str(version).split(".")
if len(parts) >= 2:
return ".".join(parts[:2])
return version
def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
"""Extract package information from existing docs/packages.md."""
if not packages_md_path.exists():
@@ -122,6 +131,8 @@ def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
if package_urls:
# Create a package entry for each URL
for package_url in package_urls:
# Decode URL to make it more readable
decoded_url = unquote(package_url)
packages.append(
{
"Flash-Attention": current_fa_version,
@@ -129,7 +140,7 @@ def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": current_os,
"package": package_url,
"package": decoded_url,
}
)
elif package_cell != "-":
@@ -154,15 +165,10 @@ def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
def extract_packages_from_assets_json(assets_path: Path) -> list[dict]:
"""Extract package information from assets.json file."""
with assets_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if "assets" not in data:
return []
assets = load_assets_json(assets_path)
packages = []
for asset in data["assets"]:
for asset in assets:
name = asset.get("name", "")
url = asset.get("url", "")
@@ -178,20 +184,17 @@ def extract_packages_from_assets_json(assets_path: Path) -> list[dict]:
# Normalize platform name
os_name = normalize_platform_name(info["platform"])
# Format versions for display
flash_version = info["flash_version"]
python_version = info["python_version"]
torch_version = info["torch_version"] # Already in format like "2.9"
cuda_version = info["cuda_version"]
# Decode URL to make it more readable
decoded_url = unquote(url)
packages.append(
{
"Flash-Attention": flash_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"Flash-Attention": info["flash_version"],
"Python": info["python_version"],
"PyTorch": info["torch_version"],
"CUDA": info["cuda_version"],
"OS": os_name,
"package": url, # Use download URL directly
"package": decoded_url,
}
)
@@ -319,25 +322,6 @@ def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
return merged_df
def get_tag_from_url(url: str) -> str:
"""Extract tag from GitHub release URL."""
if pd.isna(url) or not url:
return ""
match = re.search(r"/releases/download/([^/]+)/", str(url))
return match.group(1) if match else ""
def get_os_emoji(os_name: str) -> str:
"""Get emoji for OS name."""
os_lower = os_name.lower()
if "linux" in os_lower:
return "🐧 "
elif "windows" in os_lower:
return "🪟 "
else:
return ""
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
"""Generate markdown tables grouped by OS and Flash-Attention version."""
if df.empty:
+74 -44
View File
@@ -1,56 +1,57 @@
"""Update the History section in README.md from assets."""
"""Update the History section in README.md from assets.
This script updates the History section in README.md by inserting or updating a release entry.
It extracts version information from a GitHub release assets JSON file and formats it as a
markdown table, then inserts it into the README.md History section.
Usage:
python create_release_history.py --assets <assets.json> --tag <tag> --repo <owner/name> --output <README.md>
Arguments:
--assets: Path to JSON file containing GitHub release assets
(obtained via `gh release view --json assets`)
--tag: Release tag name (e.g., v0.7.0)
--repo: Repository in owner/name format (e.g., mjun0812/flash-attention-prebuild-wheels)
--output: Path to README.md file to update
Example:
gh release view v0.7.0 --json assets > /tmp/assets.json
python create_release_history.py \\
--assets /tmp/assets.json \\
--tag v0.7.0 \\
--repo mjun0812/flash-attention-prebuild-wheels \\
--output README.md
"""
import argparse
import json
import re
from pathlib import Path
from typing import Dict, Iterable
from common import normalize_platform_name, parse_wheel_filename
from common import (
collect_versions_from_assets,
format_versions,
load_assets_json,
)
def collect_versions(
assets: Iterable[Dict[str, str]],
) -> Dict[str, Dict[str, set[str]]]:
aggregated: Dict[str, Dict[str, set[str]]] = {}
for asset in assets:
info = parse_wheel_filename(asset.get("name", ""))
if not info:
continue
def render_body_from_versions(
versions_by_platform: dict[str, dict[str, set[str]]]
) -> str:
"""Render markdown body from aggregated version data.
platform = normalize_platform_name(info["platform"])
platform_data = aggregated.setdefault(
platform,
{
"flash_versions": set(),
"python_versions": set(),
"torch_versions": set(),
"cuda_versions": set(),
},
)
Args:
versions_by_platform: Dictionary mapping platform to version sets
platform_data["flash_versions"].add(info["flash_version"])
platform_data["python_versions"].add(info["python_version"])
platform_data["torch_versions"].add(info["torch_version"])
platform_data["cuda_versions"].add(info["cuda_version"])
return aggregated
def format_versions(values: set[str]) -> str:
if not values:
return "-"
return ", ".join(sorted(values))
def render_body_from_aggregated(aggregated: Dict[str, Dict[str, set[str]]]) -> str:
if not aggregated:
Returns:
Formatted markdown string
"""
if not versions_by_platform:
raise ValueError("No wheel assets found")
body_lines: list[str] = []
for platform in sorted(aggregated.keys()):
data = aggregated[platform]
for platform in sorted(versions_by_platform.keys()):
data = versions_by_platform[platform]
body_lines.extend(
[
f"#### {platform}",
@@ -75,12 +76,31 @@ def render_body_from_aggregated(aggregated: Dict[str, Dict[str, set[str]]]) -> s
def build_history_section(tag: str, repo: str, body: str) -> str:
"""Build history section for README.md.
Args:
tag: Release tag name
repo: Repository in owner/name format
body: Markdown body content
Returns:
Formatted history section
"""
release_url = f"https://github.com/{repo}/releases/tag/{tag}"
lines = [f"### {tag}", "", f"[Release]({release_url})", "", body.strip()]
return "\n".join(lines).rstrip() + "\n\n"
def remove_existing_section(content: str, tag: str) -> str:
"""Remove existing section for the given tag from README content.
Args:
content: README.md content
tag: Release tag name
Returns:
Content with the section removed
"""
pattern = re.compile(
rf"^### {re.escape(tag)}\n.*?(?=^### |\Z)", re.MULTILINE | re.DOTALL
)
@@ -88,6 +108,15 @@ def remove_existing_section(content: str, tag: str) -> str:
def insert_history_section(content: str, section: str) -> str:
"""Insert history section into README content.
Args:
content: README.md content
section: History section to insert
Returns:
Updated README.md content
"""
marker = "## History\n"
idx = content.find(marker)
if idx == -1:
@@ -107,11 +136,12 @@ def main() -> None:
parser.add_argument("--output", type=Path, required=True, help="Output file path")
args = parser.parse_args()
data = json.loads(args.assets.read_text(encoding="utf-8"))
assets = data.get("assets", [])
aggregated = collect_versions(assets)
history_body = render_body_from_aggregated(aggregated)
# Load and process assets
assets = load_assets_json(args.assets)
versions_by_platform = collect_versions_from_assets(assets)
history_body = render_body_from_versions(versions_by_platform)
# Build and insert history section
section = build_history_section(args.tag, args.repo, history_body)
content = args.output.read_text(encoding="utf-8")
+48 -47
View File
@@ -1,61 +1,57 @@
import json
"""Generate release notes from assets.json.
This script generates markdown release notes from a GitHub release assets JSON file.
It extracts version information from wheel filenames and creates a formatted table
showing supported Flash-Attention, Python, PyTorch, and CUDA versions for each platform.
Usage:
python create_release_note.py <assets.json>
Arguments:
assets.json: Path to JSON file containing GitHub release assets
(obtained via `gh release view --json assets`)
Output:
Markdown-formatted release notes to stdout
Example:
gh release view v0.7.0 --json assets > assets.json
python create_release_note.py assets.json > release_notes.md
"""
import sys
from pathlib import Path
from common import normalize_platform_name, parse_wheel_filename
from common import collect_versions_from_assets, format_versions, load_assets_json
def generate_release_notes_from_assets(assets_info: dict):
assets_names = [
asset["name"] for asset in assets_info if asset["name"].endswith(".whl")
]
if len(assets_names) == 0:
sys.exit(1)
def generate_release_notes(assets: list[dict]) -> str:
"""Generate release notes from assets.
assets_dict = {}
Args:
assets: List of asset dictionaries
for asset_name in assets_names:
asset_info = parse_wheel_filename(asset_name)
if asset_info is None:
continue
Returns:
Formatted release notes as markdown string
"""
versions_by_platform = collect_versions_from_assets(assets)
if asset_info["platform"] not in assets_dict:
assets_dict[asset_info["platform"]] = {
"flash_versions": set(),
"python_versions": set(),
"torch_versions": set(),
"cuda_versions": set(),
}
assets_dict[asset_info["platform"]]["flash_versions"].add(
asset_info["flash_version"]
)
assets_dict[asset_info["platform"]]["python_versions"].add(
asset_info["python_version"]
)
assets_dict[asset_info["platform"]]["torch_versions"].add(
asset_info["torch_version"]
)
assets_dict[asset_info["platform"]]["cuda_versions"].add(
asset_info["cuda_version"]
)
if not versions_by_platform:
return ""
notes = []
for platform_name, data in sorted(assets_dict.items()):
if any(len(data[key]) == 0 for key in data):
continue
platform_name = normalize_platform_name(platform_name)
for platform_name in sorted(versions_by_platform.keys()):
data = versions_by_platform[platform_name]
notes.append(f"## {platform_name}")
notes.append("")
notes.append("| Flash-Attention | Python | PyTorch | CUDA |")
notes.append("| --- | --- | --- | --- |")
flash_versions = ", ".join(sorted(data["flash_versions"]))
python_versions = ", ".join(sorted(data["python_versions"]))
torch_versions = ", ".join(sorted(data["torch_versions"]))
cuda_versions = ", ".join(sorted(data["cuda_versions"]))
flash_versions = format_versions(data["flash_versions"])
python_versions = format_versions(data["python_versions"])
torch_versions = format_versions(data["torch_versions"])
cuda_versions = format_versions(data["cuda_versions"])
notes.append(
f"| {flash_versions} | {python_versions} | {torch_versions} | {cuda_versions} |"
@@ -68,21 +64,26 @@ def generate_release_notes_from_assets(assets_info: dict):
def main():
try:
if len(sys.argv) != 2:
print("Usage: python create_release_note.py <assets.json>", file=sys.stderr)
sys.exit(1)
assets_json_path = Path(sys.argv[1])
if not assets_json_path.exists():
print(f"File not found: {assets_json_path}", file=sys.stderr)
sys.exit(1)
with open(assets_json_path, "r") as f:
assets_info = json.load(f)["assets"]
if len(assets_info) == 0:
assets = load_assets_json(assets_json_path)
if not assets:
print("No assets found in JSON file", file=sys.stderr)
sys.exit(1)
text = generate_release_notes_from_assets(assets_info)
text = generate_release_notes(assets)
if text:
print(text)
else:
print("No wheel assets found", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
+1039 -982
View File
File diff suppressed because it is too large Load Diff