[2/N] Elastic EP Milestone 2: Integrating NIXL-EP (#35627)

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
Itay Alroy
2026-03-13 15:25:33 +02:00
committed by GitHub
parent 82f836d976
commit d5af196c18
14 changed files with 635 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
import torch
@@ -413,6 +414,121 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
return 0
class NixlEPAll2AllManager(All2AllManagerBase):
"""
All2All communication based on NIXL EP kernels.
This backend supports elastic EP with dynamic rank connection/disconnection.
"""
# (nixl_ep_buffer, ep_size)
_buffer: tuple[Any, int] | None = None
_lock = threading.Lock()
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS
def _init_buffer(
self,
max_num_tokens_per_dp_rank: int,
token_hidden_size: int,
num_experts_per_rank: int,
) -> None:
from nixl_ep import Buffer # type: ignore[import-not-found]
max_num_global_experts = self.max_num_ep_ranks * num_experts_per_rank
num_rdma_bytes = Buffer.get_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=self.max_num_ep_ranks,
num_experts=max_num_global_experts,
)
assert NixlEPAll2AllManager._buffer is None, (
"NIXL EP buffer already initialized"
)
buffer = Buffer(
rank=self.rank,
tcp_store_group=self.tcp_store_group.store,
)
buffer.update_memory_buffers(
num_ranks=self.max_num_ep_ranks,
num_experts_per_rank=num_experts_per_rank,
num_rdma_bytes=num_rdma_bytes,
)
ranks_to_connect = list(range(self.cpu_group.size()))
buffer.connect_ranks(ranks_to_connect)
NixlEPAll2AllManager._buffer = (buffer, self.cpu_group.size())
def _update_buffer(self):
assert NixlEPAll2AllManager._buffer is not None
buffer, current_ep_size = NixlEPAll2AllManager._buffer
current_ranks = list(range(current_ep_size))
new_ep_size = self.cpu_group.size()
buffer.set_tcp_store_group(self.tcp_store_group.store)
if new_ep_size > len(current_ranks):
ranks_to_connect = list(range(len(current_ranks), new_ep_size))
buffer.connect_ranks(ranks_to_connect)
else:
ranks_to_disconnect = current_ranks[new_ep_size:]
buffer.disconnect_ranks(ranks_to_disconnect)
NixlEPAll2AllManager._buffer = (buffer, new_ep_size)
def get_handle(self, kwargs):
with NixlEPAll2AllManager._lock:
if (
NixlEPAll2AllManager._buffer is not None
and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
):
return NixlEPAll2AllManager._buffer[0]
num_experts_per_rank = (
kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
)
nixl_kwargs = dict(
max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
token_hidden_size=kwargs["token_hidden_size"],
num_experts_per_rank=num_experts_per_rank,
)
if NixlEPAll2AllManager._buffer is None:
self._init_buffer(**nixl_kwargs)
else:
self._update_buffer()
assert NixlEPAll2AllManager._buffer is not None
handle = NixlEPAll2AllManager._buffer[0]
return handle
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
# NOTE(yongji): NIXLEPAll2AllManager instance is recreated during
# scale-up/down, so we cannot destroy the persistent buffer here.
assert NixlEPAll2AllManager._buffer is not None
buffer = NixlEPAll2AllManager._buffer[0]
buffer.set_tcp_store_group(None)
# NIXL EP uses RDMA so no SMs are used for communication
def max_sms_used(self) -> int | None:
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.