fix: W_gate is (H, E) but F.linear expects (E, H), transpose before linear

This commit is contained in:
2026-05-31 23:55:16 +00:00
parent 5396a04c28
commit 56dff8d185

View File

@@ -49,7 +49,7 @@ def _run_prefill_path(
out_weights, out_ids,
):
"""GEMM via torch.nn.functional.linear, then fused activation + top-k."""
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.float())
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,