[2/N] Elastic EP Milestone 2: Integrating NIXL-EP (#35627)

Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
Itay Alroy
2026-03-13 15:25:33 +02:00
committed by GitHub
parent 82f836d976
commit d5af196c18
14 changed files with 635 additions and 11 deletions

View File

@@ -150,6 +150,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
data_parallel_hybrid_lb=False,
data_parallel_external_lb=False,
local_engines_only=False,
enable_elastic_ep=False,
)
vllm_config = SimpleNamespace(parallel_config=parallel_config)

View File

@@ -43,6 +43,7 @@ All2AllBackend = Literal[
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"nixl_ep",
"allgather_reducescatter",
"flashinfer_all2allv",
]
@@ -156,6 +157,7 @@ class ParallelConfig:
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
- "nixl_ep": Use nixl-ep kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers: int | None = None
@@ -580,6 +582,7 @@ class ParallelConfig:
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"nixl_ep",
)
and self.enable_expert_parallel
and self.tensor_parallel_size > 1

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
import torch
@@ -413,6 +414,121 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
return 0
class NixlEPAll2AllManager(All2AllManagerBase):
"""
All2All communication based on NIXL EP kernels.
This backend supports elastic EP with dynamic rank connection/disconnection.
"""
# (nixl_ep_buffer, ep_size)
_buffer: tuple[Any, int] | None = None
_lock = threading.Lock()
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS
def _init_buffer(
self,
max_num_tokens_per_dp_rank: int,
token_hidden_size: int,
num_experts_per_rank: int,
) -> None:
from nixl_ep import Buffer # type: ignore[import-not-found]
max_num_global_experts = self.max_num_ep_ranks * num_experts_per_rank
num_rdma_bytes = Buffer.get_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=self.max_num_ep_ranks,
num_experts=max_num_global_experts,
)
assert NixlEPAll2AllManager._buffer is None, (
"NIXL EP buffer already initialized"
)
buffer = Buffer(
rank=self.rank,
tcp_store_group=self.tcp_store_group.store,
)
buffer.update_memory_buffers(
num_ranks=self.max_num_ep_ranks,
num_experts_per_rank=num_experts_per_rank,
num_rdma_bytes=num_rdma_bytes,
)
ranks_to_connect = list(range(self.cpu_group.size()))
buffer.connect_ranks(ranks_to_connect)
NixlEPAll2AllManager._buffer = (buffer, self.cpu_group.size())
def _update_buffer(self):
assert NixlEPAll2AllManager._buffer is not None
buffer, current_ep_size = NixlEPAll2AllManager._buffer
current_ranks = list(range(current_ep_size))
new_ep_size = self.cpu_group.size()
buffer.set_tcp_store_group(self.tcp_store_group.store)
if new_ep_size > len(current_ranks):
ranks_to_connect = list(range(len(current_ranks), new_ep_size))
buffer.connect_ranks(ranks_to_connect)
else:
ranks_to_disconnect = current_ranks[new_ep_size:]
buffer.disconnect_ranks(ranks_to_disconnect)
NixlEPAll2AllManager._buffer = (buffer, new_ep_size)
def get_handle(self, kwargs):
with NixlEPAll2AllManager._lock:
if (
NixlEPAll2AllManager._buffer is not None
and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
):
return NixlEPAll2AllManager._buffer[0]
num_experts_per_rank = (
kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
)
nixl_kwargs = dict(
max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
token_hidden_size=kwargs["token_hidden_size"],
num_experts_per_rank=num_experts_per_rank,
)
if NixlEPAll2AllManager._buffer is None:
self._init_buffer(**nixl_kwargs)
else:
self._update_buffer()
assert NixlEPAll2AllManager._buffer is not None
handle = NixlEPAll2AllManager._buffer[0]
return handle
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
# NOTE(yongji): NIXLEPAll2AllManager instance is recreated during
# scale-up/down, so we cannot destroy the persistent buffer here.
assert NixlEPAll2AllManager._buffer is not None
buffer = NixlEPAll2AllManager._buffer[0]
buffer.set_tcp_store_group(None)
# NIXL EP uses RDMA so no SMs are used for communication
def max_sms_used(self) -> int | None:
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.

