[CPU] Enable shared-memory based pipeline parallel for CPU backend (#21289)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-07-22 00:07:08 +08:00
committed by GitHub
parent 6dda13c86b
commit a15a50fc17
8 changed files with 165 additions and 59 deletions

View File

@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
from typing import Any, Optional, Union
import torch
from torch.distributed import ProcessGroup
from vllm.distributed.utils import pickle
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
@@ -26,7 +27,8 @@ class CpuCommunicator(DeviceCommunicatorBase):
if (current_platform.get_cpu_architecture()
== CpuArchEnum.X86) and hasattr(
torch.ops._C,
"init_shm_manager") and unique_name.startswith("tp"):
"init_shm_manager") and (unique_name.startswith("tp")
or unique_name.startswith("pp")):
self.dist_module = _CPUSHMDistributed(self)
def all_reduce(self, input_):
@@ -94,6 +96,19 @@ class CpuCommunicator(DeviceCommunicatorBase):
input_size[dim + 1:])
return output_tensor
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
return self.dist_module.send_tensor_dict(tensor_dict, dst)
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
return self.dist_module.recv_tensor_dict(src)
class _CPUSHMDistributed:
@@ -143,3 +158,44 @@ class _CPUSHMDistributed:
input: torch.Tensor,
group: Optional[ProcessGroup] = None) -> None:
torch.ops._C.shm_all_gather(self.handle, input, output)
def send_tensor_dict(
self,
tensor_dict: dict[str, Union[torch.Tensor, Any]],
dst: int,
) -> None:
key_list = list(tensor_dict.keys())
value_list = list(tensor_dict.values())
size_list = []
for v in value_list:
if not isinstance(v, torch.Tensor):
raise RuntimeError(
"CpuCommunicator only supports sending tensors.")
size_list.append(v.size())
key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]),
dtype=torch.uint8)
value_list.append(key_size_tensor)
torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst)
return None
def recv_tensor_dict(
self,
src: int,
) -> dict[str, Union[torch.Tensor, Any]]:
tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src)
value_list: list[torch.Tensor] = tensor_list[:-1]
key_size_tensor = tensor_list[-1]
key_size = pickle.loads(key_size_tensor.numpy().tobytes())
key_list = key_size[0]
size_list = key_size[1]
assert len(key_list) == len(size_list)
assert len(key_list) == len(value_list)
tensor_dict: dict[str, torch.Tensor] = {}
for key, size, t in zip(key_list, size_list, value_list):
tensor_dict[key] = t.view(size)
return tensor_dict