mirror of
https://github.com/LostArtefacts/TR2X.git
synced 2025-01-10 07:21:55 +00:00
105 lines
2.6 KiB
Python
Executable File
105 lines
2.6 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
import argparse
|
|
import re
|
|
from pathlib import Path
|
|
from shutil import which
|
|
from subprocess import run
|
|
|
|
TOOLS_DIR = Path(__file__).parent
|
|
REPO_DIR = TOOLS_DIR.parent
|
|
SRC_PATH = REPO_DIR / "src"
|
|
|
|
|
|
def fix_imports(path: Path) -> None:
|
|
iwyu_result = run(
|
|
["include-what-you-use", "-I", "src", path],
|
|
capture_output=True,
|
|
text=True,
|
|
).stderr
|
|
run(
|
|
[which("fix_include") and "fix_include" or "iwyu-fix-includes"],
|
|
input=iwyu_result,
|
|
text=True,
|
|
)
|
|
|
|
|
|
def sort_import_group(includes: list[str]) -> list[str]:
|
|
group = sorted(includes)
|
|
FIXED_PAIRS = [
|
|
("<ddrawi.h>", "<d3dhal.h>"),
|
|
]
|
|
for before, after in FIXED_PAIRS:
|
|
if before in group and after in group:
|
|
group.remove(after)
|
|
group.insert(group.index(before) + 1, after)
|
|
return group
|
|
|
|
|
|
def sort_imports(path: Path) -> None:
|
|
source = path.read_text()
|
|
rel_path = path.relative_to(SRC_PATH)
|
|
own_include = str(rel_path.with_suffix(".h"))
|
|
own_include = {
|
|
# files headers of which are not a 1:1 match with their filename
|
|
}.get(str(rel_path), own_include)
|
|
|
|
def cb(match):
|
|
includes = re.findall(r'#include (["<][^"<>]+[">])', match.group(0))
|
|
groups = {
|
|
"own": set(),
|
|
"proj": set(),
|
|
"lib": set(),
|
|
}
|
|
for include in includes:
|
|
if include.strip('"') == own_include:
|
|
groups["own"].add(include)
|
|
elif include.startswith("<"):
|
|
groups["lib"].add(include)
|
|
else:
|
|
groups["proj"].add(include)
|
|
|
|
groups = {key: value for key, value in groups.items() if value}
|
|
|
|
ret = "\n\n".join(
|
|
"\n".join(
|
|
f"#include {include}" for include in sort_import_group(group)
|
|
)
|
|
for group in groups.values()
|
|
).strip()
|
|
return ret
|
|
|
|
source = re.sub(
|
|
"^#include [^\n]+(\n*#include [^\n]+)*",
|
|
cb,
|
|
source,
|
|
flags=re.M,
|
|
)
|
|
if source != path.read_text():
|
|
path.write_text(source)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(metavar="path", type=Path, nargs="*", dest="paths")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
paths = [path.absolute() for path in args.paths]
|
|
|
|
if not paths:
|
|
paths = sorted(
|
|
path
|
|
for path in SRC_PATH.glob("**/*.[ch]")
|
|
if path != SRC_PATH / "init.c"
|
|
)
|
|
|
|
for path in paths:
|
|
fix_imports(path)
|
|
sort_imports(path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|