diff --git a/build_linux.sh b/build_linux.sh index d072e2a..3b9571b 100755 --- a/build_linux.sh +++ b/build_linux.sh @@ -47,7 +47,8 @@ git clone https://github.com/Dao-AILab/flash-attention.git -b "v$FLASH_ATTN_VERS # Build wheels echo "Building wheels..." cd flash-attention -FLASH_ATTENTION_FORCE_BUILD=TRUE python setup.py bdist_wheel --dist-dir=dist +LOCAL_VERSION_LABEL="cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}" +FLASH_ATTENTION_FORCE_BUILD=TRUE FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL} python setup.py bdist_wheel --dist-dir=dist base_wheel_name=$(basename $(ls dist/*.whl | head -n 1)) wheel_name=$(echo $base_wheel_name | sed "s/$FLASH_ATTN_VERSION/$FLASH_ATTN_VERSION+cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}/") mv -v dist/$base_wheel_name dist/$wheel_name diff --git a/build_windows.ps1 b/build_windows.ps1 index d67b5d6..c9f32e0 100644 --- a/build_windows.ps1 +++ b/build_windows.ps1 @@ -68,6 +68,7 @@ $env:FLASH_ATTENTION_FORCE_BUILD = "TRUE" $env:NVCC_FLAGS = "-w --disable-warnings" $env:CXXFLAGS = "/w" $env:CFLAGS = "/w" +$env:FLASH_ATTN_LOCAL_VERSION = "cu$MatrixCudaVersion" + "torch$MatrixTorchVersion" cd flash-attention python setup.py bdist_wheel --dist-dir=dist diff --git a/create_matrix.py b/create_matrix.py index 831a6e7..62d2108 100644 --- a/create_matrix.py +++ b/create_matrix.py @@ -23,18 +23,21 @@ EXCLUDE = [ LINUX_MATRIX = { "flash-attn-version": [ - # "2.6.3", "2.7.4.post1", "2.8.3" - "2.8.1" + # "2.6.3", "2.7.4.post1" + "2.8.3" + ], + "python-version": [ + # "3.10", "3.11", "3.12", + "3.13" ], - "python-version": ["3.10", "3.11", "3.12", "3.13"], "torch-version": [ # "2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.0", ], "cuda-version": [ # "12.4.1", "12.6.3", "12.8.1", "12.9.1", - "12.8.1", - "13.0.1", + # "12.8.1", + "13.0.2", ], } @@ -46,10 +49,22 @@ LINUX_SELF_HOSTED_MATRIX = { } WINDOWS_MATRIX = { - "flash-attn-version": ["2.7.4.post1", "2.8.3"], - "python-version": ["3.10", "3.11", "3.12", "3.13"], - "torch-version": ["2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.0"], - "cuda-version": ["12.4.1", "12.6.3", "12.8.1", "12.9.1", "13.0.1"], + "flash-attn-version": [ + # "2.7.4.post1", + "2.8.3" + ], + "python-version": [ + # "3.10", "3.11", "3.12", + "3.13" + ], + "torch-version": [ + # "2.5.1", "2.6.0", "2.7.1", "2.8.0", + "2.9.0" + ], + "cuda-version": [ + # "12.4.1", "12.6.3", "12.8.1", "12.9.1", + "13.0.1" + ], } WINDOWS_CODEBUILD_MATRIX = { @@ -68,8 +83,8 @@ def main(): # "linux": False, # "linux_self_hosted": LINUX_SELF_HOSTED_MATRIX, "linux_self_hosted": False, - # "windows": WINDOWS_MATRIX, - "windows": False, + "windows": WINDOWS_MATRIX, + # "windows": False, # "windows_code_build": WINDOWS_CODEBUILD_MATRIX, "windows_code_build": False, "exclude": EXCLUDE,