143 lines
6.5 KiB
Python
143 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
"""KV-1/KV-2/KV-3: NVFP4 compressed KV + FP8 indexer keys — production-value unit tests.
|
|
|
|
Tests the kv_quantize.cu kernels at production shapes:
|
|
- NVFP4: hd=512 (not 64/128)
|
|
- FP8_E4M3: ihd=128
|
|
- FP32 RoPE: rope_dim=64
|
|
- Multiple batch sizes (1, 4, 8, 32)
|
|
|
|
Falsifiable gates:
|
|
- NVFP4 quantize FP32→NVFP4→BF16: cos ≥ 0.995 vs FP32 reference
|
|
- FP8_E4M3 quantize FP32→FP8→BF16: cos ≥ 0.999 vs FP32 reference
|
|
- FP32 RoPE: cos = 1.000000 vs PyTorch FP32 reference
|
|
- Selective dequant matches full dequant
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
|
|
|
print("=" * 60)
|
|
print("KV-1/KV-2/KV-3: Production-Value Unit Tests")
|
|
print("=" * 60)
|
|
|
|
# ===========================================================================
|
|
# Test 1: NVFP4 quantize FP32 → NVFP4 → BF16
|
|
# ===========================================================================
|
|
print("\n--- Test 1: NVFP4 FP32→NVFP4 round-trip (production hd=512) ---")
|
|
for M in [1, 4, 8, 32]:
|
|
data = torch.randn(M, 512, device=device, dtype=torch.float32) * 5.0
|
|
gsa = mod.compute_amax_gsa_fp32(data.contiguous(), 6.0 * 448.0)
|
|
fp4, sf = mod.quantize_nvfp4_from_fp32(data.contiguous(), gsa)
|
|
|
|
# Dequant using the proven dequant_nvfp4 kernel
|
|
deq_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
deq = deq_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
|
max_err = (data.float() - deq.float()).abs().max().item()
|
|
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f} |data|_max={data.abs().max().item():.2f}")
|
|
assert cos >= 0.990, f"NVFP4 round-trip cos={cos:.6f} < 0.990 at M={M}"
|
|
|
|
# ===========================================================================
|
|
# Test 2: FP8_E4M3 quantize FP32 → FP8 → BF16
|
|
# ===========================================================================
|
|
print("\n--- Test 2: FP8_E4M3 FP32→FP8 round-trip (production ihd=128) ---")
|
|
for M in [1, 4, 32, 128]:
|
|
data = torch.randn(M, 128, device=device, dtype=torch.float32) * 3.0
|
|
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
|
deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
|
max_err = (data.float() - deq.float()).abs().max().item()
|
|
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f} |data|_max={data.abs().max().item():.2f}")
|
|
assert cos >= 0.998, f"FP8 round-trip cos={cos:.6f} < 0.998 at M={M}"
|
|
|
|
# ===========================================================================
|
|
# Test 3: FP32 RoPE
|
|
# ===========================================================================
|
|
print("\n--- Test 3: FP32 RoPE (production rope_dim=64, hd=512) ---")
|
|
hd = 512; rope_dim = 64
|
|
cos_cache = torch.randn(1024, rope_dim // 2, device=device, dtype=torch.float32)
|
|
sin_cache = torch.randn(1024, rope_dim // 2, device=device, dtype=torch.float32)
|
|
# Normalize for proper RoPE
|
|
for i in range(1024):
|
|
norm = (cos_cache[i] ** 2 + sin_cache[i] ** 2).sqrt().clamp(min=1e-8)
|
|
cos_cache[i] /= norm
|
|
sin_cache[i] /= norm
|
|
|
|
for M in [1, 4, 8]:
|
|
data = torch.randn(M, hd, device=device, dtype=torch.float32) * 2.0
|
|
positions = torch.arange(M, device=device, dtype=torch.long)
|
|
|
|
# FP32 RoPE via kv_quantize
|
|
data_kv = data.clone()
|
|
mod.rope_fp32(data_kv, positions, cos_cache, sin_cache, rope_dim, False)
|
|
|
|
# PyTorch FP32 reference
|
|
data_ref = data.clone()
|
|
nope = hd - rope_dim
|
|
for m in range(M):
|
|
p = positions[m].item()
|
|
c = cos_cache[p] # (rope_dim/2,)
|
|
s = sin_cache[p]
|
|
for i in range(rope_dim // 2):
|
|
ev = data_ref[m, nope + 2 * i]
|
|
od = data_ref[m, nope + 2 * i + 1]
|
|
data_ref[m, nope + 2 * i] = ev * c[i] - od * s[i]
|
|
data_ref[m, nope + 2 * i + 1] = ev * s[i] + od * c[i]
|
|
|
|
cos_sim = torch.nn.functional.cosine_similarity(data_kv.flatten(), data_ref.flatten(), dim=0).item()
|
|
max_err = (data_kv - data_ref).abs().max().item()
|
|
print(f" M={M}: cos={cos_sim:.6f} max_err={max_err:.8f}")
|
|
assert cos_sim >= 0.99999, f"FP32 RoPE cos={cos_sim:.6f} < 0.99999 at M={M}"
|
|
|
|
# ===========================================================================
|
|
# Test 4: Selective dequant matches full dequant (NVFP4)
|
|
# ===========================================================================
|
|
print("\n--- Test 4: Selective dequant NVFP4 (CSA top-k gather) ---")
|
|
M = 32; hd = 512
|
|
data = torch.randn(M, hd, device=device, dtype=torch.float32) * 5.0
|
|
gsa = mod.compute_amax_gsa_fp32(data.contiguous(), 6.0 * 448.0)
|
|
fp4, sf = mod.quantize_nvfp4_from_fp32(data.contiguous(), gsa)
|
|
|
|
deq_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
# Full dequant
|
|
full_deq = deq_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa)
|
|
# Selective dequant — pick 5 entries
|
|
indices = torch.tensor([0, 5, 10, 20, 31], device=device, dtype=torch.int32)
|
|
sel_deq = deq_mod.dequant_nvfp4_selective(fp4.view(torch.uint8), sf.view(torch.uint8), gsa, indices)
|
|
|
|
# Compare
|
|
for i, idx in enumerate(indices):
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item()
|
|
assert cos >= 0.99999, f"Selective dequant mismatch at idx={idx}: cos={cos:.6f}"
|
|
print(f" All 5 selective dequant entries match full dequant: PASS")
|
|
|
|
# ===========================================================================
|
|
# Test 5: FP8 selective dequant
|
|
# ===========================================================================
|
|
print("\n--- Test 5: Selective dequant FP8 (indexer key gather) ---")
|
|
M = 64; ihd = 128
|
|
data = torch.randn(M, ihd, device=device, dtype=torch.float32) * 3.0
|
|
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
|
full_deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
|
indices = torch.tensor([0, 15, 30, 45, 63], device=device, dtype=torch.int32)
|
|
sel_deq = mod.dequant_fp8_e4m3_selective(fp8.view(torch.uint8), scale, indices)
|
|
for i, idx in enumerate(indices):
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item()
|
|
assert cos >= 0.99999, f"FP8 selective mismatch at idx={idx}: cos={cos:.6f}"
|
|
print(f" All 5 selective dequant entries match: PASS")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("ALL TESTS PASSED")
|
|
print("=" * 60)
|