diff --git a/.claude/plans/2026-01-25-fix-update-docs-ci.md b/.claude/plans/2026-01-25-fix-update-docs-ci.md deleted file mode 100644 index 703020f..0000000 --- a/.claude/plans/2026-01-25-fix-update-docs-ci.md +++ /dev/null @@ -1,80 +0,0 @@ -# Fix Update Docs CI Failure - -## 問題の概要 - -GitHub Actions の `Update Docs` ジョブが失敗している。 - -### エラーの原因 - -`create_packages.py` の 310行目でエラーが発生: - -``` -ValueError: setting an array element with a sequence. -``` - -### 根本原因 - -**pandas 3.0.0 の破壊的変更**: pandas 3.0.0 では、デフォルトの文字列型が `StringDtype` に変更された。これにより、文字列カラムにリスト(シーケンス)を直接代入できなくなった。 - -問題のコード (`create_packages.py:310`): -```python -result["package"] = unique_packages if unique_packages else [None] -``` - -`result` は pandas Series であり、`"package"` カラムに `unique_packages`(リスト)を代入しようとしているが、pandas 3.0.0 では文字列型カラムにリストを代入できない。 - ---- - -## 修正方法 - -`combine_packages` 関数を修正して、pandas Series の代わりに辞書を返すようにする。 - -### 修正箇所 - -**ファイル**: `create_packages.py` - -**変更内容**: `combine_packages` 関数の戻り値を Series から辞書に変更 - -```python -# Before (line 306-312) -# Take the first row as base -result = group.iloc[0].copy() - -# Combine packages into a list -result["package"] = unique_packages if unique_packages else [None] - -return result - -# After -# Return as a dictionary to avoid pandas StringDtype issues -return { - "Flash-Attention": group.iloc[0]["Flash-Attention"], - "Python": group.iloc[0]["Python"], - "PyTorch": group.iloc[0]["PyTorch"], - "CUDA": group.iloc[0]["CUDA"], - "OS": group.iloc[0]["OS"], - "package": unique_packages if unique_packages else [None], -} -``` - ---- - -## 検証方法 - -1. ローカルで `create_packages.py` を実行して動作確認 - ```bash - # テスト用のassets.jsonを取得 - gh release view v0.7.12 --json assets > /tmp/assets.json - - # スクリプトを実行 - python create_packages.py --assets /tmp/assets.json --output /tmp/packages.md - ``` - -2. エラーなく完了し、`/tmp/packages.md` が正しく生成されることを確認 - ---- - -## 影響範囲 - -- `create_packages.py` の `combine_packages` 関数のみ変更 -- 出力結果(Markdown ファイル)への影響なし diff --git a/.gitignore b/.gitignore index c9e6b29..85c4f76 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ .DS_Store .env* __pycache__/ -.ruff_cache/ \ No newline at end of file +.ruff_cache/ +assets.json +.claude/plans diff --git a/check_missing_packages.py b/check_missing_packages.py new file mode 100644 index 0000000..840fc9d --- /dev/null +++ b/check_missing_packages.py @@ -0,0 +1,583 @@ +#!/usr/bin/env python3 +"""Check missing packages by comparing GitHub releases with expected matrix. + +This script fetches wheel assets from GitHub releases and compares them with +the expected package matrix defined in create_matrix.py. It displays a colored +table showing which packages exist, are missing, or are excluded. + +Usage: + python check_missing_packages.py + python check_missing_packages.py --cache + python check_missing_packages.py --platform linux --flash-version 2.8.3 + python check_missing_packages.py --show-missing-only +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +import requests +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from common import parse_wheel_filename +from create_matrix import ( + EXCLUDE, + LINUX_ARM64_MATRIX, + LINUX_ARM64_SELF_HOSTED_MATRIX, + LINUX_MATRIX, + LINUX_SELF_HOSTED_MATRIX, + WINDOWS_CODEBUILD_MATRIX, + WINDOWS_MATRIX, + WINDOWS_SELF_HOSTED_MATRIX, +) + + +# Comprehensive matrix combining all platform-specific matrices +def get_comprehensive_matrix(platform: str) -> dict: + """Get comprehensive matrix for a platform by merging all related matrices.""" + if platform == "linux": + # Merge LINUX_MATRIX and LINUX_SELF_HOSTED_MATRIX + return merge_matrices([LINUX_MATRIX, LINUX_SELF_HOSTED_MATRIX]) + elif platform == "linux_arm64": + return merge_matrices([LINUX_ARM64_MATRIX, LINUX_ARM64_SELF_HOSTED_MATRIX]) + elif platform == "windows": + return merge_matrices( + [WINDOWS_MATRIX, WINDOWS_SELF_HOSTED_MATRIX, WINDOWS_CODEBUILD_MATRIX] + ) + else: + return {} + + +def merge_matrices(matrices: list[dict]) -> dict: + """Merge multiple matrices by combining their version lists.""" + merged = { + "flash-attn-version": set(), + "python-version": set(), + "torch-version": set(), + "cuda-version": set(), + } + for matrix in matrices: + for key in merged: + merged[key].update(matrix.get(key, [])) + # Convert sets to sorted lists + return {key: sorted(vals, key=parse_version_tuple) for key, vals in merged.items()} + + +def parse_version_tuple(version: str) -> tuple: + """Parse version string to tuple for sorting.""" + parts = version.replace("post", ".").split(".") + result = [] + for p in parts: + try: + result.append(int(p)) + except ValueError: + result.append(0) + return tuple(result) + + +def get_github_token() -> str | None: + """Get GitHub token from environment variable.""" + token = os.environ.get("GITHUB_TOKEN") + if not token: + print( + "Warning: GITHUB_TOKEN not set. API rate limit will be restricted.", + file=sys.stderr, + ) + return token + + +def fetch_all_releases(repo: str, token: str | None = None) -> list[dict]: + """Fetch all releases from a GitHub repository.""" + headers = {} + if token: + headers["Authorization"] = f"token {token}" + headers["Accept"] = "application/vnd.github.v3+json" + + all_releases = [] + page = 1 + per_page = 100 + + while True: + url = f"https://api.github.com/repos/{repo}/releases" + params = {"page": page, "per_page": per_page} + + print(f"Fetching releases page {page}...", file=sys.stderr) + response = requests.get(url, headers=headers, params=params, timeout=30) + + if response.status_code != 200: + print( + f"Error fetching releases: {response.status_code} - {response.text}", + file=sys.stderr, + ) + break + + releases = response.json() + if not releases: + break + + all_releases.extend(releases) + print(f" Found {len(releases)} releases on page {page}", file=sys.stderr) + + if len(releases) < per_page: + break + + page += 1 + time.sleep(0.5) + + return all_releases + + +def extract_assets_from_releases(releases: list[dict]) -> list[dict]: + """Extract all wheel assets from releases.""" + all_assets = [] + + for release in releases: + for asset in release.get("assets", []): + name = asset.get("name", "") + if not name.endswith(".whl"): + continue + asset_info = { + "name": name, + "url": asset.get("browser_download_url", ""), + } + all_assets.append(asset_info) + + return all_assets + + +def load_or_fetch_assets(repo: str, cache_path: Path, use_cache: bool) -> list[dict]: + """Load assets from cache or fetch from GitHub.""" + if use_cache and cache_path.exists(): + print(f"Loading assets from cache: {cache_path}", file=sys.stderr) + with cache_path.open("r", encoding="utf-8") as f: + data = json.load(f) + return data.get("assets", []) + + token = get_github_token() + print(f"Fetching all releases from {repo}...", file=sys.stderr) + releases = fetch_all_releases(repo, token) + print(f"Total releases found: {len(releases)}", file=sys.stderr) + + assets = extract_assets_from_releases(releases) + print(f"Total wheel assets found: {len(assets)}", file=sys.stderr) + + if use_cache: + print(f"Saving assets to cache: {cache_path}", file=sys.stderr) + with cache_path.open("w", encoding="utf-8") as f: + json.dump({"assets": assets}, f, indent=2, ensure_ascii=False) + + return assets + + +def is_excluded( + flash_version: str, + python_version: str, + torch_version: str, + cuda_version: str, +) -> bool: + """Check if a combination is in the EXCLUDE list.""" + for excl in EXCLUDE: + match = True + if "flash-attn-version" in excl and excl["flash-attn-version"] != flash_version: + match = False + if "python-version" in excl and excl["python-version"] != python_version: + match = False + if "torch-version" in excl and excl["torch-version"] != torch_version: + match = False + if "cuda-version" in excl and excl["cuda-version"] != cuda_version: + match = False + if match: + return True + return False + + +def normalize_platform_for_comparison(platform_raw: str) -> str: + """Normalize platform string for comparison. + + Returns: "linux", "linux_arm64", or "windows" + """ + platform_lower = platform_raw.lower() + if "win" in platform_lower: + return "windows" + elif "aarch64" in platform_lower or "arm64" in platform_lower: + return "linux_arm64" + elif "x86_64" in platform_lower or "linux" in platform_lower: + return "linux" + else: + return platform_lower + + +def build_existing_packages_set(assets: list[dict]) -> dict[str, set[tuple]]: + """Build a set of existing packages grouped by normalized platform. + + Returns: + Dict mapping platform to set of (flash, python, torch, cuda) tuples + """ + packages: dict[str, set[tuple]] = { + "linux": set(), + "linux_arm64": set(), + "windows": set(), + } + + for asset in assets: + name = asset.get("name", "") + info = parse_wheel_filename(name) + if not info: + continue + + platform = normalize_platform_for_comparison(info["platform"]) + if platform not in packages: + continue + + # Normalize torch version (2.9 -> 2.9.1 etc) + # The wheel has minor version only, but matrix uses full version + key = ( + info["flash_version"], + info["python_version"], + info["torch_version"], # This is like "2.9", not "2.9.1" + info["cuda_version"], + ) + packages[platform].add(key) + + return packages + + +def normalize_torch_version(version: str) -> str: + """Convert full torch version to minor version for comparison. + + Example: 2.9.1 -> 2.9, 2.10.0 -> 2.10 + """ + parts = version.split(".") + if len(parts) >= 2: + return f"{parts[0]}.{parts[1]}" + return version + + +def generate_expected_matrix(matrix: dict) -> list[tuple]: + """Generate all expected combinations from a matrix definition.""" + combinations = [] + for flash in matrix.get("flash-attn-version", []): + for python in matrix.get("python-version", []): + for torch in matrix.get("torch-version", []): + for cuda in matrix.get("cuda-version", []): + combinations.append((flash, python, torch, cuda)) + return combinations + + +def create_status_table( + platform_name: str, + flash_version: str, + matrix: dict, + existing: set[tuple], + console: Console, +) -> tuple[Table, int, int, int]: + """Create a rich table for a specific platform and flash-attn version. + + Returns: + Tuple of (table, existing_count, missing_count, excluded_count) + """ + python_versions = sorted(matrix.get("python-version", []), key=parse_version_tuple) + torch_versions = sorted(matrix.get("torch-version", []), key=parse_version_tuple) + cuda_versions = sorted(matrix.get("cuda-version", []), key=parse_version_tuple) + + # Create table + table = Table( + title=f"{platform_name} - Flash-Attention {flash_version}", + show_header=True, + header_style="bold cyan", + border_style="dim", + ) + + # Add Python column + table.add_column("Python", style="bold", justify="center") + + # Add Torch/CUDA columns - group by torch version + for torch in torch_versions: + torch_minor = normalize_torch_version(torch) + for cuda in cuda_versions: + table.add_column( + f"T{torch_minor}\nCU{cuda}", + justify="center", + min_width=6, + ) + + existing_count = 0 + missing_count = 0 + excluded_count = 0 + + # Add rows for each Python version + for python in python_versions: + row = [f"cp{python.replace('.', '')}"] + + for torch in torch_versions: + torch_minor = normalize_torch_version(torch) + for cuda in cuda_versions: + # Check status + key = (flash_version, python, torch_minor, cuda) + is_excl = is_excluded(flash_version, python, torch, cuda) + + if is_excl: + cell = Text("-", style="dim") + excluded_count += 1 + elif key in existing: + cell = Text("✓", style="bold green") + existing_count += 1 + else: + cell = Text("✗", style="bold red") + missing_count += 1 + + row.append(cell) + + table.add_row(*row) + + return table, existing_count, missing_count, excluded_count + + +def display_platform_tables( + platform: str, + matrix: dict, + existing_packages: set[tuple], + console: Console, + flash_version_filter: str | None = None, + show_missing_only: bool = False, +) -> dict: + """Display tables for a platform and return summary statistics.""" + platform_display_names = { + "linux": "🐧 Linux x86_64", + "linux_arm64": "🐧 Linux ARM64", + "windows": "🪟 Windows", + } + platform_name = platform_display_names.get(platform, platform) + + flash_versions = matrix.get("flash-attn-version", []) + if flash_version_filter: + flash_versions = [v for v in flash_versions if v == flash_version_filter] + + total_existing = 0 + total_missing = 0 + total_excluded = 0 + missing_packages = [] + + for flash_version in flash_versions: + table, existing, missing, excluded = create_status_table( + platform_name, + flash_version, + matrix, + existing_packages, + console, + ) + + total_existing += existing + total_missing += missing + total_excluded += excluded + + # Collect missing packages for summary + if missing > 0: + for python in matrix.get("python-version", []): + for torch in matrix.get("torch-version", []): + torch_minor = normalize_torch_version(torch) + for cuda in matrix.get("cuda-version", []): + key = (flash_version, python, torch_minor, cuda) + is_excl = is_excluded(flash_version, python, torch, cuda) + if not is_excl and key not in existing_packages: + missing_packages.append( + { + "platform": platform, + "flash_version": flash_version, + "python_version": python, + "torch_version": torch, + "cuda_version": cuda, + } + ) + + # Show table only if there are missing packages (when --show-missing-only) + if not show_missing_only or missing > 0: + console.print(table) + console.print() + + return { + "existing": total_existing, + "missing": total_missing, + "excluded": total_excluded, + "missing_packages": missing_packages, + } + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Check missing packages by comparing GitHub releases with expected matrix" + ) + parser.add_argument( + "--repo", + type=str, + default="mjun0812/flash-attention-prebuild-wheels", + help="GitHub repository (default: mjun0812/flash-attention-prebuild-wheels)", + ) + parser.add_argument( + "--cache", + action="store_true", + help="Use assets.json as cache (load if exists, save after fetch)", + ) + parser.add_argument( + "--cache-file", + type=str, + default="assets.json", + help="Cache file path (default: assets.json)", + ) + parser.add_argument( + "--platform", + type=str, + choices=["linux", "linux_arm64", "windows", "all"], + default="all", + help="Platform to display (default: all)", + ) + parser.add_argument( + "--flash-version", + type=str, + help="Filter by specific flash-attn version", + ) + parser.add_argument( + "--show-missing-only", + action="store_true", + help="Only show tables with missing packages", + ) + parser.add_argument( + "--list-missing", + action="store_true", + help="List all missing packages at the end", + ) + args = parser.parse_args() + + console = Console() + cache_path = Path(args.cache_file) + + # Load or fetch assets + assets = load_or_fetch_assets(args.repo, cache_path, args.cache) + + # Build existing packages set + existing_packages = build_existing_packages_set(assets) + + # Determine which platforms to process + platforms = ["linux", "linux_arm64", "windows"] + if args.platform != "all": + platforms = [args.platform] + + # Display tables and collect statistics + all_stats = {} + all_missing = [] + + console.print() + console.rule("[bold blue]Flash-Attention Package Status", style="blue") + console.print() + + for platform in platforms: + matrix = get_comprehensive_matrix(platform) + if not matrix.get("flash-attn-version"): + continue + + stats = display_platform_tables( + platform, + matrix, + existing_packages.get(platform, set()), + console, + flash_version_filter=args.flash_version, + show_missing_only=args.show_missing_only, + ) + all_stats[platform] = stats + all_missing.extend(stats["missing_packages"]) + + # Display summary + console.rule("[bold blue]Summary", style="blue") + console.print() + + summary_table = Table(show_header=True, header_style="bold") + summary_table.add_column("Platform", style="bold") + summary_table.add_column("Existing", justify="right", style="green") + summary_table.add_column("Missing", justify="right", style="red") + summary_table.add_column("Excluded", justify="right", style="dim") + summary_table.add_column("Coverage", justify="right") + + total_existing = 0 + total_missing = 0 + total_excluded = 0 + + for platform, stats in all_stats.items(): + existing = stats["existing"] + missing = stats["missing"] + excluded = stats["excluded"] + total = existing + missing + + total_existing += existing + total_missing += missing + total_excluded += excluded + + coverage = f"{existing / total * 100:.1f}%" if total > 0 else "N/A" + coverage_style = ( + "green" if missing == 0 else "yellow" if existing > missing else "red" + ) + + summary_table.add_row( + platform, + str(existing), + str(missing), + str(excluded), + Text(coverage, style=coverage_style), + ) + + # Add total row + grand_total = total_existing + total_missing + grand_coverage = ( + f"{total_existing / grand_total * 100:.1f}%" if grand_total > 0 else "N/A" + ) + summary_table.add_row( + Text("TOTAL", style="bold"), + Text(str(total_existing), style="bold green"), + Text(str(total_missing), style="bold red"), + Text(str(total_excluded), style="dim"), + Text(grand_coverage, style="bold"), + ) + + console.print(summary_table) + console.print() + + # List missing packages if requested + if args.list_missing and all_missing: + console.rule("[bold red]Missing Packages", style="red") + console.print() + + missing_table = Table(show_header=True, header_style="bold red") + missing_table.add_column("Platform") + missing_table.add_column("Flash-Attn") + missing_table.add_column("Python") + missing_table.add_column("Torch") + missing_table.add_column("CUDA") + + for pkg in sorted( + all_missing, + key=lambda x: ( + x["platform"], + parse_version_tuple(x["flash_version"]), + parse_version_tuple(x["python_version"]), + parse_version_tuple(x["torch_version"]), + parse_version_tuple(x["cuda_version"]), + ), + ): + missing_table.add_row( + pkg["platform"], + pkg["flash_version"], + pkg["python_version"], + pkg["torch_version"], + pkg["cuda_version"], + ) + + console.print(missing_table) + console.print() + + +if __name__ == "__main__": + main()