[CPU] Support gelu act in cpu_fused_moe (#38770)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -30,13 +30,15 @@
|
||||
}()
|
||||
|
||||
namespace {
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul };
|
||||
enum class FusedMOEAct { SiluAndMul, SwigluOAIAndMul, GeluAndMul };
|
||||
|
||||
FusedMOEAct get_act_type(const std::string& act) {
|
||||
if (act == "silu") {
|
||||
return FusedMOEAct::SiluAndMul;
|
||||
} else if (act == "swigluoai") {
|
||||
return FusedMOEAct::SwigluOAIAndMul;
|
||||
} else if (act == "gelu") {
|
||||
return FusedMOEAct::GeluAndMul;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid act type: " + act);
|
||||
}
|
||||
@@ -104,6 +106,43 @@ void silu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void gelu_and_mul(float* __restrict__ input, scalar_t* __restrict__ output,
|
||||
const int32_t m_size, const int32_t n_size,
|
||||
const int32_t input_stride, const int32_t output_stride) {
|
||||
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
|
||||
const int32_t dim = n_size / 2;
|
||||
float* __restrict__ gate = input;
|
||||
float* __restrict__ up = input + dim;
|
||||
vec_op::FP32Vec16 one_vec(1.0);
|
||||
vec_op::FP32Vec16 w1_vec(M_SQRT1_2);
|
||||
vec_op::FP32Vec16 w2_vec(0.5);
|
||||
alignas(64) float temp[16];
|
||||
|
||||
DEFINE_FAST_EXP
|
||||
|
||||
for (int32_t m = 0; m < m_size; ++m) {
|
||||
for (int32_t n = 0; n < dim; n += 16) {
|
||||
vec_op::FP32Vec16 gate_vec(gate + n);
|
||||
vec_op::FP32Vec16 up_vec(up + n);
|
||||
auto er_input_vec = gate_vec * w1_vec;
|
||||
|
||||
er_input_vec.save(temp);
|
||||
for (int32_t i = 0; i < 16; ++i) {
|
||||
temp[i] = std::erf(temp[i]);
|
||||
}
|
||||
vec_op::FP32Vec16 er_vec(temp);
|
||||
auto gelu = gate_vec * w2_vec * (one_vec + er_vec);
|
||||
auto gated_output_fp32 = up_vec * gelu;
|
||||
scalar_vec_t gated_output = scalar_vec_t(gated_output_fp32);
|
||||
gated_output.save(output + n);
|
||||
}
|
||||
gate += input_stride;
|
||||
up += input_stride;
|
||||
output += output_stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
float* __restrict__ input,
|
||||
@@ -118,6 +157,9 @@ FORCE_INLINE void apply_gated_act(const FusedMOEAct act,
|
||||
case FusedMOEAct::SiluAndMul:
|
||||
silu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
case FusedMOEAct::GeluAndMul:
|
||||
gelu_and_mul(input, output, m, n, input_stride, output_stride);
|
||||
return;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported act type.");
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ EXPERT_NUM = [
|
||||
HIDDEN_DIM = [128, 2880]
|
||||
INTERMEDIATE_DIM = [128, 2880]
|
||||
BATCH_SIZE = [1, 64, 256]
|
||||
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
|
||||
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI, MoEActivation.GELU]
|
||||
USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch.cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
@@ -51,6 +51,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CPU_OMP_THREADS_BIND: str = "auto"
|
||||
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
|
||||
VLLM_CPU_SGL_KERNEL: bool = False
|
||||
VLLM_CPU_ATTN_SPLIT_KV: bool = True
|
||||
VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True
|
||||
VLLM_CPU_INT4_W4A8: bool = True
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
@@ -726,6 +727,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
else None,
|
||||
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
|
||||
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
|
||||
# (CPU backend only) whether to enable attention spilt KV.
|
||||
"VLLM_CPU_ATTN_SPLIT_KV": lambda: bool(
|
||||
int(os.getenv("VLLM_CPU_ATTN_SPLIT_KV", "1"))
|
||||
),
|
||||
# (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout
|
||||
# at model load time. Eliminates per-inference layout conversion overhead.
|
||||
"VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool(
|
||||
|
||||
@@ -34,12 +34,20 @@ def _swigluoai_forward_native(
|
||||
return gated_output
|
||||
|
||||
|
||||
def _gelu_and_mul(
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate="none") * x[..., d:]
|
||||
|
||||
|
||||
# Map activation names to their native forward functions.
|
||||
# Uses static methods or standalone functions to avoid instantiating CustomOp
|
||||
# classes, which would call get_current_vllm_config() before config is set.
|
||||
_CPU_MOE_ACT_FN: dict[MoEActivation, Callable[[torch.Tensor], torch.Tensor]] = {
|
||||
MoEActivation.SILU: SiluAndMul.forward_native,
|
||||
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
|
||||
MoEActivation.GELU: _gelu_and_mul,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import ClassVar
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
@@ -181,7 +182,7 @@ class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]
|
||||
causal=causal,
|
||||
sliding_window_size=self.window_size,
|
||||
isa=self.isa,
|
||||
enable_kv_split=True,
|
||||
enable_kv_split=envs.VLLM_CPU_ATTN_SPLIT_KV,
|
||||
)
|
||||
|
||||
attn_metadata = CPUAttentionMetadata(
|
||||
|
||||
Reference in New Issue
Block a user