Files
vllm/vllm/distributed/device_communicators/mnnvl_compat.py
leo-cf-tian 2754231ba3 [Kernel] Add FlashInfer MoE A2A Kernel (#36022)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Leo Tian <lctian@nvidia.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Stefano Castagnetta <scastagnetta@nvidia.com>
Co-authored-by: root <root@lyris0267.lyris.clusters.nvidia.com>
2026-03-15 23:45:32 -07:00

39 lines
1.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
from vllm.utils.flashinfer import has_flashinfer_nvlink_two_sided
assert has_flashinfer_nvlink_two_sided(), "Flashinfer alltoallv module cannot be found"
class CustomCommunicator(CommBackend):
def __init__(self, group):
self._group = group
def Get_rank(self) -> int:
return self._group.rank()
def Get_size(self) -> int:
return self._group.size()
def allgather(self, data: int):
gathered = [None] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group)
return gathered
def bcast(self, data: Any, root: int) -> Any:
obj_list = [data]
# broadcast_object_list mutates obj_list in-place
dist.broadcast_object_list(obj_list, src=root, group=self._group)
return obj_list[0]
def barrier(self) -> None:
dist.barrier(group=self._group)
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self