diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 74096ef6e..5f4607657 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -69,54 +69,11 @@ class TrtLlmFp8ExpertsBase: """Does not support non-gated MoE (i.e. Nanotron-3-Nano).""" return True - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - """Supports Fp8 per-tensor, Fp8 block, and MXFP8.""" - SUPPORTED_W_A = [ - (kFp8Static128BlockSym, kFp8Dynamic128Sym), - (kFp8StaticTensorSym, kFp8StaticTensorSym), - (kMxfp8Static, kMxfp8Dynamic), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - @staticmethod def _supports_activation(activation: MoEActivation) -> bool: """Supports only SiLU and RELU^2 non-gated activation.""" return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - @staticmethod - def _supports_routing_method( - routing_method: RoutingMethodType, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - """Monolithic kernels need to express router support.""" - # NOTE(dbari): TopK routing could also be enabled, but need to validate models - # NOTE(dbari): Default is not implemented and should not be enabled until it is - if (weight_key, activation_key) in [ - (kFp8Static128BlockSym, kFp8Dynamic128Sym), - (kMxfp8Static, kMxfp8Dynamic), - ]: - # NOTE(rob): potentially allow others here. This is a conservative list. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): - # NOTE(dbari): as above, potentially allow others here. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Llama4, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - else: - raise ValueError("Unsupported quantization scheme.") - @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """Monolithic kernel so only use with naive DP/EP and TP.""" @@ -125,22 +82,6 @@ class TrtLlmFp8ExpertsBase: or moe_parallel_config.use_naive_all2all_kernels ) and not moe_parallel_config.enable_eplb - @staticmethod - def _supports_router_logits_dtype( - router_logits_dtype: torch.dtype | None, - routing_method: RoutingMethodType, - ) -> bool: - """ - The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). - """ - if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits - # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 - return True - def supports_chunking(self) -> bool: return False @@ -306,6 +247,22 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ] return (weight_key, activation_key) in SUPPORTED_W_A + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. + Only DeepSeekV3 routing supports float32 router_logits (which is converted + internally in the kernel). + """ + if router_logits_dtype == torch.float32: + # Only DeepSeekV3 routing handles float32 logits + # https://github.com/flashinfer-ai/flashinfer/issues/2469 + return routing_method == RoutingMethodType.DeepSeekV3 + return True + @staticmethod def _supports_routing_method( routing_method: RoutingMethodType,