Files
flash-attention-prebuild-wh…/get_torch_cuda_version.py
T
Junya Morioka 6e93ad0663 fix: update PyTorch and CUDA version compatibility rules
- Refine exclusion logic in create_matrix.py to match actual support
- Update CUDA version support for PyTorch 2.9 and 2.10
- Remove temporary build exclusions
2026-01-27 18:33:34 +09:00

29 lines
786 B
Python

import sys
support_cuda_versions = {
"2.0": [117, 118],
"2.1": [118, 121],
"2.2": [118, 121],
"2.3": [118, 121],
"2.4": [118, 121, 124],
"2.5": [118, 121, 124],
"2.6": [118, 124, 126],
"2.7": [118, 126, 128],
"2.8": [126, 128, 129],
"2.9": [126, 128, 130],
"2.10": [126, 128, 130],
}
cuda_version = int(sys.argv[1])
matrix_torch_version = sys.argv[2]
target_cuda_versions = support_cuda_versions[matrix_torch_version]
target_cuda_versions = [
v for v in target_cuda_versions if str(v)[:2] == str(cuda_version)[:2]
]
if len(target_cuda_versions) == 0:
closest_version = support_cuda_versions[matrix_torch_version][-1]
else:
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version))
print(closest_version)