feat: add package status checker script

- Add `check_missing_packages.py` to compare GitHub releases with expected matrix and display status table.
- Update `.gitignore` to exclude `assets.json` and `.claude/plans`.
- Remove obsolete plan document for pandas 3.0 fix.
This commit is contained in:
Junya Morioka
2026-01-25 17:33:51 +09:00
parent d85d2aad44
commit d32a4c9b60
3 changed files with 586 additions and 81 deletions
+583
View File
@@ -0,0 +1,583 @@
#!/usr/bin/env python3
"""Check missing packages by comparing GitHub releases with expected matrix.
This script fetches wheel assets from GitHub releases and compares them with
the expected package matrix defined in create_matrix.py. It displays a colored
table showing which packages exist, are missing, or are excluded.
Usage:
python check_missing_packages.py
python check_missing_packages.py --cache
python check_missing_packages.py --platform linux --flash-version 2.8.3
python check_missing_packages.py --show-missing-only
"""
import argparse
import json
import os
import sys
import time
from pathlib import Path
import requests
from rich.console import Console
from rich.table import Table
from rich.text import Text
from common import parse_wheel_filename
from create_matrix import (
EXCLUDE,
LINUX_ARM64_MATRIX,
LINUX_ARM64_SELF_HOSTED_MATRIX,
LINUX_MATRIX,
LINUX_SELF_HOSTED_MATRIX,
WINDOWS_CODEBUILD_MATRIX,
WINDOWS_MATRIX,
WINDOWS_SELF_HOSTED_MATRIX,
)
# Comprehensive matrix combining all platform-specific matrices
def get_comprehensive_matrix(platform: str) -> dict:
"""Get comprehensive matrix for a platform by merging all related matrices."""
if platform == "linux":
# Merge LINUX_MATRIX and LINUX_SELF_HOSTED_MATRIX
return merge_matrices([LINUX_MATRIX, LINUX_SELF_HOSTED_MATRIX])
elif platform == "linux_arm64":
return merge_matrices([LINUX_ARM64_MATRIX, LINUX_ARM64_SELF_HOSTED_MATRIX])
elif platform == "windows":
return merge_matrices(
[WINDOWS_MATRIX, WINDOWS_SELF_HOSTED_MATRIX, WINDOWS_CODEBUILD_MATRIX]
)
else:
return {}
def merge_matrices(matrices: list[dict]) -> dict:
"""Merge multiple matrices by combining their version lists."""
merged = {
"flash-attn-version": set(),
"python-version": set(),
"torch-version": set(),
"cuda-version": set(),
}
for matrix in matrices:
for key in merged:
merged[key].update(matrix.get(key, []))
# Convert sets to sorted lists
return {key: sorted(vals, key=parse_version_tuple) for key, vals in merged.items()}
def parse_version_tuple(version: str) -> tuple:
"""Parse version string to tuple for sorting."""
parts = version.replace("post", ".").split(".")
result = []
for p in parts:
try:
result.append(int(p))
except ValueError:
result.append(0)
return tuple(result)
def get_github_token() -> str | None:
"""Get GitHub token from environment variable."""
token = os.environ.get("GITHUB_TOKEN")
if not token:
print(
"Warning: GITHUB_TOKEN not set. API rate limit will be restricted.",
file=sys.stderr,
)
return token
def fetch_all_releases(repo: str, token: str | None = None) -> list[dict]:
"""Fetch all releases from a GitHub repository."""
headers = {}
if token:
headers["Authorization"] = f"token {token}"
headers["Accept"] = "application/vnd.github.v3+json"
all_releases = []
page = 1
per_page = 100
while True:
url = f"https://api.github.com/repos/{repo}/releases"
params = {"page": page, "per_page": per_page}
print(f"Fetching releases page {page}...", file=sys.stderr)
response = requests.get(url, headers=headers, params=params, timeout=30)
if response.status_code != 200:
print(
f"Error fetching releases: {response.status_code} - {response.text}",
file=sys.stderr,
)
break
releases = response.json()
if not releases:
break
all_releases.extend(releases)
print(f" Found {len(releases)} releases on page {page}", file=sys.stderr)
if len(releases) < per_page:
break
page += 1
time.sleep(0.5)
return all_releases
def extract_assets_from_releases(releases: list[dict]) -> list[dict]:
"""Extract all wheel assets from releases."""
all_assets = []
for release in releases:
for asset in release.get("assets", []):
name = asset.get("name", "")
if not name.endswith(".whl"):
continue
asset_info = {
"name": name,
"url": asset.get("browser_download_url", ""),
}
all_assets.append(asset_info)
return all_assets
def load_or_fetch_assets(repo: str, cache_path: Path, use_cache: bool) -> list[dict]:
"""Load assets from cache or fetch from GitHub."""
if use_cache and cache_path.exists():
print(f"Loading assets from cache: {cache_path}", file=sys.stderr)
with cache_path.open("r", encoding="utf-8") as f:
data = json.load(f)
return data.get("assets", [])
token = get_github_token()
print(f"Fetching all releases from {repo}...", file=sys.stderr)
releases = fetch_all_releases(repo, token)
print(f"Total releases found: {len(releases)}", file=sys.stderr)
assets = extract_assets_from_releases(releases)
print(f"Total wheel assets found: {len(assets)}", file=sys.stderr)
if use_cache:
print(f"Saving assets to cache: {cache_path}", file=sys.stderr)
with cache_path.open("w", encoding="utf-8") as f:
json.dump({"assets": assets}, f, indent=2, ensure_ascii=False)
return assets
def is_excluded(
flash_version: str,
python_version: str,
torch_version: str,
cuda_version: str,
) -> bool:
"""Check if a combination is in the EXCLUDE list."""
for excl in EXCLUDE:
match = True
if "flash-attn-version" in excl and excl["flash-attn-version"] != flash_version:
match = False
if "python-version" in excl and excl["python-version"] != python_version:
match = False
if "torch-version" in excl and excl["torch-version"] != torch_version:
match = False
if "cuda-version" in excl and excl["cuda-version"] != cuda_version:
match = False
if match:
return True
return False
def normalize_platform_for_comparison(platform_raw: str) -> str:
"""Normalize platform string for comparison.
Returns: "linux", "linux_arm64", or "windows"
"""
platform_lower = platform_raw.lower()
if "win" in platform_lower:
return "windows"
elif "aarch64" in platform_lower or "arm64" in platform_lower:
return "linux_arm64"
elif "x86_64" in platform_lower or "linux" in platform_lower:
return "linux"
else:
return platform_lower
def build_existing_packages_set(assets: list[dict]) -> dict[str, set[tuple]]:
"""Build a set of existing packages grouped by normalized platform.
Returns:
Dict mapping platform to set of (flash, python, torch, cuda) tuples
"""
packages: dict[str, set[tuple]] = {
"linux": set(),
"linux_arm64": set(),
"windows": set(),
}
for asset in assets:
name = asset.get("name", "")
info = parse_wheel_filename(name)
if not info:
continue
platform = normalize_platform_for_comparison(info["platform"])
if platform not in packages:
continue
# Normalize torch version (2.9 -> 2.9.1 etc)
# The wheel has minor version only, but matrix uses full version
key = (
info["flash_version"],
info["python_version"],
info["torch_version"], # This is like "2.9", not "2.9.1"
info["cuda_version"],
)
packages[platform].add(key)
return packages
def normalize_torch_version(version: str) -> str:
"""Convert full torch version to minor version for comparison.
Example: 2.9.1 -> 2.9, 2.10.0 -> 2.10
"""
parts = version.split(".")
if len(parts) >= 2:
return f"{parts[0]}.{parts[1]}"
return version
def generate_expected_matrix(matrix: dict) -> list[tuple]:
"""Generate all expected combinations from a matrix definition."""
combinations = []
for flash in matrix.get("flash-attn-version", []):
for python in matrix.get("python-version", []):
for torch in matrix.get("torch-version", []):
for cuda in matrix.get("cuda-version", []):
combinations.append((flash, python, torch, cuda))
return combinations
def create_status_table(
platform_name: str,
flash_version: str,
matrix: dict,
existing: set[tuple],
console: Console,
) -> tuple[Table, int, int, int]:
"""Create a rich table for a specific platform and flash-attn version.
Returns:
Tuple of (table, existing_count, missing_count, excluded_count)
"""
python_versions = sorted(matrix.get("python-version", []), key=parse_version_tuple)
torch_versions = sorted(matrix.get("torch-version", []), key=parse_version_tuple)
cuda_versions = sorted(matrix.get("cuda-version", []), key=parse_version_tuple)
# Create table
table = Table(
title=f"{platform_name} - Flash-Attention {flash_version}",
show_header=True,
header_style="bold cyan",
border_style="dim",
)
# Add Python column
table.add_column("Python", style="bold", justify="center")
# Add Torch/CUDA columns - group by torch version
for torch in torch_versions:
torch_minor = normalize_torch_version(torch)
for cuda in cuda_versions:
table.add_column(
f"T{torch_minor}\nCU{cuda}",
justify="center",
min_width=6,
)
existing_count = 0
missing_count = 0
excluded_count = 0
# Add rows for each Python version
for python in python_versions:
row = [f"cp{python.replace('.', '')}"]
for torch in torch_versions:
torch_minor = normalize_torch_version(torch)
for cuda in cuda_versions:
# Check status
key = (flash_version, python, torch_minor, cuda)
is_excl = is_excluded(flash_version, python, torch, cuda)
if is_excl:
cell = Text("-", style="dim")
excluded_count += 1
elif key in existing:
cell = Text("", style="bold green")
existing_count += 1
else:
cell = Text("", style="bold red")
missing_count += 1
row.append(cell)
table.add_row(*row)
return table, existing_count, missing_count, excluded_count
def display_platform_tables(
platform: str,
matrix: dict,
existing_packages: set[tuple],
console: Console,
flash_version_filter: str | None = None,
show_missing_only: bool = False,
) -> dict:
"""Display tables for a platform and return summary statistics."""
platform_display_names = {
"linux": "🐧 Linux x86_64",
"linux_arm64": "🐧 Linux ARM64",
"windows": "🪟 Windows",
}
platform_name = platform_display_names.get(platform, platform)
flash_versions = matrix.get("flash-attn-version", [])
if flash_version_filter:
flash_versions = [v for v in flash_versions if v == flash_version_filter]
total_existing = 0
total_missing = 0
total_excluded = 0
missing_packages = []
for flash_version in flash_versions:
table, existing, missing, excluded = create_status_table(
platform_name,
flash_version,
matrix,
existing_packages,
console,
)
total_existing += existing
total_missing += missing
total_excluded += excluded
# Collect missing packages for summary
if missing > 0:
for python in matrix.get("python-version", []):
for torch in matrix.get("torch-version", []):
torch_minor = normalize_torch_version(torch)
for cuda in matrix.get("cuda-version", []):
key = (flash_version, python, torch_minor, cuda)
is_excl = is_excluded(flash_version, python, torch, cuda)
if not is_excl and key not in existing_packages:
missing_packages.append(
{
"platform": platform,
"flash_version": flash_version,
"python_version": python,
"torch_version": torch,
"cuda_version": cuda,
}
)
# Show table only if there are missing packages (when --show-missing-only)
if not show_missing_only or missing > 0:
console.print(table)
console.print()
return {
"existing": total_existing,
"missing": total_missing,
"excluded": total_excluded,
"missing_packages": missing_packages,
}
def main() -> None:
parser = argparse.ArgumentParser(
description="Check missing packages by comparing GitHub releases with expected matrix"
)
parser.add_argument(
"--repo",
type=str,
default="mjun0812/flash-attention-prebuild-wheels",
help="GitHub repository (default: mjun0812/flash-attention-prebuild-wheels)",
)
parser.add_argument(
"--cache",
action="store_true",
help="Use assets.json as cache (load if exists, save after fetch)",
)
parser.add_argument(
"--cache-file",
type=str,
default="assets.json",
help="Cache file path (default: assets.json)",
)
parser.add_argument(
"--platform",
type=str,
choices=["linux", "linux_arm64", "windows", "all"],
default="all",
help="Platform to display (default: all)",
)
parser.add_argument(
"--flash-version",
type=str,
help="Filter by specific flash-attn version",
)
parser.add_argument(
"--show-missing-only",
action="store_true",
help="Only show tables with missing packages",
)
parser.add_argument(
"--list-missing",
action="store_true",
help="List all missing packages at the end",
)
args = parser.parse_args()
console = Console()
cache_path = Path(args.cache_file)
# Load or fetch assets
assets = load_or_fetch_assets(args.repo, cache_path, args.cache)
# Build existing packages set
existing_packages = build_existing_packages_set(assets)
# Determine which platforms to process
platforms = ["linux", "linux_arm64", "windows"]
if args.platform != "all":
platforms = [args.platform]
# Display tables and collect statistics
all_stats = {}
all_missing = []
console.print()
console.rule("[bold blue]Flash-Attention Package Status", style="blue")
console.print()
for platform in platforms:
matrix = get_comprehensive_matrix(platform)
if not matrix.get("flash-attn-version"):
continue
stats = display_platform_tables(
platform,
matrix,
existing_packages.get(platform, set()),
console,
flash_version_filter=args.flash_version,
show_missing_only=args.show_missing_only,
)
all_stats[platform] = stats
all_missing.extend(stats["missing_packages"])
# Display summary
console.rule("[bold blue]Summary", style="blue")
console.print()
summary_table = Table(show_header=True, header_style="bold")
summary_table.add_column("Platform", style="bold")
summary_table.add_column("Existing", justify="right", style="green")
summary_table.add_column("Missing", justify="right", style="red")
summary_table.add_column("Excluded", justify="right", style="dim")
summary_table.add_column("Coverage", justify="right")
total_existing = 0
total_missing = 0
total_excluded = 0
for platform, stats in all_stats.items():
existing = stats["existing"]
missing = stats["missing"]
excluded = stats["excluded"]
total = existing + missing
total_existing += existing
total_missing += missing
total_excluded += excluded
coverage = f"{existing / total * 100:.1f}%" if total > 0 else "N/A"
coverage_style = (
"green" if missing == 0 else "yellow" if existing > missing else "red"
)
summary_table.add_row(
platform,
str(existing),
str(missing),
str(excluded),
Text(coverage, style=coverage_style),
)
# Add total row
grand_total = total_existing + total_missing
grand_coverage = (
f"{total_existing / grand_total * 100:.1f}%" if grand_total > 0 else "N/A"
)
summary_table.add_row(
Text("TOTAL", style="bold"),
Text(str(total_existing), style="bold green"),
Text(str(total_missing), style="bold red"),
Text(str(total_excluded), style="dim"),
Text(grand_coverage, style="bold"),
)
console.print(summary_table)
console.print()
# List missing packages if requested
if args.list_missing and all_missing:
console.rule("[bold red]Missing Packages", style="red")
console.print()
missing_table = Table(show_header=True, header_style="bold red")
missing_table.add_column("Platform")
missing_table.add_column("Flash-Attn")
missing_table.add_column("Python")
missing_table.add_column("Torch")
missing_table.add_column("CUDA")
for pkg in sorted(
all_missing,
key=lambda x: (
x["platform"],
parse_version_tuple(x["flash_version"]),
parse_version_tuple(x["python_version"]),
parse_version_tuple(x["torch_version"]),
parse_version_tuple(x["cuda_version"]),
),
):
missing_table.add_row(
pkg["platform"],
pkg["flash_version"],
pkg["python_version"],
pkg["torch_version"],
pkg["cuda_version"],
)
console.print(missing_table)
console.print()
if __name__ == "__main__":
main()