[CPU] Refactor CPU fused MOE (#30531)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-12-18 14:36:49 +08:00
committed by GitHub
parent fc2ae6d617
commit e3ab93c896
23 changed files with 1388 additions and 200 deletions

View File

@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
const std::optional<torch::Tensor>& bias,
const int64_t pack_factor, const std::string& isa_hint);
void prepack_moe_weight(const torch::Tensor& weight,
torch::Tensor& packed_weight, const std::string& isa);
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const torch::Tensor& w13, const torch::Tensor& w2,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"pack_factor, str isa_hint) -> ()");
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
#endif
// fused moe
#if defined(__AVX512F__)
ops.def(
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
"-> ()");
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {