diff --git a/.github/workflows/_build_rocm.yml b/.github/workflows/_build_rocm.yml index 3b8cc89..be13ce1 100644 --- a/.github/workflows/_build_rocm.yml +++ b/.github/workflows/_build_rocm.yml @@ -44,7 +44,7 @@ jobs: name: Build wheels and Upload (ROCm) runs-on: ${{ inputs.runner }} container: - image: rocm/pytorch:rocm${{ inputs.rocm-version }}_ubuntu22.04_py3.10_pytorch_2.1.2 + image: ${{ inputs.rocm-version == '7.1.1' && 'rocm/pytorch:rocm7.1.1_ubuntu24.04_py3.12_pytorch_release_2.9.1' || 'rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1' }} options: --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host env: DEBIAN_FRONTEND: noninteractive diff --git a/create_matrix.py b/create_matrix.py index a900517..4abbd1a 100644 --- a/create_matrix.py +++ b/create_matrix.py @@ -265,7 +265,7 @@ ROCM_MATRIX = { "2.9.1", ], "rocm-version": [ - "7.1", + "7.1.1", "7.2", ], } diff --git a/doc/plans/rocm-build-setup.md b/doc/plans/rocm-build-setup.md index 5ef55f8..99d4ed2 100644 --- a/doc/plans/rocm-build-setup.md +++ b/doc/plans/rocm-build-setup.md @@ -160,9 +160,10 @@ python -c "import flash_attn_2_cuda; print('Flash Attention ROCm loaded successf ## Docker Container Details -The workflow uses official ROCm PyTorch containers: +The workflow uses official ROCm PyTorch containers that match the published tags, for example: ``` -rocm/pytorch:rocm{version}_ubuntu22.04_py3.10_pytorch_2.1.2 +rocm/pytorch:rocm7.1.1_ubuntu24.04_py3.12_pytorch_release_2.9.1 +rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1 ``` Device passthrough options: