[Bugfix] Add CpuCommunicator.dispatch and combine to fix DP+MoE inference (#31867)

Signed-off-by: kunzh <zhikun.wu@outlook.com>
This commit is contained in:
kzwrime
2026-01-15 12:50:48 +08:00
committed by GitHub
parent d86fc23bdd
commit edadca109c
2 changed files with 45 additions and 1 deletions

View File

@@ -8,11 +8,14 @@ import torch
from torch.distributed import ProcessGroup
from vllm.distributed.utils import pickle
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CpuCommunicator(DeviceCommunicatorBase):
def __init__(
@@ -32,6 +35,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
):
self.dist_module = _CPUSHMDistributed(self)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on CPU. "
"Falling back to `naive` all2all manager for CPU.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
@@ -110,6 +127,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
) -> dict[str, torch.Tensor | Any]:
return self.dist_module.recv_tensor_dict(src)
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(
hidden_states, is_sequence_parallel
)
return hidden_states
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):