mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:27:54 -04:00
28 lines
757 B
Python
28 lines
757 B
Python
import sys
|
|
|
|
support_cuda_versions = {
|
|
"2.0": [117, 118],
|
|
"2.1": [118, 121],
|
|
"2.2": [118, 121],
|
|
"2.3": [118, 121],
|
|
"2.4": [118, 121, 124],
|
|
"2.5": [118, 121, 124],
|
|
"2.6": [118, 124, 126],
|
|
"2.7": [118, 126, 128],
|
|
"2.8": [126, 128, 129],
|
|
"2.9": [126, 128, 130],
|
|
}
|
|
|
|
cuda_version = int(sys.argv[1])
|
|
matrix_torch_version = sys.argv[2]
|
|
|
|
target_cuda_versions = support_cuda_versions[matrix_torch_version]
|
|
target_cuda_versions = [
|
|
v for v in target_cuda_versions if str(v)[:2] == str(cuda_version)[:2]
|
|
]
|
|
if len(target_cuda_versions) == 0:
|
|
closest_version = support_cuda_versions[matrix_torch_version][-1]
|
|
else:
|
|
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version))
|
|
print(closest_version)
|