[AMD][ROCm] MoRI EP: a high-performance all2all backend (#28664)

Signed-off-by: Alex Sun <alex.s@amd.com>
This commit is contained in:
Alex Sun
2026-01-22 16:33:18 +08:00
committed by GitHub
parent 2b8a38b6d6
commit 49a1262267
16 changed files with 397 additions and 9 deletions

View File

@@ -10,7 +10,7 @@ from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from .base_device_communicator import All2AllManagerBase, Cache
@@ -507,3 +507,96 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False
class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
" to install MoRI kernels."
) # noqa
import mori
super().__init__(cpu_group)
self.handle_cache = Cache()
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
mori.shmem.shmem_torch_process_group_init("mori")
def _make_all2all_kwargs(
self,
rank: int,
num_ep_ranks: int,
input_dtype: torch.dtype,
quant_dtype: torch.dtype,
token_hidden_size: int,
scale_dim: int,
scale_type_size: int,
max_num_tokens_per_dp_rank: int,
num_local_experts: int,
num_experts_per_token: int,
):
import mori # type: ignore[import-not-found]
from vllm.platforms.rocm import on_gfx942, on_gfx950
assert on_gfx942() or on_gfx950(), (
"mori currently only support arch gfx942 and gfx950"
)
if not self.internode:
# single node
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
rdma_block_num = 0
warp_num_per_block = 16
block_num = 80
else:
# multi node
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
if on_gfx942():
warp_num_per_block = 16
block_num = 32
rdma_block_num = 16
elif on_gfx950():
warp_num_per_block = 8
block_num = 64
rdma_block_num = 32
else:
raise NotImplementedError(
"mori currently only support arch gfx942 and gfx950"
)
return dict(
rank=rank,
world_size=num_ep_ranks,
data_type=quant_dtype,
hidden_dim=token_hidden_size,
scale_dim=scale_dim,
scale_type_size=scale_type_size,
max_token_type_size=input_dtype.itemsize,
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
num_experts_per_rank=num_local_experts,
num_experts_per_token=num_experts_per_token,
warp_num_per_block=warp_num_per_block,
block_num=block_num,
kernel_type=kernel_type,
rdma_block_num=rdma_block_num,
gpu_per_node=min(8, num_ep_ranks),
)
def _make_handle(self, **kwargs):
import mori # type: ignore[import-not-found]
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
handle = mori.ops.EpDispatchCombineOp(mori_config)
return handle
def get_handle(self, kwargs):
import mori # type: ignore[import-not-found]
mori_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("MoRI all2all args %s", mori_kwargs)
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
mori_kwargs, self._make_handle
)
return handle