[Bugfix][CPU] Fix llama4 inference on CPU (#34321)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@@ -3078,6 +3078,7 @@ def cpu_fused_moe(
|
||||
topk_ids: torch.Tensor,
|
||||
act: str,
|
||||
isa: str,
|
||||
skip_weighted: bool = False,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(input)
|
||||
torch.ops._C.cpu_fused_moe(
|
||||
@@ -3089,6 +3090,7 @@ def cpu_fused_moe(
|
||||
w2_bias,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
skip_weighted,
|
||||
act,
|
||||
isa,
|
||||
)
|
||||
|
||||
@@ -238,7 +238,6 @@ class CPUFusedMOE:
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert activation in _CPU_MOE_ACT_FN, f"{activation} is not supported."
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -261,6 +260,7 @@ class CPUFusedMOE:
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def check_grouped_gemm(
|
||||
@@ -355,7 +355,14 @@ class CPUFusedMOE:
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int = -1,
|
||||
skip_weighted: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if skip_weighted:
|
||||
assert topk_ids.size(1) == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
input.mul_(topk_weights.to(input.dtype))
|
||||
|
||||
output = cpu_fused_moe(
|
||||
input,
|
||||
layer.w13_weight,
|
||||
@@ -366,6 +373,7 @@ class CPUFusedMOE:
|
||||
topk_ids,
|
||||
activation,
|
||||
self.isa,
|
||||
skip_weighted,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -377,7 +385,14 @@ class CPUFusedMOE:
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int = -1,
|
||||
skip_weighted: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if skip_weighted:
|
||||
assert topk_ids.size(1) == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
input.mul_(topk_weights.to(input.dtype))
|
||||
|
||||
output = torch.empty_like(input)
|
||||
layer_id = id(layer)
|
||||
torch.ops.vllm.cpu_fused_moe_torch(
|
||||
@@ -388,6 +403,7 @@ class CPUFusedMOE:
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
skip_weighted,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -401,6 +417,7 @@ def cpu_fused_moe_torch(
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int = -1,
|
||||
skip_weighted: bool = False,
|
||||
) -> None:
|
||||
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
|
||||
|
||||
@@ -434,13 +451,16 @@ def cpu_fused_moe_torch(
|
||||
new_x = torch.empty_like(outs)
|
||||
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weights.dtype)
|
||||
.mul_(topk_weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
if skip_weighted:
|
||||
final_out = new_x
|
||||
else:
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.type(topk_weights.dtype)
|
||||
.mul_(topk_weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(new_x.dtype)
|
||||
)
|
||||
output.copy_(final_out)
|
||||
|
||||
|
||||
|
||||
@@ -160,12 +160,21 @@ class CPUWorker(Worker):
|
||||
x for x in logical_cpu_list if x.numa_node == selected_numa_node
|
||||
]
|
||||
else:
|
||||
assert len(logical_cpu_list) >= self.parallel_config.world_size
|
||||
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
|
||||
sim_cpu_num_per_node = (
|
||||
len(logical_cpu_list) // self.parallel_config.world_size
|
||||
# This is a bit tricky because the internal DP size
|
||||
# is always 1 for non-MoE models
|
||||
world_size_across_dp = (
|
||||
self.parallel_config.world_size
|
||||
* self.parallel_config._api_process_count
|
||||
)
|
||||
start_idx = self.local_rank * sim_cpu_num_per_node
|
||||
assert len(logical_cpu_list) >= world_size_across_dp
|
||||
logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.numa_node)
|
||||
sim_cpu_num_per_node = len(logical_cpu_list) // world_size_across_dp
|
||||
assert self.parallel_config.data_parallel_rank_local is not None
|
||||
start_idx = (
|
||||
self.local_rank
|
||||
+ self.parallel_config.world_size
|
||||
* self.parallel_config.data_parallel_rank_local
|
||||
) * sim_cpu_num_per_node
|
||||
logical_cpu_list = logical_cpu_list[
|
||||
start_idx : (start_idx + sim_cpu_num_per_node)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user