[CPU] Refactor CPU fused MOE (#30531)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -110,6 +110,17 @@ void cpu_gemm_wna16(const torch::Tensor& input, const torch::Tensor& q_weight,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
const int64_t pack_factor, const std::string& isa_hint);
|
||||
|
||||
void prepack_moe_weight(const torch::Tensor& weight,
|
||||
torch::Tensor& packed_weight, const std::string& isa);
|
||||
|
||||
void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
|
||||
const torch::Tensor& w13, const torch::Tensor& w2,
|
||||
const std::optional<torch::Tensor>& w13_bias,
|
||||
const std::optional<torch::Tensor>& w2_bias,
|
||||
const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_id, const std::string& act,
|
||||
const std::string& isa);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@@ -296,6 +307,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"pack_factor, str isa_hint) -> ()");
|
||||
ops.impl("cpu_gemm_wna16", torch::kCPU, &cpu_gemm_wna16);
|
||||
#endif
|
||||
|
||||
// fused moe
|
||||
#if defined(__AVX512F__)
|
||||
ops.def(
|
||||
"prepack_moe_weight(Tensor weight, Tensor(a1!) packed_weight, str isa) "
|
||||
"-> ()");
|
||||
ops.impl("prepack_moe_weight", torch::kCPU, &prepack_moe_weight);
|
||||
ops.def(
|
||||
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
|
||||
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
|
||||
"str act, str isa) -> ()");
|
||||
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
|
||||
Reference in New Issue
Block a user