[AMD][ROCm] MoRI EP: a high-performance all2all backend (#28664)

Signed-off-by: Alex Sun <alex.s@amd.com>
This commit is contained in:
Alex Sun
2026-01-22 16:33:18 +08:00
committed by GitHub
parent 2b8a38b6d6
commit 49a1262267
16 changed files with 397 additions and 9 deletions

View File

@@ -141,7 +141,7 @@ def make_config(args: argparse.Namespace) -> Config:
quant_config = None
if args.quant_dtype is not None:
quant_config = FusedMoEQuantConfig(
quant_config = FusedMoEQuantConfig.make(
quant_dtype=args.quant_dtype,
per_act_token_quant=args.per_token_quantized_activations,
per_out_ch_quant=args.per_channel_quantized_weights,

View File

@@ -28,7 +28,13 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.import_utils import (
has_aiter,
has_deep_ep,
has_deep_gemm,
has_mori,
has_pplx,
)
from .mk_objects import (
TestMoEQuantConfig,
@@ -211,6 +217,14 @@ class Config:
or info.backend == "deepep_low_latency"
)
def needs_aiter(self):
info = expert_info(self.fused_experts_type)
return info.needs_aiter
def needs_mori(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend == "mori"
def all2all_backend(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend
@@ -278,6 +292,10 @@ class Config:
return False, "Needs DeepGEMM, but DeepGEMM not available."
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
return False, "Needs PPLX, but PPLX not available."
if self.needs_aiter() and not has_aiter(): # noqa: SIM103
return False, "Needs Aiter, but Aiter not available."
if self.needs_mori() and not has_mori(): # noqa: SIM103
return False, "Needs MoRI, but MoRI not available."
return True, None

View File

@@ -37,7 +37,13 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.import_utils import (
has_aiter,
has_deep_ep,
has_deep_gemm,
has_mori,
has_pplx,
)
@dataclass
@@ -66,6 +72,7 @@ class ExpertInfo:
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
needs_aiter: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize, PrepareFinalizeInfo] = {}
@@ -126,6 +133,7 @@ def register_experts(
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
needs_aiter: bool = False,
):
global EXPERT_INFO
global MK_FUSED_EXPERT_TYPES
@@ -139,6 +147,7 @@ def register_experts(
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
needs_aiter,
)
MK_FUSED_EXPERT_TYPES.append(kind)
@@ -218,6 +227,20 @@ if has_deep_ep() and not current_platform.has_device_capability(100):
backend="deepep_low_latency",
)
if has_mori():
from vllm.model_executor.layers.fused_moe.mori_prepare_finalize import (
MoriPrepareAndFinalize,
)
register_prepare_and_finalize(
MoriPrepareAndFinalize,
standard_format,
fp8_types,
blocked_quantization_support=True,
backend="mori",
supports_apply_weight_on_input=False,
)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize,
@@ -261,6 +284,25 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability
)
else:
FlashInferCutlassMoEPrepareAndFinalize = None
FlashInferExperts = None
if has_aiter():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
register_experts(
AiterExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_aiter=True,
)
else:
AiterExperts = None
if has_deep_gemm() and is_deep_gemm_supported():
register_experts(
@@ -316,6 +358,9 @@ if cutlass_fp8_supported():
supports_chunking=False,
supports_expert_map=False,
)
else:
CutlassBatchedExpertsFp8 = None
CutlassExpertsFp8 = None
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp4
@@ -328,6 +373,8 @@ if cutlass_fp4_supported():
supports_chunking=True,
supports_expert_map=False,
)
else:
CutlassExpertsFp4 = None
MK_QUANT_CONFIGS: list[TestMoEQuantConfig | None] = [
None,

View File

@@ -79,6 +79,8 @@ def _rocm_aiter_fused_moe_impl(
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
@@ -100,6 +102,8 @@ def _rocm_aiter_fused_moe_impl(
w2_scale,
a1_scale,
a2_scale,
num_local_tokens=num_local_tokens,
dtype=output_dtype,
)
@@ -117,7 +121,11 @@ def _rocm_aiter_fused_moe_fake(
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
) -> torch.Tensor:
if output_dtype is not None:
return torch.empty_like(hidden_states, dtype=output_dtype)
return torch.empty_like(hidden_states)
@@ -1236,6 +1244,8 @@ class rocm_aiter_ops:
w2_scale: torch.Tensor | None = None,
a1_scale: torch.Tensor | None = None,
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
@@ -1251,6 +1261,8 @@ class rocm_aiter_ops:
w2_scale,
a1_scale,
a2_scale,
num_local_tokens,
output_dtype,
)
@staticmethod

View File

@@ -43,6 +43,7 @@ All2AllBackend = Literal[
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
]
@@ -158,6 +159,7 @@ class ParallelConfig:
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers: int | None = None
@@ -443,6 +445,7 @@ class ParallelConfig:
"naive",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
)
and self.enable_expert_parallel
and self.tensor_parallel_size > 1

View File

@@ -10,7 +10,7 @@ from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from .base_device_communicator import All2AllManagerBase, Cache
@@ -507,3 +507,96 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False
class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
assert has_mori(), (
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
" to install MoRI kernels."
) # noqa
import mori
super().__init__(cpu_group)
self.handle_cache = Cache()
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
mori.shmem.shmem_torch_process_group_init("mori")
def _make_all2all_kwargs(
self,
rank: int,
num_ep_ranks: int,
input_dtype: torch.dtype,
quant_dtype: torch.dtype,
token_hidden_size: int,
scale_dim: int,
scale_type_size: int,
max_num_tokens_per_dp_rank: int,
num_local_experts: int,
num_experts_per_token: int,
):
import mori # type: ignore[import-not-found]
from vllm.platforms.rocm import on_gfx942, on_gfx950
assert on_gfx942() or on_gfx950(), (
"mori currently only support arch gfx942 and gfx950"
)
if not self.internode:
# single node
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
rdma_block_num = 0
warp_num_per_block = 16
block_num = 80
else:
# multi node
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
if on_gfx942():
warp_num_per_block = 16
block_num = 32
rdma_block_num = 16
elif on_gfx950():
warp_num_per_block = 8
block_num = 64
rdma_block_num = 32
else:
raise NotImplementedError(
"mori currently only support arch gfx942 and gfx950"
)
return dict(
rank=rank,
world_size=num_ep_ranks,
data_type=quant_dtype,
hidden_dim=token_hidden_size,
scale_dim=scale_dim,
scale_type_size=scale_type_size,
max_token_type_size=input_dtype.itemsize,
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
num_experts_per_rank=num_local_experts,
num_experts_per_token=num_experts_per_token,
warp_num_per_block=warp_num_per_block,
block_num=block_num,
kernel_type=kernel_type,
rdma_block_num=rdma_block_num,
gpu_per_node=min(8, num_ep_ranks),
)
def _make_handle(self, **kwargs):
import mori # type: ignore[import-not-found]
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
handle = mori.ops.EpDispatchCombineOp(mori_config)
return handle
def get_handle(self, kwargs):
import mori # type: ignore[import-not-found]
mori_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("MoRI all2all args %s", mori_kwargs)
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
mori_kwargs, self._make_handle
)
return handle

