Fix gate weight transpose: checkpoint is (E, H), Router expects (H, E)
This commit is contained in:
@@ -715,6 +715,9 @@ def main():
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
if gw is not None and eb is not None:
|
||||
# Checkpoint may store gate weight transposed: (n_experts, hidden) vs (hidden, n_experts)
|
||||
if gw.shape == (cfg["n_routed_experts"], H):
|
||||
gw = gw.T.contiguous()
|
||||
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights()
|
||||
routers[li] = router
|
||||
|
||||
Reference in New Issue
Block a user