[Kernel] Add FlashInfer MoE A2A Kernel (#36022)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
This commit is contained in:
leo-cf-tian
2026-03-16 02:45:32 -04:00
committed by GitHub
parent 2390d44209
commit 2754231ba3
19 changed files with 417 additions and 43 deletions

View File

@@ -4,23 +4,36 @@ import threading
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.flashinfer import (
has_flashinfer_nvlink_one_sided,
has_flashinfer_nvlink_two_sided,
)
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
if has_flashinfer_nvlink_two_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)
if has_flashinfer_nvlink_one_sided():
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_moe_alltoall import (
MoeAlltoAll, # type: ignore[import-not-found]
moe_a2a_get_workspace_size_per_rank,
)
logger = init_logger(__name__)
@@ -529,9 +542,9 @@ class NixlEPAll2AllManager(All2AllManagerBase):
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.
All2All communication based on flashinfer all2allv/two-sided NVLink kernels.
"""
# This type lint could be removed after all of the work in
@@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
world_size: int
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
assert has_flashinfer_nvlink_two_sided(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group, tcp_store_group)
@@ -597,7 +610,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
if not has_flashinfer_all2all():
if not has_flashinfer_nvlink_two_sided():
return False
if self.world_size <= 1:
@@ -633,6 +646,119 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.initialized = False
class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
"""
All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel.
This is a newer kernel from trtllm that should perform better than the kernel
used by flashinfer_nvlink_two_sided.
"""
rank: int
world_size: int
def __init__(self, cpu_group):
assert has_flashinfer_nvlink_one_sided(), (
"flashinfer trtllm_moe_alltoall module not found. "
"Please install/check flashinfer"
)
super().__init__(cpu_group)
logger.debug(
"Initialize FlashInfer One-sided NVLink rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.initialized = False
self.moe_alltoall: MoeAlltoAll | None = None
self.mapping = None
def initialize(
self,
max_num_tokens: int,
top_k: int,
num_experts: int,
hidden_size: int,
):
"""Initialize the MoeAlltoAll workspace."""
if self.initialized:
return
self.cleanup()
gpus_per_node = torch.accelerator.device_count()
logger.debug(
"Making One-sided NVLink mapping: rank=%d, world size=%d",
self.rank,
self.world_size,
)
self.mapping = Mapping(
self.world_size,
self.rank,
gpus_per_node,
tp_size=self.world_size,
moe_ep_size=self.world_size,
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator,
)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
)
total_dispatch_payload_size_per_token = (
hidden_size // 2 # nvfp4 hidden states
+ hidden_size // 16 # fp8 scaling factors
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
self.workspace_size = moe_a2a_get_workspace_size_per_rank(
ep_size=self.world_size,
max_num_tokens=max_num_tokens,
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
combine_payload_size_per_token=combine_payload_size_per_token,
)
self.moe_alltoall = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_experts=num_experts,
workspace_size_per_rank=self.workspace_size,
mnnvl_config=dp_config,
)
self.gpus_per_node = gpus_per_node
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.initialized = True
logger.info(
"FlashInfer One-sided NVLink initialized for rank %s, size %s",
self.rank,
self.world_size,
)
dist.barrier()
def get_handle(self, kwargs):
return self
def cleanup(self):
"""Clean up resources."""
if self.initialized and self.moe_alltoall is not None:
try:
del self.moe_alltoall
except Exception as e:
logger.warning(
"Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e
)
finally:
self.moe_alltoall = None
self.mapping = None
self.initialized = False
class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (