[ROCm] Enable MXFP4 MoE weight pre-shuffling on gfx950 and update aiter (#34192)

Signed-off-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
Douglas Lehr
2026-02-12 07:06:33 -06:00
committed by GitHub
parent fb455ed547
commit 8a798be929
2 changed files with 11 additions and 3 deletions

View File

@@ -933,7 +933,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight.view(self.fp4_dtype),
requires_grad=layer.w2_weight.requires_grad,
)
# Pre-shuffle weight
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
torch.cuda.empty_cache()
def get_fused_moe_quant_config(