mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:37:53 -04:00
fix: linux script
This commit is contained in:
+20
-21
@@ -24,27 +24,26 @@ echo " Torch Matrix: $MATRIX_TORCH_VERSION"
|
||||
|
||||
# Install PyTorch
|
||||
echo "Installing PyTorch $TORCH_VERSION+cu$CUDA_VERSION..."
|
||||
TORCH_CUDA_VERSION=$(python -c "\
|
||||
support_cuda_versions = { \
|
||||
'2.0': [117, 118], \
|
||||
'2.1': [118, 121], \
|
||||
'2.2': [118, 121], \
|
||||
'2.3': [118, 121], \
|
||||
'2.4': [118, 121, 124], \
|
||||
'2.5': [118, 121, 124], \
|
||||
'2.6': [118, 124, 126], \
|
||||
'2.7': [118, 126, 128], \
|
||||
'2.8': [128], \
|
||||
}; \
|
||||
cuda_version = int('$MATRIX_CUDA_VERSION'); \
|
||||
matrix_torch_version = '$MATRIX_TORCH_VERSION'; \
|
||||
target_cuda_versions = support_cuda_versions[matrix_torch_version]; \
|
||||
target_cuda_versions = [v for v in target_cuda_versions if str(v)[:2] == str(cuda_version)[:2]]; \
|
||||
if len(target_cuda_versions) == 0: \
|
||||
closest_version = support_cuda_versions[matrix_torch_version][-1]; \
|
||||
else: \
|
||||
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version)); \
|
||||
print(closest_version) \
|
||||
TORCH_CUDA_VERSION=$(python -c "support_cuda_versions = { \
|
||||
'2.0': [117, 118], \
|
||||
'2.1': [118, 121], \
|
||||
'2.2': [118, 121], \
|
||||
'2.3': [118, 121], \
|
||||
'2.4': [118, 121, 124], \
|
||||
'2.5': [118, 121, 124], \
|
||||
'2.6': [118, 124, 126], \
|
||||
'2.7': [118, 126, 128], \
|
||||
'2.8': [128], \
|
||||
}; \
|
||||
cuda_version = int('$MATRIX_CUDA_VERSION'); \
|
||||
matrix_torch_version = '$MATRIX_TORCH_VERSION'; \
|
||||
target_cuda_versions = support_cuda_versions[matrix_torch_version]; \
|
||||
target_cuda_versions = [v for v in target_cuda_versions if str(v)[:2] == str(cuda_version)[:2]]; \
|
||||
if len(target_cuda_versions) == 0: \
|
||||
closest_version = support_cuda_versions[matrix_torch_version][-1]; \
|
||||
else: \
|
||||
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version)); \
|
||||
print(closest_version) \
|
||||
")
|
||||
|
||||
if [[ $TORCH_VERSION == *"dev"* ]]; then
|
||||
|
||||
Reference in New Issue
Block a user