mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-06-30 23:57:53 -04:00
chore: update script for docs
This commit is contained in:
@@ -169,7 +169,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install pandas
|
||||
|
||||
- name: Update release history and packages section in README.md
|
||||
- name: Update docs
|
||||
run: |
|
||||
gh release view "${{ github.ref_name }}" --json assets > /tmp/assets.json
|
||||
python create_release_history.py \
|
||||
@@ -177,7 +177,7 @@ jobs:
|
||||
--tag "${{ github.ref_name }}" \
|
||||
--repo "${{ github.repository }}" \
|
||||
--output docs/release_history.md
|
||||
python insert_packages_to_readme.py --assets /tmp/assets.json --update
|
||||
python create_packages.py --assets /tmp/assets.json --output docs/packages.md
|
||||
|
||||
- name: Commit and push docs updates
|
||||
run: |
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""
|
||||
python insert_packages_to_readme.py --assets assets.json --update
|
||||
Create and update docs/packages.md from assets.json
|
||||
|
||||
Usage:
|
||||
python create_packages.py --assets assets.json --output docs/packages.md
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -38,49 +41,39 @@ def normalize_semantic_version(version: str) -> str:
|
||||
return version
|
||||
|
||||
|
||||
def extract_packages_from_readme(readme_path: Path) -> list[dict]:
|
||||
"""Extract package information from existing Packages section in README.md."""
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find Packages section
|
||||
packages_start = content.find("## Packages")
|
||||
if packages_start == -1:
|
||||
def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
|
||||
"""Extract package information from existing docs/packages.md."""
|
||||
if not packages_md_path.exists():
|
||||
return []
|
||||
|
||||
# Find the end of Packages section
|
||||
packages_end = content.find("## History", packages_start)
|
||||
if packages_end == -1:
|
||||
remaining_content = content[packages_start + len("## Packages") :]
|
||||
next_section = remaining_content.find("\n## ")
|
||||
if next_section != -1:
|
||||
packages_end = packages_start + len("## Packages") + next_section
|
||||
else:
|
||||
packages_end = len(content)
|
||||
with packages_md_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
packages_section = content[packages_start:packages_end]
|
||||
lines = packages_section.splitlines()
|
||||
lines = content.splitlines()
|
||||
|
||||
packages = []
|
||||
current_os = None
|
||||
current_fa_version = None
|
||||
in_table = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
for line in lines:
|
||||
line_stripped = line.strip()
|
||||
|
||||
# Detect OS heading (### Linux x86_64)
|
||||
if line_stripped.startswith("### ") and not line_stripped.startswith("#### "):
|
||||
current_os = line_stripped[4:].strip()
|
||||
# Detect OS heading (## Linux x86_64)
|
||||
if line_stripped.startswith("## ") and not line_stripped.startswith("### "):
|
||||
# Remove emoji from OS name (e.g., "🐧 Linux x86_64" -> "Linux x86_64")
|
||||
os_name = line_stripped[3:].strip()
|
||||
while os_name and ord(os_name[0]) > 127:
|
||||
os_name = os_name[1:].strip()
|
||||
current_os = os_name
|
||||
current_fa_version = None
|
||||
in_table = False
|
||||
continue
|
||||
|
||||
# Detect Flash-Attention version heading (#### Flash-Attention 2.8.3)
|
||||
if line_stripped.startswith("#### Flash-Attention "):
|
||||
# Extract version after "#### Flash-Attention "
|
||||
# Detect Flash-Attention version heading (### Flash-Attention 2.8.3)
|
||||
if line_stripped.startswith("### Flash-Attention "):
|
||||
current_fa_version = line_stripped.replace(
|
||||
"#### Flash-Attention ", ""
|
||||
"### Flash-Attention ", ""
|
||||
).strip()
|
||||
in_table = False
|
||||
continue
|
||||
@@ -337,8 +330,26 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
||||
|
||||
all_sections = []
|
||||
|
||||
# Generate Table of Contents
|
||||
os_names = sorted(df["OS"].unique())
|
||||
toc_lines = ["## Table of Contents", ""]
|
||||
for os_name in os_names:
|
||||
# Create anchor link (lowercase, replace spaces with hyphens)
|
||||
os_anchor = os_name.lower().replace(" ", "-")
|
||||
toc_lines.append(f"- [{os_name}](#{os_anchor})")
|
||||
|
||||
# Add Flash-Attention versions for this OS (sorted)
|
||||
os_df = df[df["OS"] == os_name].copy()
|
||||
os_df = sort_packages(os_df, flash_ascending=False)
|
||||
for fa_version in os_df["Flash-Attention"].unique():
|
||||
# Create anchor for Flash-Attention version
|
||||
fa_anchor = f"flash-attention-{fa_version.replace('.', '')}".lower()
|
||||
toc_lines.append(f" - [Flash-Attention {fa_version}](#{fa_anchor})")
|
||||
toc_lines.append("")
|
||||
all_sections.extend(toc_lines)
|
||||
|
||||
# Group by OS and sort each group
|
||||
for os_name in sorted(df["OS"].unique()):
|
||||
for os_name in os_names:
|
||||
os_df = df[df["OS"] == os_name].copy()
|
||||
|
||||
# Sort within OS group: Flash-Attention > Python > PyTorch > CUDA
|
||||
@@ -352,7 +363,7 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
||||
|
||||
# Create OS section header with emoji
|
||||
os_emoji = get_os_emoji(os_name)
|
||||
os_lines = [f"### {os_emoji}{os_name}", ""]
|
||||
os_lines = [f"## {os_emoji}{os_name}", ""]
|
||||
|
||||
# Group by Flash-Attention version within each OS
|
||||
fa_versions = []
|
||||
@@ -400,7 +411,7 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
||||
|
||||
# Create collapsible section for this Flash-Attention version
|
||||
fa_section = [
|
||||
f"#### Flash-Attention {fa_version}",
|
||||
f"### Flash-Attention {fa_version}",
|
||||
"",
|
||||
"<details>",
|
||||
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
|
||||
@@ -419,44 +430,9 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
||||
return "\n".join(all_sections)
|
||||
|
||||
|
||||
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
|
||||
"""Update the Packages section in README.md with new content."""
|
||||
with readme_path.open("r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the Packages section
|
||||
packages_start = content.find("## Packages")
|
||||
if packages_start == -1:
|
||||
raise ValueError("Packages section not found in README.md")
|
||||
|
||||
# Find the end of Packages section (next ## section or History section)
|
||||
packages_end = content.find("## History", packages_start)
|
||||
if packages_end == -1:
|
||||
# If no History section found, look for any other ## section
|
||||
remaining_content = content[packages_start + len("## Packages") :]
|
||||
next_section = remaining_content.find("\n## ")
|
||||
if next_section != -1:
|
||||
packages_end = packages_start + len("## Packages") + next_section
|
||||
else:
|
||||
packages_end = len(content)
|
||||
|
||||
# Replace the Packages section
|
||||
new_content = (
|
||||
content[:packages_start]
|
||||
+ "## Packages\n\n"
|
||||
+ packages_markdown
|
||||
+ "\n\n"
|
||||
+ content[packages_end:]
|
||||
)
|
||||
|
||||
# Write back to file
|
||||
with readme_path.open("w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate a one-row-per-package Markdown table from assets.json file"
|
||||
description="Create and update docs/packages.md from assets.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--assets",
|
||||
@@ -465,30 +441,29 @@ def main() -> None:
|
||||
help="Path to assets.json file (default: assets.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--update",
|
||||
action="store_true",
|
||||
help="Update the Packages section in README.md instead of printing to stdout",
|
||||
"--output",
|
||||
type=str,
|
||||
default="docs/packages.md",
|
||||
help="Output file path (default: docs/packages.md)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assets_path = Path(args.assets)
|
||||
if not assets_path.exists():
|
||||
print(f"Error: {assets_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
output_path = Path(args.output)
|
||||
|
||||
readme_path = Path("README.md")
|
||||
# Extract packages from assets.json if it exists
|
||||
assets_packages = []
|
||||
if assets_path.exists():
|
||||
assets_packages = extract_packages_from_assets_json(assets_path)
|
||||
|
||||
# Extract packages from assets.json
|
||||
assets_packages = extract_packages_from_assets_json(assets_path)
|
||||
|
||||
# Extract packages from existing README.md
|
||||
readme_packages = extract_packages_from_readme(readme_path)
|
||||
# Extract packages from existing docs/packages.md
|
||||
packages_md_packages = extract_packages_from_packages_md(output_path)
|
||||
|
||||
# Combine both lists
|
||||
all_packages = assets_packages + readme_packages
|
||||
all_packages = assets_packages + packages_md_packages
|
||||
|
||||
if not all_packages:
|
||||
print(f"No packages found in {assets_path} or README.md", file=sys.stderr)
|
||||
print(f"No packages found in {assets_path} or {output_path}", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Convert to DataFrame and process
|
||||
@@ -503,13 +478,15 @@ def main() -> None:
|
||||
df_merged = merge_duplicate_rows(df_sorted)
|
||||
markdown = generate_markdown_table_by_os(df_merged)
|
||||
|
||||
if args.update:
|
||||
# Update the README.md file
|
||||
update_readme_packages_section(readme_path, markdown)
|
||||
print(f"Updated Packages section in {readme_path}")
|
||||
else:
|
||||
# Print to stdout (original behavior)
|
||||
print(markdown)
|
||||
# Create parent directory if it doesn't exist
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate markdown with "# Packages" header for standalone file
|
||||
standalone_markdown = f"# Packages\n\n{markdown}"
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
f.write(standalone_markdown)
|
||||
print(f"Written packages to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
+1129
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user