[FEAT] [ROCm] Upgrade AITER Fused MoE kernels. (#18271)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
rocm_aiter_fused_experts, shuffle_weights)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
|
||||
layer.w2_weight.data,
|
||||
layout=(16, 16))
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
|
||||
@@ -595,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
@@ -627,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data,
|
||||
layer.w2_weight.data,
|
||||
layout=(16, 16))
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
@@ -675,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
w13_scales, w2_scales = expand_weights(
|
||||
layer.w13_weight_scale.data,
|
||||
layer.w2_weight_scale.data,
|
||||
expansion_dims=[
|
||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||
])
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_scales.contiguous(), requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales.contiguous(), requires_grad=False)
|
||||
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layout=(16, 16))
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
@@ -760,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
start += shard_size
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
expansion_dims = [
|
||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||
]
|
||||
max_w13_scales, w2_scales = expand_weights(
|
||||
max_w13_scales,
|
||||
layer.w2_weight_scale.data,
|
||||
expansion_dims=expansion_dims)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
w2_scales.contiguous(), requires_grad=False)
|
||||
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
layout=(32, 32))
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight, layer.w2_weight)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user