Refactor pplx init logic to make it modular (prepare for deepep) (#18200)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1,11 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import threading
|
||||
from typing import Optional
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Cache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def get_or_create(self, kwargs, func):
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
instance = func(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.intranode = in_the_same_node_as(cpu_group, source_rank=0)
|
||||
self.internode = not self.intranode
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
# based on the kwargs.
|
||||
# different layers can have different configs,
|
||||
# e.g. one layer has hidden size 1024, another has 2048.
|
||||
# usually the underlying implementation caches the handle
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
@@ -31,6 +96,18 @@ class DeviceCommunicatorBase:
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group,
|
||||
self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
|
||||
self.use_all2all = "ep" in unique_name and use_ep
|
||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@@ -154,9 +231,17 @@ class DeviceCommunicatorBase:
|
||||
model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
pass
|
||||
if not self.use_all2all:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module for module in model.modules()
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize(module.moe_config,
|
||||
module.quant_config)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user