[Bugfix][CPU] Fix llama4 inference on CPU (#34321)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2026-02-11 19:07:23 +08:00
committed by GitHub
parent 40b8f55358
commit 05339a7b20
6 changed files with 60 additions and 18 deletions

View File

@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
const std::optional<torch::Tensor>& w13_bias,
const std::optional<torch::Tensor>& w2_bias,
const torch::Tensor& topk_weights,
const torch::Tensor& topk_id, const std::string& act,
const std::string& isa);
const torch::Tensor& topk_id, const bool skip_weighted,
const std::string& act, const std::string& isa);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
"bool skip_weighted, "
"str act, str isa) -> ()");
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
#endif