[Kernel] Add FlashInfer MoE A2A Kernel (#36022)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
This commit is contained in:
leo-cf-tian
2026-03-16 02:45:32 -04:00
committed by GitHub
parent 2390d44209
commit 2754231ba3
19 changed files with 417 additions and 43 deletions

View File

@@ -45,7 +45,9 @@ All2AllBackend = Literal[
"mori",
"nixl_ep",
"allgather_reducescatter",
"flashinfer_all2allv",
"flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided
"flashinfer_nvlink_two_sided",
"flashinfer_nvlink_one_sided",
]
@@ -158,7 +160,8 @@ class ParallelConfig:
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
- "nixl_ep": Use nixl-ep kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
max_parallel_loading_workers: int | None = None
"""Maximum number of parallel loading workers when loading model

View File

@@ -4,23 +4,36 @@ import threading
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.flashinfer import (
has_flashinfer_nvlink_one_sided,
has_flashinfer_nvlink_two_sided,
)
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
if has_flashinfer_nvlink_two_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
if has_flashinfer_nvlink_one_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_moe_alltoall import (
MoeAlltoAll, # type: ignore[import-not-found]
moe_a2a_get_workspace_size_per_rank,
)
logger = init_logger(__name__)
@@ -529,9 +542,9 @@ class NixlEPAll2AllManager(All2AllManagerBase):
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.
All2All communication based on flashinfer all2allv/two-sided NVLink kernels.
"""
# This type lint could be removed after all of the work in
@@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
world_size: int
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
assert has_flashinfer_nvlink_two_sided(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group, tcp_store_group)
@@ -597,7 +610,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
if not has_flashinfer_all2all():
if not has_flashinfer_nvlink_two_sided():
return False
if self.world_size <= 1:
@@ -633,6 +646,119 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.initialized = False
class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
"""
All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel.
This is a newer kernel from trtllm that should perform better than the kernel
used by flashinfer_nvlink_two_sided.
"""
rank: int
world_size: int
def __init__(self, cpu_group):
assert has_flashinfer_nvlink_one_sided(), (
"flashinfer trtllm_moe_alltoall module not found. "
"Please install/check flashinfer"
)
super().__init__(cpu_group)
logger.debug(
"Initialize FlashInfer One-sided NVLink rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.initialized = False
self.moe_alltoall: MoeAlltoAll | None = None
self.mapping = None
def initialize(
self,
max_num_tokens: int,
top_k: int,
num_experts: int,
hidden_size: int,
):
"""Initialize the MoeAlltoAll workspace."""
if self.initialized:
return
self.cleanup()
gpus_per_node = torch.accelerator.device_count()
logger.debug(
"Making One-sided NVLink mapping: rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.mapping = Mapping(
self.world_size,
self.rank,
gpus_per_node,
tp_size=self.world_size,
moe_ep_size=self.world_size,
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator,
)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
)
total_dispatch_payload_size_per_token = (
hidden_size // 2 # nvfp4 hidden states
+ hidden_size // 16 # fp8 scaling factors
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
self.workspace_size = moe_a2a_get_workspace_size_per_rank(
ep_size=self.world_size,
max_num_tokens=max_num_tokens,
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
combine_payload_size_per_token=combine_payload_size_per_token,
)
self.moe_alltoall = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_experts=num_experts,
workspace_size_per_rank=self.workspace_size,
mnnvl_config=dp_config,
)
self.gpus_per_node = gpus_per_node
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.initialized = True
logger.info(
"FlashInfer One-sided NVLink initialized for rank %s, size %s",
self.rank,
self.world_size,
)
dist.barrier()
def get_handle(self, kwargs):
return self
def cleanup(self):
"""Clean up resources."""
if self.initialized and self.moe_alltoall is not None:
try:
del self.moe_alltoall
except Exception as e:
logger.warning(
"Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e
)
finally:
self.moe_alltoall = None
self.mapping = None
self.initialized = False
class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (

View File

@@ -149,12 +149,25 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager = NixlEPAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
elif (
self.all2all_backend == "flashinfer_all2allv"
or self.all2all_backend == "flashinfer_nvlink_two_sided"
):
if self.all2all_backend == "flashinfer_all2allv":
logger.warning_once(
"'flashinfer_all2allv' is deprecated and has been renamed to"
"'flashinfer_nvlink_two_sided'. It will be removed in a future"
"release."
)
from .all2all import FlashInferNVLinkTwoSidedManager
self.all2all_manager = FlashInferAllToAllManager(
self.all2all_manager = FlashInferNVLinkTwoSidedManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "flashinfer_nvlink_one_sided":
from .all2all import FlashInferNVLinkOneSidedManager
self.all2all_manager = FlashInferNVLinkOneSidedManager(self.cpu_group)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")

View File

@@ -5,9 +5,9 @@ from typing import Any
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found"
class CustomCommunicator(CommBackend):
@@ -25,14 +25,14 @@ class CustomCommunicator(CommBackend):
dist.all_gather_object(gathered, data, group=self._group)
return gathered
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
# are unimplemented on vLLM side. If we need to utilize these
# methods in the future, can create a concrete implementation.
def bcast(self, data: Any, root: int) -> Any:
raise NotImplementedError
obj_list = [data]
# broadcast_object_list mutates obj_list in-place
dist.broadcast_object_list(obj_list, src=root, group=self._group)
return obj_list[0]
def barrier(self) -> None:
raise NotImplementedError
dist.barrier(group=self._group)
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self

View File

@@ -5,6 +5,7 @@ from typing import Any
import torch
from vllm.config import get_current_vllm_config
from vllm.distributed import (
get_ep_group,
)
@@ -14,8 +15,11 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_one_sided_prepare_finalize import ( # noqa: E501
FlashInferNVLinkOneSidedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.flashinfer_nvlink_two_sided_prepare_finalize import ( # noqa: E501
FlashInferNVLinkTwoSidedPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
@@ -206,9 +210,22 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch=use_fp8_dispatch,
)
elif moe.use_fi_all2allv_kernels:
elif moe.use_fi_nvl_two_sided_kernels:
assert quant_config is not None
prepare_finalize = FlashInferA2APrepareAndFinalize(
prepare_finalize = FlashInferNVLinkTwoSidedPrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_fi_nvl_one_sided_kernels:
assert quant_config is not None
max_num_tokens = (
get_current_vllm_config().scheduler_config.max_num_batched_tokens
)
prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
top_k=moe.experts_per_token,
num_experts=moe.num_experts,
hidden_size=moe.hidden_dim,
num_dispatchers=all2all_manager.world_size,
)

View File

@@ -957,9 +957,17 @@ class FusedMoEParallelConfig:
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@property
def use_fi_all2allv_kernels(self):
def use_fi_nvl_two_sided_kernels(self):
return self.use_all2all_kernels and (
self.all2all_backend == "flashinfer_all2allv"
or self.all2all_backend == "flashinfer_nvlink_two_sided"
)
@property
def use_fi_nvl_one_sided_kernels(self):
return (
self.use_all2all_kernels and self.all2all_backend == "flashinfer_all2allv"
self.use_all2all_kernels
and self.all2all_backend == "flashinfer_nvlink_one_sided"
)
@property
@@ -1240,8 +1248,12 @@ class FusedMoEConfig:
return self.moe_parallel_config.use_mori_kernels
@property
def use_fi_all2allv_kernels(self):
return self.moe_parallel_config.use_fi_all2allv_kernels
def use_fi_nvl_two_sided_kernels(self):
return self.moe_parallel_config.use_fi_nvl_two_sided_kernels
@property
def use_fi_nvl_one_sided_kernels(self):
return self.moe_parallel_config.use_fi_nvl_one_sided_kernels
@property
def use_naive_all2all_kernels(self):

View File

@@ -396,8 +396,9 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
# Note that the BATCHED activation format does not use
# the expert map for identifying experts.
return not (
moe_parallel_config.use_fi_all2allv_kernels
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_deepep_ht_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:

View File

@@ -152,7 +152,10 @@ class DeepGemmExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_ep_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferNVLinkOneSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""FlashInfer implementation using the Moe AlltoAll kernel."""
def __init__(
self,
max_num_tokens: int,
top_k: int,
num_experts: int,
hidden_size: int,
num_dispatchers: int = 1,
):
super().__init__()
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.num_dispatchers_ = num_dispatchers
self.all2all_manager = get_ep_group().device_communicator.all2all_manager
self.all2all_manager.initialize(
max_num_tokens=self.max_num_tokens,
top_k=self.top_k,
num_experts=self.num_experts,
hidden_size=self.hidden_size,
)
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return False
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int32
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
global_num_tokens_cpu = get_local_sizes()
self.runtime_max_tokens_per_rank = (
max(global_num_tokens_cpu)
if global_num_tokens_cpu is not None
else a1.shape[0]
)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
)
payloads = []
payloads.append(a1q)
if a1q_scale is not None:
payloads.append(a1q_scale)
payloads.append(topk_ids)
payloads.append(topk_weights)
recv_payloads = self.all2all_manager.moe_alltoall.dispatch(
token_selected_experts=topk_ids,
input_payloads=payloads,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
)
if a1q_scale is not None:
a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads
# Apply scale interleaving only for CUTLASS (not TRT-LLM)
if (
quant_config.quant_dtype == "nvfp4"
and quant_config.is_nvfp4_scale_swizzled
):
a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1])
a1q_scale_recv = a1q_scale_recv.view(torch.uint8)
a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv)
a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16)
else:
a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads
a1q_scale_recv = None
a1q_recv = a1q_recv.view(-1, a1q_recv.shape[-1])
topk_ids_recv = topk_ids_recv.view(-1, topk_ids_recv.shape[-1])
topk_weights_recv = topk_weights_recv.view(-1, topk_weights_recv.shape[-1])
return a1q_recv, a1q_scale_recv, None, topk_ids_recv, topk_weights_recv
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert self.all2all_manager.moe_alltoall is not None
ep_size = self.all2all_manager.world_size
hidden_size = fused_expert_output.shape[-1]
fused_expert_output = fused_expert_output.view(
ep_size, self.runtime_max_tokens_per_rank, hidden_size
)
combined_output = self.all2all_manager.moe_alltoall.combine(
payload=fused_expert_output,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
)
output.copy_(combined_output)

View File

@@ -18,7 +18,7 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
class FlashInferNVLinkTwoSidedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(

View File

@@ -600,7 +600,10 @@ class MarlinExpertsBase(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
@property
def quant_type_id(self) -> int:

View File

@@ -1965,7 +1965,10 @@ class TritonExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True

View File

@@ -638,7 +638,7 @@ class FusedMoE(CustomOp):
self.use_overlapped = (
not (
(self.enable_eplb and backend != "allgather_reducescatter")
or self.moe_parallel_config.use_fi_all2allv_kernels
or self.moe_parallel_config.use_fi_nvl_two_sided_kernels
)
and self._shared_experts is not None
)

View File

@@ -332,7 +332,10 @@ class AiterExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not moe_parallel_config.use_fi_all2allv_kernels
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self):
return True

View File

@@ -233,7 +233,7 @@ class DefaultMoERunner(MoERunner):
return (
self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
or self.moe_config.moe_parallel_config.use_fi_nvl_two_sided_kernels
or self.moe_config.moe_parallel_config.use_nixl_ep_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK

View File

@@ -150,7 +150,7 @@ def has_flashinfer_comm() -> bool:
@functools.cache
def has_flashinfer_all2all() -> bool:
def has_flashinfer_nvlink_two_sided() -> bool:
"""Return `True` if FlashInfer mnnvl all2all is available."""
if not has_flashinfer_comm():
return False
@@ -170,6 +170,14 @@ def has_flashinfer_all2all() -> bool:
return True
@functools.cache
def has_flashinfer_nvlink_one_sided() -> bool:
"""Return `True` if FlashInfer trtllm_moe_alltoall module is available."""
if not has_flashinfer_comm():
return False
return importlib.util.find_spec("flashinfer.comm.trtllm_moe_alltoall") is not None
@functools.cache
def has_flashinfer_moe() -> bool:
"""Return `True` if FlashInfer MoE module is available."""
@@ -766,7 +774,8 @@ __all__ = [
"autotune",
"has_flashinfer_moe",
"has_flashinfer_comm",
"has_flashinfer_all2all",
"has_flashinfer_nvlink_two_sided",
"has_flashinfer_nvlink_one_sided",
"has_flashinfer_cutlass_fused_moe",
"has_flashinfer_cutedsl_grouped_gemm_nt_masked",
"has_flashinfer_fp8_blockscale_gemm",