[UX] Replace VLLM_ALL2ALL_BACKEND with --all2all-backend (#26732)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -111,6 +111,7 @@ class DeviceCommunicatorBase:
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
@@ -119,9 +120,11 @@ class DeviceCommunicatorBase:
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -91,33 +91,32 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "allgather_reducescatter":
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "pplx":
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "deepep_high_throughput":
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "deepep_low_latency":
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "flashinfer_all2allv":
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
if is_global_first_rank():
|
||||
logger.info(
|
||||
|
||||
@@ -6,7 +6,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
@@ -24,15 +23,14 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend != "naive":
|
||||
if self.all2all_backend != "naive":
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU."
|
||||
"`%s` all2all manager is not supported on XPU. "
|
||||
"Falling back to `naive` all2all manager for XPU.",
|
||||
all2all_backend,
|
||||
self.all2all_backend,
|
||||
)
|
||||
all2all_backend = "naive"
|
||||
if all2all_backend == "naive":
|
||||
self.all2all_backend = "naive"
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
|
||||
Reference in New Issue
Block a user