Support mnnvl all2allv from Flashinfer (#21003)
Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
@@ -10,9 +10,15 @@ from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
from flashinfer.comm import Mapping
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -47,24 +53,22 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@@ -300,4 +304,95 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_all2all(
|
||||
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
self.initialized = False
|
||||
self.alltoall_info = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
gpus_per_node: int,
|
||||
):
|
||||
"""Initialize workspace"""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
logger.debug("making map: "
|
||||
"rank=%d, world size=%d", rank, world_size)
|
||||
self.mapping = Mapping(
|
||||
world_size,
|
||||
rank,
|
||||
gpus_per_node,
|
||||
tp_size=world_size,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator)
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
fabric_page_size=1 << 29, # 512MB
|
||||
allocation_granularity=0 # Auto-detect
|
||||
)
|
||||
|
||||
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
|
||||
self.mapping, dp_config)
|
||||
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
|
||||
self.mapping, dp_config)
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.initialized = True
|
||||
|
||||
logger.info("FlashInfer All2All initialized for rank %s, size %s",
|
||||
rank, world_size)
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not has_flashinfer_all2all():
|
||||
return False
|
||||
|
||||
if self.world_size <= 1:
|
||||
return False
|
||||
|
||||
if not self.initialized:
|
||||
self.initialize(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
gpus_per_node=torch.cuda.device_count,
|
||||
)
|
||||
return self.initialized
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up workspace"""
|
||||
if self.initialized and self.workspace_tensor is not None \
|
||||
and self.prepare_workspace_tensor is not None:
|
||||
try:
|
||||
del self.workspace_tensor
|
||||
del self.prepare_workspace_tensor
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
finally:
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
Reference in New Issue
Block a user