[CPU] Support gelu act in cpu_fused_moe (#38770)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2026-04-02 14:14:32 +08:00
committed by GitHub
parent 9bd7231106
commit c6f722b93e
5 changed files with 59 additions and 3 deletions

View File

@@ -30,13 +30,15 @@
}()
namespace {
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
FusedMOEAct get_act_type(const std::string& act) {
if (act == "silu") {
return FusedMOEAct::SiluAndMul;
} else if (act == "swigluoai") {
return FusedMOEAct::SwigluOAIAndMul;
} else if (act == "gelu") {
return FusedMOEAct::GeluAndMul;
} else {
TORCH_CHECK(false, "Invalid act type: " + act);
}
@@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
}
}
template <typename scalar_t>
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
const int32_t m_size, const int32_t n_size,
const int32_t input_stride, const int32_t output_stride) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
const int32_t dim = n_size / 2;
float* __restrict__ gate = input;
float* __restrict__ up = input + dim;
vec_op::FP32Vec16 one_vec(1.0);
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
vec_op::FP32Vec16 w2_vec(0.5);
alignas(64) float temp[16];
DEFINE_FAST_EXP
for (int32_t m = 0; m < m_size; ++m) {
for (int32_t n = 0; n < dim; n += 16) {
vec_op::FP32Vec16 gate_vec(gate + n);
vec_op::FP32Vec16 up_vec(up + n);
auto er_input_vec = gate_vec * w1_vec;
er_input_vec.save(temp);
for (int32_t i = 0; i < 16; ++i) {
temp[i] = std::erf(temp[i]);
}
vec_op::FP32Vec16 er_vec(temp);
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
auto gated_output_fp32 = up_vec * gelu;
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
gated_output.save(output + n);
}
gate += input_stride;
up += input_stride;
output += output_stride;
}
}
template <typename scalar_t>
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
float* __restrict__ input,
@@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
case FusedMOEAct::SiluAndMul:
silu_and_mul(input, output, m, n, input_stride, output_stride);
return;
case FusedMOEAct::GeluAndMul:
gelu_and_mul(input, output, m, n, input_stride, output_stride);
return;
default:
TORCH_CHECK(false, "Unsupported act type.");
}

View File

@@ -20,7 +20,7 @@ EXPERT_NUM = [
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI, MoEActivation.GELU]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]

View File

@@ -51,6 +51,7 @@ if TYPE_CHECKING:
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_CPU_ATTN_SPLIT_KV: bool = True
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
VLLM_CPU_INT4_W4A8: bool = True
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
@@ -726,6 +727,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
else None,
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# (CPU backend only) whether to enable attention spilt KV.
"VLLM_CPU_ATTN_SPLIT_KV": lambda: bool(
int(os.getenv("VLLM_CPU_ATTN_SPLIT_KV", "1"))
),
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
# at model load time. Eliminates per-inference layout conversion overhead.
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(

View File

@@ -34,12 +34,20 @@ def _swigluoai_forward_native(
return gated_output
def _gelu_and_mul(
x: torch.Tensor,
) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate="none") * x[..., d:]
# Map activation names to their native forward functions.
# Uses static methods or standalone functions to avoid instantiating CustomOp
# classes, which would call get_current_vllm_config() before config is set.
_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
MoEActivation.SILU: SiluAndMul.forward_native,
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
MoEActivation.GELU: _gelu_and_mul,
}

View File

@@ -6,6 +6,7 @@ from typing import ClassVar
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
@@ -181,7 +182,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
causal=causal,
sliding_window_size=self.window_size,
isa=self.isa,
enable_kv_split=True,
enable_kv_split=envs.VLLM_CPU_ATTN_SPLIT_KV,
)
attn_metadata = CPUAttentionMetadata(