mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:17:55 -04:00
268c7013f1
Co-authored-by: mjun0812 <77187490+mjun0812@users.noreply.github.com>
558 lines
17 KiB
Python
558 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""Check missing packages by comparing GitHub releases with expected matrix.
|
|
|
|
This script fetches wheel assets from GitHub releases and compares them with
|
|
the expected package matrix defined in create_matrix.py. It displays a colored
|
|
table showing which packages exist, are missing, or are excluded.
|
|
|
|
Usage:
|
|
python check_missing_packages.py
|
|
python check_missing_packages.py --cache
|
|
python check_missing_packages.py --platform linux --flash-version 2.8.3
|
|
python check_missing_packages.py --show-missing-only
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
from rich.text import Text
|
|
|
|
from common import parse_wheel_filename
|
|
from coverage_matrix import LINUX_ARM64_MATRIX, LINUX_MATRIX, WINDOWS_MATRIX
|
|
from create_matrix import EXCLUDE
|
|
|
|
|
|
# Comprehensive matrix combining all platform-specific matrices
|
|
def get_comprehensive_matrix(platform: str) -> dict:
|
|
"""Get comprehensive matrix for a platform."""
|
|
if platform == "linux":
|
|
return LINUX_MATRIX
|
|
elif platform == "linux_arm64":
|
|
return LINUX_ARM64_MATRIX
|
|
elif platform == "windows":
|
|
return WINDOWS_MATRIX
|
|
else:
|
|
return {}
|
|
|
|
|
|
def parse_version_tuple(version: str) -> tuple:
|
|
"""Parse version string to tuple for sorting."""
|
|
parts = version.replace("post", ".").split(".")
|
|
result = []
|
|
for p in parts:
|
|
try:
|
|
result.append(int(p))
|
|
except ValueError:
|
|
result.append(0)
|
|
return tuple(result)
|
|
|
|
|
|
def get_github_token() -> str | None:
|
|
"""Get GitHub token from environment variable."""
|
|
token = os.environ.get("GITHUB_TOKEN")
|
|
if not token:
|
|
print(
|
|
"Warning: GITHUB_TOKEN not set. API rate limit will be restricted.",
|
|
file=sys.stderr,
|
|
)
|
|
return token
|
|
|
|
|
|
def fetch_all_releases(repo: str, token: str | None = None) -> list[dict]:
|
|
"""Fetch all releases from a GitHub repository."""
|
|
headers = {}
|
|
if token:
|
|
headers["Authorization"] = f"token {token}"
|
|
headers["Accept"] = "application/vnd.github.v3+json"
|
|
|
|
all_releases = []
|
|
page = 1
|
|
per_page = 100
|
|
|
|
while True:
|
|
url = f"https://api.github.com/repos/{repo}/releases"
|
|
params = {"page": page, "per_page": per_page}
|
|
|
|
print(f"Fetching releases page {page}...", file=sys.stderr)
|
|
response = requests.get(url, headers=headers, params=params, timeout=30)
|
|
|
|
if response.status_code != 200:
|
|
print(
|
|
f"Error fetching releases: {response.status_code} - {response.text}",
|
|
file=sys.stderr,
|
|
)
|
|
break
|
|
|
|
releases = response.json()
|
|
if not releases:
|
|
break
|
|
|
|
all_releases.extend(releases)
|
|
print(f" Found {len(releases)} releases on page {page}", file=sys.stderr)
|
|
|
|
if len(releases) < per_page:
|
|
break
|
|
|
|
page += 1
|
|
time.sleep(0.5)
|
|
|
|
return all_releases
|
|
|
|
|
|
def extract_assets_from_releases(releases: list[dict]) -> list[dict]:
|
|
"""Extract all wheel assets from releases."""
|
|
all_assets = []
|
|
|
|
for release in releases:
|
|
for asset in release.get("assets", []):
|
|
name = asset.get("name", "")
|
|
if not name.endswith(".whl"):
|
|
continue
|
|
asset_info = {
|
|
"name": name,
|
|
"url": asset.get("browser_download_url", ""),
|
|
}
|
|
all_assets.append(asset_info)
|
|
|
|
return all_assets
|
|
|
|
|
|
def load_or_fetch_assets(repo: str, cache_path: Path, use_cache: bool) -> list[dict]:
|
|
"""Load assets from cache or fetch from GitHub."""
|
|
if use_cache and cache_path.exists():
|
|
print(f"Loading assets from cache: {cache_path}", file=sys.stderr)
|
|
with cache_path.open("r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
return data.get("assets", [])
|
|
|
|
token = get_github_token()
|
|
print(f"Fetching all releases from {repo}...", file=sys.stderr)
|
|
releases = fetch_all_releases(repo, token)
|
|
print(f"Total releases found: {len(releases)}", file=sys.stderr)
|
|
|
|
assets = extract_assets_from_releases(releases)
|
|
print(f"Total wheel assets found: {len(assets)}", file=sys.stderr)
|
|
|
|
if use_cache:
|
|
print(f"Saving assets to cache: {cache_path}", file=sys.stderr)
|
|
with cache_path.open("w", encoding="utf-8") as f:
|
|
json.dump({"assets": assets}, f, indent=2, ensure_ascii=False)
|
|
|
|
return assets
|
|
|
|
|
|
def is_excluded(
|
|
flash_version: str,
|
|
python_version: str,
|
|
torch_version: str,
|
|
cuda_version: str,
|
|
) -> bool:
|
|
"""Check if a combination is in the EXCLUDE list."""
|
|
for excl in EXCLUDE:
|
|
match = True
|
|
if "flash-attn-version" in excl and excl["flash-attn-version"] != flash_version:
|
|
match = False
|
|
if "python-version" in excl and excl["python-version"] != python_version:
|
|
match = False
|
|
if "torch-version" in excl and excl["torch-version"] != torch_version:
|
|
match = False
|
|
if "cuda-version" in excl and excl["cuda-version"] != cuda_version:
|
|
match = False
|
|
if match:
|
|
return True
|
|
return False
|
|
|
|
|
|
def normalize_platform_for_comparison(platform_raw: str) -> str:
|
|
"""Normalize platform string for comparison.
|
|
|
|
Returns: "linux", "linux_arm64", or "windows"
|
|
"""
|
|
platform_lower = platform_raw.lower()
|
|
if "win" in platform_lower:
|
|
return "windows"
|
|
elif "aarch64" in platform_lower or "arm64" in platform_lower:
|
|
return "linux_arm64"
|
|
elif "x86_64" in platform_lower or "linux" in platform_lower:
|
|
return "linux"
|
|
else:
|
|
return platform_lower
|
|
|
|
|
|
def build_existing_packages_set(assets: list[dict]) -> dict[str, set[tuple]]:
|
|
"""Build a set of existing packages grouped by normalized platform.
|
|
|
|
Returns:
|
|
Dict mapping platform to set of (flash, python, torch, cuda) tuples
|
|
"""
|
|
packages: dict[str, set[tuple]] = {
|
|
"linux": set(),
|
|
"linux_arm64": set(),
|
|
"windows": set(),
|
|
}
|
|
|
|
for asset in assets:
|
|
name = asset.get("name", "")
|
|
info = parse_wheel_filename(name)
|
|
if not info:
|
|
continue
|
|
|
|
platform = normalize_platform_for_comparison(info["platform"])
|
|
if platform not in packages:
|
|
continue
|
|
|
|
# Normalize torch version (2.9 -> 2.9.1 etc)
|
|
# The wheel has minor version only, but matrix uses full version
|
|
key = (
|
|
info["flash_version"],
|
|
info["python_version"],
|
|
info["torch_version"], # This is like "2.9", not "2.9.1"
|
|
info["cuda_version"],
|
|
)
|
|
packages[platform].add(key)
|
|
|
|
return packages
|
|
|
|
|
|
def normalize_torch_version(version: str) -> str:
|
|
"""Convert full torch version to minor version for comparison.
|
|
|
|
Example: 2.9.1 -> 2.9, 2.10.0 -> 2.10
|
|
"""
|
|
parts = version.split(".")
|
|
if len(parts) >= 2:
|
|
return f"{parts[0]}.{parts[1]}"
|
|
return version
|
|
|
|
|
|
def generate_expected_matrix(matrix: dict) -> list[tuple]:
|
|
"""Generate all expected combinations from a matrix definition."""
|
|
combinations = []
|
|
for flash in matrix.get("flash-attn-version", []):
|
|
for python in matrix.get("python-version", []):
|
|
for torch in matrix.get("torch-version", []):
|
|
for cuda in matrix.get("cuda-version", []):
|
|
combinations.append((flash, python, torch, cuda))
|
|
return combinations
|
|
|
|
|
|
def create_status_table(
|
|
platform_name: str,
|
|
flash_version: str,
|
|
matrix: dict,
|
|
existing: set[tuple],
|
|
console: Console,
|
|
) -> tuple[Table, int, int, int]:
|
|
"""Create a rich table for a specific platform and flash-attn version.
|
|
|
|
Returns:
|
|
Tuple of (table, existing_count, missing_count, excluded_count)
|
|
"""
|
|
python_versions = sorted(matrix.get("python-version", []), key=parse_version_tuple)
|
|
torch_versions = sorted(matrix.get("torch-version", []), key=parse_version_tuple)
|
|
cuda_versions = sorted(matrix.get("cuda-version", []), key=parse_version_tuple)
|
|
|
|
# Create table
|
|
table = Table(
|
|
title=f"{platform_name} - Flash-Attention {flash_version}",
|
|
show_header=True,
|
|
header_style="bold cyan",
|
|
border_style="dim",
|
|
)
|
|
|
|
# Add Python column
|
|
table.add_column("Python", style="bold", justify="center")
|
|
|
|
# Add Torch/CUDA columns - group by torch version
|
|
for torch in torch_versions:
|
|
torch_minor = normalize_torch_version(torch)
|
|
for cuda in cuda_versions:
|
|
table.add_column(
|
|
f"T{torch_minor}\nCU{cuda}",
|
|
justify="center",
|
|
min_width=6,
|
|
)
|
|
|
|
existing_count = 0
|
|
missing_count = 0
|
|
excluded_count = 0
|
|
|
|
# Add rows for each Python version
|
|
for python in python_versions:
|
|
row = [f"cp{python.replace('.', '')}"]
|
|
|
|
for torch in torch_versions:
|
|
torch_minor = normalize_torch_version(torch)
|
|
for cuda in cuda_versions:
|
|
# Check status
|
|
key = (flash_version, python, torch_minor, cuda)
|
|
is_excl = is_excluded(flash_version, python, torch, cuda)
|
|
|
|
if is_excl:
|
|
cell = Text("-", style="dim")
|
|
excluded_count += 1
|
|
elif key in existing:
|
|
cell = Text("✓", style="bold green")
|
|
existing_count += 1
|
|
else:
|
|
cell = Text("✗", style="bold red")
|
|
missing_count += 1
|
|
|
|
row.append(cell)
|
|
|
|
table.add_row(*row)
|
|
|
|
return table, existing_count, missing_count, excluded_count
|
|
|
|
|
|
def display_platform_tables(
|
|
platform: str,
|
|
matrix: dict,
|
|
existing_packages: set[tuple],
|
|
console: Console,
|
|
flash_version_filter: str | None = None,
|
|
show_missing_only: bool = False,
|
|
) -> dict:
|
|
"""Display tables for a platform and return summary statistics."""
|
|
platform_display_names = {
|
|
"linux": "🐧 Linux x86_64",
|
|
"linux_arm64": "🐧 Linux ARM64",
|
|
"windows": "🪟 Windows",
|
|
}
|
|
platform_name = platform_display_names.get(platform, platform)
|
|
|
|
flash_versions = matrix.get("flash-attn-version", [])
|
|
if flash_version_filter:
|
|
flash_versions = [v for v in flash_versions if v == flash_version_filter]
|
|
|
|
total_existing = 0
|
|
total_missing = 0
|
|
total_excluded = 0
|
|
missing_packages = []
|
|
|
|
for flash_version in flash_versions:
|
|
table, existing, missing, excluded = create_status_table(
|
|
platform_name,
|
|
flash_version,
|
|
matrix,
|
|
existing_packages,
|
|
console,
|
|
)
|
|
|
|
total_existing += existing
|
|
total_missing += missing
|
|
total_excluded += excluded
|
|
|
|
# Collect missing packages for summary
|
|
if missing > 0:
|
|
for python in matrix.get("python-version", []):
|
|
for torch in matrix.get("torch-version", []):
|
|
torch_minor = normalize_torch_version(torch)
|
|
for cuda in matrix.get("cuda-version", []):
|
|
key = (flash_version, python, torch_minor, cuda)
|
|
is_excl = is_excluded(flash_version, python, torch, cuda)
|
|
if not is_excl and key not in existing_packages:
|
|
missing_packages.append(
|
|
{
|
|
"platform": platform,
|
|
"flash_version": flash_version,
|
|
"python_version": python,
|
|
"torch_version": torch,
|
|
"cuda_version": cuda,
|
|
}
|
|
)
|
|
|
|
# Show table only if there are missing packages (when --show-missing-only)
|
|
if not show_missing_only or missing > 0:
|
|
console.print(table)
|
|
console.print()
|
|
|
|
return {
|
|
"existing": total_existing,
|
|
"missing": total_missing,
|
|
"excluded": total_excluded,
|
|
"missing_packages": missing_packages,
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Check missing packages by comparing GitHub releases with expected matrix"
|
|
)
|
|
parser.add_argument(
|
|
"--repo",
|
|
type=str,
|
|
default="mjun0812/flash-attention-prebuild-wheels",
|
|
help="GitHub repository (default: mjun0812/flash-attention-prebuild-wheels)",
|
|
)
|
|
parser.add_argument(
|
|
"--cache",
|
|
action="store_true",
|
|
help="Use assets.json as cache (load if exists, save after fetch)",
|
|
)
|
|
parser.add_argument(
|
|
"--cache-file",
|
|
type=str,
|
|
default="assets.json",
|
|
help="Cache file path (default: assets.json)",
|
|
)
|
|
parser.add_argument(
|
|
"--platform",
|
|
type=str,
|
|
choices=["linux", "linux_arm64", "windows", "all"],
|
|
default="all",
|
|
help="Platform to display (default: all)",
|
|
)
|
|
parser.add_argument(
|
|
"--flash-version",
|
|
type=str,
|
|
help="Filter by specific flash-attn version",
|
|
)
|
|
parser.add_argument(
|
|
"--show-missing-only",
|
|
action="store_true",
|
|
help="Only show tables with missing packages",
|
|
)
|
|
parser.add_argument(
|
|
"--list-missing",
|
|
action="store_true",
|
|
help="List all missing packages at the end",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
console = Console()
|
|
cache_path = Path(args.cache_file)
|
|
|
|
# Load or fetch assets
|
|
assets = load_or_fetch_assets(args.repo, cache_path, args.cache)
|
|
|
|
# Build existing packages set
|
|
existing_packages = build_existing_packages_set(assets)
|
|
|
|
# Determine which platforms to process
|
|
platforms = ["linux", "linux_arm64", "windows"]
|
|
if args.platform != "all":
|
|
platforms = [args.platform]
|
|
|
|
# Display tables and collect statistics
|
|
all_stats = {}
|
|
all_missing = []
|
|
|
|
console.print()
|
|
console.rule("[bold blue]Flash-Attention Package Status", style="blue")
|
|
console.print()
|
|
|
|
for platform in platforms:
|
|
matrix = get_comprehensive_matrix(platform)
|
|
if not matrix.get("flash-attn-version"):
|
|
continue
|
|
|
|
stats = display_platform_tables(
|
|
platform,
|
|
matrix,
|
|
existing_packages.get(platform, set()),
|
|
console,
|
|
flash_version_filter=args.flash_version,
|
|
show_missing_only=args.show_missing_only,
|
|
)
|
|
all_stats[platform] = stats
|
|
all_missing.extend(stats["missing_packages"])
|
|
|
|
# Display summary
|
|
console.rule("[bold blue]Summary", style="blue")
|
|
console.print()
|
|
|
|
summary_table = Table(show_header=True, header_style="bold")
|
|
summary_table.add_column("Platform", style="bold")
|
|
summary_table.add_column("Existing", justify="right", style="green")
|
|
summary_table.add_column("Missing", justify="right", style="red")
|
|
summary_table.add_column("Excluded", justify="right", style="dim")
|
|
summary_table.add_column("Coverage", justify="right")
|
|
|
|
total_existing = 0
|
|
total_missing = 0
|
|
total_excluded = 0
|
|
|
|
for platform, stats in all_stats.items():
|
|
existing = stats["existing"]
|
|
missing = stats["missing"]
|
|
excluded = stats["excluded"]
|
|
total = existing + missing
|
|
|
|
total_existing += existing
|
|
total_missing += missing
|
|
total_excluded += excluded
|
|
|
|
coverage = f"{existing / total * 100:.1f}%" if total > 0 else "N/A"
|
|
coverage_style = (
|
|
"green" if missing == 0 else "yellow" if existing > missing else "red"
|
|
)
|
|
|
|
summary_table.add_row(
|
|
platform,
|
|
str(existing),
|
|
str(missing),
|
|
str(excluded),
|
|
Text(coverage, style=coverage_style),
|
|
)
|
|
|
|
# Add total row
|
|
grand_total = total_existing + total_missing
|
|
grand_coverage = (
|
|
f"{total_existing / grand_total * 100:.1f}%" if grand_total > 0 else "N/A"
|
|
)
|
|
summary_table.add_row(
|
|
Text("TOTAL", style="bold"),
|
|
Text(str(total_existing), style="bold green"),
|
|
Text(str(total_missing), style="bold red"),
|
|
Text(str(total_excluded), style="dim"),
|
|
Text(grand_coverage, style="bold"),
|
|
)
|
|
|
|
console.print(summary_table)
|
|
console.print()
|
|
|
|
# List missing packages if requested
|
|
if args.list_missing and all_missing:
|
|
console.rule("[bold red]Missing Packages", style="red")
|
|
console.print()
|
|
|
|
missing_table = Table(show_header=True, header_style="bold red")
|
|
missing_table.add_column("Platform")
|
|
missing_table.add_column("Flash-Attn")
|
|
missing_table.add_column("Python")
|
|
missing_table.add_column("Torch")
|
|
missing_table.add_column("CUDA")
|
|
|
|
for pkg in sorted(
|
|
all_missing,
|
|
key=lambda x: (
|
|
x["platform"],
|
|
parse_version_tuple(x["flash_version"]),
|
|
parse_version_tuple(x["python_version"]),
|
|
parse_version_tuple(x["torch_version"]),
|
|
parse_version_tuple(x["cuda_version"]),
|
|
),
|
|
):
|
|
missing_table.add_row(
|
|
pkg["platform"],
|
|
pkg["flash_version"],
|
|
pkg["python_version"],
|
|
pkg["torch_version"],
|
|
pkg["cuda_version"],
|
|
)
|
|
|
|
console.print(missing_table)
|
|
console.print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|