Fix: move weight tensors to CUDA before dequant

This commit is contained in:
2026-05-30 22:43:47 +00:00
parent 47c7b3c50b
commit 5dcfb333ea

View File

@@ -128,12 +128,16 @@ class CheckpointReader:
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
"""NVFP4 linear: dequant → BF16 matmul."""
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
w = dequant_nvfp4_weight(
weight.cuda(),
weight_scale.cuda(),
weight_scale_2.cuda() if weight_scale_2 is not None else None,
)
return torch.nn.functional.linear(x, w)
def bf16_linear(x, weight):
"""BF16 linear."""
return torch.nn.functional.linear(x, weight.bfloat16())
return torch.nn.functional.linear(x, weight.cuda().bfloat16())
# =====================================================================