[MoE] Add RoutingMethodType.Simulated to TRT-LLM FP8/NVFP4 kernel allowlists (#38329)
Signed-off-by: Jaewon Lee <jaewon@meta.com>
This commit is contained in:
@@ -256,13 +256,18 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
DeepSeekV3 routing supports float32 router_logits (converted internally).
|
||||
Simulated routing generates synthetic decisions and is agnostic to dtype.
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# DeepSeekV3 routing handles float32 logits internally.
|
||||
# Simulated routing generates synthetic decisions, so the
|
||||
# kernel doesn't care about the actual logits dtype.
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return routing_method in (
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Simulated,
|
||||
)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
@@ -288,12 +293,14 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Simulated,
|
||||
]
|
||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
||||
# NOTE(dbari): as above, potentially allow others here.
|
||||
return routing_method in [
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Simulated,
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unsupported quantization scheme.")
|
||||
|
||||
@@ -255,6 +255,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
||||
RoutingMethodType.Renormalize,
|
||||
RoutingMethodType.RenormalizeNaive,
|
||||
RoutingMethodType.Llama4,
|
||||
RoutingMethodType.Simulated,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -264,13 +265,18 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
||||
) -> bool:
|
||||
"""
|
||||
The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
|
||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||
internally in the kernel).
|
||||
DeepSeekV3 routing supports float32 router_logits (converted internally).
|
||||
Simulated routing generates synthetic decisions and is agnostic to dtype.
|
||||
"""
|
||||
if router_logits_dtype == torch.float32:
|
||||
# Only DeepSeekV3 routing handles float32 logits
|
||||
# DeepSeekV3 routing handles float32 logits internally.
|
||||
# Simulated routing generates synthetic decisions, so the
|
||||
# kernel doesn't care about the actual logits dtype.
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||
return routing_method == RoutingMethodType.DeepSeekV3
|
||||
return routing_method in (
|
||||
RoutingMethodType.DeepSeekV3,
|
||||
RoutingMethodType.Simulated,
|
||||
)
|
||||
return True
|
||||
|
||||
def apply(
|
||||
|
||||
Reference in New Issue
Block a user