[Bug][MoE] Fix TRTLLM NVFP4 Routing Kernel Precision (#36725)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
@@ -298,10 +298,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
||||
and self.routing_method_type != RoutingMethodType.Llama4
|
||||
)
|
||||
|
||||
# Prepare routing bias into kernel format.
|
||||
routing_bias = e_score_correction_bias
|
||||
if routing_bias is not None:
|
||||
routing_bias = routing_bias.to(torch.bfloat16)
|
||||
# Prepare router logits for kernel format.
|
||||
router_logits = (
|
||||
router_logits.to(torch.float32)
|
||||
if self.routing_method_type == RoutingMethodType.DeepSeekV3
|
||||
@@ -311,7 +308,7 @@ class TrtLlmNvFp4ExpertsMonolithic(
|
||||
# Invoke kernel.
|
||||
return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits,
|
||||
routing_bias=routing_bias,
|
||||
routing_bias=e_score_correction_bias,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*hidden_states.shape[:-1], -1
|
||||
|
||||
Reference in New Issue
Block a user