diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index cc2942e2..f4fa4aea 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -93,10 +93,11 @@ def test_nvfp4_primitives(): print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}") # The runner uses E4M3 or E8M0? Check what quantize actually produces - from dsv4.ops.quantize import quantize_to_nvfp4 + from dsv4.ops.quantize import quantize_activation_nvfp4 x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) - x_fp4, x_sf = quantize_to_nvfp4(x) - print(f" quantize_to_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}") + global_scale = torch.tensor(448.0, device='cuda', dtype=torch.float32) + x_fp4, x_sf = quantize_activation_nvfp4(x, global_scale) + print(f" quantize_activation_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}") if x_sf.dtype == torch.float8_e4m3fn: print(f" ✅ Scale factors are E4M3 — NVFP4 correct") elif x_sf.dtype == torch.float8_e8m0fnu: