#!/usr/bin/env python3 """Test KV-1/KV-2: Fused compress + NVFP4 quantize kernel. Verifies that the single-kernel compress+quantize path produces output with cos >= 0.999 vs the BF16 reference path. Production values: - hd=512, m=4 (CSA), m=128 (HCA) - T=32 (CSA: 8 blocks), T=256 (HCA: 2 blocks) - kv_dim=1024 (CSA: 2*hd), kv_dim=512 (HCA: hd) """ import torch import math def test_csa_compress_quant(): """KV-1: CSA compress + NVFP4 quantize vs BF16 reference.""" torch.manual_seed(42) device = 'cuda' hd = 512 m = 4 T = 32 # 8 compressed blocks kv_dim = 2 * hd # CSA uses 2*hd for kv/gate projections kv_proj = torch.randn(T, kv_dim, device=device) * 0.5 gate_proj = torch.randn(T, kv_dim, device=device) * 0.3 position_bias = torch.randn(m, kv_dim, device=device) * 0.1 kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5 # BF16 reference path from dsv4.kernels.compressor.production_compress import csa_compress_production ref_bf16 = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) # NVFP4 fused path from dsv4.kernels.compressor.production_compress import csa_compress_production_nvfp4 fp4_data, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) # Dequant NVFP4 → BF16 from dsv4.kernels.cuda.loader import get_cuda_module dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) nvfp4_bf16 = dequant_mod.dequant_nvfp4( fp4_data.view(torch.uint8), sf.view(torch.uint8), gsa, ) # Compare ref_f = ref_bf16.float() nvfp4_f = nvfp4_bf16.float() cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item() max_err = (ref_f - nvfp4_f).abs().max().item() ref_max = ref_f.abs().max().item() print(f"CSA compress + NVFP4 quantize:") print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}") print(f" fp4 shape: {tuple(fp4_data.shape)}, sf shape: {tuple(sf.shape)}, gsa shape: {tuple(gsa.shape)}") print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}") print(f" max_error: {max_err:.6f}") print(f" cosine: {cos:.6f}") assert cos >= 0.999, f"CSA compress+quant cos={cos:.6f} < 0.999" print(f" ✅ PASS (cos={cos:.6f})") def test_hca_compress_quant(): """KV-2: HCA compress + NVFP4 quantize vs BF16 reference.""" torch.manual_seed(42) device = 'cuda' hd = 512 m = 128 T = 256 # 2 compressed blocks kv_proj = torch.randn(T, hd, device=device) * 0.5 gate_proj = torch.randn(T, hd, device=device) * 0.3 position_bias = torch.randn(m, hd, device=device) * 0.1 kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5 # BF16 reference path from dsv4.kernels.compressor.production_compress import hca_compress_production ref_bf16 = hca_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) # NVFP4 fused path from dsv4.kernels.compressor.production_compress import hca_compress_production_nvfp4 fp4_data, sf, gsa = hca_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m) # Dequant NVFP4 → BF16 from dsv4.kernels.cuda.loader import get_cuda_module dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) nvfp4_bf16 = dequant_mod.dequant_nvfp4( fp4_data.view(torch.uint8), sf.view(torch.uint8), gsa, ) # Compare ref_f = ref_bf16.float() nvfp4_f = nvfp4_bf16.float() cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item() max_err = (ref_f - nvfp4_f).abs().max().item() ref_max = ref_f.abs().max().item() print(f"HCA compress + NVFP4 quantize:") print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}") print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}") print(f" max_error: {max_err:.6f}") print(f" cosine: {cos:.6f}") assert cos >= 0.999, f"HCA compress+quant cos={cos:.6f} < 0.999" print(f" ✅ PASS (cos={cos:.6f})") def test_dequant_selective(): """Test selective dequant: only top-k entries from a larger FP4 buffer.""" torch.manual_seed(42) device = 'cuda' M = 64 # total entries in cache N = 512 # hd K = 8 # top-k # Create BF16 data bf16_data = torch.randn(M, N, device=device, dtype=torch.bfloat16) * 2.0 # Quantize to NVFP4 from dsv4.ops.quantize import quantize_nvfp4_gpu_fused fp4, sf, gsa = quantize_nvfp4_gpu_fused(bf16_data) # Select K random indices indices = torch.randperm(M, device=device)[:K].to(torch.int32) # Selective dequant from dsv4.kernels.cuda.loader import get_cuda_module dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) sel_bf16 = dequant_mod.dequant_nvfp4_selective( fp4.view(torch.uint8), sf.view(torch.uint8), gsa, indices, ) # Full dequant for comparison full_bf16 = dequant_mod.dequant_nvfp4( fp4.view(torch.uint8), sf.view(torch.uint8), gsa, ) # Compare selected entries ref = full_bf16[indices.cpu().numpy()].to(device) cos = torch.nn.functional.cosine_similarity(sel_bf16.float().flatten(), ref.float().flatten(), dim=0).item() print(f"Selective dequant (M={M}, K={K}, N={N}):") print(f" sel shape: {tuple(sel_bf16.shape)}") print(f" cosine vs full dequant: {cos:.6f}") assert cos >= 0.9999, f"Selective dequant cos={cos:.6f} < 0.9999" print(f" ✅ PASS (cos={cos:.6f})") if __name__ == "__main__": print("=" * 60) print("KV-1/KV-2: Compress + NVFP4 Quantize Tests") print("Production values: hd=512, m=4 (CSA), m=128 (HCA)") print("=" * 60) test_csa_compress_quant() print() test_hca_compress_quant() print() test_dequant_selective() print("\n" + "=" * 60) print("ALL TESTS PASSED") print("=" * 60)