fix: W_gate is (H, E) but F.linear expects (E, H), transpose before linear
This commit is contained in:
@@ -49,7 +49,7 @@ def _run_prefill_path(
|
||||
out_weights, out_ids,
|
||||
):
|
||||
"""GEMM via torch.nn.functional.linear, then fused activation + top-k."""
|
||||
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.float())
|
||||
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
run_fused_activation_topk(
|
||||
logits, e_bias, routed_scaling_factor, top_k,
|
||||
|
||||
Reference in New Issue
Block a user