From c6f722b93e8e795065751172812ee6a5540e5901 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 2 Apr 2026 14:14:32 +0800 Subject: [PATCH] [CPU] Support gelu act in cpu_fused_moe (#38770) Signed-off-by: jiang1.li --- csrc/cpu/cpu_fused_moe.cpp | 44 ++++++++++++++++++- tests/kernels/moe/test_cpu_fused_moe.py | 2 +- vllm/envs.py | 5 +++ .../layers/fused_moe/cpu_fused_moe.py | 8 ++++ vllm/v1/attention/backends/cpu_attn.py | 3 +- 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/csrc/cpu/cpu_fused_moe.cpp b/csrc/cpu/cpu_fused_moe.cpp index 1a8264539..0dc5060fe 100644 --- a/csrc/cpu/cpu_fused_moe.cpp +++ b/csrc/cpu/cpu_fused_moe.cpp @@ -30,13 +30,15 @@ }() namespace { -enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul }; +enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul }; FusedMOEAct get_act_type(const std::string& act) { if (act == "silu") { return FusedMOEAct::SiluAndMul; } else if (act == "swigluoai") { return FusedMOEAct::SwigluOAIAndMul; + } else if (act == "gelu") { + return FusedMOEAct::GeluAndMul; } else { TORCH_CHECK(false, "Invalid act type: " + act); } @@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output, } } +template +void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output, + const int32_t m_size, const int32_t n_size, + const int32_t input_stride, const int32_t output_stride) { + using scalar_vec_t = typename cpu_utils::VecTypeTrait::vec_t; + const int32_t dim = n_size / 2; + float* __restrict__ gate = input; + float* __restrict__ up = input + dim; + vec_op::FP32Vec16 one_vec(1.0); + vec_op::FP32Vec16 w1_vec(M_SQRT1_2); + vec_op::FP32Vec16 w2_vec(0.5); + alignas(64) float temp[16]; + + DEFINE_FAST_EXP + + for (int32_t m = 0; m < m_size; ++m) { + for (int32_t n = 0; n < dim; n += 16) { + vec_op::FP32Vec16 gate_vec(gate + n); + vec_op::FP32Vec16 up_vec(up + n); + auto er_input_vec = gate_vec * w1_vec; + + er_input_vec.save(temp); + for (int32_t i = 0; i < 16; ++i) { + temp[i] = std::erf(temp[i]); + } + vec_op::FP32Vec16 er_vec(temp); + auto gelu = gate_vec * w2_vec * (one_vec + er_vec); + auto gated_output_fp32 = up_vec * gelu; + scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32); + gated_output.save(output + n); + } + gate += input_stride; + up += input_stride; + output += output_stride; + } +} + template FORCE_INLINE void apply_gated_act(const FusedMOEAct act, float* __restrict__ input, @@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act, case FusedMOEAct::SiluAndMul: silu_and_mul(input, output, m, n, input_stride, output_stride); return; + case FusedMOEAct::GeluAndMul: + gelu_and_mul(input, output, m, n, input_stride, output_stride); + return; default: TORCH_CHECK(false, "Unsupported act type."); } diff --git a/tests/kernels/moe/test_cpu_fused_moe.py b/tests/kernels/moe/test_cpu_fused_moe.py index 467ba3c5f..73859175c 100644 --- a/tests/kernels/moe/test_cpu_fused_moe.py +++ b/tests/kernels/moe/test_cpu_fused_moe.py @@ -20,7 +20,7 @@ EXPERT_NUM = [ HIDDEN_DIM = [128, 2880] INTERMEDIATE_DIM = [128, 2880] BATCH_SIZE = [1, 64, 256] -ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI] +ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI, MoEActivation.GELU] USE_BIAS = [True, False] ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"] DTYPE = [torch.bfloat16] diff --git a/vllm/envs.py b/vllm/envs.py index 0d68b0f97..0a40030cf 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -51,6 +51,7 @@ if TYPE_CHECKING: VLLM_CPU_OMP_THREADS_BIND: str = "auto" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_SGL_KERNEL: bool = False + VLLM_CPU_ATTN_SPLIT_KV: bool = True VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True VLLM_CPU_INT4_W4A8: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") @@ -726,6 +727,10 @@ environment_variables: dict[str, Callable[[], Any]] = { else None, # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), + # (CPU backend only) whether to enable attention spilt KV. + "VLLM_CPU_ATTN_SPLIT_KV": lambda: bool( + int(os.getenv("VLLM_CPU_ATTN_SPLIT_KV", "1")) + ), # (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout # at model load time. Eliminates per-inference layout conversion overhead. "VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool( diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 72e9db514..e1bedd6f4 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -34,12 +34,20 @@ def _swigluoai_forward_native( return gated_output +def _gelu_and_mul( + x: torch.Tensor, +) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate="none") * x[..., d:] + + # Map activation names to their native forward functions. # Uses static methods or standalone functions to avoid instantiating CustomOp # classes, which would call get_current_vllm_config() before config is set. _CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = { MoEActivation.SILU: SiluAndMul.forward_native, MoEActivation.SWIGLUOAI: _swigluoai_forward_native, + MoEActivation.GELU: _gelu_and_mul, } diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 90151a251..5216301ef 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -6,6 +6,7 @@ from typing import ClassVar import torch from vllm import _custom_ops as ops +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform @@ -181,7 +182,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata] causal=causal, sliding_window_size=self.window_size, isa=self.isa, - enable_kv_split=True, + enable_kv_split=envs.VLLM_CPU_ATTN_SPLIT_KV, ) attn_metadata = CPUAttentionMetadata(