diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index dda5ae88..606a2f74 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -49,7 +49,7 @@ def _run_prefill_path( out_weights, out_ids, ): """GEMM via torch.nn.functional.linear, then fused activation + top-k.""" - logits = torch.nn.functional.linear(hidden_states.float(), W_gate.float()) + logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float()) from dsv4.kernels.router._activation_topk import run_fused_activation_topk run_fused_activation_topk( logits, e_bias, routed_scaling_factor, top_k,