[Bugfix][CPU] Fix llama4 inference on CPU (#34321)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -119,8 +119,8 @@ void cpu_fused_moe(torch::Tensor& output, const torch::Tensor& input,
|
||||
const std::optional<torch::Tensor>& w13_bias,
|
||||
const std::optional<torch::Tensor>& w2_bias,
|
||||
const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_id, const std::string& act,
|
||||
const std::string& isa);
|
||||
const torch::Tensor& topk_id, const bool skip_weighted,
|
||||
const std::string& act, const std::string& isa);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
@@ -320,6 +320,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"cpu_fused_moe(Tensor(a0!) output, Tensor input, Tensor w13, Tensor w2, "
|
||||
"Tensor? w13_bias, Tensor? w2_bias, Tensor topk_weights, Tensor topk_id, "
|
||||
"bool skip_weighted, "
|
||||
"str act, str isa) -> ()");
|
||||
ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe);
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user