[1/N] Elastic EP Milestone 2 (#34861)
Signed-off-by: Yongji Wu <wuyongji317@gmail.com> Signed-off-by: Itay Alroy <ialroy@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -31,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
@@ -138,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
@@ -239,12 +239,12 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
@@ -282,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
self.handle_cache._cache.clear()
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
@@ -290,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
@@ -314,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -347,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
@@ -387,6 +391,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -418,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
|
||||
Reference in New Issue
Block a user