[ROCm] Enable MORI EP for unquantized MoE with AITER backend (#37529)

Signed-off-by: Tan Pin Siang <pinsiang.tan@amd.com>
This commit is contained in:
Tan Pin Siang
2026-03-30 15:19:33 +08:00
committed by GitHub
parent 57861ae48d
commit 85c0950b1f
2 changed files with 24 additions and 9 deletions

View File

@@ -186,16 +186,23 @@ def maybe_make_prepare_finalize(
use_fp8_dispatch = (
quant_config.is_per_act_token or quant_config.is_block_quantized
)
# For PTPC (per token per channel) quant, the scale dim for each token is 1
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
if use_fp8_dispatch:
# For PTPC (per token per channel) quant, scale dim is 1
# For 1x128 quant, scale dim is hidden_dim // 128
quant_dtype = quant_config.quant_dtype
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
else:
# Unquantized dispatch (e.g. AITER with defer_input_quant):
# dispatch raw BF16/FP16 data, no scales needed.
quant_dtype = moe.in_dtype
scale_dim = 0
all_to_all_args = dict(
rank=all2all_manager.rank,
num_ep_ranks=all2all_manager.world_size,
quant_dtype=quant_config.quant_dtype,
quant_dtype=quant_dtype,
token_hidden_size=moe.hidden_dim,
scale_dim=scale_dim,
scale_type_size=torch.float32.itemsize,
scale_type_size=0 if scale_dim == 0 else torch.float32.itemsize,
max_num_tokens_per_dp_rank=moe.max_num_tokens,
input_dtype=moe.in_dtype,
num_local_experts=moe.num_experts // all2all_manager.world_size,

View File

@@ -108,10 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalizeModular | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@@ -130,6 +127,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
max_num_tokens=self.moe.max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
elif (
self.unquantized_backend == UnquantizedMoeBackend.AITER
and rocm_aiter_ops.is_fused_moe_enabled()
):
from .rocm_aiter_fused_moe import AiterExperts
logger.debug("AiterExperts %s", self.moe)
return AiterExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonExperts %s", self.moe)
return TritonExperts(