From 538dbb0643e5c62000c7ecb04dcb221050b4e031 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 08:39:12 +0000 Subject: [PATCH] fix: use quantize_activation_nvfp4 in diag --- tests/unit/test_nvfp4_primitives.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index cc2942e2..f4fa4aea 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -93,10 +93,11 @@ def test_nvfp4_primitives(): print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}") # The runner uses E4M3 or E8M0? Check what quantize actually produces - from dsv4.ops.quantize import quantize_to_nvfp4 + from dsv4.ops.quantize import quantize_activation_nvfp4 x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) - x_fp4, x_sf = quantize_to_nvfp4(x) - print(f" quantize_to_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}") + global_scale = torch.tensor(448.0, device='cuda', dtype=torch.float32) + x_fp4, x_sf = quantize_activation_nvfp4(x, global_scale) + print(f" quantize_activation_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}") if x_sf.dtype == torch.float8_e4m3fn: print(f" ✅ Scale factors are E4M3 — NVFP4 correct") elif x_sf.dtype == torch.float8_e8m0fnu: