[Bugfix] Fix moe weight losing all extra attrs after process_weights_after_loading. (#16854)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@@ -113,12 +113,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w13_weight.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
# Padding the weight for better performance on ROCm
|
||||
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
# Lazy import to avoid importing triton.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
@@ -127,10 +124,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
layer.w13_weight.data = shuffled_w13
|
||||
layer.w2_weight.data = shuffled_w2
|
||||
|
||||
if current_platform.is_cpu():
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
||||
|
||||
Reference in New Issue
Block a user