[WideEP] Remove pplx all2all backend (#33724)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
0f2f24c8b2
commit
eb19955c37
@@ -3,14 +3,13 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
@@ -235,96 +234,6 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install pplx_kernels."
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
|
||||
Reference in New Issue
Block a user