diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 9a6f67b42..c0a7dfc49 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -256,13 +256,18 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit ) -> bool: """ The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). + DeepSeekV3 routing supports float32 router_logits (converted internally). + Simulated routing generates synthetic decisions and is agnostic to dtype. """ if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits + # DeepSeekV3 routing handles float32 logits internally. + # Simulated routing generates synthetic decisions, so the + # kernel doesn't care about the actual logits dtype. # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 + return routing_method in ( + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Simulated, + ) return True @staticmethod @@ -288,12 +293,14 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit # NOTE(rob): potentially allow others here. This is a conservative list. return routing_method in [ RoutingMethodType.DeepSeekV3, + RoutingMethodType.Simulated, ] elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): # NOTE(dbari): as above, potentially allow others here. return routing_method in [ RoutingMethodType.DeepSeekV3, RoutingMethodType.Llama4, + RoutingMethodType.Simulated, ] else: raise ValueError("Unsupported quantization scheme.") diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 84beb6abb..b47391c41 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -255,6 +255,7 @@ class TrtLlmNvFp4ExpertsMonolithic( RoutingMethodType.Renormalize, RoutingMethodType.RenormalizeNaive, RoutingMethodType.Llama4, + RoutingMethodType.Simulated, ] @staticmethod @@ -264,13 +265,18 @@ class TrtLlmNvFp4ExpertsMonolithic( ) -> bool: """ The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). + DeepSeekV3 routing supports float32 router_logits (converted internally). + Simulated routing generates synthetic decisions and is agnostic to dtype. """ if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits + # DeepSeekV3 routing handles float32 logits internally. + # Simulated routing generates synthetic decisions, so the + # kernel doesn't care about the actual logits dtype. # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 + return routing_method in ( + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Simulated, + ) return True def apply(