mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:37:53 -04:00
update workflow and script
This commit is contained in:
+2
-23
@@ -23,30 +23,9 @@ echo " CUDA Matrix: $MATRIX_CUDA_VERSION"
|
||||
echo " Torch Matrix: $MATRIX_TORCH_VERSION"
|
||||
|
||||
# Install PyTorch
|
||||
echo "Installing PyTorch $TORCH_VERSION+cu$CUDA_VERSION..."
|
||||
TORCH_CUDA_VERSION=$(python -c "
|
||||
support_cuda_versions = {
|
||||
'2.0': [117, 118],
|
||||
'2.1': [118, 121],
|
||||
'2.2': [118, 121],
|
||||
'2.3': [118, 121],
|
||||
'2.4': [118, 121, 124],
|
||||
'2.5': [118, 121, 124],
|
||||
'2.6': [118, 124, 126],
|
||||
'2.7': [118, 126, 128],
|
||||
'2.8': [128],
|
||||
};
|
||||
cuda_version = int('$MATRIX_CUDA_VERSION');
|
||||
matrix_torch_version = '$MATRIX_TORCH_VERSION';
|
||||
target_cuda_versions = support_cuda_versions[matrix_torch_version];
|
||||
target_cuda_versions = [v for v in target_cuda_versions if str(v)[:2] == str(cuda_version)[:2]];
|
||||
if len(target_cuda_versions) == 0:
|
||||
closest_version = support_cuda_versions[matrix_torch_version][-1];
|
||||
else:
|
||||
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version));
|
||||
print(closest_version);
|
||||
")
|
||||
TORCH_CUDA_VERSION=$(python get_torch_cuda_version.py $MATRIX_CUDA_VERSION $MATRIX_TORCH_VERSION)
|
||||
|
||||
echo "Installing PyTorch $TORCH_VERSION+cu$TORCH_CUDA_VERSION..."
|
||||
if [[ $TORCH_VERSION == *"dev"* ]]; then
|
||||
pip install --force-reinstall --no-cache-dir --pre torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
|
||||
else
|
||||
|
||||
+4
-4
@@ -31,13 +31,13 @@ Write-Host " CUDA Matrix: $MatrixCudaVersion"
|
||||
Write-Host " Torch Matrix: $MatrixTorchVersion"
|
||||
|
||||
# Install PyTorch
|
||||
Write-Host "Installing PyTorch $TorchVersion+cu$CudaVersion..."
|
||||
$env:TORCH_CUDA_VERSION = python -c "from os import environ as env; support_cuda_versions = { '2.1': [121], '2.2': [121], '2.3': [121], '2.4': [121, 124], '2.5': [121, 124], '2.6': [124, 126], '2.7': [126, 128], '2.8': [128], }; target_cuda_versions = support_cuda_versions['$MatrixTorchVersion']; cuda_version = int('$MatrixCudaVersion'); closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version)); print(closest_version)"
|
||||
$env:TORCH_CUDA_VERSION = python get_torch_cuda_version.py $MatrixCudaVersion $MatrixTorchVersion
|
||||
|
||||
Write-Host "Installing PyTorch $TorchVersion+cu$env:TORCH_CUDA_VERSION..."
|
||||
if ($TorchVersion -like "*dev*") {
|
||||
pip install --pre torch==$TorchVersion --index-url https://download.pytorch.org/whl/nightly/cu$env:TORCH_CUDA_VERSION
|
||||
pip install --force-reinstall --no-cache-dir --pre torch==$TorchVersion --index-url https://download.pytorch.org/whl/nightly/cu$env:TORCH_CUDA_VERSION
|
||||
} else {
|
||||
pip install --no-cache-dir torch==$TorchVersion --index-url https://download.pytorch.org/whl/cu$env:TORCH_CUDA_VERSION
|
||||
pip install --force-reinstall --no-cache-dir torch==$TorchVersion --index-url https://download.pytorch.org/whl/cu$env:TORCH_CUDA_VERSION
|
||||
}
|
||||
|
||||
# Verify installation
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import sys
|
||||
|
||||
support_cuda_versions = {
|
||||
"2.0": [117, 118],
|
||||
"2.1": [118, 121],
|
||||
"2.2": [118, 121],
|
||||
"2.3": [118, 121],
|
||||
"2.4": [118, 121, 124],
|
||||
"2.5": [118, 121, 124],
|
||||
"2.6": [118, 124, 126],
|
||||
"2.7": [118, 126, 128],
|
||||
"2.8": [128],
|
||||
}
|
||||
|
||||
cuda_version = int(sys.argv[1])
|
||||
matrix_torch_version = sys.argv[2]
|
||||
|
||||
target_cuda_versions = support_cuda_versions[matrix_torch_version]
|
||||
target_cuda_versions = [
|
||||
v for v in target_cuda_versions if str(v)[:2] == str(cuda_version)[:2]
|
||||
]
|
||||
if len(target_cuda_versions) == 0:
|
||||
closest_version = support_cuda_versions[matrix_torch_version][-1]
|
||||
else:
|
||||
closest_version = min(target_cuda_versions, key=lambda x: abs(x - cuda_version))
|
||||
print(closest_version)
|
||||
Reference in New Issue
Block a user