View File

@@ -110,6 +110,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager

View File

@@ -187,6 +187,7 @@ if TYPE_CHECKING:
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
] = "allgather_reducescatter"
@@ -1298,6 +1299,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "mori", use MoRI kernels
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
"VLLM_ALL2ALL_BACKEND": env_with_choices(
"VLLM_ALL2ALL_BACKEND",
@@ -1307,6 +1309,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"allgather_reducescatter",
"flashinfer_all2allv",
],

View File

@@ -88,6 +88,9 @@ if HAS_TRITON:
fused_experts,
get_config_file_name,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
fused_topk,
)
@@ -99,6 +102,7 @@ if HAS_TRITON:
)
__all__ += [
"AiterExperts",
"fused_topk",
"fused_experts",
"get_config_file_name",

View File

@@ -16,7 +16,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
if current_platform.is_cuda_alike():
if has_pplx():
@@ -30,6 +30,8 @@ if current_platform.is_cuda_alike():
DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize,
)
if has_mori():
from .mori_prepare_finalize import MoriPrepareAndFinalize
def maybe_roundup_layer_hidden_size(
@@ -169,5 +171,36 @@ def maybe_make_prepare_finalize(
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
)
elif moe.use_mori_kernels:
assert quant_config is not None
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.is_per_act_token or quant_config.is_block_quantized
)
# For PTPC (per token per channel) quant, the scale dim for each token is 1
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
all_to_all_args = dict(
rank=all2all_manager.rank,
num_ep_ranks=all2all_manager.world_size,
quant_dtype=quant_config.quant_dtype,
token_hidden_size=moe.hidden_dim,
scale_dim=scale_dim,
scale_type_size=torch.float32.itemsize,
max_num_tokens_per_dp_rank=moe.max_num_tokens,
input_dtype=moe.in_dtype,
num_local_experts=moe.num_experts // all2all_manager.world_size,
num_experts_per_token=moe.experts_per_token,
)
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = MoriPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
)
return prepare_finalize

View File

