diff --git a/single_shot_inference.py b/single_shot_inference.py index 718ce59f..9b67516d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -715,6 +715,9 @@ def main(): gw = all_w.get(f"{pfx}.gate.weight") eb = all_w.get(f"{pfx}.gate.e_score_correction_bias") if gw is not None and eb is not None: + # Checkpoint may store gate weight transposed: (n_experts, hidden) vs (hidden, n_experts) + if gw.shape == (cfg["n_routed_experts"], H): + gw = gw.T.contiguous() router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32)) router.finalize_weights() routers[li] = router