View File

@@ -143,6 +143,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import MoriAll2AllManager
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
elif self.all2all_backend == "nixl_ep":
from .all2all import NixlEPAll2AllManager
self.all2all_manager = NixlEPAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager

View File

@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False
VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32
def get_default_cache_root():
@@ -1628,6 +1629,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool(
int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0"))
),
# NIXL EP environment variables
"VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int(
os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "32")
),
}

View File

@@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
make_moe_prepare_and_finalize_no_dp_ep,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori
from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep
logger = init_logger(__name__)
@@ -38,6 +38,11 @@ if current_platform.is_cuda_alike():
)
if has_mori():
from .mori_prepare_finalize import MoriPrepareAndFinalize
if has_nixl_ep():
from .nixl_ep_prepare_finalize import (
NIXL_EP_QUANT_BLOCK_SHAPE,
NixlEPPrepareAndFinalize,
)
def maybe_roundup_layer_hidden_size(
@@ -69,6 +74,11 @@ def maybe_roundup_layer_hidden_size(
hidden_size
)
if moe_parallel_config.use_nixl_ep_kernels:
hidden_size = NixlEPPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size
)
return hidden_size
@@ -209,4 +219,39 @@ def maybe_make_prepare_finalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_nixl_ep_kernels:
assert quant_config is not None
global_to_physical = physical_to_global = local_expert_global_ids = None
if routing_tables is not None:
(
global_to_physical,
physical_to_global,
local_expert_global_ids,
) = routing_tables
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts // all2all_manager.world_size,
)
handle = all2all_manager.get_handle(all_to_all_args)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == NIXL_EP_QUANT_BLOCK_SHAPE
)
prepare_finalize = NixlEPPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
global_to_physical=global_to_physical,
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
)
return prepare_finalize

View File

@@ -976,6 +976,10 @@ class FusedMoEParallelConfig:
def use_mori_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "mori"
@property
def use_nixl_ep_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "nixl_ep"
@staticmethod
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
@@ -1242,3 +1246,7 @@ class FusedMoEConfig:
@property
def use_naive_all2all_kernels(self):
return self.moe_parallel_config.use_naive_all2all_kernels
@property
def use_nixl_ep_kernels(self):
return self.moe_parallel_config.use_nixl_ep_kernels

View File

