diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 948f8dc56..c6e972e89 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -1,5 +1,5 @@ ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete -ARG TRITON_BRANCH="f332c492" +ARG TRITON_BRANCH="57c693b6" ARG TRITON_REPO="https://github.com/ROCm/triton.git" ARG PYTORCH_BRANCH="89075173" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" @@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0" ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="6af8b687" +ARG AITER_BRANCH="v0.1.10.post2" ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG MORI_BRANCH="2d02c6a9" ARG MORI_REPO="https://github.com/ROCm/mori.git" @@ -239,7 +239,7 @@ RUN pip install pyyaml && cd aiter \ export HIP_CLANG_PATH=/opt/sccache-wrappers \ && sccache --show-stats; \ fi \ - && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \ + && GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \ && if [ "$USE_SCCACHE" = "1" ]; then sccache --show-stats; fi \ && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 555b94c1c..66db09505 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -933,7 +933,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): layer.w2_weight.view(self.fp4_dtype), requires_grad=layer.w2_weight.requires_grad, ) + # Pre-shuffle weight + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight.data, layer.w2_weight.data + ) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + layer.w13_weight.is_shuffled = True + layer.w2_weight.is_shuffled = True torch.cuda.empty_cache() def get_fused_moe_quant_config(