mirror of
https://github.com/BillyOutlast/flash-attention.git
synced 2026-06-30 21:07:55 -04:00
reduce chance of build oom (#2079)
This commit is contained in:
committed by
GitHub
parent
04e6ee1fb5
commit
f15ccf5ff2
@@ -64,6 +64,7 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE
|
||||
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
||||
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
|
||||
SKIP_CK_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CK_BUILD", "TRUE") == "TRUE" if USE_TRITON_ROCM else False
|
||||
NVCC_THREADS = os.getenv("NVCC_THREADS") or "4"
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cuda_archs() -> str:
|
||||
@@ -186,8 +187,7 @@ def detect_hipify_v2():
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
|
||||
return nvcc_extra_args + ["--threads", nvcc_threads]
|
||||
return nvcc_extra_args + ["--threads", NVCC_THREADS]
|
||||
|
||||
|
||||
def rename_cpp_to_cu(cpp_files):
|
||||
@@ -571,15 +571,23 @@ class NinjaBuildExtension(BuildExtension):
|
||||
if not os.environ.get("MAX_JOBS"):
|
||||
import psutil
|
||||
|
||||
nvcc_threads = max(1, int(NVCC_THREADS))
|
||||
|
||||
# calculate the maximum allowed NUM_JOBS based on cores
|
||||
max_num_jobs_cores = max(1, os.cpu_count() // 2)
|
||||
|
||||
# calculate the maximum allowed NUM_JOBS based on free memory
|
||||
free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
|
||||
max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
|
||||
# Assume worst-case peak observed memory usage of ~5GB per NVCC thread.
|
||||
# Limit: peak_threads = max_jobs * nvcc_threads and peak_threads * 5GB <= free_memory.
|
||||
max_num_jobs_memory = max(1, int(free_memory_gb / (5 * nvcc_threads)))
|
||||
|
||||
# pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
|
||||
max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
|
||||
print(
|
||||
f"Auto set MAX_JOBS to `{max_jobs}`, NVCC_THREADS to `{nvcc_threads}`. "
|
||||
"If you see memory pressure, please use a lower `MAX_JOBS=N` or `NVCC_THREADS=N` value."
|
||||
)
|
||||
os.environ["MAX_JOBS"] = str(max_jobs)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user