Fix: ensure FP4 LUT on CUDA before index op

This commit is contained in:
2026-05-30 22:43:01 +00:00
parent 13bae9dd55
commit 47c7b3c50b

View File

@@ -54,8 +54,8 @@ def dequant_nvfp4_weight(
high_sign = (high >> 3).bool()
high_idx = (high & 0x07).long()
# LUT lookup
lut = FP4_LUT.to(device=weight.device)
# LUT lookup (ensure LUT on same device as weight)
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)