Add explicit .to(dev) on W_gate after transpose — belt and suspenders

This commit is contained in:
2026-06-03 14:17:02 +00:00
parent ef94c48957
commit 95e45a87e3

View File

@@ -1316,7 +1316,7 @@ def main():
gsb = 1.0 * ws2_v # global_scale_b = gs * ws2
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H)
router.W_gate = gate_bf16.T.contiguous() # (H, E) for F.linear(x, W_gate.T)
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)
else:
# BF16 gate weight from checkpoint
gw = all_w.get(f"{pfx}.gate.weight")