mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:17:55 -04:00
86938251af
- Move reusable functions (parse_numeric_version, normalize_semantic_version, get_tag_from_url, get_os_emoji, collect_versions_from_assets, format_versions) to common.py for code reuse across scripts - Add comprehensive docstrings to all modules and functions for better maintainability - Enhance parse_wheel_filename to support multiple manylinux tags (e.g., manylinux_2_24_x86_64.manylinux_2_28_x86_64) - Improve normalize_platform_name to handle manylinux formats with multiple tags - Refactor create_packages.py to use shared utilities from common.py - Refactor create_release_history.py with improved function naming and error handling - Refactor create_release_note.py with better error messages and usage documentation - Decode URL encoding for improved readability in package documentation - Update docs/packages.md with formatted output from refactored scripts
246 lines
8.2 KiB
Python
246 lines
8.2 KiB
Python
"""Common utility functions for processing Flash-Attention wheel packages.
|
|
|
|
This module provides shared functionality for parsing wheel filenames, extracting version
|
|
information, and processing GitHub release assets. It is used by scripts that generate
|
|
documentation and release notes for Flash-Attention prebuilt wheels.
|
|
|
|
Functions:
|
|
- load_assets_json: Load assets from GitHub release JSON file
|
|
- parse_wheel_filename: Extract version info from wheel filename
|
|
- normalize_platform_name: Standardize platform names for display
|
|
- parse_numeric_version: Convert version strings to tuples for sorting
|
|
- normalize_semantic_version: Remove patch version from semantic versions
|
|
- get_tag_from_url: Extract release tag from GitHub download URL
|
|
- get_os_emoji: Get emoji representation for OS names
|
|
- collect_versions_from_assets: Aggregate version info by platform
|
|
- format_versions: Format version sets as comma-separated strings
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import pandas as pd
|
|
|
|
|
|
def load_assets_json(path: Path) -> list[dict]:
|
|
"""Load assets from assets.json file.
|
|
|
|
Args:
|
|
path: Path to assets.json file
|
|
|
|
Returns:
|
|
List of asset dictionaries
|
|
"""
|
|
with path.open("r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
return data.get("assets", [])
|
|
|
|
|
|
def parse_numeric_version(text: str) -> tuple:
|
|
"""Extract numeric version tuple for sorting.
|
|
|
|
Examples:
|
|
"2.9.1" -> (2, 9, 1)
|
|
"3.10" -> (3, 10)
|
|
"""
|
|
nums = re.findall(r"\d+", text)
|
|
return tuple(int(n) for n in nums)
|
|
|
|
|
|
def normalize_semantic_version(version: str) -> str:
|
|
"""Normalize semantic version by removing patch version.
|
|
|
|
Examples:
|
|
2.9.1 -> 2.9
|
|
2.8.1 -> 2.8
|
|
2.6.3 -> 2.6
|
|
2.9 -> 2.9 (no change if no patch version)
|
|
"""
|
|
if pd.isna(version) or not version:
|
|
return version
|
|
|
|
# Split by '.' and take only major.minor
|
|
parts = str(version).split(".")
|
|
if len(parts) >= 2:
|
|
return ".".join(parts[:2])
|
|
return version
|
|
|
|
|
|
def get_tag_from_url(url: str) -> str:
|
|
"""Extract tag from GitHub release URL.
|
|
|
|
Examples:
|
|
"https://github.com/user/repo/releases/download/v1.0.0/file.whl" -> "v1.0.0"
|
|
"""
|
|
if pd.isna(url) or not url:
|
|
return ""
|
|
match = re.search(r"/releases/download/([^/]+)/", str(url))
|
|
return match.group(1) if match else ""
|
|
|
|
|
|
def get_os_emoji(os_name: str) -> str:
|
|
"""Get emoji for OS name.
|
|
|
|
Args:
|
|
os_name: OS name (e.g., "Linux x86_64", "Windows")
|
|
|
|
Returns:
|
|
Emoji string with trailing space, or empty string
|
|
"""
|
|
os_lower = os_name.lower()
|
|
if "linux" in os_lower:
|
|
return "🐧 "
|
|
elif "windows" in os_lower:
|
|
return "🪟 "
|
|
else:
|
|
return ""
|
|
|
|
|
|
def collect_versions_from_assets(
|
|
assets: Iterable[dict],
|
|
) -> dict[str, dict[str, set[str]]]:
|
|
"""Collect version information from assets, grouped by platform.
|
|
|
|
Args:
|
|
assets: Iterable of asset dictionaries with "name" key
|
|
|
|
Returns:
|
|
Dictionary mapping platform name to version sets:
|
|
{
|
|
"Linux x86_64": {
|
|
"flash_versions": {"2.6.3", "2.7.4"},
|
|
"python_versions": {"3.10", "3.11"},
|
|
"torch_versions": {"2.5", "2.6"},
|
|
"cuda_versions": {"12.4", "13.0"}
|
|
},
|
|
...
|
|
}
|
|
"""
|
|
aggregated: dict[str, dict[str, set[str]]] = {}
|
|
|
|
for asset in assets:
|
|
name = asset.get("name", "")
|
|
if not name.endswith(".whl"):
|
|
continue
|
|
|
|
info = parse_wheel_filename(name)
|
|
if not info:
|
|
continue
|
|
|
|
platform = normalize_platform_name(info["platform"])
|
|
platform_data = aggregated.setdefault(
|
|
platform,
|
|
{
|
|
"flash_versions": set(),
|
|
"python_versions": set(),
|
|
"torch_versions": set(),
|
|
"cuda_versions": set(),
|
|
},
|
|
)
|
|
|
|
platform_data["flash_versions"].add(info["flash_version"])
|
|
platform_data["python_versions"].add(info["python_version"])
|
|
platform_data["torch_versions"].add(info["torch_version"])
|
|
platform_data["cuda_versions"].add(info["cuda_version"])
|
|
|
|
return aggregated
|
|
|
|
|
|
def format_versions(values: set[str]) -> str:
|
|
"""Format a set of version strings as comma-separated sorted string.
|
|
|
|
Args:
|
|
values: Set of version strings
|
|
|
|
Returns:
|
|
Comma-separated sorted string, or "-" if empty
|
|
"""
|
|
if not values:
|
|
return "-"
|
|
return ", ".join(sorted(values))
|
|
|
|
|
|
def parse_wheel_filename(filename: str) -> dict | None:
|
|
"""
|
|
Extract information from a wheel filename.
|
|
Examples:
|
|
flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
|
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
|
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
|
|
flash_attn-2.8.3+cu128torch2.9-cp313-cp313-manylinux_2_34_x86_64.whl
|
|
flash_attn-2.6.3+cu128torch2.9-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
|
|
|
---
|
|
Wheel filename から情報を抽出
|
|
例: flash_attn-2.6.3+cu124torch2.5-cp311-cp311-linux_x86_64.whl
|
|
flash_attn-2.7.4+cu124torch2.6-cp311-cp311-linux_x86_64.whl
|
|
flash_attn-2.7.4.post1+cu130torch2.9-cp310-cp310-linux_x86_64.whl
|
|
flash_attn-2.6.3+cu128torch2.9-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
|
"""
|
|
# Flash Attention wheelのパターンに合わせて正規表現を調整
|
|
# PyTorchバージョンはマイナーバージョン1桁の形式も対応 (例: torch2.9)
|
|
# post1 のようなバージョンサフィックスにも対応 (例: 2.7.4.post1)
|
|
# manylinux の複数タグにも対応 (例: manylinux_2_24_x86_64.manylinux_2_28_x86_64)
|
|
pattern = r"flash_attn-(\d+\.\d+\.\d+(?:\.[a-z0-9]+)?)\+cu(\d+)torch(\d+\.\d+)-cp(\d+)-cp\d+-(.+?)\.whl"
|
|
match = re.match(pattern, filename)
|
|
|
|
if match:
|
|
flash_version = match.group(1)
|
|
cuda_version = f"{match.group(2)[:2]}.{match.group(2)[2:]}" # 130 -> 13.0
|
|
torch_version = match.group(3)
|
|
python_version = f"{match.group(4)[:1]}.{match.group(4)[1:]}" # 310 -> 3.10
|
|
platform = match.group(5) # linux_x86_64, win32など
|
|
|
|
return {
|
|
"flash_version": flash_version,
|
|
"cuda_version": cuda_version,
|
|
"torch_version": torch_version,
|
|
"python_version": python_version,
|
|
"platform": platform,
|
|
}
|
|
return None
|
|
|
|
|
|
def normalize_platform_name(raw: str) -> str:
|
|
"""Platform name normalization
|
|
Examples:
|
|
linux -> Linux
|
|
linux_x86_64 -> Linux x86_64
|
|
manylinux_2_34_x86_64 -> Manylinux 2_34 x86_64
|
|
manylinux_2_17_aarch64 -> Manylinux 2_17 arm64
|
|
manylinux_2_24_x86_64.manylinux_2_28_x86_64 -> Manylinux 2_24 x86_64
|
|
win32 -> Windows
|
|
amd64 -> x86_64
|
|
"""
|
|
# Handle manylinux format with multiple tags: use only the first tag
|
|
# Example: manylinux_2_24_x86_64.manylinux_2_28_x86_64 -> manylinux_2_24_x86_64
|
|
if "." in raw and raw.startswith("manylinux"):
|
|
raw = raw.split(".")[0]
|
|
|
|
# Handle manylinux format: manylinux_X_Y_ARCH -> Manylinux X_Y ARCH
|
|
if raw.startswith("manylinux"):
|
|
# Extract parts from manylinux_X_Y_ARCH format
|
|
# Examples: manylinux_2_34_x86_64, manylinux_2_17_aarch64
|
|
parts = raw.split("_")
|
|
if len(parts) >= 4:
|
|
# parts[0] = 'manylinux', parts[1] = X, parts[2] = Y, parts[3:] = ARCH parts
|
|
# ARCH can contain underscores (e.g., x86_64)
|
|
version = f"{parts[1]}_{parts[2]}"
|
|
arch = "_".join(parts[3:]) # Join remaining parts for arch (e.g., x86_64)
|
|
# Apply architecture normalization
|
|
if arch == "aarch64":
|
|
arch = "arm64"
|
|
return f"Manylinux {version} {arch}"
|
|
|
|
name = raw[:1].upper() + raw[1:] # linux -> Linux
|
|
name = name.replace("_", " ", 1) # linux_x86_64 -> Linux x86_64
|
|
if "Win" in name:
|
|
name = name.replace("Win", "Windows")
|
|
if "amd64" in name:
|
|
name = name.replace("amd64", "x86_64")
|
|
if "aarch64" in name:
|
|
name = name.replace("aarch64", "arm64")
|
|
return name
|