Files
flash-attention-prebuild-wh…/create_packages.py
T
2025-11-26 23:00:21 +09:00

508 lines
17 KiB
Python

"""
Create and update docs/packages.md from assets.json
Usage:
python create_packages.py --assets assets.json --output docs/packages.md
"""
import argparse
import json
import re
import sys
from pathlib import Path
import pandas as pd
from common import normalize_platform_name, parse_wheel_filename
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting."""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def normalize_semantic_version(version: str) -> str:
"""Normalize semantic version by removing patch version.
Examples:
2.9.1 -> 2.9
2.8.1 -> 2.8
2.6.3 -> 2.6
2.9 -> 2.9 (no change if no patch version)
"""
if pd.isna(version) or not version:
return version
# Split by '.' and take only major.minor
parts = str(version).split(".")
if len(parts) >= 2:
return ".".join(parts[:2])
return version
def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
"""Extract package information from existing docs/packages.md."""
if not packages_md_path.exists():
return []
with packages_md_path.open("r", encoding="utf-8") as f:
content = f.read()
lines = content.splitlines()
packages = []
current_os = None
current_fa_version = None
in_table = False
for line in lines:
line_stripped = line.strip()
# Detect OS heading (## Linux x86_64)
if line_stripped.startswith("## ") and not line_stripped.startswith("### "):
# Remove emoji from OS name (e.g., "🐧 Linux x86_64" -> "Linux x86_64")
os_name = line_stripped[3:].strip()
while os_name and ord(os_name[0]) > 127:
os_name = os_name[1:].strip()
current_os = os_name
current_fa_version = None
in_table = False
continue
# Detect Flash-Attention version heading (### Flash-Attention 2.8.3)
if line_stripped.startswith("### Flash-Attention "):
current_fa_version = line_stripped.replace(
"### Flash-Attention ", ""
).strip()
in_table = False
continue
# Detect table start
if "| Python | PyTorch | CUDA | package |" in line_stripped:
in_table = True
continue
# Skip table separator line
if in_table and "| ------ |" in line_stripped:
continue
# Process table rows
if (
in_table
and line_stripped.startswith("|")
and current_os
and current_fa_version
):
# Parse table row: | Python | PyTorch | CUDA | package |
cells = [
c.strip() for c in line_stripped.split("|")[1:-1]
] # Remove empty first/last cells
cells = [c for c in cells if c] # Remove empty cells
if len(cells) >= 4:
python_version = cells[0]
torch_version = cells[1]
cuda_version = cells[2]
package_cell = cells[3]
# Extract all URLs from package cell
# Pattern: [Release1](url1), [Download1](url1), [Release](url), [Download](url), ...
# Support both Release and Download patterns for backward compatibility
# Also support version suffix: [Download1(v1.0.0)](url)
package_urls = re.findall(
r"\[(?:Release|Download)\d*(?:\([^)]*\))?\]\(([^)]+)\)",
package_cell,
)
if package_urls:
# Create a package entry for each URL
for package_url in package_urls:
packages.append(
{
"Flash-Attention": current_fa_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": current_os,
"package": package_url,
}
)
elif package_cell != "-":
# Handle other formats
packages.append(
{
"Flash-Attention": current_fa_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": current_os,
"package": None,
}
)
# Detect end of table (empty line or closing </details>)
if in_table and (not line_stripped or line_stripped == "</details>"):
in_table = False
return packages
def extract_packages_from_assets_json(assets_path: Path) -> list[dict]:
"""Extract package information from assets.json file."""
with assets_path.open("r", encoding="utf-8") as f:
data = json.load(f)
if "assets" not in data:
return []
packages = []
for asset in data["assets"]:
name = asset.get("name", "")
url = asset.get("url", "")
# Only process .whl files
if not name.endswith(".whl"):
continue
# Parse wheel filename
info = parse_wheel_filename(name)
if not info:
continue
# Normalize platform name
os_name = normalize_platform_name(info["platform"])
# Format versions for display
flash_version = info["flash_version"]
python_version = info["python_version"]
torch_version = info["torch_version"] # Already in format like "2.9"
cuda_version = info["cuda_version"]
packages.append(
{
"Flash-Attention": flash_version,
"Python": python_version,
"PyTorch": torch_version,
"CUDA": cuda_version,
"OS": os_name,
"package": url, # Use download URL directly
}
)
return packages
def sort_packages(
df: pd.DataFrame,
flash_ascending: bool = False,
python_ascending: bool = False,
pytorch_ascending: bool = False,
cuda_ascending: bool = False,
os_ascending: bool = True,
package_ascending: bool = False,
) -> pd.DataFrame:
"""
Sort packages by columns from left to right.
Args:
df: DataFrame to sort
flash_ascending: Sort Flash-Attention in ascending order (default: False, newer first)
python_ascending: Sort Python in ascending order (default: False, newer first)
pytorch_ascending: Sort PyTorch in ascending order (default: False, newer first)
cuda_ascending: Sort CUDA in ascending order (default: False, newer first)
os_ascending: Sort OS in ascending order (default: True, alphabetical)
package_ascending: Sort package in ascending order (default: False, newer first)
Returns:
Sorted DataFrame
"""
df = df.copy()
# Add sorting keys for version columns
df["fa_sort"] = df["Flash-Attention"].apply(parse_numeric_version)
df["py_sort"] = df["Python"].apply(parse_numeric_version)
df["pt_sort"] = df["PyTorch"].apply(parse_numeric_version)
df["cu_sort"] = df["CUDA"].apply(parse_numeric_version)
df["os_sort"] = df["OS"].str.lower()
# Package sort: extract version from download URL
def package_sort_key(url):
# Handle list of URLs (take the first one for sorting)
if isinstance(url, list):
if not url or all(pd.isna(u) or not u for u in url):
return tuple()
# Find first non-empty URL
for u in url:
if pd.notna(u) and u:
url = u
break
else:
return tuple()
if pd.isna(url) or not url:
return tuple() # No URL
# Extract tag from download URL: /releases/download/{tag}/
tag_match = re.search(r"/releases/download/([^/]+)/", str(url))
if not tag_match:
return tuple()
tag = tag_match.group(1)
return parse_numeric_version(tag)
df["pkg_sort"] = df["package"].apply(package_sort_key)
# Sort by columns from left to right
df_sorted = df.sort_values(
by=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"],
ascending=[
flash_ascending,
os_ascending,
python_ascending,
pytorch_ascending,
cuda_ascending,
package_ascending,
],
)
# Drop sorting columns
return df_sorted.drop(
columns=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
)
def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
"""Merge rows with duplicate Flash-Attention, Python, PyTorch, CUDA, OS values."""
# Group by all columns except 'package'
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
def combine_packages(group):
# Get unique non-null packages (handle both list and scalar values)
all_packages = []
for pkg in group["package"]:
if pd.notna(pkg):
if isinstance(pkg, list):
all_packages.extend(pkg)
else:
all_packages.append(pkg)
# Remove duplicates while preserving order
seen = set()
unique_packages = []
for pkg in all_packages:
if pkg and pkg not in seen:
seen.add(pkg)
unique_packages.append(pkg)
# Take the first row as base
result = group.iloc[0].copy()
# Combine packages into a list
result["package"] = unique_packages if unique_packages else [None]
return result
# Group and combine
merged_df = df.groupby(group_cols, as_index=False).apply(
combine_packages, include_groups=False
)
# Reset index to clean up
merged_df = merged_df.reset_index(drop=True)
return merged_df
def get_tag_from_url(url: str) -> str:
"""Extract tag from GitHub release URL."""
if pd.isna(url) or not url:
return ""
match = re.search(r"/releases/download/([^/]+)/", str(url))
return match.group(1) if match else ""
def get_os_emoji(os_name: str) -> str:
"""Get emoji for OS name."""
os_lower = os_name.lower()
if "linux" in os_lower:
return "🐧 "
elif "windows" in os_lower:
return "🪟 "
else:
return ""
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
"""Generate markdown tables grouped by OS and Flash-Attention version."""
if df.empty:
return ""
all_sections = []
# Generate Table of Contents
os_names = sorted(df["OS"].unique())
toc_lines = ["## Table of Contents", ""]
for os_name in os_names:
# Create anchor link (lowercase, replace spaces with hyphens)
os_anchor = os_name.lower().replace(" ", "-")
toc_lines.append(f"- [{os_name}](#{os_anchor})")
# Add Flash-Attention versions for this OS (sorted)
os_df = df[df["OS"] == os_name].copy()
os_df = sort_packages(os_df, flash_ascending=False)
for fa_version in os_df["Flash-Attention"].unique():
# Create anchor for Flash-Attention version
fa_anchor = f"flash-attention-{fa_version.replace('.', '')}".lower()
toc_lines.append(f" - [Flash-Attention {fa_version}](#{fa_anchor})")
toc_lines.append("")
all_sections.extend(toc_lines)
# Group by OS and sort each group
for os_name in os_names:
os_df = df[df["OS"] == os_name].copy()
# Sort within OS group: Flash-Attention > Python > PyTorch > CUDA
os_df = sort_packages(
os_df,
flash_ascending=False,
python_ascending=True,
pytorch_ascending=True,
cuda_ascending=True,
)
# Create OS section header with emoji
os_emoji = get_os_emoji(os_name)
os_lines = [f"## {os_emoji}{os_name}", ""]
# Group by Flash-Attention version within each OS
fa_versions = []
for fa_version in os_df["Flash-Attention"].unique():
fa_df = os_df[os_df["Flash-Attention"] == fa_version].copy()
# Sort by Python > PyTorch > CUDA within each Flash-Attention version
fa_df = sort_packages(
fa_df,
python_ascending=True,
pytorch_ascending=True,
cuda_ascending=True,
)
# Create collapsible table for this Flash-Attention version
table_lines = [
"| Python | PyTorch | CUDA | package |",
"| ------ | ------- | ---- | ------- |",
]
for _, row in fa_df.iterrows():
packages = row["package"]
# Handle case where packages is a list
if isinstance(packages, list):
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
# Create numbered download links
package_links = []
for i, pkg in enumerate(packages, 1):
if pd.notna(pkg) and pkg:
tag = get_tag_from_url(pkg)
tag_str = f"({tag})" if tag else ""
package_links.append(f"[Download{i}{tag_str}]({pkg})")
package_cell = ", ".join(package_links)
else:
package_cell = "-"
else:
# Handle single package (backward compatibility)
tag = get_tag_from_url(packages)
tag_str = f"({tag})" if tag else ""
package_cell = (
f"[Download{tag_str}]({packages})"
if pd.notna(packages) and packages
else "-"
)
line = f"| {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {package_cell} |"
table_lines.append(line)
# Create collapsible section for this Flash-Attention version
fa_section = [
f"### Flash-Attention {fa_version}",
"",
"<details>",
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
"",
"\n".join(table_lines),
"",
"</details>",
"",
]
fa_versions.extend(fa_section)
os_lines.extend(fa_versions)
all_sections.extend(os_lines)
return "\n".join(all_sections)
def main() -> None:
parser = argparse.ArgumentParser(
description="Create and update docs/packages.md from assets.json"
)
parser.add_argument(
"--assets",
type=str,
default="assets.json",
help="Path to assets.json file (default: assets.json)",
)
parser.add_argument(
"--output",
type=str,
default="docs/packages.md",
help="Output file path (default: docs/packages.md)",
)
args = parser.parse_args()
assets_path = Path(args.assets)
output_path = Path(args.output)
# Extract packages from assets.json if it exists
assets_packages = []
if assets_path.exists():
assets_packages = extract_packages_from_assets_json(assets_path)
# Extract packages from existing docs/packages.md
packages_md_packages = extract_packages_from_packages_md(output_path)
# Combine both lists
all_packages = assets_packages + packages_md_packages
if not all_packages:
print(f"No packages found in {assets_path} or {output_path}", file=sys.stderr)
return
# Convert to DataFrame and process
df = pd.DataFrame(all_packages)
# Normalize CUDA versions (remove patch version)
df["CUDA"] = df["CUDA"].apply(normalize_semantic_version)
# Normalize PyTorch versions (remove patch version)
df["PyTorch"] = df["PyTorch"].apply(normalize_semantic_version)
# Normalize Python versions (remove patch version)
df["Python"] = df["Python"].apply(normalize_semantic_version)
df_sorted = sort_packages(df)
df_merged = merge_duplicate_rows(df_sorted)
markdown = generate_markdown_table_by_os(df_merged)
# Create parent directory if it doesn't exist
output_path.parent.mkdir(parents=True, exist_ok=True)
# Generate markdown with "# Packages" header for standalone file
standalone_markdown = f"# Packages\n\n{markdown}"
with output_path.open("w", encoding="utf-8") as f:
f.write(standalone_markdown)
print(f"Written packages to {output_path}")
if __name__ == "__main__":
main()