mirror of
https://github.com/BillyOutlast/flash-attention-prebuild-wheels-rocm.git
synced 2026-07-01 01:37:53 -04:00
135 lines
4.4 KiB
Bash
135 lines
4.4 KiB
Bash
#!/bin/bash
|
|
|
|
set -e
|
|
|
|
# Parameters with defaults
|
|
FLASH_ATTN_VERSION=$1
|
|
PYTHON_VERSION=$2
|
|
TORCH_VERSION=$3
|
|
ROCM_VERSION=$4
|
|
|
|
echo "Building Flash Attention for ROCm with parameters:"
|
|
echo " Flash-Attention: $FLASH_ATTN_VERSION"
|
|
echo " Python: $PYTHON_VERSION"
|
|
echo " PyTorch: $TORCH_VERSION"
|
|
echo " ROCm: $ROCM_VERSION"
|
|
|
|
# Set ROCm and PyTorch versions
|
|
MATRIX_ROCM_VERSION=$(echo $ROCM_VERSION | awk -F \. {'print $1 "." $2'})
|
|
MATRIX_TORCH_VERSION=$(echo $TORCH_VERSION | awk -F \. {'print $1 "." $2'})
|
|
|
|
echo "Derived versions:"
|
|
echo " ROCm Matrix: $MATRIX_ROCM_VERSION"
|
|
echo " Torch Matrix: $MATRIX_TORCH_VERSION"
|
|
|
|
# Install PyTorch for ROCm
|
|
echo "Installing PyTorch $TORCH_VERSION for ROCm $MATRIX_ROCM_VERSION..."
|
|
if [[ $TORCH_VERSION == *"dev"* ]]; then
|
|
pip install --force-reinstall --no-cache-dir --pre torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/nightly/rocm${MATRIX_ROCM_VERSION}
|
|
else
|
|
pip install --force-reinstall --no-cache-dir torch==$TORCH_VERSION --index-url https://download.pytorch.org/whl/rocm${MATRIX_ROCM_VERSION}
|
|
fi
|
|
|
|
# Install additional dependencies
|
|
echo "Installing build dependencies..."
|
|
pip install ninja packaging
|
|
|
|
# Verify installation
|
|
echo "Verifying installations..."
|
|
python -V
|
|
python -c "import torch; print('PyTorch:', torch.__version__)"
|
|
python -c "import torch; print('ROCm:', torch.version.hip if hasattr(torch.version, 'hip') else 'Not found')"
|
|
|
|
# Display ROCm information
|
|
if command -v rocm-smi &> /dev/null; then
|
|
echo "ROCm SMI:"
|
|
rocm-smi --showproductname || true
|
|
fi
|
|
|
|
if command -v rocminfo &> /dev/null; then
|
|
echo "Available ROCm devices:"
|
|
rocminfo | grep -E "Name:|Marketing Name:" || true
|
|
fi
|
|
|
|
# Checkout flash-attn
|
|
echo "Checking out flash-attention v$FLASH_ATTN_VERSION..."
|
|
git clone https://github.com/Dao-AILab/flash-attention.git -b "v$FLASH_ATTN_VERSION"
|
|
|
|
# Determine MAX_JOBS based on system resources
|
|
NUM_THREADS=$(nproc)
|
|
RAM_GB=$(free -g | awk '/^Mem:/{print $2}')
|
|
echo "System resources:"
|
|
echo " CPU threads: $NUM_THREADS"
|
|
echo " RAM: ${RAM_GB}GB"
|
|
|
|
# Calculate MAX_JOBS based on available resources
|
|
# ROCm builds are memory intensive, so we use conservative estimates
|
|
if [[ -z "${MAX_JOBS:-}" ]]; then
|
|
MAX_PRODUCT_CPU=$NUM_THREADS
|
|
MAX_PRODUCT_RAM=$(awk -v ram="$RAM_GB" 'BEGIN {print int(ram / 4)}')
|
|
MAX_JOBS=$((MAX_PRODUCT_CPU < MAX_PRODUCT_RAM ? MAX_PRODUCT_CPU : MAX_PRODUCT_RAM))
|
|
|
|
# Ensure minimum values
|
|
MAX_JOBS=$((MAX_JOBS < 1 ? 1 : MAX_JOBS))
|
|
|
|
# Cap at 8 to avoid overwhelming the system
|
|
MAX_JOBS=$((MAX_JOBS > 8 ? 8 : MAX_JOBS))
|
|
fi
|
|
|
|
echo "Build parallelism settings:"
|
|
echo " MAX_JOBS: $MAX_JOBS"
|
|
|
|
# Detect GPU architecture if available
|
|
if command -v rocminfo &> /dev/null; then
|
|
GPU_ARCH=$(rocminfo | grep -o -m 1 'gfx[0-9a-z]*' | head -n 1)
|
|
if [ -n "$GPU_ARCH" ]; then
|
|
echo "Detected GPU architecture: $GPU_ARCH"
|
|
export PYTORCH_ROCM_ARCH=$GPU_ARCH
|
|
fi
|
|
fi
|
|
|
|
# Set default architectures if not detected
|
|
# Common ROCm architectures: gfx90a (MI200), gfx942 (MI300), gfx1030, gfx1100
|
|
if [ -z "${PYTORCH_ROCM_ARCH:-}" ]; then
|
|
# Build for multiple common architectures
|
|
export PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx1030;gfx1100"
|
|
echo "No GPU detected, building for multiple architectures: $PYTORCH_ROCM_ARCH"
|
|
fi
|
|
|
|
# Build wheels with Composable Kernel (CK) backend (default for AMD)
|
|
echo "Building wheels with Composable Kernel backend..."
|
|
cd flash-attention
|
|
|
|
LOCAL_VERSION_LABEL="rocm${MATRIX_ROCM_VERSION//./}torch${MATRIX_TORCH_VERSION}"
|
|
|
|
# Disable Triton backend to use Composable Kernel (CK) backend
|
|
export FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"
|
|
export FLASH_ATTENTION_FORCE_BUILD=TRUE
|
|
export FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL}
|
|
|
|
MAX_JOBS=$MAX_JOBS time python setup.py bdist_wheel --dist-dir=dist
|
|
|
|
wheel_name=$(basename $(ls dist/*.whl | head -n 1))
|
|
echo "Built wheel: $wheel_name"
|
|
|
|
# Optional: Also build Triton backend if requested
|
|
if [ "${BUILD_TRITON_BACKEND:-false}" = "true" ]; then
|
|
echo "Building additional wheel with Triton backend..."
|
|
|
|
# Clean previous build
|
|
python setup.py clean
|
|
rm -rf build dist
|
|
|
|
LOCAL_VERSION_LABEL_TRITON="rocm${MATRIX_ROCM_VERSION//./}torch${MATRIX_TORCH_VERSION}triton"
|
|
|
|
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
|
|
export FLASH_ATTN_LOCAL_VERSION=${LOCAL_VERSION_LABEL_TRITON}
|
|
|
|
MAX_JOBS=$MAX_JOBS time python setup.py bdist_wheel --dist-dir=dist_triton
|
|
|
|
wheel_name_triton=$(basename $(ls dist_triton/*.whl | head -n 1))
|
|
echo "Built Triton wheel: $wheel_name_triton"
|
|
fi
|
|
|
|
echo "Build complete!"
|