[Kernel] Add FlashInfer MoE A2A Kernel (#36022)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Leo Tian <lctian@nvidia.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
This commit is contained in:
@@ -45,7 +45,9 @@ All2AllBackend = Literal[
|
||||
"mori",
|
||||
"nixl_ep",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
"flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided
|
||||
"flashinfer_nvlink_two_sided",
|
||||
"flashinfer_nvlink_one_sided",
|
||||
]
|
||||
|
||||
|
||||
@@ -158,7 +160,8 @@ class ParallelConfig:
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "mori": Use mori kernels\n
|
||||
- "nixl_ep": Use nixl-ep kernels\n
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
|
||||
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
"""Maximum number of parallel loading workers when loading model
|
||||
|
||||
@@ -4,23 +4,36 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_nvlink_one_sided,
|
||||
has_flashinfer_nvlink_two_sided,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
if has_flashinfer_nvlink_two_sided():
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_alltoall import (
|
||||
MnnvlMoe, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
if has_flashinfer_nvlink_one_sided():
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_moe_alltoall import (
|
||||
MoeAlltoAll, # type: ignore[import-not-found]
|
||||
moe_a2a_get_workspace_size_per_rank,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -529,9 +542,9 @@ class NixlEPAll2AllManager(All2AllManagerBase):
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
All2All communication based on flashinfer all2allv/two-sided NVLink kernels.
|
||||
"""
|
||||
|
||||
# This type lint could be removed after all of the work in
|
||||
@@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
assert has_flashinfer_nvlink_two_sided(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
@@ -597,7 +610,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not has_flashinfer_all2all():
|
||||
if not has_flashinfer_nvlink_two_sided():
|
||||
return False
|
||||
|
||||
if self.world_size <= 1:
|
||||
@@ -633,6 +646,119 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel.
|
||||
This is a newer kernel from trtllm that should perform better than the kernel
|
||||
used by flashinfer_nvlink_two_sided.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_nvlink_one_sided(), (
|
||||
"flashinfer trtllm_moe_alltoall module not found. "
|
||||
"Please install/check flashinfer"
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize FlashInfer One-sided NVLink rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.initialized = False
|
||||
self.moe_alltoall: MoeAlltoAll | None = None
|
||||
self.mapping = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
"""Initialize the MoeAlltoAll workspace."""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
gpus_per_node = torch.accelerator.device_count()
|
||||
logger.debug(
|
||||
"Making One-sided NVLink mapping: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.mapping = Mapping(
|
||||
self.world_size,
|
||||
self.rank,
|
||||
gpus_per_node,
|
||||
tp_size=self.world_size,
|
||||
moe_ep_size=self.world_size,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator,
|
||||
)
|
||||
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
)
|
||||
total_dispatch_payload_size_per_token = (
|
||||
hidden_size // 2 # nvfp4 hidden states
|
||||
+ hidden_size // 16 # fp8 scaling factors
|
||||
+ top_k * 4 # int32 topks ids
|
||||
+ top_k * 4 # float32 topk weights
|
||||
)
|
||||
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
|
||||
self.workspace_size = moe_a2a_get_workspace_size_per_rank(
|
||||
ep_size=self.world_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
|
||||
combine_payload_size_per_token=combine_payload_size_per_token,
|
||||
)
|
||||
|
||||
self.moe_alltoall = MoeAlltoAll(
|
||||
mapping=self.mapping,
|
||||
max_num_tokens=max_num_tokens,
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
workspace_size_per_rank=self.workspace_size,
|
||||
mnnvl_config=dp_config,
|
||||
)
|
||||
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.initialized = True
|
||||
|
||||
logger.info(
|
||||
"FlashInfer One-sided NVLink initialized for rank %s, size %s",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
if self.initialized and self.moe_alltoall is not None:
|
||||
try:
|
||||
del self.moe_alltoall
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e
|
||||
)
|
||||
finally:
|
||||
self.moe_alltoall = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class MoriAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
assert has_mori(), (
|
||||
|
||||
@@ -149,12 +149,25 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager = NixlEPAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
elif (
|
||||
self.all2all_backend == "flashinfer_all2allv"
|
||||
or self.all2all_backend == "flashinfer_nvlink_two_sided"
|
||||
):
|
||||
if self.all2all_backend == "flashinfer_all2allv":
|
||||
logger.warning_once(
|
||||
"'flashinfer_all2allv' is deprecated and has been renamed to"
|
||||
"'flashinfer_nvlink_two_sided'. It will be removed in a future"
|
||||
"release."
|
||||
)
|
||||
from .all2all import FlashInferNVLinkTwoSidedManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.all2all_manager = FlashInferNVLinkTwoSidedManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "flashinfer_nvlink_one_sided":
|
||||
from .all2all import FlashInferNVLinkOneSidedManager
|
||||
|
||||
self.all2all_manager = FlashInferNVLinkOneSidedManager(self.cpu_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ from typing import Any
|
||||
import torch.distributed as dist
|
||||
from flashinfer.comm.mnnvl import CommBackend as CommBackend
|
||||
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided
|
||||
|
||||
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
|
||||
assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found"
|
||||
|
||||
|
||||
class CustomCommunicator(CommBackend):
|
||||
@@ -25,14 +25,14 @@ class CustomCommunicator(CommBackend):
|
||||
dist.all_gather_object(gathered, data, group=self._group)
|
||||
return gathered
|
||||
|
||||
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
|
||||
# are unimplemented on vLLM side. If we need to utilize these
|
||||
# methods in the future, can create a concrete implementation.
|
||||
def bcast(self, data: Any, root: int) -> Any:
|
||||
raise NotImplementedError
|
||||
obj_list = [data]
|
||||
# broadcast_object_list mutates obj_list in-place
|
||||
dist.broadcast_object_list(obj_list, src=root, group=self._group)
|
||||
return obj_list[0]
|
||||
|
||||
def barrier(self) -> None:
|
||||
raise NotImplementedError
|
||||
dist.barrier(group=self._group)
|
||||
|
||||
def Split(self, color: int, key: int) -> "CustomCommunicator":
|
||||
return self
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
)
|
||||
@@ -14,8 +15,11 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501
|
||||
FlashInferNVLinkOneSidedPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501
|
||||
FlashInferNVLinkTwoSidedPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
@@ -206,9 +210,22 @@ def maybe_make_prepare_finalize(
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
elif moe.use_fi_all2allv_kernels:
|
||||
elif moe.use_fi_nvl_two_sided_kernels:
|
||||
assert quant_config is not None
|
||||
prepare_finalize = FlashInferA2APrepareAndFinalize(
|
||||
prepare_finalize = FlashInferNVLinkTwoSidedPrepareAndFinalize(
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
elif moe.use_fi_nvl_one_sided_kernels:
|
||||
assert quant_config is not None
|
||||
max_num_tokens = (
|
||||
get_current_vllm_config().scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize(
|
||||
max_num_tokens=max_num_tokens,
|
||||
top_k=moe.experts_per_token,
|
||||
num_experts=moe.num_experts,
|
||||
hidden_size=moe.hidden_dim,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -957,9 +957,17 @@ class FusedMoEParallelConfig:
|
||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||
|
||||
@property
|
||||
def use_fi_all2allv_kernels(self):
|
||||
def use_fi_nvl_two_sided_kernels(self):
|
||||
return self.use_all2all_kernels and (
|
||||
self.all2all_backend == "flashinfer_all2allv"
|
||||
or self.all2all_backend == "flashinfer_nvlink_two_sided"
|
||||
)
|
||||
|
||||
@property
|
||||
def use_fi_nvl_one_sided_kernels(self):
|
||||
return (
|
||||
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
|
||||
self.use_all2all_kernels
|
||||
and self.all2all_backend == "flashinfer_nvlink_one_sided"
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1240,8 +1248,12 @@ class FusedMoEConfig:
|
||||
return self.moe_parallel_config.use_mori_kernels
|
||||
|
||||
@property
|
||||
def use_fi_all2allv_kernels(self):
|
||||
return self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
def use_fi_nvl_two_sided_kernels(self):
|
||||
return self.moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
|
||||
@property
|
||||
def use_fi_nvl_one_sided_kernels(self):
|
||||
return self.moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
|
||||
@property
|
||||
def use_naive_all2all_kernels(self):
|
||||
|
||||
@@ -396,8 +396,9 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
# Note that the BATCHED activation format does not use
|
||||
# the expert map for identifying experts.
|
||||
return not (
|
||||
moe_parallel_config.use_fi_all2allv_kernels
|
||||
moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or moe_parallel_config.use_deepep_ht_kernels
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
|
||||
@@ -152,7 +152,10 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
return not (
|
||||
moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||
|
||||
|
||||
def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""FlashInfer implementation using the Moe AlltoAll kernel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
self.all2all_manager.initialize(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
top_k=self.top_k,
|
||||
num_experts=self.num_experts,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> int | None:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def output_is_reduced(self) -> bool:
|
||||
return False
|
||||
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.int32
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
defer_input_quant: bool = False,
|
||||
) -> mk.PrepareResultType:
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
)
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
global_num_tokens_cpu = get_local_sizes()
|
||||
self.runtime_max_tokens_per_rank = (
|
||||
max(global_num_tokens_cpu)
|
||||
if global_num_tokens_cpu is not None
|
||||
else a1.shape[0]
|
||||
)
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
||||
)
|
||||
|
||||
payloads = []
|
||||
payloads.append(a1q)
|
||||
if a1q_scale is not None:
|
||||
payloads.append(a1q_scale)
|
||||
payloads.append(topk_ids)
|
||||
payloads.append(topk_weights)
|
||||
|
||||
recv_payloads = self.all2all_manager.moe_alltoall.dispatch(
|
||||
token_selected_experts=topk_ids,
|
||||
input_payloads=payloads,
|
||||
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
|
||||
)
|
||||
if a1q_scale is not None:
|
||||
a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads
|
||||
# Apply scale interleaving only for CUTLASS (not TRT-LLM)
|
||||
if (
|
||||
quant_config.quant_dtype == "nvfp4"
|
||||
and quant_config.is_nvfp4_scale_swizzled
|
||||
):
|
||||
a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1])
|
||||
a1q_scale_recv = a1q_scale_recv.view(torch.uint8)
|
||||
a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv)
|
||||
a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16)
|
||||
else:
|
||||
a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads
|
||||
a1q_scale_recv = None
|
||||
a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1])
|
||||
topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1])
|
||||
topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1])
|
||||
|
||||
return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
assert self.all2all_manager.moe_alltoall is not None
|
||||
|
||||
ep_size = self.all2all_manager.world_size
|
||||
hidden_size = fused_expert_output.shape[-1]
|
||||
fused_expert_output = fused_expert_output.view(
|
||||
ep_size, self.runtime_max_tokens_per_rank, hidden_size
|
||||
)
|
||||
|
||||
combined_output = self.all2all_manager.moe_alltoall.combine(
|
||||
payload=fused_expert_output,
|
||||
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
|
||||
)
|
||||
output.copy_(combined_output)
|
||||
@@ -18,7 +18,7 @@ def get_local_sizes():
|
||||
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
|
||||
|
||||
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
@@ -600,7 +600,10 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
return not (
|
||||
moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
@property
|
||||
def quant_type_id(self) -> int:
|
||||
|
||||
@@ -1965,7 +1965,10 @@ class TritonExperts(mk.FusedMoEExpertsModular):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
return not (
|
||||
moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -638,7 +638,7 @@ class FusedMoE(CustomOp):
|
||||
self.use_overlapped = (
|
||||
not (
|
||||
(self.enable_eplb and backend != "allgather_reducescatter")
|
||||
or self.moe_parallel_config.use_fi_all2allv_kernels
|
||||
or self.moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
)
|
||||
and self._shared_experts is not None
|
||||
)
|
||||
|
||||
@@ -332,7 +332,10 @@ class AiterExperts(mk.FusedMoEExpertsModular):
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return not moe_parallel_config.use_fi_all2allv_kernels
|
||||
return not (
|
||||
moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or moe_parallel_config.use_fi_nvl_one_sided_kernels
|
||||
)
|
||||
|
||||
def supports_expert_map(self):
|
||||
return True
|
||||
|
||||
@@ -233,7 +233,7 @@ class DefaultMoERunner(MoERunner):
|
||||
return (
|
||||
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.moe_parallel_config.use_mori_kernels
|
||||
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
|
||||
or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
|
||||
or self.moe_config.moe_parallel_config.use_nixl_ep_kernels
|
||||
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def has_flashinfer_comm() -> bool:
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_all2all() -> bool:
|
||||
def has_flashinfer_nvlink_two_sided() -> bool:
|
||||
"""Return `True` if FlashInfer mnnvl all2all is available."""
|
||||
if not has_flashinfer_comm():
|
||||
return False
|
||||
@@ -170,6 +170,14 @@ def has_flashinfer_all2all() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_nvlink_one_sided() -> bool:
|
||||
"""Return `True` if FlashInfer trtllm_moe_alltoall module is available."""
|
||||
if not has_flashinfer_comm():
|
||||
return False
|
||||
return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_moe() -> bool:
|
||||
"""Return `True` if FlashInfer MoE module is available."""
|
||||
@@ -766,7 +774,8 @@ __all__ = [
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_comm",
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_nvlink_two_sided",
|
||||
"has_flashinfer_nvlink_one_sided",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
|
||||
"has_flashinfer_fp8_blockscale_gemm",
|
||||
|
||||
Reference in New Issue
Block a user