Refactor pplx init logic to make it modular (prepare for deepep) (#18200)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-05-23 23:43:43 +08:00
committed by GitHub
parent 022d8abe29
commit 6a7988c55b
16 changed files with 300 additions and 287 deletions

View File

@@ -1,44 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import TYPE_CHECKING
import torch
import torch.distributed as dist
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from .base_device_communicator import All2AllManagerBase, Cache
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None
class All2AllBase:
def __init__(self, cpu_group, model):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_ep_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
self.ep_group = get_ep_group()
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class NaiveAll2All(All2AllBase):
class NaiveAll2AllManager(All2AllManagerBase):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
@@ -46,8 +26,8 @@ class NaiveAll2All(All2AllBase):
debugging.
"""
def __init__(self, cpu_group, model):
super().__init__(cpu_group, model)
def __init__(self, cpu_group):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
@@ -91,3 +71,56 @@ class NaiveAll2All(All2AllBase):
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: "
"rank=%d, world size=%d", self.rank, self.world_size)
uid = nvshmem_get_unique_id(
) if self.rank == 0 else nvshmem_alloc_empty_unique_id()
dist.broadcast(uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()

View File

@@ -1,11 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional
from weakref import WeakValueDictionary
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance
class All2AllManagerBase:
def __init__(self, cpu_group):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
@@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
@@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
if not self.use_all2all:
return
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)
def dispatch(
self, hidden_states: torch.Tensor,

View File

@@ -6,10 +6,12 @@ import torch
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .all2all import All2AllBase
from .base_device_communicator import DeviceCommunicatorBase
logger = init_logger(__name__)
class CudaCommunicator(DeviceCommunicatorBase):
@@ -31,8 +33,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_all2all = "ep" in unique_name
self.all2all_impl: Optional[All2AllBase] = None
self.use_custom_allreduce = use_custom_allreduce
# lazy import to avoid documentation build error
@@ -56,6 +56,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
elif all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
logger.info("Using PPLX all2all manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
def all_reduce(self, input_):
# always try custom allreduce first,
# and then pynccl.
@@ -136,31 +149,19 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_impl is not None:
self.all2all_impl.destroy()
self.all2all_impl = None
def prepare_communication_buffer_for_model(self,
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
return
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
from .all2all import NaiveAll2All
self.all2all_impl = NaiveAll2All(self.cpu_group, model)
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_impl is not None
hidden_states, router_logits = self.all2all_impl.dispatch(
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_impl is not None
hidden_states = self.all2all_impl.combine(hidden_states)
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states