[Bug] Fix routing bias dtype for trtllm per-block fp8 moe (#38989)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Wei Zhao
2026-04-08 22:42:43 -04:00
committed by GitHub
parent 2f41d6c063
commit 92fbec391b

View File

@@ -358,6 +358,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
router_logits = router_logits.to(torch.float32)
# Currently FI requires bfloat16 routing bias.
# https://github.com/flashinfer-ai/flashinfer/issues/2909
if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
is_mxfp8 = self.quant_config.block_shape == [1, 32]
if is_mxfp8:
fp8_quant_type = Fp8QuantizationType.MxFp8