diff --git a/single_shot_inference.py b/single_shot_inference.py index 72da8263..fdcd5877 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -128,12 +128,16 @@ class CheckpointReader: def nvfp4_linear(x, weight, weight_scale, weight_scale_2): """NVFP4 linear: dequant → BF16 matmul.""" - w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) + w = dequant_nvfp4_weight( + weight.cuda(), + weight_scale.cuda(), + weight_scale_2.cuda() if weight_scale_2 is not None else None, + ) return torch.nn.functional.linear(x, w) def bf16_linear(x, weight): """BF16 linear.""" - return torch.nn.functional.linear(x, weight.bfloat16()) + return torch.nn.functional.linear(x, weight.cuda().bfloat16()) # =====================================================================