diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index d711b9246..5e08ae35f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -150,6 +150,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch): data_parallel_hybrid_lb=False, data_parallel_external_lb=False, local_engines_only=False, + enable_elastic_ep=False, ) vllm_config = SimpleNamespace(parallel_config=parallel_config) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 10a9cd9a5..fcad56133 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -43,6 +43,7 @@ All2AllBackend = Literal[ "deepep_high_throughput", "deepep_low_latency", "mori", + "nixl_ep", "allgather_reducescatter", "flashinfer_all2allv", ] @@ -156,6 +157,7 @@ class ParallelConfig: - "deepep_high_throughput": Use deepep high-throughput kernels\n - "deepep_low_latency": Use deepep low-latency kernels\n - "mori": Use mori kernels\n + - "nixl_ep": Use nixl-ep kernels\n - "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl""" max_parallel_loading_workers: int | None = None @@ -580,6 +582,7 @@ class ParallelConfig: "deepep_high_throughput", "deepep_low_latency", "mori", + "nixl_ep", ) and self.enable_expert_parallel and self.tensor_parallel_size > 1 diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 97c5faad6..de5c5a79c 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading from typing import Any import torch @@ -413,6 +414,121 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): return 0 +class NixlEPAll2AllManager(All2AllManagerBase): + """ + All2All communication based on NIXL EP kernels. + This backend supports elastic EP with dynamic rank connection/disconnection. + """ + + # (nixl_ep_buffer, ep_size) + _buffer: tuple[Any, int] | None = None + _lock = threading.Lock() + + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) + + self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS + + def _init_buffer( + self, + max_num_tokens_per_dp_rank: int, + token_hidden_size: int, + num_experts_per_rank: int, + ) -> None: + from nixl_ep import Buffer # type: ignore[import-not-found] + + max_num_global_experts = self.max_num_ep_ranks * num_experts_per_rank + num_rdma_bytes = Buffer.get_rdma_size_hint( + num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, + hidden=token_hidden_size, + num_ranks=self.max_num_ep_ranks, + num_experts=max_num_global_experts, + ) + assert NixlEPAll2AllManager._buffer is None, ( + "NIXL EP buffer already initialized" + ) + buffer = Buffer( + rank=self.rank, + tcp_store_group=self.tcp_store_group.store, + ) + buffer.update_memory_buffers( + num_ranks=self.max_num_ep_ranks, + num_experts_per_rank=num_experts_per_rank, + num_rdma_bytes=num_rdma_bytes, + ) + ranks_to_connect = list(range(self.cpu_group.size())) + buffer.connect_ranks(ranks_to_connect) + NixlEPAll2AllManager._buffer = (buffer, self.cpu_group.size()) + + def _update_buffer(self): + assert NixlEPAll2AllManager._buffer is not None + buffer, current_ep_size = NixlEPAll2AllManager._buffer + current_ranks = list(range(current_ep_size)) + new_ep_size = self.cpu_group.size() + buffer.set_tcp_store_group(self.tcp_store_group.store) + if new_ep_size > len(current_ranks): + ranks_to_connect = list(range(len(current_ranks), new_ep_size)) + buffer.connect_ranks(ranks_to_connect) + else: + ranks_to_disconnect = current_ranks[new_ep_size:] + buffer.disconnect_ranks(ranks_to_disconnect) + NixlEPAll2AllManager._buffer = (buffer, new_ep_size) + + def get_handle(self, kwargs): + with NixlEPAll2AllManager._lock: + if ( + NixlEPAll2AllManager._buffer is not None + and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size() + ): + return NixlEPAll2AllManager._buffer[0] + + num_experts_per_rank = ( + kwargs["num_global_experts"] // kwargs["num_ep_ranks"] + ) + nixl_kwargs = dict( + max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"], + token_hidden_size=kwargs["token_hidden_size"], + num_experts_per_rank=num_experts_per_rank, + ) + if NixlEPAll2AllManager._buffer is None: + self._init_buffer(**nixl_kwargs) + else: + self._update_buffer() + + assert NixlEPAll2AllManager._buffer is not None + handle = NixlEPAll2AllManager._buffer[0] + return handle + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): + raise NotImplementedError + + def combine( + self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False + ) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + # NOTE(yongji): NIXLEPAll2AllManager instance is recreated during + # scale-up/down, so we cannot destroy the persistent buffer here. + assert NixlEPAll2AllManager._buffer is not None + buffer = NixlEPAll2AllManager._buffer[0] + buffer.set_tcp_store_group(None) + + # NIXL EP uses RDMA so no SMs are used for communication + def max_sms_used(self) -> int | None: + return 0 + + class FlashInferAllToAllManager(All2AllManagerBase): """ All2All communication based on flashinfer kernels. diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 5e18dbde9..faa3d093a 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -143,6 +143,12 @@ class CudaCommunicator(DeviceCommunicatorBase): from .all2all import MoriAll2AllManager self.all2all_manager = MoriAll2AllManager(self.cpu_group) + elif self.all2all_backend == "nixl_ep": + from .all2all import NixlEPAll2AllManager + + self.all2all_manager = NixlEPAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager diff --git a/vllm/envs.py b/vllm/envs.py index 3b7312a4f..d310e9e13 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -244,6 +244,7 @@ if TYPE_CHECKING: VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False + VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 def get_default_cache_root(): @@ -1628,6 +1629,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool( int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0")) ), + # NIXL EP environment variables + "VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int( + os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "32") + ), } diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 47ca95ee5..4d215645e 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( make_moe_prepare_and_finalize_no_dp_ep, ) from vllm.platforms import current_platform -from vllm.utils.import_utils import has_deep_ep, has_mori +from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep logger = init_logger(__name__) @@ -38,6 +38,11 @@ if current_platform.is_cuda_alike(): ) if has_mori(): from .mori_prepare_finalize import MoriPrepareAndFinalize + if has_nixl_ep(): + from .nixl_ep_prepare_finalize import ( + NIXL_EP_QUANT_BLOCK_SHAPE, + NixlEPPrepareAndFinalize, + ) def maybe_roundup_layer_hidden_size( @@ -69,6 +74,11 @@ def maybe_roundup_layer_hidden_size( hidden_size ) + if moe_parallel_config.use_nixl_ep_kernels: + hidden_size = NixlEPPrepareAndFinalize.maybe_roundup_layer_hidden_size( + hidden_size + ) + return hidden_size @@ -209,4 +219,39 @@ def maybe_make_prepare_finalize( num_dispatchers=all2all_manager.world_size, ) + elif moe.use_nixl_ep_kernels: + assert quant_config is not None + global_to_physical = physical_to_global = local_expert_global_ids = None + if routing_tables is not None: + ( + global_to_physical, + physical_to_global, + local_expert_global_ids, + ) = routing_tables + all_to_all_args = dict( + max_num_tokens_per_dp_rank=moe.max_num_tokens, + token_hidden_size=moe.hidden_dim, + num_ep_ranks=all2all_manager.world_size, + num_global_experts=moe.num_experts, + num_local_experts=moe.num_experts // all2all_manager.world_size, + ) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note: We may want to use FP8 dispatch just to reduce + # data movement. + use_fp8_dispatch = ( + quant_config.quant_dtype == current_platform.fp8_dtype() + and quant_config.block_shape == NIXL_EP_QUANT_BLOCK_SHAPE + ) + + prepare_finalize = NixlEPPrepareAndFinalize( + handle, + max_tokens_per_rank=moe.max_num_tokens, + num_dispatchers=all2all_manager.world_size, + use_fp8_dispatch=use_fp8_dispatch, + global_to_physical=global_to_physical, + physical_to_global=physical_to_global, + local_expert_global_ids=local_expert_global_ids, + ) + return prepare_finalize diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index e0ed9130c..57c787ca6 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -976,6 +976,10 @@ class FusedMoEParallelConfig: def use_mori_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "mori" + @property + def use_nixl_ep_kernels(self): + return self.use_all2all_kernels and self.all2all_backend == "nixl_ep" + @staticmethod def flatten_tp_across_dp_and_pcp( tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int @@ -1242,3 +1246,7 @@ class FusedMoEConfig: @property def use_naive_all2all_kernels(self): return self.moe_parallel_config.use_naive_all2all_kernels + + @property + def use_nixl_ep_kernels(self): + return self.moe_parallel_config.use_nixl_ep_kernels diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 92b0f0e0d..6b35c18dc 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -177,10 +177,11 @@ def determine_expert_placement_strategy( if ( moe_parallel_config.use_all2all_kernels and not moe_parallel_config.use_deepep_ll_kernels + and not moe_parallel_config.use_nixl_ep_kernels ): logger.warning( "Round-robin expert placement currently only supports " - "the DeepEP low-latency backend, but '%s' was configured. " + "the DeepEP low-latency or NIXL EP backend, but '%s' was configured. " "Falling back to linear expert placement.", moe_parallel_config.all2all_backend, ) @@ -745,10 +746,10 @@ class FusedMoE(CustomOp): self, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None: # Currently routing_tables only needed for round-robin expert placement - # with DeepEP-ll all2all backend. - if ( - self.expert_placement_strategy != "round_robin" - or not self.moe_parallel_config.use_deepep_ll_kernels + # with DeepEP-ll or NIXL EP all2all backends. + if self.expert_placement_strategy != "round_robin" or ( + not self.moe_parallel_config.use_deepep_ll_kernels + and not self.moe_parallel_config.use_nixl_ep_kernels ): return None diff --git a/vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py new file mode 100644 index 000000000..dbc54e2c9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import nixl_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input, + normalize_batched_scales_shape, +) +from vllm.v1.worker.ubatching import ( + dbo_current_ubatch_id, + dbo_enabled, + dbo_maybe_run_recv_hook, +) + +logger = init_logger(__name__) + +# NIXL EP kernels quantize dispatch inputs in 128 element chunks. +NIXL_EP_QUANT_BLOCK_SIZE = 128 +NIXL_EP_QUANT_BLOCK_SHAPE = [NIXL_EP_QUANT_BLOCK_SIZE, NIXL_EP_QUANT_BLOCK_SIZE] + + +def dequant_fp8( + expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor +) -> torch.Tensor: + """ + Return dequantized tensor in fp32 + """ + assert expert_x_fp8.is_contiguous() + expert_x_scales = expert_x_scales.contiguous() + num_experts = expert_x_fp8.size(0) + + expert_x_fp32 = expert_x_fp8.to(torch.float32).view( + num_experts, -1, NIXL_EP_QUANT_BLOCK_SIZE + ) + expert_x_scales = expert_x_scales.view(num_experts, -1, 1) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) + + +class NixlEPPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): + """ + Prepare/Finalize using NIXL EP kernels. + """ + + # NIXL EP kernels are compiled only for certain specific hidden sizes. + # NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends + # on it. + SUPPORTED_HIDDEN_SIZES = [2048, 2560, 3072, 4096, 5120, 6144, 7168, 8192] + assert sorted(set(SUPPORTED_HIDDEN_SIZES)) == SUPPORTED_HIDDEN_SIZES + + @staticmethod + def maybe_roundup_layer_hidden_size(hidden_size: int) -> int: + # Round up hidden size to the closest supported hidden size. + _supported_hs = NixlEPPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES + + for x in _supported_hs: + if x >= hidden_size: + return x + + raise ValueError( + f"Hidden Size {hidden_size} is greater than the " + f"maximum supported hidden size {_supported_hs[-1]}" + ) + + def __init__( + self, + buffer: nixl_ep.Buffer, + max_tokens_per_rank: int, + num_dispatchers: int, + use_fp8_dispatch: bool = False, + global_to_physical: torch.Tensor | None = None, + physical_to_global: torch.Tensor | None = None, + local_expert_global_ids: torch.Tensor | None = None, + ): + super().__init__() + + self.buffer = buffer + self.max_tokens_per_rank = max_tokens_per_rank + self.use_fp8_dispatch = use_fp8_dispatch + # The dispatch function returns a handle that the combine function + # requires. We store the handle here so it is available to the + # combine function. + self.handles: list[tuple | None] = [None, None] + self.num_dispatchers_ = num_dispatchers + + topk_indices_dtype = self.topk_indices_dtype() + + def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None: + if tensor is None or topk_indices_dtype is None: + return tensor + return tensor.to(dtype=topk_indices_dtype) + + self.global_to_physical = _maybe_cast(global_to_physical) + self.physical_to_global = _maybe_cast(physical_to_global) + self.local_expert_global_ids = _maybe_cast(local_expert_global_ids) + + # We don't have enough information to determine if we should dispatch + # activation scales in a packed ue8m0 format during object construction + # time. This setting is handled by post_init_setup. + self.use_ue8m0_dispatch = False + + def post_init_setup(self, fused_experts: mk.FusedMoEExperts): + if not fused_experts.supports_packed_ue8m0_act_scales(): + # Early exit. + return + + if self.use_fp8_dispatch: + logger.debug_once( + "Update NixlEPPrepareAndFinalize to do packed ue8m0 scales dispatch." + ) + self.use_ue8m0_dispatch = True + else: + logger.warning_once( + "NixlEPPrepareAndFinalize is setup to dispatch raw/unquantized " + f"activations despite ({fused_experts.__class__.__name__}) being able " + "to support quantized activations.", + scope="local", + ) + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def output_is_reduced(self) -> bool: + return True + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.BatchedExperts + + def max_num_tokens_per_rank(self) -> int | None: + return self.max_tokens_per_rank + + def topk_indices_dtype(self) -> torch.dtype | None: + return torch.int64 + + def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor: + if self.global_to_physical is None: + return topk_ids + return self.global_to_physical[topk_ids] + + def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor: + if self.local_expert_global_ids is None: + return expert_topk_ids + return self.local_expert_global_ids[expert_topk_ids] + + def _do_quant( + self, + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + a1_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if self.use_fp8_dispatch: + block_k = ( + quant_config.block_shape[1] + if quant_config.block_shape is not None + else None + ) + if block_k == NIXL_EP_QUANT_BLOCK_SIZE: + # NIXL EP kernels did the quantization for us. + x, x_scales = x + return x, x_scales + + # Dequant to get back the tokens in the datatype we dispatched in. + x_fp8, x_scales = x + x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) + + assert isinstance(x, torch.Tensor) + + num_experts, max_tokens, hidden_dim = x.size() + + x = x.view((-1, hidden_dim)) + q_dtype = quant_config.quant_dtype + + if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm": + logger.info_once( + "Skip quantization when using FlashInfer CUTEDSL(masked_gemm) " + "for ModelOptNvFp4FusedMoE." + ) + q_dtype = None + + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + q_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + x = x.view((num_experts, -1, hidden_dim)) + + if q_dtype is not None: + assert x_scales is not None + x_scales = normalize_batched_scales_shape(x_scales, num_experts) + + return x, x_scales + + def supports_async(self) -> bool: + return True + + def prepare_async( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> tuple[Callable, mk.ReceiverType]: + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) + + hidden_size = a1.size(1) + assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, ( + f"Hidden Size {hidden_size} not in supported list of hidden sizes" + f"{self.SUPPORTED_HIDDEN_SIZES}" + ) + + a2a_idx = dbo_current_ubatch_id() + + if self.use_fp8_dispatch: + assert hidden_size % 128 == 0, ( + "NIXL EP kernels quantize the inputs in blocks of shape 128" + ) + + has_per_token_scales = ( + quant_config.a1_scale.numel() != 1 + if quant_config.a1_scale is not None + else ( + quant_config.a2_scale.numel() != 1 + if quant_config.a2_scale is not None + else False + ) + ) + assert not has_per_token_scales, ( + "NIXL EP kernels don't support dispatching per-token scales" + ) + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + a1 = a1 * topk_weights.to(a1.dtype) + + # Dispatch + dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) + expert_x, expert_num_tokens, handle, _, hook = self.buffer.dispatch( + a1, + dispatch_topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + # round_scale needs to be set to dispatch in ue8m0 + round_scale=self.use_ue8m0_dispatch, + use_ue8m0=self.use_ue8m0_dispatch, + async_finish=False, + return_recv_hook=True, + ) + self.handles[a2a_idx] = handle + + return ( + hook, + lambda: self._receiver( + expert_x, + expert_num_tokens, + quant_config.a1_scale, + a1.dtype, + quant_config, + ), + ) + + def _receiver( + self, + expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + expert_num_tokens: torch.Tensor, + a1_scale: torch.Tensor | None, + a1_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config) + + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None + ) + + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + if defer_input_quant: + raise NotImplementedError( + f"{self.__class__.__name__} does not support defer_input_quant=True. " + "Please select an MoE kernel that accepts quantized inputs." + ) + hook, receiver = self.prepare_async( + a1, + topk_weights, + topk_ids, + num_experts, + expert_map, + apply_router_weight_on_input, + quant_config, + ) + hook() + return receiver() + + def _finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + do_async: bool, + ) -> tuple[Callable, Callable]: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( + "Weight application and reduction happens in the combine kernel." + ) + + a2a_idx = dbo_current_ubatch_id() + do_recv_hook = dbo_enabled() or do_async + handle = self.handles[a2a_idx] + assert handle is not None + + combine_topk_weights = topk_weights + if apply_router_weight_on_input: + # weights have already been applied. + combine_topk_weights = torch.ones_like(topk_weights) + + combine_topk_ids = self._map_global_to_physical_ids(topk_ids) + # TODO (varun) : Enable zero copy mode + dbo_maybe_run_recv_hook() + _, _, recv_hook = self.buffer.combine( + fused_expert_output, + combine_topk_ids, + combine_topk_weights, + handle, + async_finish=False, + zero_copy=False, + return_recv_hook=do_recv_hook, + out=output, + ) + + return recv_hook, lambda: None + + def finalize_async( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> tuple[Callable, Callable]: + return self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=True, + ) + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + self._finalize( + output, + fused_expert_output, + topk_weights, + topk_ids, + apply_router_weight_on_input, + weight_and_reduce_impl, + do_async=False, + ) diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index db97a5374..d3c950dcb 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -234,6 +234,7 @@ class DefaultMoERunner(MoERunner): self.moe_config.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.moe_parallel_config.use_mori_kernels or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels + or self.moe_config.moe_parallel_config.use_nixl_ep_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK def _maybe_setup_shared_experts_stream( diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 01df2b000..1ad024a6f 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -896,7 +896,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): # batched activation format. As self.fused_experts is not # initialized at this point, we resort to checking the MoE config # directly. - is_batched_moe = self.moe.use_deepep_ll_kernels + is_batched_moe = ( + self.moe.use_deepep_ll_kernels or self.moe.use_nixl_ep_kernels + ) if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 91e724012..e7f966b27 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -412,6 +412,11 @@ def has_deep_gemm() -> bool: return _has_module("deep_gemm") +def has_nixl_ep() -> bool: + """Whether the optional `nixl_ep` package is available.""" + return _has_module("nixl_ep") + + def has_triton_kernels() -> bool: """Whether the optional `triton_kernels` package is available.""" is_available = _has_module("triton_kernels") or _has_module( diff --git a/vllm/utils/network_utils.py b/vllm/utils/network_utils.py index 6ffae768e..6b940c92d 100644 --- a/vllm/utils/network_utils.py +++ b/vllm/utils/network_utils.py @@ -288,6 +288,7 @@ def make_zmq_socket( bind: bool | None = None, identity: bytes | None = None, linger: int | None = None, + router_handover: bool = False, ) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -314,6 +315,10 @@ def make_zmq_socket( socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) + if socket_type == zmq.ROUTER and router_handover: + # Let a new connection take over an identity left behind by a dead one. + socket.setsockopt(zmq.ROUTER_HANDOVER, 1) + if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) @@ -344,12 +349,20 @@ def zmq_socket_ctx( bind: bool | None = None, linger: int = 0, identity: bytes | None = None, + router_handover: bool = False, ) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) + yield make_zmq_socket( + ctx, + path, + socket_type, + bind=bind, + identity=identity, + router_handover=router_handover, + ) except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f199e3b8d..2c0135589 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -544,6 +544,11 @@ class MPClient(EngineCoreClient): try: # State used for data parallel. self.engines_running = False + parallel_config = vllm_config.parallel_config + # Elastic EP can remove a rank and later add it back with the same + # identity. The client input ROUTER needs handover to allow the new + # engine to replace the dead connection. + enable_input_socket_handover = parallel_config.enable_elastic_ep self.stats_update_address: str | None = None if client_addresses: @@ -552,7 +557,11 @@ class MPClient(EngineCoreClient): output_address = client_addresses["output_address"] self.stats_update_address = client_addresses.get("stats_update_address") self.input_socket = self.resources.input_socket = make_zmq_socket( - self.ctx, input_address, zmq.ROUTER, bind=True + self.ctx, + input_address, + zmq.ROUTER, + bind=True, + router_handover=enable_input_socket_handover, ) self.resources.output_socket = make_zmq_socket( self.ctx, output_address, zmq.PULL @@ -561,7 +570,11 @@ class MPClient(EngineCoreClient): # Engines are managed by this client. addresses = get_engine_zmq_addresses(vllm_config) self.input_socket = self.resources.input_socket = make_zmq_socket( - self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True + self.ctx, + addresses.inputs[0], + zmq.ROUTER, + bind=True, + router_handover=enable_input_socket_handover, ) self.resources.output_socket = make_zmq_socket( self.ctx, addresses.outputs[0], zmq.PULL @@ -582,7 +595,6 @@ class MPClient(EngineCoreClient): coordinator.get_stats_publish_address() ) - parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_index dp_local_size = parallel_config.data_parallel_size_local