reduce chance of build oom (#2079)

This commit is contained in:
Qubitium-ModelCloud
2026-01-21 18:36:22 +08:00
committed by GitHub
parent 04e6ee1fb5
commit f15ccf5ff2
+11 -3
View File
@@ -64,6 +64,7 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False
NVCC_THREADS = os.getenv("NVCC_THREADS") or "4"
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
@@ -186,8 +187,7 @@ def detect_hipify_v2():
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
return nvcc_extra_args + ["--threads", nvcc_threads]
return nvcc_extra_args + ["--threads", NVCC_THREADS]
def rename_cpp_to_cu(cpp_files):
@@ -571,15 +571,23 @@ class NinjaBuildExtension(BuildExtension):
if not os.environ.get("MAX_JOBS"):
import psutil
nvcc_threads = max(1, int(NVCC_THREADS))
# calculate the maximum allowed NUM_JOBS based on cores
max_num_jobs_cores = max(1, os.cpu_count() // 2)
# calculate the maximum allowed NUM_JOBS based on free memory
free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
# Assume worst-case peak observed memory usage of ~5GB per NVCC thread.
# Limit: peak_threads = max_jobs * nvcc_threads and peak_threads * 5GB <= free_memory.
max_num_jobs_memory = max(1, int(free_memory_gb / (5 * nvcc_threads)))
# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
print(
f"Auto set MAX_JOBS to `{max_jobs}`, NVCC_THREADS to `{nvcc_threads}`. "
"If you see memory pressure, please use a lower `MAX_JOBS=N` or `NVCC_THREADS=N` value."
)
os.environ["MAX_JOBS"] = str(max_jobs)
super().__init__(*args, **kwargs)