[2/N] Elastic EP Milestone 2: Integrating NIXL-EP (#35627)
Signed-off-by: Itay Alroy <ialroy@nvidia.com> Co-authored-by: Yongji Wu <wuyongji317@gmail.com> Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
This commit is contained in:
@@ -150,6 +150,7 @@ def test_mp_client_uses_env_timeout(monkeypatch: pytest.MonkeyPatch):
|
||||
data_parallel_hybrid_lb=False,
|
||||
data_parallel_external_lb=False,
|
||||
local_engines_only=False,
|
||||
enable_elastic_ep=False,
|
||||
)
|
||||
vllm_config = SimpleNamespace(parallel_config=parallel_config)
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ All2AllBackend = Literal[
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"nixl_ep",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
@@ -156,6 +157,7 @@ class ParallelConfig:
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels\n
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "mori": Use mori kernels\n
|
||||
- "nixl_ep": Use nixl-ep kernels\n
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
@@ -580,6 +582,7 @@ class ParallelConfig:
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"nixl_ep",
|
||||
)
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -413,6 +414,121 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
return 0
|
||||
|
||||
|
||||
class NixlEPAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on NIXL EP kernels.
|
||||
This backend supports elastic EP with dynamic rank connection/disconnection.
|
||||
"""
|
||||
|
||||
# (nixl_ep_buffer, ep_size)
|
||||
_buffer: tuple[Any, int] | None = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
self.max_num_ep_ranks = envs.VLLM_NIXL_EP_MAX_NUM_RANKS
|
||||
|
||||
def _init_buffer(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_experts_per_rank: int,
|
||||
) -> None:
|
||||
from nixl_ep import Buffer # type: ignore[import-not-found]
|
||||
|
||||
max_num_global_experts = self.max_num_ep_ranks * num_experts_per_rank
|
||||
num_rdma_bytes = Buffer.get_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=self.max_num_ep_ranks,
|
||||
num_experts=max_num_global_experts,
|
||||
)
|
||||
assert NixlEPAll2AllManager._buffer is None, (
|
||||
"NIXL EP buffer already initialized"
|
||||
)
|
||||
buffer = Buffer(
|
||||
rank=self.rank,
|
||||
tcp_store_group=self.tcp_store_group.store,
|
||||
)
|
||||
buffer.update_memory_buffers(
|
||||
num_ranks=self.max_num_ep_ranks,
|
||||
num_experts_per_rank=num_experts_per_rank,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
)
|
||||
ranks_to_connect = list(range(self.cpu_group.size()))
|
||||
buffer.connect_ranks(ranks_to_connect)
|
||||
NixlEPAll2AllManager._buffer = (buffer, self.cpu_group.size())
|
||||
|
||||
def _update_buffer(self):
|
||||
assert NixlEPAll2AllManager._buffer is not None
|
||||
buffer, current_ep_size = NixlEPAll2AllManager._buffer
|
||||
current_ranks = list(range(current_ep_size))
|
||||
new_ep_size = self.cpu_group.size()
|
||||
buffer.set_tcp_store_group(self.tcp_store_group.store)
|
||||
if new_ep_size > len(current_ranks):
|
||||
ranks_to_connect = list(range(len(current_ranks), new_ep_size))
|
||||
buffer.connect_ranks(ranks_to_connect)
|
||||
else:
|
||||
ranks_to_disconnect = current_ranks[new_ep_size:]
|
||||
buffer.disconnect_ranks(ranks_to_disconnect)
|
||||
NixlEPAll2AllManager._buffer = (buffer, new_ep_size)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
with NixlEPAll2AllManager._lock:
|
||||
if (
|
||||
NixlEPAll2AllManager._buffer is not None
|
||||
and NixlEPAll2AllManager._buffer[1] == self.cpu_group.size()
|
||||
):
|
||||
return NixlEPAll2AllManager._buffer[0]
|
||||
|
||||
num_experts_per_rank = (
|
||||
kwargs["num_global_experts"] // kwargs["num_ep_ranks"]
|
||||
)
|
||||
nixl_kwargs = dict(
|
||||
max_num_tokens_per_dp_rank=kwargs["max_num_tokens_per_dp_rank"],
|
||||
token_hidden_size=kwargs["token_hidden_size"],
|
||||
num_experts_per_rank=num_experts_per_rank,
|
||||
)
|
||||
if NixlEPAll2AllManager._buffer is None:
|
||||
self._init_buffer(**nixl_kwargs)
|
||||
else:
|
||||
self._update_buffer()
|
||||
|
||||
assert NixlEPAll2AllManager._buffer is not None
|
||||
handle = NixlEPAll2AllManager._buffer[0]
|
||||
return handle
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
# NOTE(yongji): NIXLEPAll2AllManager instance is recreated during
|
||||
# scale-up/down, so we cannot destroy the persistent buffer here.
|
||||
assert NixlEPAll2AllManager._buffer is not None
|
||||
buffer = NixlEPAll2AllManager._buffer[0]
|
||||
buffer.set_tcp_store_group(None)
|
||||
|
||||
# NIXL EP uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> int | None:
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
|
||||
@@ -143,6 +143,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "nixl_ep":
|
||||
from .all2all import NixlEPAll2AllManager
|
||||
|
||||
self.all2all_manager = NixlEPAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
|
||||
@@ -244,6 +244,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
|
||||
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
|
||||
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False
|
||||
VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -1628,6 +1629,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool(
|
||||
int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0"))
|
||||
),
|
||||
# NIXL EP environment variables
|
||||
"VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int(
|
||||
os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "32")
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
make_moe_prepare_and_finalize_no_dp_ep,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_nixl_ep
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -38,6 +38,11 @@ if current_platform.is_cuda_alike():
|
||||
)
|
||||
if has_mori():
|
||||
from .mori_prepare_finalize import MoriPrepareAndFinalize
|
||||
if has_nixl_ep():
|
||||
from .nixl_ep_prepare_finalize import (
|
||||
NIXL_EP_QUANT_BLOCK_SHAPE,
|
||||
NixlEPPrepareAndFinalize,
|
||||
)
|
||||
|
||||
|
||||
def maybe_roundup_layer_hidden_size(
|
||||
@@ -69,6 +74,11 @@ def maybe_roundup_layer_hidden_size(
|
||||
hidden_size
|
||||
)
|
||||
|
||||
if moe_parallel_config.use_nixl_ep_kernels:
|
||||
hidden_size = NixlEPPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size
|
||||
)
|
||||
|
||||
return hidden_size
|
||||
|
||||
|
||||
@@ -209,4 +219,39 @@ def maybe_make_prepare_finalize(
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
elif moe.use_nixl_ep_kernels:
|
||||
assert quant_config is not None
|
||||
global_to_physical = physical_to_global = local_expert_global_ids = None
|
||||
if routing_tables is not None:
|
||||
(
|
||||
global_to_physical,
|
||||
physical_to_global,
|
||||
local_expert_global_ids,
|
||||
) = routing_tables
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
num_global_experts=moe.num_experts,
|
||||
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||
)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||
and quant_config.block_shape == NIXL_EP_QUANT_BLOCK_SHAPE
|
||||
)
|
||||
|
||||
prepare_finalize = NixlEPPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
global_to_physical=global_to_physical,
|
||||
physical_to_global=physical_to_global,
|
||||
local_expert_global_ids=local_expert_global_ids,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
|
||||
@@ -976,6 +976,10 @@ class FusedMoEParallelConfig:
|
||||
def use_mori_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "mori"
|
||||
|
||||
@property
|
||||
def use_nixl_ep_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "nixl_ep"
|
||||
|
||||
@staticmethod
|
||||
def flatten_tp_across_dp_and_pcp(
|
||||
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
|
||||
@@ -1242,3 +1246,7 @@ class FusedMoEConfig:
|
||||
@property
|
||||
def use_naive_all2all_kernels(self):
|
||||
return self.moe_parallel_config.use_naive_all2all_kernels
|
||||
|
||||
@property
|
||||
def use_nixl_ep_kernels(self):
|
||||
return self.moe_parallel_config.use_nixl_ep_kernels
|
||||
|
||||
@@ -177,10 +177,11 @@ def determine_expert_placement_strategy(
|
||||
if (
|
||||
moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.use_deepep_ll_kernels
|
||||
and not moe_parallel_config.use_nixl_ep_kernels
|
||||
):
|
||||
logger.warning(
|
||||
"Round-robin expert placement currently only supports "
|
||||
"the DeepEP low-latency backend, but '%s' was configured. "
|
||||
"the DeepEP low-latency or NIXL EP backend, but '%s' was configured. "
|
||||
"Falling back to linear expert placement.",
|
||||
moe_parallel_config.all2all_backend,
|
||||
)
|
||||
@@ -745,10 +746,10 @@ class FusedMoE(CustomOp):
|
||||
self,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
||||
# Currently routing_tables only needed for round-robin expert placement
|
||||
# with DeepEP-ll all2all backend.
|
||||
if (
|
||||
self.expert_placement_strategy != "round_robin"
|
||||
or not self.moe_parallel_config.use_deepep_ll_kernels
|
||||
# with DeepEP-ll or NIXL EP all2all backends.
|
||||
if self.expert_placement_strategy != "round_robin" or (
|
||||
not self.moe_parallel_config.use_deepep_ll_kernels
|
||||
and not self.moe_parallel_config.use_nixl_ep_kernels
|
||||
):
|
||||
return None
|
||||
|
||||
|
||||
406
vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py
Normal file
406
vllm/model_executor/layers/fused_moe/nixl_ep_prepare_finalize.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import nixl_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input,
|
||||
normalize_batched_scales_shape,
|
||||
)
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NIXL EP kernels quantize dispatch inputs in 128 element chunks.
|
||||
NIXL_EP_QUANT_BLOCK_SIZE = 128
|
||||
NIXL_EP_QUANT_BLOCK_SHAPE = [NIXL_EP_QUANT_BLOCK_SIZE, NIXL_EP_QUANT_BLOCK_SIZE]
|
||||
|
||||
|
||||
def dequant_fp8(
|
||||
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Return dequantized tensor in fp32
|
||||
"""
|
||||
assert expert_x_fp8.is_contiguous()
|
||||
expert_x_scales = expert_x_scales.contiguous()
|
||||
num_experts = expert_x_fp8.size(0)
|
||||
|
||||
expert_x_fp32 = expert_x_fp8.to(torch.float32).view(
|
||||
num_experts, -1, NIXL_EP_QUANT_BLOCK_SIZE
|
||||
)
|
||||
expert_x_scales = expert_x_scales.view(num_experts, -1, 1)
|
||||
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
|
||||
|
||||
|
||||
class NixlEPPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""
|
||||
Prepare/Finalize using NIXL EP kernels.
|
||||
"""
|
||||
|
||||
# NIXL EP kernels are compiled only for certain specific hidden sizes.
|
||||
# NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends
|
||||
# on it.
|
||||
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 3072, 4096, 5120, 6144, 7168, 8192]
|
||||
assert sorted(set(SUPPORTED_HIDDEN_SIZES)) == SUPPORTED_HIDDEN_SIZES
|
||||
|
||||
@staticmethod
|
||||
def maybe_roundup_layer_hidden_size(hidden_size: int) -> int:
|
||||
# Round up hidden size to the closest supported hidden size.
|
||||
_supported_hs = NixlEPPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES
|
||||
|
||||
for x in _supported_hs:
|
||||
if x >= hidden_size:
|
||||
return x
|
||||
|
||||
raise ValueError(
|
||||
f"Hidden Size {hidden_size} is greater than the "
|
||||
f"maximum supported hidden size {_supported_hs[-1]}"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
buffer: nixl_ep.Buffer,
|
||||
max_tokens_per_rank: int,
|
||||
num_dispatchers: int,
|
||||
use_fp8_dispatch: bool = False,
|
||||
global_to_physical: torch.Tensor | None = None,
|
||||
physical_to_global: torch.Tensor | None = None,
|
||||
local_expert_global_ids: torch.Tensor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.buffer = buffer
|
||||
self.max_tokens_per_rank = max_tokens_per_rank
|
||||
self.use_fp8_dispatch = use_fp8_dispatch
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handles: list[tuple | None] = [None, None]
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
topk_indices_dtype = self.topk_indices_dtype()
|
||||
|
||||
def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if tensor is None or topk_indices_dtype is None:
|
||||
return tensor
|
||||
return tensor.to(dtype=topk_indices_dtype)
|
||||
|
||||
self.global_to_physical = _maybe_cast(global_to_physical)
|
||||
self.physical_to_global = _maybe_cast(physical_to_global)
|
||||
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
|
||||
|
||||
# We don't have enough information to determine if we should dispatch
|
||||
# activation scales in a packed ue8m0 format during object construction
|
||||
# time. This setting is handled by post_init_setup.
|
||||
self.use_ue8m0_dispatch = False
|
||||
|
||||
def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
|
||||
if not fused_experts.supports_packed_ue8m0_act_scales():
|
||||
# Early exit.
|
||||
return
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
logger.debug_once(
|
||||
"Update NixlEPPrepareAndFinalize to do packed ue8m0 scales dispatch."
|
||||
)
|
||||
self.use_ue8m0_dispatch = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"NixlEPPrepareAndFinalize is setup to dispatch raw/unquantized "
|
||||
f"activations despite ({fused_experts.__class__.__name__}) being able "
|
||||
"to support quantized activations.",
|
||||
scope="local",
|
||||
)
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return self.max_tokens_per_rank
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.int64
|
||||
|
||||
def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_to_physical is None:
|
||||
return topk_ids
|
||||
return self.global_to_physical[topk_ids]
|
||||
|
||||
def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.local_expert_global_ids is None:
|
||||
return expert_topk_ids
|
||||
return self.local_expert_global_ids[expert_topk_ids]
|
||||
|
||||
def _do_quant(
|
||||
self,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if self.use_fp8_dispatch:
|
||||
block_k = (
|
||||
quant_config.block_shape[1]
|
||||
if quant_config.block_shape is not None
|
||||
else None
|
||||
)
|
||||
if block_k == NIXL_EP_QUANT_BLOCK_SIZE:
|
||||
# NIXL EP kernels did the quantization for us.
|
||||
x, x_scales = x
|
||||
return x, x_scales
|
||||
|
||||
# Dequant to get back the tokens in the datatype we dispatched in.
|
||||
x_fp8, x_scales = x
|
||||
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
|
||||
|
||||
assert isinstance(x, torch.Tensor)
|
||||
|
||||
num_experts, max_tokens, hidden_dim = x.size()
|
||||
|
||||
x = x.view((-1, hidden_dim))
|
||||
q_dtype = quant_config.quant_dtype
|
||||
|
||||
if envs.VLLM_FLASHINFER_MOE_BACKEND == "masked_gemm":
|
||||
logger.info_once(
|
||||
"Skip quantization when using FlashInfer CUTEDSL(masked_gemm) "
|
||||
"for ModelOptNvFp4FusedMoE."
|
||||
)
|
||||
q_dtype = None
|
||||
|
||||
x, x_scales = moe_kernel_quantize_input(
|
||||
x,
|
||||
quant_config.a1_scale,
|
||||
q_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
)
|
||||
x = x.view((num_experts, -1, hidden_dim))
|
||||
|
||||
if q_dtype is not None:
|
||||
assert x_scales is not None
|
||||
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
||||
|
||||
return x, x_scales
|
||||
|
||||
def supports_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> tuple[Callable, mk.ReceiverType]:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
|
||||
hidden_size = a1.size(1)
|
||||
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, (
|
||||
f"Hidden Size {hidden_size} not in supported list of hidden sizes"
|
||||
f"{self.SUPPORTED_HIDDEN_SIZES}"
|
||||
)
|
||||
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
|
||||
if self.use_fp8_dispatch:
|
||||
assert hidden_size % 128 == 0, (
|
||||
"NIXL EP kernels quantize the inputs in blocks of shape 128"
|
||||
)
|
||||
|
||||
has_per_token_scales = (
|
||||
quant_config.a1_scale.numel() != 1
|
||||
if quant_config.a1_scale is not None
|
||||
else (
|
||||
quant_config.a2_scale.numel() != 1
|
||||
if quant_config.a2_scale is not None
|
||||
else False
|
||||
)
|
||||
)
|
||||
assert not has_per_token_scales, (
|
||||
"NIXL EP kernels don't support dispatching per-token scales"
|
||||
)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
expert_x, expert_num_tokens, handle, _, hook = self.buffer.dispatch(
|
||||
a1,
|
||||
dispatch_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
# round_scale needs to be set to dispatch in ue8m0
|
||||
round_scale=self.use_ue8m0_dispatch,
|
||||
use_ue8m0=self.use_ue8m0_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
self.handles[a2a_idx] = handle
|
||||
|
||||
return (
|
||||
hook,
|
||||
lambda: self._receiver(
|
||||
expert_x,
|
||||
expert_num_tokens,
|
||||
quant_config.a1_scale,
|
||||
a1.dtype,
|
||||
quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
expert_num_tokens: torch.Tensor,
|
||||
a1_scale: torch.Tensor | None,
|
||||
a1_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
|
||||
)
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if defer_input_quant:
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not support defer_input_quant=True. "
|
||||
"Please select an MoE kernel that accepts quantized inputs."
|
||||
)
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config,
|
||||
)
|
||||
hook()
|
||||
return receiver()
|
||||
|
||||
def _finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
do_async: bool,
|
||||
) -> tuple[Callable, Callable]:
|
||||
assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
|
||||
"Weight application and reduction happens in the combine kernel."
|
||||
)
|
||||
|
||||
a2a_idx = dbo_current_ubatch_id()
|
||||
do_recv_hook = dbo_enabled() or do_async
|
||||
handle = self.handles[a2a_idx]
|
||||
assert handle is not None
|
||||
|
||||
combine_topk_weights = topk_weights
|
||||
if apply_router_weight_on_input:
|
||||
# weights have already been applied.
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
dbo_maybe_run_recv_hook()
|
||||
_, _, recv_hook = self.buffer.combine(
|
||||
fused_expert_output,
|
||||
combine_topk_ids,
|
||||
combine_topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
zero_copy=False,
|
||||
return_recv_hook=do_recv_hook,
|
||||
out=output,
|
||||
)
|
||||
|
||||
return recv_hook, lambda: None
|
||||
|
||||
def finalize_async(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> tuple[Callable, Callable]:
|
||||
return self._finalize(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
do_async=True,
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
self._finalize(
|
||||
output,
|
||||
fused_expert_output,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
weight_and_reduce_impl,
|
||||
do_async=False,
|
||||
)
|
||||
@@ -234,6 +234,7 @@ class DefaultMoERunner(MoERunner):
|
||||
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.moe_parallel_config.use_mori_kernels
|
||||
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
|
||||
or self.moe_config.moe_parallel_config.use_nixl_ep_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
|
||||
def _maybe_setup_shared_experts_stream(
|
||||
|
||||
@@ -896,7 +896,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
# batched activation format. As self.fused_experts is not
|
||||
# initialized at this point, we resort to checking the MoE config
|
||||
# directly.
|
||||
is_batched_moe = self.moe.use_deepep_ll_kernels
|
||||
is_batched_moe = (
|
||||
self.moe.use_deepep_ll_kernels or self.moe.use_nixl_ep_kernels
|
||||
)
|
||||
if is_batched_moe:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
|
||||
@@ -412,6 +412,11 @@ def has_deep_gemm() -> bool:
|
||||
return _has_module("deep_gemm")
|
||||
|
||||
|
||||
def has_nixl_ep() -> bool:
|
||||
"""Whether the optional `nixl_ep` package is available."""
|
||||
return _has_module("nixl_ep")
|
||||
|
||||
|
||||
def has_triton_kernels() -> bool:
|
||||
"""Whether the optional `triton_kernels` package is available."""
|
||||
is_available = _has_module("triton_kernels") or _has_module(
|
||||
|
||||
@@ -288,6 +288,7 @@ def make_zmq_socket(
|
||||
bind: bool | None = None,
|
||||
identity: bytes | None = None,
|
||||
linger: int | None = None,
|
||||
router_handover: bool = False,
|
||||
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
@@ -314,6 +315,10 @@ def make_zmq_socket(
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
|
||||
if socket_type == zmq.ROUTER and router_handover:
|
||||
# Let a new connection take over an identity left behind by a dead one.
|
||||
socket.setsockopt(zmq.ROUTER_HANDOVER, 1)
|
||||
|
||||
if identity is not None:
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
@@ -344,12 +349,20 @@ def zmq_socket_ctx(
|
||||
bind: bool | None = None,
|
||||
linger: int = 0,
|
||||
identity: bytes | None = None,
|
||||
router_handover: bool = False,
|
||||
) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
|
||||
yield make_zmq_socket(
|
||||
ctx,
|
||||
path,
|
||||
socket_type,
|
||||
bind=bind,
|
||||
identity=identity,
|
||||
router_handover=router_handover,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
|
||||
@@ -544,6 +544,11 @@ class MPClient(EngineCoreClient):
|
||||
try:
|
||||
# State used for data parallel.
|
||||
self.engines_running = False
|
||||
parallel_config = vllm_config.parallel_config
|
||||
# Elastic EP can remove a rank and later add it back with the same
|
||||
# identity. The client input ROUTER needs handover to allow the new
|
||||
# engine to replace the dead connection.
|
||||
enable_input_socket_handover = parallel_config.enable_elastic_ep
|
||||
|
||||
self.stats_update_address: str | None = None
|
||||
if client_addresses:
|
||||
@@ -552,7 +557,11 @@ class MPClient(EngineCoreClient):
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get("stats_update_address")
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True
|
||||
self.ctx,
|
||||
input_address,
|
||||
zmq.ROUTER,
|
||||
bind=True,
|
||||
router_handover=enable_input_socket_handover,
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL
|
||||
@@ -561,7 +570,11 @@ class MPClient(EngineCoreClient):
|
||||
# Engines are managed by this client.
|
||||
addresses = get_engine_zmq_addresses(vllm_config)
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, addresses.inputs[0], zmq.ROUTER, bind=True
|
||||
self.ctx,
|
||||
addresses.inputs[0],
|
||||
zmq.ROUTER,
|
||||
bind=True,
|
||||
router_handover=enable_input_socket_handover,
|
||||
)
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, addresses.outputs[0], zmq.PULL
|
||||
@@ -582,7 +595,6 @@ class MPClient(EngineCoreClient):
|
||||
coordinator.get_stats_publish_address()
|
||||
)
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_index
|
||||
dp_local_size = parallel_config.data_parallel_size_local
|
||||
|
||||
Reference in New Issue
Block a user