diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index df840930..b9a6f80a 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -97,7 +97,8 @@ def dense_router_dispatch_nvfp4_fused( # Decode the gate_weight from NVFP4 to BF16 for cuBLAS from dsv4.ops.quantize import dequantize_nvfp4 gate_bf16 = dequantize_nvfp4(gate_weight, gate_weight_scale, gate_ws2) - logits = torch.nn.functional.linear(hidden_states.float(), gate_bf16.T.float()) + logits = torch.nn.functional.linear(hidden_states, gate_bf16.T) + logits = logits.float() # BF16 → FP32 for numerical stability in sqrt(softplus) run_fused_activation_topk( logits, e_bias, routed_scaling_factor, top_k,