[Bugfix] Add CpuCommunicator.dispatch and combine to fix DP+MoE inference (#31867)
Signed-off-by: kunzh <zhikun.wu@outlook.com>
This commit is contained in:
@@ -8,11 +8,14 @@ import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.utils import pickle
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CpuCommunicator(DeviceCommunicatorBase):
|
||||
def __init__(
|
||||
@@ -32,6 +35,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on CPU. "
|
||||
"Falling back to `naive` all2all manager for CPU.",
|
||||
self.all2all_backend, # type: ignore[has-type]
|
||||
)
|
||||
self.all2all_backend = "naive"
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@@ -110,6 +127,30 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
def dispatch( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
|
||||
Reference in New Issue
Block a user