From 95e45a87e3609dfac3bc267678944eee5b4a2ba3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 14:17:02 +0000 Subject: [PATCH] =?UTF-8?q?Add=20explicit=20.to(dev)=20on=20W=5Fgate=20aft?= =?UTF-8?q?er=20transpose=20=E2=80=94=20belt=20and=20suspenders?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- single_shot_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 008d1d3a..bd6a376d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1316,7 +1316,7 @@ def main(): gsb = 1.0 * ws2_v # global_scale_b = gs * ws2 gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32) gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H) - router.W_gate = gate_bf16.T.contiguous() # (H, E) for F.linear(x, W_gate.T) + router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T) else: # BF16 gate weight from checkpoint gw = all_w.get(f"{pfx}.gate.weight")