@@ -177,10 +177,11 @@ def determine_expert_placement_strategy(
if (
moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.use_deepep_ll_kernels
and not moe_parallel_config.use_nixl_ep_kernels
):
logger.warning(
"Round-robin expert placement currently only supports "
"the DeepEP low-latency backend, but '%s' was configured. "
"the DeepEP low-latency or NIXL EP backend, but '%s' was configured. "
"Falling back to linear expert placement.",
moe_parallel_config.all2all_backend,
)
@@ -745,10 +746,10 @@ class FusedMoE(CustomOp):
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
# Currently routing_tables only needed for round-robin expert placement
# with DeepEP-ll all2all backend.
if (
self.expert_placement_strategy != "round_robin"
or not self.moe_parallel_config.use_deepep_ll_kernels
# with DeepEP-ll or NIXL EP all2all backends.
if self.expert_placement_strategy != "round_robin" or (
not self.moe_parallel_config.use_deepep_ll_kernels
and not self.moe_parallel_config.use_nixl_ep_kernels
):
return None

View File

@@ -0,0 +1,406 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import nixl_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input,
normalize_batched_scales_shape,
)
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
dbo_maybe_run_recv_hook,
)
logger = init_logger(__name__)
# NIXL EP kernels quantize dispatch inputs in 128 element chunks.
NIXL_EP_QUANT_BLOCK_SIZE = 128
NIXL_EP_QUANT_BLOCK_SHAPE = [NIXL_EP_QUANT_BLOCK_SIZE, NIXL_EP_QUANT_BLOCK_SIZE]
def dequant_fp8(
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
) -> torch.Tensor:
"""
Return dequantized tensor in fp32
"""
assert expert_x_fp8.is_contiguous()
expert_x_scales = expert_x_scales.contiguous()
num_experts = expert_x_fp8.size(0)
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
num_experts, -1, NIXL_EP_QUANT_BLOCK_SIZE
)
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
class NixlEPPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using NIXL EP kernels.
"""
# NIXL EP kernels are compiled only for certain specific hidden sizes.
# NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends
# on it.
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 3072, 4096, 5120, 6144, 7168, 8192]
assert sorted(set(SUPPORTED_HIDDEN_SIZES)) == SUPPORTED_HIDDEN_SIZES
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int) -> int:
# Round up hidden size to the closest supported hidden size.
_supported_hs = NixlEPPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES
for x in _supported_hs:
if x >= hidden_size:
return x
raise ValueError(
f"Hidden Size {hidden_size} is greater than the "
f"maximum supported hidden size {_supported_hs[-1]}"
)
def __init__(
self,
buffer: nixl_ep.Buffer,
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False,
global_to_physical: torch.Tensor | None = None,
physical_to_global: torch.Tensor | None = None,
local_expert_global_ids: torch.Tensor | None = None,
):
super().__init__()
self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers
topk_indices_dtype = self.topk_indices_dtype()
def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None or topk_indices_dtype is None:
return tensor
return tensor.to(dtype=topk_indices_dtype)
self.global_to_physical = _maybe_cast(global_to_physical)
self.physical_to_global = _maybe_cast(physical_to_global)
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False
def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit.
return
if self.use_fp8_dispatch:
logger.debug_once(
"Update NixlEPPrepareAndFinalize to do packed ue8m0 scales dispatch."
)
self.use_ue8m0_dispatch = True
else:
logger.warning_once(
"NixlEPPrepareAndFinalize is setup to dispatch raw/unquantized "
f"activations despite ({fused_experts.__class__.__name__}) being able "
"to support quantized activations.",
scope="local",
)
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> int | None:
return self.max_tokens_per_rank
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
if self.global_to_physical is None:
return topk_ids
return self.global_to_physical[topk_ids]
def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
if self.local_expert_global_ids is None:
return expert_topk_ids
return self.local_expert_global_ids[expert_topk_ids]
def _do_quant(
self,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.use_fp8_dispatch:
block_k = (
quant_config.block_shape[1]
if quant_config.block_shape is not None
else None
)
if block_k == NIXL_EP_QUANT_BLOCK_SIZE:
# NIXL EP kernels did the quantization for us.
x, x_scales = x
return x, x_scales
# Dequant to get back the tokens in the datatype we dispatched in.
x_fp8, x_scales = x
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
assert isinstance(x, torch.Tensor)
num_experts, max_tokens, hidden_dim = x.size()
x = x.view((-1, hidden_dim))
q_dtype = quant_config.quant_dtype
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
logger.info_once(
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
"for ModelOptNvFp4FusedMoE."
)
q_dtype = None
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
if q_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
return x, x_scales
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> tuple[Callable, mk.ReceiverType]:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
f"Hidden Size {hidden_size} not in supported list of hidden sizes"
f"{self.SUPPORTED_HIDDEN_SIZES}"
)
a2a_idx = dbo_current_ubatch_id()
if self.use_fp8_dispatch:
assert hidden_size % 128 == 0, (
"NIXL EP kernels quantize the inputs in blocks of shape 128"
)
has_per_token_scales = (
quant_config.a1_scale.numel() != 1
if quant_config.a1_scale is not None
else (
quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None
else False
)
)
assert not has_per_token_scales, (
"NIXL EP kernels don't support dispatching per-token scales"
)
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
# round_scale needs to be set to dispatch in ue8m0
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
async_finish=False,
return_recv_hook=True,
)
self.handles[a2a_idx] = handle
return (
hook,
lambda: self._receiver(
expert_x,
expert_num_tokens,
quant_config.a1_scale,
a1.dtype,
quant_config,
),
)
def _receiver(
self,
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
expert_num_tokens: torch.Tensor,
a1_scale: torch.Tensor | None,
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if defer_input_quant:
raise NotImplementedError(
f"{self.__class__.__name__} does not support defer_input_quant=True. "
"Please select an MoE kernel that accepts quantized inputs."
)
hook, receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
hook()
return receiver()
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
do_async: bool,
) -> tuple[Callable, Callable]:
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
"Weight application and reduction happens in the combine kernel."
)
a2a_idx = dbo_current_ubatch_id()
do_recv_hook = dbo_enabled() or do_async
handle = self.handles[a2a_idx]
assert handle is not None
combine_topk_weights = topk_weights
if apply_router_weight_on_input:
# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
# TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.combine(
fused_expert_output,
combine_topk_ids,
combine_topk_weights,
handle,
async_finish=False,
zero_copy=False,
return_recv_hook=do_recv_hook,
out=output,
)
return recv_hook, lambda: None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> tuple[Callable, Callable]:
return self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=True,
)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
do_async=False,
)

View File

@@ -234,6 +234,7 @@ class DefaultMoERunner(MoERunner):
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
or self.moe_config.moe_parallel_config.use_nixl_ep_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
def _maybe_setup_shared_experts_stream(

View File

@@ -896,7 +896,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
is_batched_moe = self.moe.use_deepep_ll_kernels
is_batched_moe = (
self.moe.use_deepep_ll_kernels or self.moe.use_nixl_ep_kernels
)
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:

View File

@@ -412,6 +412,11 @@ def has_deep_gemm() -> bool:
return _has_module("deep_gemm")
def has_nixl_ep() -> bool:
"""Whether the optional `nixl_ep` package is available."""
return _has_module("nixl_ep")
def has_triton_kernels() -> bool:
"""Whether the optional `triton_kernels` package is available."""
is_available = _has_module("triton_kernels") or _has_module(

View File

@@ -288,6 +288,7 @@ def make_zmq_socket(
bind: bool | None = None,
identity: bytes | None = None,
linger: int | None = None,
router_handover: bool = False,
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
@@ -314,6 +315,10 @@ def make_zmq_socket(
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
if socket_type == zmq.ROUTER and router_handover:
# Let a new connection take over an identity left behind by a dead one.
socket.setsockopt(zmq.ROUTER_HANDOVER, 1)
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
@@ -344,12 +349,20 @@ def zmq_socket_ctx(
bind: bool | None = None,
linger: int = 0,
identity: bytes | None = None,
router_handover: bool = False,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
yield make_zmq_socket(
ctx,
path,
socket_type,
bind=bind,
identity=identity,
router_handover=router_handover,
)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")

View File

@@ -544,6 +544,11 @@ class MPClient(EngineCoreClient):
try:
# State used for data parallel.
self.engines_running = False
parallel_config = vllm_config.parallel_config
# Elastic EP can remove a rank and later add it back with the same
# identity. The client input ROUTER needs handover to allow the new
# engine to replace the dead connection.
enable_input_socket_handover = parallel_config.enable_elastic_ep
self.stats_update_address: str | None = None
if client_addresses:
@@ -552,7 +557,11 @@ class MPClient(EngineCoreClient):
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True
self.ctx,
input_address,
zmq.ROUTER,
bind=True,
router_handover=enable_input_socket_handover,
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
@@ -561,7 +570,11 @@ class MPClient(EngineCoreClient):
# Engines are managed by this client.
addresses = get_engine_zmq_addresses(vllm_config)
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
self.ctx,
addresses.inputs[0],
zmq.ROUTER,
bind=True,
router_handover=enable_input_socket_handover,
)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
@@ -582,7 +595,6 @@ class MPClient(EngineCoreClient):
coordinator.get_stats_publish_address()
)
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
dp_local_size = parallel_config.data_parallel_size_local