Add explicit .to(dev) on W_gate after transpose — belt and suspenders
This commit is contained in:
@@ -1316,7 +1316,7 @@ def main():
|
||||
gsb = 1.0 * ws2_v # global_scale_b = gs * ws2
|
||||
gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32)
|
||||
gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H)
|
||||
router.W_gate = gate_bf16.T.contiguous() # (H, E) for F.linear(x, W_gate.T)
|
||||
router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T)
|
||||
else:
|
||||
# BF16 gate weight from checkpoint
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
|
||||
Reference in New Issue
Block a user