Refactor pplx init logic to make it modular (prepare for deepep) (#18200)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-05-23 23:43:43 +08:00
committed by GitHub
parent 022d8abe29
commit 6a7988c55b
16 changed files with 300 additions and 287 deletions

View File

@@ -1,11 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
import threading
from typing import Optional
from weakref import WeakValueDictionary
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
class Cache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def get_or_create(self, kwargs, func):
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
instance = func(**kwargs)
self._cache[key] = instance
return instance
class All2AllManagerBase:
def __init__(self, cpu_group):
self.cpu_group = cpu_group
# compute some common properties
from vllm.distributed.parallel_state import (get_dp_group,
get_tp_group,
in_the_same_node_as)
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
self.internode = not self.intranode
def get_handle(self, kwargs):
# get a handle for the all2all communication,
# based on the kwargs.
# different layers can have different configs,
# e.g. one layer has hidden size 1024, another has 2048.
# usually the underlying implementation caches the handle
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
pass
class DeviceCommunicatorBase:
"""
Base class for device-specific communicator.
@@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
self.rank_in_group = dist.get_group_rank(self.cpu_group,
self.global_rank)
use_ep = False
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
# as long as we use data parallel (coupled data parallel
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_
@@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
model: torch.nn.Module) -> None:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
if not self.use_all2all:
return
moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config,
module.quant_config)
def dispatch(
self, hidden_states: torch.Tensor,