[FEAT] [ROCm] Upgrade AITER Fused MoE kernels. (#18271)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-05-27 14:14:07 +08:00
committed by GitHub
parent b50602d5f0
commit d260f799a9
4 changed files with 130 additions and 314 deletions

View File

@@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
rocm_aiter_fused_experts, shuffle_weights)
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)

View File

@@ -595,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
is_rocm_aiter_moe_enabled, shuffle_weights)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
@@ -627,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data,
layer.w2_weight.data,
layout=(16, 16))
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
@@ -675,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
w13_scales, w2_scales = expand_weights(
layer.w13_weight_scale.data,
layer.w2_weight_scale.data,
expansion_dims=[
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
])
layer.w13_weight_scale = torch.nn.Parameter(
w13_scales.contiguous(), requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=(16, 16))
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)
@@ -760,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size
if self.rocm_aiter_moe_enabled:
# reshaping weights is required for aiter moe kernel.
expansion_dims = [
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
]
max_w13_scales, w2_scales = expand_weights(
max_w13_scales,
layer.w2_weight_scale.data,
expansion_dims=expansion_dims)
layer.w2_weight_scale = torch.nn.Parameter(
w2_scales.contiguous(), requires_grad=False)
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
layer.w2_weight,
layout=(32, 32))
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight, layer.w2_weight)
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
requires_grad=False)