[ROCm] Enable MORI EP for unquantized MoE with AITER backend (#37529)
Signed-off-by: Tan Pin Siang <pinsiang.tan@amd.com>
This commit is contained in:
@@ -186,16 +186,23 @@ def maybe_make_prepare_finalize(
|
||||
use_fp8_dispatch = (
|
||||
quant_config.is_per_act_token or quant_config.is_block_quantized
|
||||
)
|
||||
# For PTPC (per token per channel) quant, the scale dim for each token is 1
|
||||
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
|
||||
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
|
||||
if use_fp8_dispatch:
|
||||
# For PTPC (per token per channel) quant, scale dim is 1
|
||||
# For 1x128 quant, scale dim is hidden_dim // 128
|
||||
quant_dtype = quant_config.quant_dtype
|
||||
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
|
||||
else:
|
||||
# Unquantized dispatch (e.g. AITER with defer_input_quant):
|
||||
# dispatch raw BF16/FP16 data, no scales needed.
|
||||
quant_dtype = moe.in_dtype
|
||||
scale_dim = 0
|
||||
all_to_all_args = dict(
|
||||
rank=all2all_manager.rank,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
scale_dim=scale_dim,
|
||||
scale_type_size=torch.float32.itemsize,
|
||||
scale_type_size=0 if scale_dim == 0 else torch.float32.itemsize,
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
input_dtype=moe.in_dtype,
|
||||
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||
|
||||
@@ -108,10 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalizeModular | None:
|
||||
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
@@ -130,6 +127,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
elif (
|
||||
self.unquantized_backend == UnquantizedMoeBackend.AITER
|
||||
and rocm_aiter_ops.is_fused_moe_enabled()
|
||||
):
|
||||
from .rocm_aiter_fused_moe import AiterExperts
|
||||
|
||||
logger.debug("AiterExperts %s", self.moe)
|
||||
return AiterExperts(
|
||||
moe_config=self.moe,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
return TritonExperts(
|
||||
|
||||
Reference in New Issue
Block a user