[1/N] Elastic EP Milestone 2 (#34861)

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
Itay Alroy
2026-02-28 06:46:42 +02:00
committed by GitHub
parent 90805ff464
commit dea268336f
53 changed files with 3613 additions and 1016 deletions

View File

@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,