diff --git a/single_shot_inference.py b/single_shot_inference.py index 008d1d3a..bd6a376d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1316,7 +1316,7 @@ def main(): gsb = 1.0 * ws2_v # global_scale_b = gs * ws2 gsa = torch.tensor([gsb] * gate_w.shape[0], device=dev, dtype=torch.float32) gate_bf16 = dequantize_nvfp4(gate_w.to(dev), gate_ws.to(dev), gsa) # (E_packed*2, H) - router.W_gate = gate_bf16.T.contiguous() # (H, E) for F.linear(x, W_gate.T) + router.W_gate = gate_bf16.T.contiguous().to(dev) # (H, E) for F.linear(x, W_gate.T) else: # BF16 gate weight from checkpoint gw = all_w.get(f"{pfx}.gate.weight")