update workflow

This commit is contained in:
Junya Morioka
2025-04-30 22:01:30 +09:00
parent 4e8ae49f60
commit 211b267280
+8 -4
View File
@@ -83,7 +83,7 @@ jobs:
cuda: ${{ matrix.cuda-version }}
linux-local-args: '["--toolkit"]'
method: "network"
- run: sudo apt install -y ninja-build
sub-packages: '["nvcc"]'
- name: Set CUDA and PyTorch versions
run: |
@@ -91,10 +91,14 @@ jobs:
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "CACHE_KEY=cuda-ext-${{ matrix.flash-attn-version }}-py${{ matrix.python-version }}-torch${{ matrix.torch-version }}-cuda${{ matrix.cuda-version }}" >> $GITHUB_ENV
- name: Install build dependencies
run: |
sudo apt install -y ninja-build
pip install -U pip
pip install setuptools==75.8.0 wheel setuptools packaging psutil
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install -U pip
pip install wheel setuptools packaging
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
support_cuda_versions = { \
'2.0': [117, 118], \
@@ -139,10 +143,10 @@ jobs:
- name: Build wheels
timeout-minutes: 600
run: |
pip install setuptools==68.0.0 ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export MAX_JOBS=$(nproc)
export NVCC_THREADS=$(nproc)
cd flash-attention
FLASH_ATTENTION_FORCE_BUILD="TRUE" python setup.py bdist_wheel --dist-dir=dist
base_wheel_name=$(basename $(ls dist/*.whl | head -n 1))