[Bug] Fix fp8 trtllm MoE modular kernel supported routing methods (#37346)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
This commit is contained in:
@@ -69,54 +69,11 @@ class TrtLlmFp8ExpertsBase:
|
|||||||
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
"""Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _supports_quant_scheme(
|
|
||||||
weight_key: QuantKey | None,
|
|
||||||
activation_key: QuantKey | None,
|
|
||||||
) -> bool:
|
|
||||||
"""Supports Fp8 per-tensor, Fp8 block, and MXFP8."""
|
|
||||||
SUPPORTED_W_A = [
|
|
||||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
|
||||||
(kFp8StaticTensorSym, kFp8StaticTensorSym),
|
|
||||||
(kMxfp8Static, kMxfp8Dynamic),
|
|
||||||
]
|
|
||||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_activation(activation: MoEActivation) -> bool:
|
def _supports_activation(activation: MoEActivation) -> bool:
|
||||||
"""Supports only SiLU and RELU^2 non-gated activation."""
|
"""Supports only SiLU and RELU^2 non-gated activation."""
|
||||||
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _supports_routing_method(
|
|
||||||
routing_method: RoutingMethodType,
|
|
||||||
weight_key: QuantKey | None,
|
|
||||||
activation_key: QuantKey | None,
|
|
||||||
) -> bool:
|
|
||||||
"""Monolithic kernels need to express router support."""
|
|
||||||
# NOTE(dbari): TopK routing could also be enabled, but need to validate models
|
|
||||||
# NOTE(dbari): Default is not implemented and should not be enabled until it is
|
|
||||||
if (weight_key, activation_key) in [
|
|
||||||
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
|
|
||||||
(kMxfp8Static, kMxfp8Dynamic),
|
|
||||||
]:
|
|
||||||
# NOTE(rob): potentially allow others here. This is a conservative list.
|
|
||||||
return routing_method in [
|
|
||||||
RoutingMethodType.DeepSeekV3,
|
|
||||||
RoutingMethodType.Renormalize,
|
|
||||||
RoutingMethodType.RenormalizeNaive,
|
|
||||||
]
|
|
||||||
elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
|
|
||||||
# NOTE(dbari): as above, potentially allow others here.
|
|
||||||
return routing_method in [
|
|
||||||
RoutingMethodType.DeepSeekV3,
|
|
||||||
RoutingMethodType.Llama4,
|
|
||||||
RoutingMethodType.Renormalize,
|
|
||||||
RoutingMethodType.RenormalizeNaive,
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported quantization scheme.")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||||
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
"""Monolithic kernel so only use with naive DP/EP and TP."""
|
||||||
@@ -125,22 +82,6 @@ class TrtLlmFp8ExpertsBase:
|
|||||||
or moe_parallel_config.use_naive_all2all_kernels
|
or moe_parallel_config.use_naive_all2all_kernels
|
||||||
) and not moe_parallel_config.enable_eplb
|
) and not moe_parallel_config.enable_eplb
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _supports_router_logits_dtype(
|
|
||||||
router_logits_dtype: torch.dtype | None,
|
|
||||||
routing_method: RoutingMethodType,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
|
||||||
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
|
||||||
internally in the kernel).
|
|
||||||
"""
|
|
||||||
if router_logits_dtype == torch.float32:
|
|
||||||
# Only DeepSeekV3 routing handles float32 logits
|
|
||||||
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
|
||||||
return routing_method == RoutingMethodType.DeepSeekV3
|
|
||||||
return True
|
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -306,6 +247,22 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
|||||||
]
|
]
|
||||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _supports_router_logits_dtype(
|
||||||
|
router_logits_dtype: torch.dtype | None,
|
||||||
|
routing_method: RoutingMethodType,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
|
||||||
|
Only DeepSeekV3 routing supports float32 router_logits (which is converted
|
||||||
|
internally in the kernel).
|
||||||
|
"""
|
||||||
|
if router_logits_dtype == torch.float32:
|
||||||
|
# Only DeepSeekV3 routing handles float32 logits
|
||||||
|
# https://github.com/flashinfer-ai/flashinfer/issues/2469
|
||||||
|
return routing_method == RoutingMethodType.DeepSeekV3
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _supports_routing_method(
|
def _supports_routing_method(
|
||||||
routing_method: RoutingMethodType,
|
routing_method: RoutingMethodType,
|
||||||
|
|||||||
Reference in New Issue
Block a user