[Bug] Fix routing bias dtype for trtllm per-block fp8 moe (#38989)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -358,6 +358,11 @@ class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolit
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3:
|
||||
router_logits = router_logits.to(torch.float32)
|
||||
|
||||
# Currently FI requires bfloat16 routing bias.
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/2909
|
||||
if e_score_correction_bias is not None:
|
||||
e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16)
|
||||
|
||||
is_mxfp8 = self.quant_config.block_shape == [1, 32]
|
||||
if is_mxfp8:
|
||||
fp8_quant_type = Fp8QuantizationType.MxFp8
|
||||
|
||||
Reference in New Issue
Block a user