[Bugfix][CPU] Fix llama4 inference on CPU (#34321)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2026-02-11 19:07:23 +08:00
committed by GitHub
parent 40b8f55358
commit 05339a7b20
6 changed files with 60 additions and 18 deletions

View File

@@ -3078,6 +3078,7 @@ def cpu_fused_moe(
topk_ids: torch.Tensor,
act: str,
isa: str,
skip_weighted: bool = False,
) -> torch.Tensor:
output = torch.empty_like(input)
torch.ops._C.cpu_fused_moe(
@@ -3089,6 +3090,7 @@ def cpu_fused_moe(
w2_bias,
topk_weights,
topk_ids,
skip_weighted,
act,
isa,
)

View File

@@ -238,7 +238,6 @@ class CPUFusedMOE:
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -261,6 +260,7 @@ class CPUFusedMOE:
topk_ids,
activation,
global_num_experts,
apply_router_weight_on_input,
)
def check_grouped_gemm(
@@ -355,7 +355,14 @@ class CPUFusedMOE:
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor:
if skip_weighted:
assert topk_ids.size(1) == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
input.mul_(topk_weights.to(input.dtype))
output = cpu_fused_moe(
input,
layer.w13_weight,
@@ -366,6 +373,7 @@ class CPUFusedMOE:
topk_ids,
activation,
self.isa,
skip_weighted,
)
return output
@@ -377,7 +385,14 @@ class CPUFusedMOE:
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> torch.Tensor:
if skip_weighted:
assert topk_ids.size(1) == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
input.mul_(topk_weights.to(input.dtype))
output = torch.empty_like(input)
layer_id = id(layer)
torch.ops.vllm.cpu_fused_moe_torch(
@@ -388,6 +403,7 @@ class CPUFusedMOE:
topk_ids,
activation,
global_num_experts,
skip_weighted,
)
return output
@@ -401,6 +417,7 @@ def cpu_fused_moe_torch(
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
skip_weighted: bool = False,
) -> None:
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
@@ -434,13 +451,16 @@ def cpu_fused_moe_torch(
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
if skip_weighted:
final_out = new_x
else:
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
output.copy_(final_out)

View File

@@ -160,12 +160,21 @@ class CPUWorker(Worker):
x for x in logical_cpu_list if x.numa_node == selected_numa_node
]
else:
assert len(logical_cpu_list) >= self.parallel_config.world_size
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
sim_cpu_num_per_node = (
len(logical_cpu_list) // self.parallel_config.world_size
# This is a bit tricky because the internal DP size
# is always 1 for non-MoE models
world_size_across_dp = (
self.parallel_config.world_size
* self.parallel_config._api_process_count
)
start_idx = self.local_rank * sim_cpu_num_per_node
assert len(logical_cpu_list) >= world_size_across_dp
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
assert self.parallel_config.data_parallel_rank_local is not None
start_idx = (
self.local_rank
+ self.parallel_config.world_size
* self.parallel_config.data_parallel_rank_local
) * sim_cpu_num_per_node
logical_cpu_list = logical_cpu_list[
start_idx : (start_idx + sim_cpu_num_per_node)
]