Fix W_gate device: ensure .to(dev) after transpose

This commit is contained in:
2026-06-03 12:56:52 +00:00
parent bd10bdbbd9
commit 2866eb92e7

View File

@@ -1317,7 +1317,7 @@ def main():
gate_ws2.to(dev) if gate_ws2 is not None else None,
gate_isc.to(dev) if gate_isc is not None else None)
# W_gate shape: (H, E) for F.linear(x, W_gate.T)
router.W_gate = gate_bf16.T.contiguous()
router.W_gate = gate_bf16.T.contiguous().to(dev)
router.gate_lin = None # force BF16 dispatch path
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: BF16 router gate (dequantized from NVFP4)", flush=True)
@@ -1326,7 +1326,7 @@ def main():
gw = all_w.get(f"{pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
router.W_gate = g_bf16.T.contiguous()
router.W_gate = g_bf16.T.contiguous().to(dev)
router.gate_lin = None
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: BF16 router gate (checkpoint BF16)", flush=True)