Files
flash-attention-prebuild-wh…/common.py
T
Junya Morioka 86938251af refactor: improve code structure and documentation with shared utilities
- Move reusable functions (parse_numeric_version, normalize_semantic_version, get_tag_from_url, get_os_emoji, collect_versions_from_assets, format_versions) to common.py for code reuse across scripts
- Add comprehensive docstrings to all modules and functions for better maintainability
- Enhance parse_wheel_filename to support multiple manylinux tags (e.g., manylinux_2_24_x86_64.manylinux_2_28_x86_64)
- Improve normalize_platform_name to handle manylinux formats with multiple tags
- Refactor create_packages.py to use shared utilities from common.py
- Refactor create_release_history.py with improved function naming and error handling
- Refactor create_release_note.py with better error messages and usage documentation
- Decode URL encoding for improved readability in package documentation
- Update docs/packages.md with formatted output from refactored scripts
2025-12-21 03:20:11 +09:00

246 lines
8.2 KiB
Python

"""Common utility functions for processing Flash-Attention wheel packages.
This module provides shared functionality for parsing wheel filenames, extracting version
information, and processing GitHub release assets. It is used by scripts that generate
documentation and release notes for Flash-Attention prebuilt wheels.
Functions:
- load_assets_json: Load assets from GitHub release JSON file
- parse_wheel_filename: Extract version info from wheel filename
- normalize_platform_name: Standardize platform names for display
- parse_numeric_version: Convert version strings to tuples for sorting
- normalize_semantic_version: Remove patch version from semantic versions
- get_tag_from_url: Extract release tag from GitHub download URL
- get_os_emoji: Get emoji representation for OS names
- collect_versions_from_assets: Aggregate version info by platform
- format_versions: Format version sets as comma-separated strings
"""
import json
import re
from pathlib import Path
from typing import Iterable
import pandas as pd
def load_assets_json(path: Path) -> list[dict]:
"""Load assets from assets.json file.
Args:
path: Path to assets.json file
Returns:
List of asset dictionaries
"""
with path.open("r", encoding="utf-8") as f:
data = json.load(f)
return data.get("assets", [])
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting.
Examples:
"2.9.1" -> (2, 9, 1)
"3.10" -> (3, 10)
"""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def normalize_semantic_version(version: str) -> str:
"""Normalize semantic version by removing patch version.
Examples:
2.9.1 -> 2.9
2.8.1 -> 2.8
2.6.3 -> 2.6
2.9 -> 2.9 (no change if no patch version)
"""
if pd.isna(version) or not version:
return version
# Split by '.' and take only major.minor
parts = str(version).split(".")
if len(parts) >= 2:
return ".".join(parts[:2])
return version
def get_tag_from_url(url: str) -> str:
"""Extract tag from GitHub release URL.
Examples:
"https://github.com/user/repo/releases/download/v1.0.0/file.whl" -> "v1.0.0"
"""
if pd.isna(url) or not url:
return ""
match = re.search(r"/releases/download/([^/]+)/", str(url))
return match.group(1) if match else ""
def get_os_emoji(os_name: str) -> str:
"""Get emoji for OS name.
Args:
os_name: OS name (e.g., "Linux x86_64", "Windows")
Returns:
Emoji string with trailing space, or empty string
"""
os_lower = os_name.lower()
if "linux" in os_lower:
return "🐧 "
elif "windows" in os_lower:
return "🪟 "
else:
return ""
def collect_versions_from_assets(
assets: Iterable[dict],
) -> dict[str, dict[str, set[str]]]:
"""Collect version information from assets, grouped by platform.
Args:
assets: Iterable of asset dictionaries with "name" key
Returns:
Dictionary mapping platform name to version sets:
{
"Linux x86_64": {
"flash_versions": {"2.6.3", "2.7.4"},
"python_versions": {"3.10", "3.11"},
"torch_versions": {"2.5", "2.6"},
"cuda_versions": {"12.4", "13.0"}
},
...
}
"""
aggregated: dict[str, dict[str, set[str]]] = {}
for asset in assets:
name = asset.get("name", "")
if not name.endswith(".whl"):
continue
info = parse_wheel_filename(name)
if not info:
continue
platform = normalize_platform_name(info["platform"])
platform_data = aggregated.setdefault(
platform,
{
"flash_versions": set(),
"python_versions": set(),
"torch_versions": set(),
"cuda_versions": set(),
},
)
platform_data["flash_versions"].add(info["flash_version"])
platform_data["python_versions"].add(info["python_version"])
platform_data["torch_versions"].add(info["torch_version"])
platform_data["cuda_versions"].add(info["cuda_version"])
return aggregated
def format_versions(values: set[str]) -> str:
"""Format a set of version strings as comma-separated sorted string.
Args:
values: Set of version strings
Returns:
Comma-separated sorted string, or "-" if empty
"""
if not values:
return "-"
return ", ".join(sorted(values))
def parse_wheel_filename(filename: str) -> dict | None:
"""
Extract information from a wheel filename.
Examples:
flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
flash_attn-2.8.3+cu128torch2.9-cp313-cp313-manylinux_2_34_x86_64.whl
flash_attn-2.6.3+cu128torch2.9-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
---
Wheel filename から情報を抽出
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
flash_attn-2.6.3+cu128torch2.9-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
"""
# Flash Attention wheelのパターンに合わせて正規表現を調整
# PyTorchバージョンはマイナーバージョン1桁の形式も対応 (例: torch2.9)
# post1 のようなバージョンサフィックスにも対応 (例: 2.7.4.post1)
# manylinux の複数タグにも対応 (例: manylinux_2_24_x86_64.manylinux_2_28_x86_64)
pattern = r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(.+?)\.whl"
match = re.match(pattern, filename)
if match:
flash_version = match.group(1)
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 130 -> 13.0
torch_version = match.group(3)
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 310 -> 3.10
platform = match.group(5) # linux_x86_64, win32など
return {
"flash_version": flash_version,
"cuda_version": cuda_version,
"torch_version": torch_version,
"python_version": python_version,
"platform": platform,
}
return None
def normalize_platform_name(raw: str) -> str:
"""Platform name normalization
Examples:
linux -> Linux
linux_x86_64 -> Linux x86_64
manylinux_2_34_x86_64 -> Manylinux 2_34 x86_64
manylinux_2_17_aarch64 -> Manylinux 2_17 arm64
manylinux_2_24_x86_64.manylinux_2_28_x86_64 -> Manylinux 2_24 x86_64
win32 -> Windows
amd64 -> x86_64
"""
# Handle manylinux format with multiple tags: use only the first tag
# Example: manylinux_2_24_x86_64.manylinux_2_28_x86_64 -> manylinux_2_24_x86_64
if "." in raw and raw.startswith("manylinux"):
raw = raw.split(".")[0]
# Handle manylinux format: manylinux_X_Y_ARCH -> Manylinux X_Y ARCH
if raw.startswith("manylinux"):
# Extract parts from manylinux_X_Y_ARCH format
# Examples: manylinux_2_34_x86_64, manylinux_2_17_aarch64
parts = raw.split("_")
if len(parts) >= 4:
# parts[0] = 'manylinux', parts[1] = X, parts[2] = Y, parts[3:] = ARCH parts
# ARCH can contain underscores (e.g., x86_64)
version = f"{parts[1]}_{parts[2]}"
arch = "_".join(parts[3:]) # Join remaining parts for arch (e.g., x86_64)
# Apply architecture normalization
if arch == "aarch64":
arch = "arm64"
return f"Manylinux {version} {arch}"
name = raw[:1].upper() + raw[1:] # linux -> Linux
name = name.replace("_", " ", 1) # linux_x86_64 -> Linux x86_64
if "Win" in name:
name = name.replace("Win", "Windows")
if "amd64" in name:
name = name.replace("amd64", "x86_64")
if "aarch64" in name:
name = name.replace("aarch64", "arm64")
return name