diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 44c9bb79e..534004e11 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -186,16 +186,23 @@ def maybe_make_prepare_finalize( use_fp8_dispatch = ( quant_config.is_per_act_token or quant_config.is_block_quantized ) - # For PTPC (per token per channel) quant, the scale dim for each token is 1 - # For 1x128 quant, the scale dim for each token is hidden_dim // 128 - scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 + if use_fp8_dispatch: + # For PTPC (per token per channel) quant, scale dim is 1 + # For 1x128 quant, scale dim is hidden_dim // 128 + quant_dtype = quant_config.quant_dtype + scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 + else: + # Unquantized dispatch (e.g. AITER with defer_input_quant): + # dispatch raw BF16/FP16 data, no scales needed. + quant_dtype = moe.in_dtype + scale_dim = 0 all_to_all_args = dict( rank=all2all_manager.rank, num_ep_ranks=all2all_manager.world_size, - quant_dtype=quant_config.quant_dtype, + quant_dtype=quant_dtype, token_hidden_size=moe.hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + scale_type_size=0 if scale_dim == 0 else torch.float32.itemsize, max_num_tokens_per_dp_rank=moe.max_num_tokens, input_dtype=moe.in_dtype, num_local_experts=moe.num_experts // all2all_manager.world_size, diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index a29d8a7d8..38b552b02 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -108,10 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> FusedMoEPrepareAndFinalizeModular | None: - if self.unquantized_backend == UnquantizedMoeBackend.AITER: - return None - else: - return super().maybe_make_prepare_finalize(routing_tables) + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, @@ -130,6 +127,17 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), ) + elif ( + self.unquantized_backend == UnquantizedMoeBackend.AITER + and rocm_aiter_ops.is_fused_moe_enabled() + ): + from .rocm_aiter_fused_moe import AiterExperts + + logger.debug("AiterExperts %s", self.moe) + return AiterExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ) else: logger.debug("TritonExperts %s", self.moe) return TritonExperts(