TR2X/tools/sort_imports
2023-10-05 08:07:07 +02:00

105 lines
2.6 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import re
from pathlib import Path
from shutil import which
from subprocess import run
TOOLS_DIR = Path(__file__).parent
REPO_DIR = TOOLS_DIR.parent
SRC_PATH = REPO_DIR / "src"
def fix_imports(path: Path) -> None:
iwyu_result = run(
["include-what-you-use", "-I", "src", path],
capture_output=True,
text=True,
).stderr
run(
[which("fix_include") and "fix_include" or "iwyu-fix-includes"],
input=iwyu_result,
text=True,
)
def sort_import_group(includes: list[str]) -> list[str]:
group = sorted(includes)
FIXED_PAIRS = [
("<ddrawi.h>", "<d3dhal.h>"),
]
for before, after in FIXED_PAIRS:
if before in group and after in group:
group.remove(after)
group.insert(group.index(before) + 1, after)
return group
def sort_imports(path: Path) -> None:
source = path.read_text()
rel_path = path.relative_to(SRC_PATH)
own_include = str(rel_path.with_suffix(".h"))
own_include = {
# files headers of which are not a 1:1 match with their filename
}.get(str(rel_path), own_include)
def cb(match):
includes = re.findall(r'#include (["<][^"<>]+[">])', match.group(0))
groups = {
"own": set(),
"proj": set(),
"lib": set(),
}
for include in includes:
if include.strip('"') == own_include:
groups["own"].add(include)
elif include.startswith("<"):
groups["lib"].add(include)
else:
groups["proj"].add(include)
groups = {key: value for key, value in groups.items() if value}
ret = "\n\n".join(
"\n".join(
f"#include {include}" for include in sort_import_group(group)
)
for group in groups.values()
).strip()
return ret
source = re.sub(
"^#include [^\n]+(\n*#include [^\n]+)*",
cb,
source,
flags=re.M,
)
if source != path.read_text():
path.write_text(source)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(metavar="path", type=Path, nargs="*", dest="paths")
return parser.parse_args()
def main() -> None:
args = parse_args()
paths = [path.absolute() for path in args.paths]
if not paths:
paths = sorted(
path
for path in SRC_PATH.glob("**/*.[ch]")
if path != SRC_PATH / "init.c"
)
for path in paths:
fix_imports(path)
sort_imports(path)
if __name__ == "__main__":
main()