diff --git a/tests/kernels/moe/test_cpu_fused_moe.py b/tests/kernels/moe/test_cpu_fused_moe.py index c0f817a9c..681f42091 100644 --- a/tests/kernels/moe/test_cpu_fused_moe.py +++ b/tests/kernels/moe/test_cpu_fused_moe.py @@ -6,7 +6,7 @@ import torch from tests.kernels.allclose_default import get_default_atol, get_default_rtol from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight -from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT +from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT_FN from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -68,12 +68,7 @@ def ref_fused_moe( tokens_for_this_expert, curr_w13, curr_w13_bias ) # Note: to simulate the kernel implementation - gate_up = ( - _CPU_MOE_ACT[activation] - .forward_native(gate_up) - .to(dtype=input.dtype) - .float() - ) + gate_up = _CPU_MOE_ACT_FN[activation](gate_up).to(dtype=input.dtype).float() expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias) outputs.append(expert_out) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 7055e41aa..ee4798d84 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -8,33 +8,38 @@ from torch.nn import functional as F from vllm import _custom_ops as ops from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight -from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter from vllm.utils.torch_utils import direct_register_custom_op _CPU_MOE_LAYER_CACHE = {} -class _LazyActivationDict(dict): - """Lazily instantiate activation functions on first access. +def _swigluoai_forward_native( + x: torch.Tensor, + alpha: float = 1.702, + limit: float = 7.0, +) -> torch.Tensor: + """PyTorch-native implementation of SwigluOAIAndMul.forward_native. - Avoids triggering CustomOp.__init__() at module import time, - which would call get_current_vllm_config() before config is set. + Standalone function to avoid instantiating SwigluOAIAndMul (a CustomOp) + which would trigger get_current_vllm_config() before config is set. """ - - _factories: dict[str, type[SiluAndMul] | type[SwigluOAIAndMul]] = { - "silu": SiluAndMul, - "swigluoai": SwigluOAIAndMul, - } - - def __missing__(self, key: str) -> SiluAndMul | SwigluOAIAndMul: - if key not in self._factories: - raise KeyError(f"{key} is not a supported activation") - self[key] = self._factories[key]() - return self[key] + gate, up = x[..., ::2], x[..., 1::2] + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + glu = gate * torch.sigmoid(gate * alpha) + gated_output = (up + 1) * glu + return gated_output -_CPU_MOE_ACT = _LazyActivationDict() +# Map activation names to their native forward functions. +# Uses static methods or standalone functions to avoid instantiating CustomOp +# classes, which would call get_current_vllm_config() before config is set. +_CPU_MOE_ACT_FN: dict[str, Callable[[torch.Tensor], torch.Tensor]] = { + "silu": SiluAndMul.forward_native, + "swigluoai": _swigluoai_forward_native, +} def grouped_topk( @@ -230,7 +235,7 @@ class CPUFusedMOE: apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - assert activation in _CPU_MOE_ACT._factories, f"{activation} is not supported." + assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported." assert not apply_router_weight_on_input topk_weights, topk_ids = select_experts( @@ -418,7 +423,7 @@ def cpu_fused_moe_torch( tokens_for_this_expert = sorted_tokens[start_idx:end_idx] gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore - gate_up = _CPU_MOE_ACT[activation].forward_native(gate_up) + gate_up = _CPU_MOE_ACT_FN[activation](gate_up) expert_out = layer.down_linear[i](gate_up) # type: ignore outputs.append(expert_out) start_idx = end_idx