[2/N] Elastic EP Milestone 2: Integrating NIXL-EP (#35627)
Signed-off-by: Itay Alroy <ialroy@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -413,6 +414,121 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
return 0
|
||||
|
||||
|
||||
class NixlEPAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on NIXL EP kernels.
|
||||
This backend supports elastic EP with dynamic rank connection/disconnection.
|
||||
"""
|
||||
|
||||
# (nixl_ep_buffer, ep_size)
|
||||
_buffer: tuple[Any, int] | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS
|
||||
|
||||
def _init_buffer(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_experts_per_rank: int,
|
||||
) -> None:
|
||||
from nixl_ep import Buffer # type: ignore[import-not-found]
|
||||
|
||||
max_num_global_experts = self.max_num_ep_ranks * num_experts_per_rank
|
||||
num_rdma_bytes = Buffer.get_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=self.max_num_ep_ranks,
|
||||
num_experts=max_num_global_experts,
|
||||
)
|
||||
assert NixlEPAll2AllManager._buffer is None, (
|
||||
"NIXL EP buffer already initialized"
|
||||
)
|
||||
buffer = Buffer(
|
||||
rank=self.rank,
|
||||
tcp_store_group=self.tcp_store_group.store,
|
||||
)
|
||||
buffer.update_memory_buffers(
|
||||
num_ranks=self.max_num_ep_ranks,
|
||||
num_experts_per_rank=num_experts_per_rank,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
)
|
||||
ranks_to_connect = list(range(self.cpu_group.size()))
|
||||
buffer.connect_ranks(ranks_to_connect)
|
||||
NixlEPAll2AllManager._buffer = (buffer, self.cpu_group.size())
|
||||
|
||||
def _update_buffer(self):
|
||||
assert NixlEPAll2AllManager._buffer is not None
|
||||
buffer, current_ep_size = NixlEPAll2AllManager._buffer
|
||||
current_ranks = list(range(current_ep_size))
|
||||
new_ep_size = self.cpu_group.size()
|
||||
buffer.set_tcp_store_group(self.tcp_store_group.store)
|
||||
if new_ep_size > len(current_ranks):
|
||||
ranks_to_connect = list(range(len(current_ranks), new_ep_size))
|
||||
buffer.connect_ranks(ranks_to_connect)
|
||||
else:
|
||||
ranks_to_disconnect = current_ranks[new_ep_size:]
|
||||
buffer.disconnect_ranks(ranks_to_disconnect)
|
||||
NixlEPAll2AllManager._buffer = (buffer, new_ep_size)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
with NixlEPAll2AllManager._lock:
|
||||
if (
|
||||
NixlEPAll2AllManager._buffer is not None
|
||||
and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
|
||||
):
|
||||
return NixlEPAll2AllManager._buffer[0]
|
||||
|
||||
num_experts_per_rank = (
|
||||
kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
|
||||
)
|
||||
nixl_kwargs = dict(
|
||||
max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
|
||||
token_hidden_size=kwargs["token_hidden_size"],
|
||||
num_experts_per_rank=num_experts_per_rank,
|
||||
)
|
||||
if NixlEPAll2AllManager._buffer is None:
|
||||
self._init_buffer(**nixl_kwargs)
|
||||
else:
|
||||
self._update_buffer()
|
||||
|
||||
assert NixlEPAll2AllManager._buffer is not None
|
||||
handle = NixlEPAll2AllManager._buffer[0]
|
||||
return handle
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
# NOTE(yongji): NIXLEPAll2AllManager instance is recreated during
|
||||
# scale-up/down, so we cannot destroy the persistent buffer here.
|
||||
assert NixlEPAll2AllManager._buffer is not None
|
||||
buffer = NixlEPAll2AllManager._buffer[0]
|
||||
buffer.set_tcp_store_group(None)
|
||||
|
||||
# NIXL EP uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> int | None:
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
|
||||
Reference in New Issue
Block a user