[MoE] Add RoutingMethodType.Simulated to TRT-LLM FP8/NVFP4 kernel allowlists (#38329)

Signed-off-by: Jaewon Lee <jaewon@meta.com>
This commit is contained in:
Jaewon
2026-03-29 22:53:43 -07:00
committed by GitHub
parent 92f0db57a8
commit d816834c1a
2 changed files with 21 additions and 8 deletions

View File

@@ -256,13 +256,18 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
) -> bool:
"""
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
DeepSeekV3 routing supports float32 router_logits (converted internally).
Simulated routing generates synthetic decisions and is agnostic to dtype.
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# DeepSeekV3 routing handles float32 logits internally.
# Simulated routing generates synthetic decisions, so the
# kernel doesn't care about the actual logits dtype.
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return routing_method in (
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
)
return True
@staticmethod
@@ -288,12 +293,14 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
# NOTE(rob): potentially allow others here. This is a conservative list.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
]
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
# NOTE(dbari): as above, potentially allow others here.
return routing_method in [
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Llama4,
RoutingMethodType.Simulated,
]
else:
raise ValueError("Unsupported quantization scheme.")

View File

@@ -255,6 +255,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.Llama4,
RoutingMethodType.Simulated,
]
@staticmethod
@@ -264,13 +265,18 @@ class TrtLlmNvFp4ExpertsMonolithic(
) -> bool:
"""
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
Only DeepSeekV3 routing supports float32 router_logits (which is converted
internally in the kernel).
DeepSeekV3 routing supports float32 router_logits (converted internally).
Simulated routing generates synthetic decisions and is agnostic to dtype.
"""
if router_logits_dtype == torch.float32:
# Only DeepSeekV3 routing handles float32 logits
# DeepSeekV3 routing handles float32 logits internally.
# Simulated routing generates synthetic decisions, so the
# kernel doesn't care about the actual logits dtype.
# https://github.com/flashinfer-ai/flashinfer/issues/2469
return routing_method == RoutingMethodType.DeepSeekV3
return routing_method in (
RoutingMethodType.DeepSeekV3,
RoutingMethodType.Simulated,
)
return True
def apply(