@@ -893,6 +893,10 @@ class FusedMoEParallelConfig:
self.all2all_backend in ["naive", "allgather_reducescatter"]
)
@property
def use_mori_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "mori"
@staticmethod
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
@@ -1136,6 +1140,10 @@ class FusedMoEConfig:
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
"""

View File

@@ -570,6 +570,14 @@ class FusedMoE(CustomOp):
self.moe_config_use_flashinfer_cutlass_kernels = (
self.moe_config.use_flashinfer_cutlass_kernels
)
if self.use_mori_kernels:
assert self.rocm_aiter_fmoe_enabled, (
"Mori needs to be used with aiter fused_moe for now."
)
assert not self.aiter_fmoe_shared_expert_enabled, (
"Mori does not support fusion shared expert now. "
"Turn it off by setting VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS=0"
)
self.quant_config = quant_config
@@ -712,6 +720,10 @@ class FusedMoE(CustomOp):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_mori_kernels(self):
return self.moe_parallel_config.use_mori_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return (
@@ -729,6 +741,7 @@ class FusedMoE(CustomOp):
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK

View File

@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import mori
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.platforms import current_platform
logger = init_logger(__name__)
class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using MoRI kernels.
"""
def __init__(
self,
mori_op: mori.ops.EpDispatchCombineOp,
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False,
):
super().__init__()
self.mori_op = mori_op
self.num_dispatchers_ = num_dispatchers
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def output_is_reduced(self) -> bool:
return True
def num_dispatchers(self):
return self.num_dispatchers_
def max_num_tokens_per_rank(self) -> int | None:
return self.max_tokens_per_rank
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int32
def supports_async(self) -> bool:
return False
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
"""
Returns a tuple of:
- quantized + dispatched a.
- Optional quantized + dispatched a1_scales.
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
"""
assert not apply_router_weight_on_input, (
"mori does not support apply_router_weight_on_input=True now."
)
scale = None
if self.use_fp8_dispatch:
from aiter import QuantType, get_hip_quant
if quant_config.is_block_quantized:
quant_func = get_hip_quant(QuantType.per_1x128)
a1, scale = quant_func(a1, quant_dtype=current_platform.fp8_dtype())
elif quant_config.is_per_act_token:
quant_func = get_hip_quant(QuantType.per_Token)
a1, scale = quant_func(a1, quant_dtype=current_platform.fp8_dtype())
(
dispatch_a1,
dispatch_weights,
dispatch_scale,
dispatch_ids,
dispatch_recv_token_num,
) = self.mori_op.dispatch(a1, topk_weights, scale, topk_ids)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=dispatch_recv_token_num, expert_num_tokens_cpu=None
)
return (
dispatch_a1,
dispatch_scale,
expert_tokens_meta,
dispatch_ids,
dispatch_weights,
)
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
num_token = output.shape[0]
result = self.mori_op.combine(
fused_expert_output,
None,
topk_ids,
)[0]
output.copy_(result[:num_token])

View File

@@ -188,6 +188,9 @@ def rocm_aiter_fused_experts(
apply_router_weight_on_input: bool = False,
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
a1q_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
@@ -216,6 +219,9 @@ def rocm_aiter_fused_experts(
assert topk_weights.shape[-1] == 1, (
"Only support topk=1 when `apply_router_weight_on_input` is True"
)
assert num_local_tokens is None, (
"AITER tkw1 kernel does not support `num_local_tokens`"
)
return rocm_aiter_ops.asm_moe_tkw1(
hidden_states,
@@ -272,9 +278,11 @@ def rocm_aiter_fused_experts(
activation_method=activation_method,
w1_scale=quant_config.w1_scale,
w2_scale=quant_config.w2_scale,
a1_scale=quant_config.a1_scale,
a1_scale=quant_config.a1_scale if a1q_scale is None else a1q_scale,
a2_scale=quant_config.a2_scale,
doweight_stage1=apply_router_weight_on_input,
num_local_tokens=num_local_tokens,
output_dtype=output_dtype,
)
@@ -370,9 +378,12 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
# TODO(rob): rocm_aiter_fused_experts uses self.quant_config's
# a_scales for static quantization. Update this to fit better
# with the interface once all quant integrations are complete.
assert a1q_scale is None
assert a2_scale == self.quant_config.a2_scale
assert expert_tokens_meta is None
if expert_tokens_meta is not None:
num_local_tokens = expert_tokens_meta.expert_num_tokens
else:
num_local_tokens = None
result = rocm_aiter_fused_experts(
hidden_states=hidden_states,
@@ -384,6 +395,8 @@ class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.quant_config,
a1q_scale=a1q_scale,
num_local_tokens=num_local_tokens,
output_dtype=output.dtype,
)
assert result.shape == output.shape
output.copy_(result)

View File

@@ -106,6 +106,12 @@ def on_gfx9() -> bool:
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
@cache
def on_gfx942() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
return any(arch in GPU_ARCH for arch in ["gfx942"])
@cache
def on_gfx950() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName

View File

@@ -451,3 +451,13 @@ def has_helion() -> bool:
# use helion...
"""
return _has_module("helion")
def has_aiter() -> bool:
"""Whether the optional `aiter` package is available."""
return _has_module("aiter")
def has_mori() -> bool:
"""Whether the optional `mori` package is available."""
return _has_module("mori")