fix test: use proper global_scale from quantize_to_nvfp4 for larger shape test
This commit is contained in:
@@ -83,24 +83,21 @@ def test_quantize_nvfp4_gpu_larger():
|
||||
torch.manual_seed(42)
|
||||
M, N = 64, 4096
|
||||
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
||||
global_scale = 1.0 / (6.0 * 448.0)
|
||||
# Use quantize_to_nvfp4 to get a proper global_scale (from data amax)
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
_, _, global_scale = quantize_to_nvfp4(x)
|
||||
|
||||
ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale)
|
||||
gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale)
|
||||
|
||||
ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, N)
|
||||
gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, N)
|
||||
# Byte-exact comparison
|
||||
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
|
||||
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
|
||||
print(f" FP4 byte match: {fp4_match*100:.1f}%")
|
||||
print(f" SF byte match: {sf_match*100:.1f}%")
|
||||
|
||||
cos_ref = torch.nn.functional.cosine_similarity(
|
||||
x.flatten().float().unsqueeze(0), ref_deq.flatten().float().unsqueeze(0)
|
||||
).item()
|
||||
cos_gpu = torch.nn.functional.cosine_similarity(
|
||||
x.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
|
||||
).item()
|
||||
|
||||
print(f" Python round-trip cos: {cos_ref:.6f}")
|
||||
print(f" GPU round-trip cos: {cos_gpu:.6f}")
|
||||
assert cos_gpu >= 0.95, f"GPU round-trip cosine too low: {cos_gpu}"
|
||||
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
|
||||
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
|
||||
print(f" ✅ PASS")
|
||||
|
||||
|
||||
@@ -110,7 +107,7 @@ def test_quantize_nvfp4_gpu_no_cpu_sync():
|
||||
torch.manual_seed(42)
|
||||
M, N = 32, 512
|
||||
x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda')
|
||||
global_scale = 0.001
|
||||
global_scale = 1.0
|
||||
|
||||
# This should NOT trigger a CUDA synchronization
|
||||
# If it does, cudagraph capture would fail
|
||||
|
||||
Reference in New Issue
Block a user