Files
flash-attention-prebuild-wh…/generate_packages_table.py
T
Junya Morioka 39ea2fb7c0 update scripts
2025-08-10 17:05:29 +09:00

399 lines
14 KiB
Python

"""
python generate_packages_table.py --update-readme
"""
import argparse
import itertools
import re
import sys
from pathlib import Path
import pandas as pd
def parse_numeric_version(text: str) -> tuple:
"""Extract numeric version tuple for sorting."""
nums = re.findall(r"\d+", text)
return tuple(int(n) for n in nums)
def extract_packages_from_history(text: str) -> list[dict]:
"""Extract package information from History section."""
lines = text.splitlines()
# Find start of History section
in_history = False
for i, line in enumerate(lines):
if line.strip().startswith("## ") and "History" in line:
in_history = True
lines = lines[i:]
break
if not in_history:
return []
packages = []
current_release_url = None
current_os = "Linux x86_64" # default
i = 0
while i < len(lines):
line = lines[i].strip()
# Reset on new version
if line.startswith("### "):
current_release_url = None
current_os = "Linux x86_64"
# Capture Release link
elif "[Release](" in line:
match = re.search(r"\[Release\]\(([^)]+)\)", line)
if match:
current_release_url = match.group(1)
# Capture OS heading
elif line.startswith("#### "):
current_os = line[5:].strip() or "Linux x86_64"
# Process table
elif line.startswith("| Flash-Attention") or line.startswith(
"|Flash-Attention"
):
# Skip header and separator
i += 2
# Process table rows
while i < len(lines):
row_line = lines[i].strip()
if not row_line.startswith("|") or not row_line:
break
# Parse table row
cells = [c.strip() for c in row_line.split("|")]
cells = [c for c in cells if c] # Remove empty cells
if len(cells) >= 4:
fa_versions = [v.strip() for v in cells[0].split(",") if v.strip()]
py_versions = [v.strip() for v in cells[1].split(",") if v.strip()]
pt_versions = [v.strip() for v in cells[2].split(",") if v.strip()]
cu_versions = [v.strip() for v in cells[3].split(",") if v.strip()]
# Generate all combinations
for fa, py, pt, cu in itertools.product(
fa_versions, py_versions, pt_versions, cu_versions
):
packages.append(
{
"Flash-Attention": fa,
"Python": py,
"PyTorch": pt,
"CUDA": cu,
"OS": current_os,
"package": current_release_url,
}
)
i += 1
continue
i += 1
return packages
def sort_packages(df: pd.DataFrame) -> pd.DataFrame:
"""Sort packages with custom priority."""
# Add sorting keys
# Flash-Attention: descending order (newer versions first)
df["fa_sort"] = df["Flash-Attention"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
df["os_sort"] = df["OS"].str.lower()
# Python, PyTorch, CUDA: descending order (newer versions first)
df["py_sort"] = df["Python"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
df["pt_sort"] = df["PyTorch"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
df["cu_sort"] = df["CUDA"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
# Package sort: extract version from URL, newer first
def package_sort_key(url):
if pd.isna(url) or not url:
return (1, tuple()) # No URL comes last
tag_match = re.search(r"/tag/([^/]+)$", str(url))
if not tag_match:
return (1, tuple())
tag = tag_match.group(1)
version_tuple = parse_numeric_version(tag)
return (0, tuple(-v for v in version_tuple)) # Negate for descending
df["pkg_sort"] = df["package"].apply(package_sort_key)
# Sort by priority: Flash-Attention > OS > Python > PyTorch > CUDA > package
df_sorted = df.sort_values(
["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
)
# Drop sorting columns
return df_sorted.drop(
columns=["fa_sort", "os_sort", "py_sort", "pt_sort", "cu_sort", "pkg_sort"]
)
def merge_duplicate_rows(df: pd.DataFrame) -> pd.DataFrame:
"""Merge rows with duplicate Flash-Attention, Python, PyTorch, CUDA, OS values."""
# Group by all columns except 'package'
group_cols = ["Flash-Attention", "Python", "PyTorch", "CUDA", "OS"]
def combine_packages(group):
# Get unique non-null packages
packages = [pkg for pkg in group["package"].dropna().unique() if pkg]
# Take the first row as base
result = group.iloc[0].copy()
# Combine packages into a list
result["package"] = packages if packages else [None]
return result
# Group and combine
merged_df = df.groupby(group_cols, as_index=False).apply(
combine_packages, include_groups=False
)
# Reset index to clean up
merged_df = merged_df.reset_index(drop=True)
return merged_df
def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
"""Generate markdown tables grouped by OS and Flash-Attention version."""
if df.empty:
return ""
all_sections = []
# Group by OS and sort each group
for os_name in sorted(df["OS"].unique()):
os_df = df[df["OS"] == os_name].copy()
# Re-sort within each OS group to ensure Flash-Attention is in descending order
os_df["fa_sort"] = os_df["Flash-Attention"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
os_df["py_sort"] = os_df["Python"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
os_df["pt_sort"] = os_df["PyTorch"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
os_df["cu_sort"] = os_df["CUDA"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
# Sort by Flash-Attention > Python > PyTorch > CUDA
os_df = os_df.sort_values(["fa_sort", "py_sort", "pt_sort", "cu_sort"])
os_df = os_df.drop(columns=["fa_sort", "py_sort", "pt_sort", "cu_sort"])
# Create OS section header
os_lines = [f"### {os_name}", ""]
# Group by Flash-Attention version within each OS
fa_versions = []
for fa_version in os_df["Flash-Attention"].unique():
fa_df = os_df[os_df["Flash-Attention"] == fa_version].copy()
# Re-sort by Python > PyTorch > CUDA within each Flash-Attention version
fa_df["py_sort"] = fa_df["Python"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
fa_df["pt_sort"] = fa_df["PyTorch"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
fa_df["cu_sort"] = fa_df["CUDA"].apply(
lambda x: tuple(-v for v in parse_numeric_version(x))
)
fa_df = fa_df.sort_values(["py_sort", "pt_sort", "cu_sort"])
fa_df = fa_df.drop(columns=["py_sort", "pt_sort", "cu_sort"])
# Create collapsible table for this Flash-Attention version
table_lines = [
"| Python | PyTorch | CUDA | package |",
"| ------ | ------- | ---- | ------- |",
]
for _, row in fa_df.iterrows():
packages = row["package"]
# Handle case where packages is a list
if isinstance(packages, list):
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
# Create numbered release links
package_links = []
for i, pkg in enumerate(packages, 1):
if pd.notna(pkg) and pkg:
package_links.append(f"[Release{i}]({pkg})")
package_cell = ", ".join(package_links)
else:
package_cell = "-"
else:
# Handle single package (backward compatibility)
package_cell = (
f"[Release]({packages})"
if pd.notna(packages) and packages
else "-"
)
line = f"| {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {package_cell} |"
table_lines.append(line)
# Create collapsible section for this Flash-Attention version
fa_section = [
f"#### Flash-Attention {fa_version}",
"",
"<details>",
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
"",
"\n".join(table_lines),
"",
"</details>",
"",
]
fa_versions.extend(fa_section)
os_lines.extend(fa_versions)
all_sections.extend(os_lines)
return "\n".join(all_sections)
def generate_markdown_table(df: pd.DataFrame) -> str:
"""Generate markdown table from DataFrame (legacy function for backward compatibility)."""
lines = [
"| Flash-Attention | Python | PyTorch | CUDA | OS | package |",
"| --------------- | ------ | ------- | ------ | ---- | ------- |",
]
for _, row in df.iterrows():
packages = row["package"]
# Handle case where packages is a list
if isinstance(packages, list):
if packages and any(pd.notna(pkg) and pkg for pkg in packages):
# Create numbered release links
package_links = []
for i, pkg in enumerate(packages, 1):
if pd.notna(pkg) and pkg:
package_links.append(f"[Release{i}]({pkg})")
package_cell = ", ".join(package_links)
else:
package_cell = "-"
else:
# Handle single package (backward compatibility)
package_cell = (
f"[Release]({packages})" if pd.notna(packages) and packages else "-"
)
line = f"| {row['Flash-Attention']} | {row['Python']} | {row['PyTorch']} | {row['CUDA']} | {row['OS']} | {package_cell} |"
lines.append(line)
return "\n".join(lines)
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
"""Update the Packages section in README.md with new content."""
try:
with readme_path.open("r", encoding="utf-8") as f:
content = f.read()
# Find the Packages section
packages_start = content.find("## Packages")
if packages_start == -1:
raise ValueError("Packages section not found in README.md")
# Find the end of Packages section (next ## section or History section)
packages_end = content.find("## History", packages_start)
if packages_end == -1:
# If no History section found, look for any other ## section
remaining_content = content[packages_start + len("## Packages") :]
next_section = remaining_content.find("\n## ")
if next_section != -1:
packages_end = packages_start + len("## Packages") + next_section
else:
packages_end = len(content)
# Replace the Packages section
new_content = (
content[:packages_start]
+ "## Packages\n\n"
+ packages_markdown
+ "\n\n"
+ content[packages_end:]
)
# Write back to file
with readme_path.open("w", encoding="utf-8") as f:
f.write(new_content)
except Exception as e:
raise RuntimeError(f"Failed to update README.md: {e}")
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate a one-row-per-package Markdown table from the History section of a README.md"
)
parser.add_argument(
"readme",
nargs="?",
type=Path,
default=Path("README.md"),
help="Path to README.md (default: README.md)",
)
parser.add_argument(
"--update-readme",
action="store_true",
help="Update the Packages section in README.md instead of printing to stdout",
)
args = parser.parse_args()
try:
with args.readme.open("r", encoding="utf-8") as f:
text = f.read()
packages = extract_packages_from_history(text)
if not packages:
print("No packages found in History section", file=sys.stderr)
return
df = pd.DataFrame(packages)
df_sorted = sort_packages(df)
df_merged = merge_duplicate_rows(df_sorted)
markdown = generate_markdown_table_by_os(df_merged)
if args.update_readme:
# Update the README.md file
update_readme_packages_section(args.readme, markdown)
print(f"Updated Packages section in {args.readme}")
else:
# Print to stdout (original behavior)
print(markdown)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
if __name__ == "__main__":
main()