diff --git a/single_shot_inference.py b/single_shot_inference.py index 8ca6c526..5d61e74c 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -53,7 +53,7 @@ NUM_GPUS = 8 # NVFP4 dequantization — matches checkpoint format exactly # ===================================================================== -FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.]) +FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) # E2M1 magnitudes def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): """Dequantize NVFP4 weight to BF16.