96 lines
5.0 KiB
Python
96 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Diagnostic: isolate KV-1 quantize/dequant quality issue."""
|
|
|
|
import torch
|
|
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
|
|
# 1. Test existing quantize_nvfp4_gpu_fused round-trip on compressed-like data
|
|
print("=== Test 1: Existing quantize_nvfp4_gpu_fused round-trip ===")
|
|
data = torch.randn(8, 512, device=device, dtype=torch.bfloat16) * 5.0
|
|
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
|
fp4, sf, gsa = quantize_nvfp4_gpu_fused(data)
|
|
print(f" data shape={tuple(data.shape)}, fp4 shape={tuple(fp4.shape)}, sf shape={tuple(sf.shape)}, gsa shape={tuple(gsa.shape)}")
|
|
print(f" fp4 dtype={fp4.dtype}, sf dtype={sf.dtype}, gsa dtype={gsa.dtype}")
|
|
|
|
# Dequant using our dequant kernel
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
deq = dequant_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa)
|
|
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
|
print(f" cos(data, dequant): {cos:.6f}")
|
|
print(f" |data|_max: {data.abs().max().item():.4f}, |deq|_max: {deq.abs().max().item():.4f}")
|
|
print(f" max_error: {(data.float()-deq.float()).abs().max().item():.4f}")
|
|
|
|
# 2. Test existing quantize_nvfp4 (CPU-synced) round-trip
|
|
print("\n=== Test 2: quantize_nvfp4 (CPU sync) round-trip ===")
|
|
from dsv4.ops.quantize import quantize_to_nvfp4
|
|
fp4_2, sf_2, gs_2 = quantize_to_nvfp4(data)
|
|
deq_2 = dequant_mod.dequant_nvfp4(fp4_2.view(torch.uint8), sf_2.view(torch.uint8), gs_2.expand(data.shape[0]).contiguous())
|
|
cos2 = torch.nn.functional.cosine_similarity(data.float().flatten(), deq_2.float().flatten(), dim=0).item()
|
|
print(f" cos(data, dequant): {cos2:.6f}")
|
|
|
|
# 3. Test Python reference dequant
|
|
print("\n=== Test 3: Python reference dequant ===")
|
|
from dsv4.ops.quantize import quantize_to_nvfp4
|
|
fp4_3, sf_3, gs_3 = quantize_to_nvfp4(data)
|
|
# Python dequant: E2M1_LUT[nibble] * sf * gsa
|
|
E2M1 = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.])
|
|
raw = fp4_3.view(torch.uint8)
|
|
lo = raw & 0x0F
|
|
hi = (raw >> 4) & 0x0F
|
|
lo_mag = E2M1[lo & 0x07].to(device) * torch.where(lo & 0x08, -1., 1.)
|
|
hi_mag = E2M1[hi & 0x07].to(device) * torch.where(hi & 0x08, -1., 1.)
|
|
py_deq = torch.stack([lo_mag, hi_mag], -1).reshape(data.shape).float()
|
|
py_deq = py_deq * sf_3.float().repeat_interleave(16, -1) * gs_3
|
|
cos3 = torch.nn.functional.cosine_similarity(data.float().flatten(), py_deq.flatten(), dim=0).item()
|
|
print(f" cos(data, py_dequant): {cos3:.6f}")
|
|
print(f" max_error: {(data.float()-py_deq).abs().max().item():.4f}")
|
|
|
|
# 4. Compare CUDA dequant vs Python dequant
|
|
cos_cd = torch.nn.functional.cosine_similarity(deq.float().flatten(), py_deq.flatten(), dim=0).item()
|
|
print(f" cos(cuda_deq, py_deq): {cos_cd:.6f}")
|
|
|
|
# 5. Test the fused compress+quant path with diagnostic prints
|
|
print("\n=== Test 4: Fused CSA compress+quant — value comparison ===")
|
|
hd = 512; m = 4; T = 32; kv_dim = 2*hd
|
|
kv_proj = torch.randn(T, kv_dim, device=device) * 0.5
|
|
gate_proj = torch.randn(T, kv_dim, device=device) * 0.3
|
|
position_bias = torch.randn(m, kv_dim, device=device) * 0.1
|
|
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
|
|
|
|
from dsv4.kernels.compressor.production_compress import csa_compress_production, csa_compress_production_nvfp4
|
|
ref = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
|
fp4, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
|
deq = dequant_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa)
|
|
|
|
cos4 = torch.nn.functional.cosine_similarity(ref.float().flatten(), deq.float().flatten(), dim=0).item()
|
|
print(f" cos(ref, deq): {cos4:.6f}")
|
|
print(f" |ref|_max: {ref.abs().max().item():.4f}, |deq|_max: {deq.abs().max().item():.4f}")
|
|
|
|
# Compare with Python dequant of the fused output
|
|
raw = fp4.view(torch.uint8)
|
|
lo = raw & 0x0F; hi = (raw >> 4) & 0x0F
|
|
lo_mag = E2M1[lo & 0x07].to(device) * torch.where(lo & 0x08, -1., 1.)
|
|
hi_mag = E2M1[hi & 0x07].to(device) * torch.where(hi & 0x08, -1., 1.)
|
|
py_deq4 = torch.stack([lo_mag, hi_mag], -1).reshape(ref.shape).float()
|
|
py_deq4 = py_deq4 * sf.float().repeat_interleave(16, -1) * gsa.unsqueeze(-1)
|
|
cos4p = torch.nn.functional.cosine_similarity(ref.float().flatten(), py_deq4.flatten(), dim=0).item()
|
|
print(f" cos(ref, py_deq_of_fused): {cos4p:.6f}")
|
|
|
|
# Now test: quantize the BF16 ref with the standard quantizer and dequant
|
|
fp4_ref, sf_ref, gsa_ref = quantize_nvfp4_gpu_fused(ref)
|
|
deq_ref = dequant_mod.dequant_nvfp4(fp4_ref.view(torch.uint8), sf_ref.view(torch.uint8), gsa_ref)
|
|
cos4r = torch.nn.functional.cosine_similarity(ref.float().flatten(), deq_ref.float().flatten(), dim=0).item()
|
|
print(f" cos(ref, standard_quant_dequant): {cos4r:.6f}")
|
|
|
|
# Print per-block gsa comparison
|
|
print(f" Fused gsa: {gsa.tolist()}")
|
|
print(f" Standard gsa: {gsa_ref.tolist()}")
|
|
|
|
# Print some raw values
|
|
print(f"\n First 10 ref values (row 0): {ref[0, :10].float().tolist()}")
|
|
print(f" First 10 deq values (row 0): {deq[0, :10].float().tolist()}")
|
|
print(f" First 10 py_deq values (row 0): {py_deq4[0, :10].float().tolist()}")
|