Support mnnvl all2allv from Flashinfer (#21003)

Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Shu Wang
2025-09-24 13:38:16 -05:00
committed by GitHub
parent 2dda3e35d0
commit 54e42b72db
10 changed files with 410 additions and 40 deletions

View File

@@ -10,9 +10,15 @@ from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer_all2all
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
logger = init_logger(__name__)
@@ -47,24 +53,22 @@ class NaiveAll2AllManager(All2AllManagerBase):
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
all_hidden_states = self.dp_group.all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
@@ -300,4 +304,95 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
return 0
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.
"""
def __init__(self, cpu_group):
assert has_flashinfer_all2all(
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
super().__init__(cpu_group)
logger.debug(
"Initialize for flashinfer All2All "
"rank=%d, world size=%d", self.rank, self.world_size)
self.initialized = False
self.alltoall_info = None
def initialize(
self,
world_size: int,
rank: int,
gpus_per_node: int,
):
"""Initialize workspace"""
if self.initialized:
return
self.cleanup()
logger.debug("making map: "
"rank=%d, world size=%d", rank, world_size)
self.mapping = Mapping(
world_size,
rank,
gpus_per_node,
tp_size=world_size,
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0 # Auto-detect
)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
self.mapping, dp_config)
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
self.mapping, dp_config)
self.world_size = world_size
self.rank = rank
self.gpus_per_node = gpus_per_node
self.initialized = True
logger.info("FlashInfer All2All initialized for rank %s, size %s",
rank, world_size)
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
if not has_flashinfer_all2all():
return False
if self.world_size <= 1:
return False
if not self.initialized:
self.initialize(
world_size=self.world_size,
rank=self.rank,
gpus_per_node=torch.cuda.device_count,
)
return self.initialized
def get_handle(self, kwargs):
return self
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.workspace_tensor is not None \
and self.prepare_workspace_tensor is not None:
try:
del self.workspace_tensor
del self.prepare_workspace_tensor
except Exception as e:
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
finally:
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False