[ROCm] Enable MXFP4 MoE weight pre-shuffling on gfx950 and update aiter (#34192)
Signed-off-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
|
||||
ARG TRITON_BRANCH="f332c492"
|
||||
ARG TRITON_BRANCH="57c693b6"
|
||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||
ARG PYTORCH_BRANCH="89075173"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
|
||||
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
|
||||
ARG FA_BRANCH="0e60e394"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="6af8b687"
|
||||
ARG AITER_BRANCH="v0.1.10.post2"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
ARG MORI_BRANCH="2d02c6a9"
|
||||
ARG MORI_REPO="https://github.com/ROCm/mori.git"
|
||||
@@ -239,7 +239,7 @@ RUN pip install pyyaml && cd aiter \
|
||||
export HIP_CLANG_PATH=/opt/sccache-wrappers \
|
||||
&& sccache --show-stats; \
|
||||
fi \
|
||||
&& PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& if [ "$USE_SCCACHE" = "1" ]; then sccache --show-stats; fi \
|
||||
&& ls /app/aiter/dist/*.whl
|
||||
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
|
||||
|
||||
@@ -933,7 +933,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w2_weight.requires_grad,
|
||||
)
|
||||
# Pre-shuffle weight
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
layer.w13_weight.is_shuffled = True
|
||||
layer.w2_weight.is_shuffled = True
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
|
||||
Reference in New Issue
Block a user