From 5dcfb333eabcd1db1414393b936626e44ecbe4c4 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 22:43:47 +0000 Subject: [PATCH] Fix: move weight tensors to CUDA before dequant --- single_shot_inference.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 72da8263..fdcd5877 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -128,12 +128,16 @@ class CheckpointReader: def nvfp4_linear(x, weight, weight_scale, weight_scale_2): """NVFP4 linear: dequant → BF16 matmul.""" - w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) + w = dequant_nvfp4_weight( + weight.cuda(), + weight_scale.cuda(), + weight_scale_2.cuda() if weight_scale_2 is not None else None, + ) return torch.nn.functional.linear(x, w) def bf16_linear(x, weight): """BF16 linear.""" - return torch.nn.functional.linear(x, weight.bfloat16()) + return torch.nn.functional.linear(x, weight.cuda().bfloat16()) # =====================================================================