From 2866eb92e7ce02acba7cad083079787ea801ec00 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 12:56:52 +0000 Subject: [PATCH] Fix W_gate device: ensure .to(dev) after transpose --- single_shot_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 6c8744b1..6f9ae641 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1317,7 +1317,7 @@ def main(): gate_ws2.to(dev) if gate_ws2 is not None else None, gate_isc.to(dev) if gate_isc is not None else None) # W_gate shape: (H, E) for F.linear(x, W_gate.T) - router.W_gate = gate_bf16.T.contiguous() + router.W_gate = gate_bf16.T.contiguous().to(dev) router.gate_lin = None # force BF16 dispatch path router.load_weights(e_bias=eb.to(dev, torch.float32)) if li < 5: print(f" L{li}: BF16 router gate (dequantized from NVFP4)", flush=True) @@ -1326,7 +1326,7 @@ def main(): gw = all_w.get(f"{pfx}.gate.weight") if gw is not None: g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous() - router.W_gate = g_bf16.T.contiguous() + router.W_gate = g_bf16.T.contiguous().to(dev) router.gate_lin = None router.load_weights(e_bias=eb.to(dev, torch.float32)) if li < 5: print(f" L{li}: BF16 router gate (checkpoint BF16)", flush=True)