[Kernel] Add FlashInfer MoE A2A Kernel (#36022)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Signed-off-by: Leo Tian <lctian@nvidia.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com> Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
This commit is contained in:
@@ -4,23 +4,36 @@ import threading
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_nvlink_one_sided,
|
||||
has_flashinfer_nvlink_two_sided,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
if has_flashinfer_nvlink_two_sided():
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_alltoall import (
|
||||
MnnvlMoe, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
if has_flashinfer_nvlink_one_sided():
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_moe_alltoall import (
|
||||
MoeAlltoAll, # type: ignore[import-not-found]
|
||||
moe_a2a_get_workspace_size_per_rank,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -529,9 +542,9 @@ class NixlEPAll2AllManager(All2AllManagerBase):
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
class FlashInferNVLinkTwoSidedManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
All2All communication based on flashinfer all2allv/two-sided NVLink kernels.
|
||||
"""
|
||||
|
||||
# This type lint could be removed after all of the work in
|
||||
@@ -540,7 +553,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
assert has_flashinfer_nvlink_two_sided(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
@@ -597,7 +610,7 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not has_flashinfer_all2all():
|
||||
if not has_flashinfer_nvlink_two_sided():
|
||||
return False
|
||||
|
||||
if self.world_size <= 1:
|
||||
@@ -633,6 +646,119 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on FlashInfer's MoeAlltoAll/One-sided NVLink kernel.
|
||||
This is a newer kernel from trtllm that should perform better than the kernel
|
||||
used by flashinfer_nvlink_two_sided.
|
||||
"""
|
||||
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_nvlink_one_sided(), (
|
||||
"flashinfer trtllm_moe_alltoall module not found. "
|
||||
"Please install/check flashinfer"
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize FlashInfer One-sided NVLink rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.initialized = False
|
||||
self.moe_alltoall: MoeAlltoAll | None = None
|
||||
self.mapping = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
"""Initialize the MoeAlltoAll workspace."""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
gpus_per_node = torch.accelerator.device_count()
|
||||
logger.debug(
|
||||
"Making One-sided NVLink mapping: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.mapping = Mapping(
|
||||
self.world_size,
|
||||
self.rank,
|
||||
gpus_per_node,
|
||||
tp_size=self.world_size,
|
||||
moe_ep_size=self.world_size,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator,
|
||||
)
|
||||
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
)
|
||||
total_dispatch_payload_size_per_token = (
|
||||
hidden_size // 2 # nvfp4 hidden states
|
||||
+ hidden_size // 16 # fp8 scaling factors
|
||||
+ top_k * 4 # int32 topks ids
|
||||
+ top_k * 4 # float32 topk weights
|
||||
)
|
||||
combine_payload_size_per_token = hidden_size * 2 # bf16 hidden states
|
||||
self.workspace_size = moe_a2a_get_workspace_size_per_rank(
|
||||
ep_size=self.world_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
total_dispatch_payload_size_per_token=total_dispatch_payload_size_per_token,
|
||||
combine_payload_size_per_token=combine_payload_size_per_token,
|
||||
)
|
||||
|
||||
self.moe_alltoall = MoeAlltoAll(
|
||||
mapping=self.mapping,
|
||||
max_num_tokens=max_num_tokens,
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
workspace_size_per_rank=self.workspace_size,
|
||||
mnnvl_config=dp_config,
|
||||
)
|
||||
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.initialized = True
|
||||
|
||||
logger.info(
|
||||
"FlashInfer One-sided NVLink initialized for rank %s, size %s",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
dist.barrier()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up resources."""
|
||||
if self.initialized and self.moe_alltoall is not None:
|
||||
try:
|
||||
del self.moe_alltoall
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to cleanup FlashInfer One-sided NVLink workspace: %s", e
|
||||
)
|
||||
finally:
|
||||
self.moe_alltoall = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class MoriAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
assert has_mori(), (
|
||||
|
||||
Reference in New Issue
Block a user