diff --git a/tests/unit/test_nvfp4_gpu_quantize.py b/tests/unit/test_nvfp4_gpu_quantize.py index 70586ac8..a6350fb3 100644 --- a/tests/unit/test_nvfp4_gpu_quantize.py +++ b/tests/unit/test_nvfp4_gpu_quantize.py @@ -83,24 +83,21 @@ def test_quantize_nvfp4_gpu_larger(): torch.manual_seed(42) M, N = 64, 4096 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') - global_scale = 1.0 / (6.0 * 448.0) + # Use quantize_to_nvfp4 to get a proper global_scale (from data amax) + from dsv4.ops.quantize import quantize_to_nvfp4 + _, _, global_scale = quantize_to_nvfp4(x) ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale) gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale) - ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, N) - gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, N) + # Byte-exact comparison + fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item() + sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item() + print(f" FP4 byte match: {fp4_match*100:.1f}%") + print(f" SF byte match: {sf_match*100:.1f}%") - cos_ref = torch.nn.functional.cosine_similarity( - x.flatten().float().unsqueeze(0), ref_deq.flatten().float().unsqueeze(0) - ).item() - cos_gpu = torch.nn.functional.cosine_similarity( - x.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0) - ).item() - - print(f" Python round-trip cos: {cos_ref:.6f}") - print(f" GPU round-trip cos: {cos_gpu:.6f}") - assert cos_gpu >= 0.95, f"GPU round-trip cosine too low: {cos_gpu}" + assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}" + assert sf_match >= 0.99, f"SF byte match too low: {sf_match}" print(f" ✅ PASS") @@ -110,7 +107,7 @@ def test_quantize_nvfp4_gpu_no_cpu_sync(): torch.manual_seed(42) M, N = 32, 512 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') - global_scale = 0.001 + global_scale = 1.0 # This should NOT trigger a CUDA synchronization # If it does, cudagraph capture would fail