Fix: ensure FP4 LUT on CUDA before index op
This commit is contained in:
@@ -54,8 +54,8 @@ def dequant_nvfp4_weight(
|
||||
high_sign = (high >> 3).bool()
|
||||
high_idx = (high & 0x07).long()
|
||||
|
||||
# LUT lookup
|
||||
lut = FP4_LUT.to(device=weight.device)
|
||||
# LUT lookup (ensure LUT on same device as weight)
|
||||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user