mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:37:53 -04:00
chore: update script for docs
This commit is contained in:
@@ -169,7 +169,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pip install pandas
|
run: pip install pandas
|
||||||
|
|
||||||
- name: Update release history and packages section in README.md
|
- name: Update docs
|
||||||
run: |
|
run: |
|
||||||
gh release view "${{ github.ref_name }}" --json assets > /tmp/assets.json
|
gh release view "${{ github.ref_name }}" --json assets > /tmp/assets.json
|
||||||
python create_release_history.py \
|
python create_release_history.py \
|
||||||
@@ -177,7 +177,7 @@ jobs:
|
|||||||
--tag "${{ github.ref_name }}" \
|
--tag "${{ github.ref_name }}" \
|
||||||
--repo "${{ github.repository }}" \
|
--repo "${{ github.repository }}" \
|
||||||
--output docs/release_history.md
|
--output docs/release_history.md
|
||||||
python insert_packages_to_readme.py --assets /tmp/assets.json --update
|
python create_packages.py --assets /tmp/assets.json --output docs/packages.md
|
||||||
|
|
||||||
- name: Commit and push docs updates
|
- name: Commit and push docs updates
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
python insert_packages_to_readme.py --assets assets.json --update
|
Create and update docs/packages.md from assets.json
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python create_packages.py --assets assets.json --output docs/packages.md
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -38,49 +41,39 @@ def normalize_semantic_version(version: str) -> str:
|
|||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def extract_packages_from_readme(readme_path: Path) -> list[dict]:
|
def extract_packages_from_packages_md(packages_md_path: Path) -> list[dict]:
|
||||||
"""Extract package information from existing Packages section in README.md."""
|
"""Extract package information from existing docs/packages.md."""
|
||||||
with readme_path.open("r", encoding="utf-8") as f:
|
if not packages_md_path.exists():
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Find Packages section
|
|
||||||
packages_start = content.find("## Packages")
|
|
||||||
if packages_start == -1:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Find the end of Packages section
|
with packages_md_path.open("r", encoding="utf-8") as f:
|
||||||
packages_end = content.find("## History", packages_start)
|
content = f.read()
|
||||||
if packages_end == -1:
|
|
||||||
remaining_content = content[packages_start + len("## Packages") :]
|
|
||||||
next_section = remaining_content.find("\n## ")
|
|
||||||
if next_section != -1:
|
|
||||||
packages_end = packages_start + len("## Packages") + next_section
|
|
||||||
else:
|
|
||||||
packages_end = len(content)
|
|
||||||
|
|
||||||
packages_section = content[packages_start:packages_end]
|
lines = content.splitlines()
|
||||||
lines = packages_section.splitlines()
|
|
||||||
|
|
||||||
packages = []
|
packages = []
|
||||||
current_os = None
|
current_os = None
|
||||||
current_fa_version = None
|
current_fa_version = None
|
||||||
in_table = False
|
in_table = False
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
line_stripped = line.strip()
|
line_stripped = line.strip()
|
||||||
|
|
||||||
# Detect OS heading (### Linux x86_64)
|
# Detect OS heading (## Linux x86_64)
|
||||||
if line_stripped.startswith("### ") and not line_stripped.startswith("#### "):
|
if line_stripped.startswith("## ") and not line_stripped.startswith("### "):
|
||||||
current_os = line_stripped[4:].strip()
|
# Remove emoji from OS name (e.g., "🐧 Linux x86_64" -> "Linux x86_64")
|
||||||
|
os_name = line_stripped[3:].strip()
|
||||||
|
while os_name and ord(os_name[0]) > 127:
|
||||||
|
os_name = os_name[1:].strip()
|
||||||
|
current_os = os_name
|
||||||
current_fa_version = None
|
current_fa_version = None
|
||||||
in_table = False
|
in_table = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Detect Flash-Attention version heading (#### Flash-Attention 2.8.3)
|
# Detect Flash-Attention version heading (### Flash-Attention 2.8.3)
|
||||||
if line_stripped.startswith("#### Flash-Attention "):
|
if line_stripped.startswith("### Flash-Attention "):
|
||||||
# Extract version after "#### Flash-Attention "
|
|
||||||
current_fa_version = line_stripped.replace(
|
current_fa_version = line_stripped.replace(
|
||||||
"#### Flash-Attention ", ""
|
"### Flash-Attention ", ""
|
||||||
).strip()
|
).strip()
|
||||||
in_table = False
|
in_table = False
|
||||||
continue
|
continue
|
||||||
@@ -337,8 +330,26 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
|||||||
|
|
||||||
all_sections = []
|
all_sections = []
|
||||||
|
|
||||||
|
# Generate Table of Contents
|
||||||
|
os_names = sorted(df["OS"].unique())
|
||||||
|
toc_lines = ["## Table of Contents", ""]
|
||||||
|
for os_name in os_names:
|
||||||
|
# Create anchor link (lowercase, replace spaces with hyphens)
|
||||||
|
os_anchor = os_name.lower().replace(" ", "-")
|
||||||
|
toc_lines.append(f"- [{os_name}](#{os_anchor})")
|
||||||
|
|
||||||
|
# Add Flash-Attention versions for this OS (sorted)
|
||||||
|
os_df = df[df["OS"] == os_name].copy()
|
||||||
|
os_df = sort_packages(os_df, flash_ascending=False)
|
||||||
|
for fa_version in os_df["Flash-Attention"].unique():
|
||||||
|
# Create anchor for Flash-Attention version
|
||||||
|
fa_anchor = f"flash-attention-{fa_version.replace('.', '')}".lower()
|
||||||
|
toc_lines.append(f" - [Flash-Attention {fa_version}](#{fa_anchor})")
|
||||||
|
toc_lines.append("")
|
||||||
|
all_sections.extend(toc_lines)
|
||||||
|
|
||||||
# Group by OS and sort each group
|
# Group by OS and sort each group
|
||||||
for os_name in sorted(df["OS"].unique()):
|
for os_name in os_names:
|
||||||
os_df = df[df["OS"] == os_name].copy()
|
os_df = df[df["OS"] == os_name].copy()
|
||||||
|
|
||||||
# Sort within OS group: Flash-Attention > Python > PyTorch > CUDA
|
# Sort within OS group: Flash-Attention > Python > PyTorch > CUDA
|
||||||
@@ -352,7 +363,7 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
|||||||
|
|
||||||
# Create OS section header with emoji
|
# Create OS section header with emoji
|
||||||
os_emoji = get_os_emoji(os_name)
|
os_emoji = get_os_emoji(os_name)
|
||||||
os_lines = [f"### {os_emoji}{os_name}", ""]
|
os_lines = [f"## {os_emoji}{os_name}", ""]
|
||||||
|
|
||||||
# Group by Flash-Attention version within each OS
|
# Group by Flash-Attention version within each OS
|
||||||
fa_versions = []
|
fa_versions = []
|
||||||
@@ -400,7 +411,7 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
|||||||
|
|
||||||
# Create collapsible section for this Flash-Attention version
|
# Create collapsible section for this Flash-Attention version
|
||||||
fa_section = [
|
fa_section = [
|
||||||
f"#### Flash-Attention {fa_version}",
|
f"### Flash-Attention {fa_version}",
|
||||||
"",
|
"",
|
||||||
"<details>",
|
"<details>",
|
||||||
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
|
f"<summary>Packages for Flash-Attention {fa_version}</summary>",
|
||||||
@@ -419,44 +430,9 @@ def generate_markdown_table_by_os(df: pd.DataFrame) -> str:
|
|||||||
return "\n".join(all_sections)
|
return "\n".join(all_sections)
|
||||||
|
|
||||||
|
|
||||||
def update_readme_packages_section(readme_path: Path, packages_markdown: str) -> None:
|
|
||||||
"""Update the Packages section in README.md with new content."""
|
|
||||||
with readme_path.open("r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# Find the Packages section
|
|
||||||
packages_start = content.find("## Packages")
|
|
||||||
if packages_start == -1:
|
|
||||||
raise ValueError("Packages section not found in README.md")
|
|
||||||
|
|
||||||
# Find the end of Packages section (next ## section or History section)
|
|
||||||
packages_end = content.find("## History", packages_start)
|
|
||||||
if packages_end == -1:
|
|
||||||
# If no History section found, look for any other ## section
|
|
||||||
remaining_content = content[packages_start + len("## Packages") :]
|
|
||||||
next_section = remaining_content.find("\n## ")
|
|
||||||
if next_section != -1:
|
|
||||||
packages_end = packages_start + len("## Packages") + next_section
|
|
||||||
else:
|
|
||||||
packages_end = len(content)
|
|
||||||
|
|
||||||
# Replace the Packages section
|
|
||||||
new_content = (
|
|
||||||
content[:packages_start]
|
|
||||||
+ "## Packages\n\n"
|
|
||||||
+ packages_markdown
|
|
||||||
+ "\n\n"
|
|
||||||
+ content[packages_end:]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write back to file
|
|
||||||
with readme_path.open("w", encoding="utf-8") as f:
|
|
||||||
f.write(new_content)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate a one-row-per-package Markdown table from assets.json file"
|
description="Create and update docs/packages.md from assets.json"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--assets",
|
"--assets",
|
||||||
@@ -465,30 +441,29 @@ def main() -> None:
|
|||||||
help="Path to assets.json file (default: assets.json)",
|
help="Path to assets.json file (default: assets.json)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--update",
|
"--output",
|
||||||
action="store_true",
|
type=str,
|
||||||
help="Update the Packages section in README.md instead of printing to stdout",
|
default="docs/packages.md",
|
||||||
|
help="Output file path (default: docs/packages.md)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assets_path = Path(args.assets)
|
assets_path = Path(args.assets)
|
||||||
if not assets_path.exists():
|
output_path = Path(args.output)
|
||||||
print(f"Error: {assets_path} not found", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
readme_path = Path("README.md")
|
# Extract packages from assets.json if it exists
|
||||||
|
assets_packages = []
|
||||||
|
if assets_path.exists():
|
||||||
|
assets_packages = extract_packages_from_assets_json(assets_path)
|
||||||
|
|
||||||
# Extract packages from assets.json
|
# Extract packages from existing docs/packages.md
|
||||||
assets_packages = extract_packages_from_assets_json(assets_path)
|
packages_md_packages = extract_packages_from_packages_md(output_path)
|
||||||
|
|
||||||
# Extract packages from existing README.md
|
|
||||||
readme_packages = extract_packages_from_readme(readme_path)
|
|
||||||
|
|
||||||
# Combine both lists
|
# Combine both lists
|
||||||
all_packages = assets_packages + readme_packages
|
all_packages = assets_packages + packages_md_packages
|
||||||
|
|
||||||
if not all_packages:
|
if not all_packages:
|
||||||
print(f"No packages found in {assets_path} or README.md", file=sys.stderr)
|
print(f"No packages found in {assets_path} or {output_path}", file=sys.stderr)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Convert to DataFrame and process
|
# Convert to DataFrame and process
|
||||||
@@ -503,13 +478,15 @@ def main() -> None:
|
|||||||
df_merged = merge_duplicate_rows(df_sorted)
|
df_merged = merge_duplicate_rows(df_sorted)
|
||||||
markdown = generate_markdown_table_by_os(df_merged)
|
markdown = generate_markdown_table_by_os(df_merged)
|
||||||
|
|
||||||
if args.update:
|
# Create parent directory if it doesn't exist
|
||||||
# Update the README.md file
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
update_readme_packages_section(readme_path, markdown)
|
|
||||||
print(f"Updated Packages section in {readme_path}")
|
# Generate markdown with "# Packages" header for standalone file
|
||||||
else:
|
standalone_markdown = f"# Packages\n\n{markdown}"
|
||||||
# Print to stdout (original behavior)
|
|
||||||
print(markdown)
|
with output_path.open("w", encoding="utf-8") as f:
|
||||||
|
f.write(standalone_markdown)
|
||||||
|
print(f"Written packages to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
+1129
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user