Fix: move weight tensors to CUDA before dequant
This commit is contained in:
@@ -128,12 +128,16 @@ class CheckpointReader:
|
||||
|
||||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||||
"""NVFP4 linear: dequant → BF16 matmul."""
|
||||
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
||||
w = dequant_nvfp4_weight(
|
||||
weight.cuda(),
|
||||
weight_scale.cuda(),
|
||||
weight_scale_2.cuda() if weight_scale_2 is not None else None,
|
||||
)
|
||||
return torch.nn.functional.linear(x, w)
|
||||
|
||||
def bf16_linear(x, weight):
|
||||
"""BF16 linear."""
|
||||
return torch.nn.functional.linear(x, weight.bfloat16())
|
||||
return torch.nn.functional.linear(x, weight.cuda().bfloat16())
|
||||
|
||||
|
||||
# =====================================================================
|
||||
|
||||
Reference in New Issue
Block a user