[ROCm] Enable MXFP4 MoE weight pre-shuffling on gfx950 and update aiter (#34192)
Signed-off-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
@@ -933,7 +933,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.w2_weight.view(self.fp4_dtype),
|
||||
requires_grad=layer.w2_weight.requires_grad,
|
||||
)
|
||||
# Pre-shuffle weight
|
||||
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data
|
||||
)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
|
||||
layer.w13_weight.is_shuffled = True
|
||||
layer.w2_weight.is_shuffled = True
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
|
||||
Reference in New Issue
Block a user