diff --git a/single_shot_inference.py b/single_shot_inference.py index 8e79b3d8..72da8263 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -54,8 +54,8 @@ def dequant_nvfp4_weight( high_sign = (high >> 3).bool() high_idx = (high & 0x07).long() - # LUT lookup - lut = FP4_LUT.to(device=weight.device) + # LUT lookup (ensure LUT on same device as weight) + lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)