- compressor_reduce_quant.cu: Single-kernel CSA/HCA compress + RMSNorm + NVFP4 quantize. No intermediate BF16. FP32 → E2M1 + E4M3 + FP32 gsa in one kernel. Shared memory: ~2.5KB per CTA (FP32 staging + nibble buffer). - dequant_nvfp4.cu: NVFP4 → BF16 dequantization kernels. Full dequant (HCA dense gather) and selective dequant (CSA top-k gather). Single kernel launch per gather operation. - production_compress.py: Added csa_compress_production_nvfp4() and hca_compress_production_nvfp4() — production path for KV-1/KV-2. - loader.py: Preload dequant_nvfp4 and compressor_reduce_quant modules. - test_kv_compress_quant.py: Unit tests verifying cos >= 0.999 between BF16 reference and NVFP4 round-trip path.
171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Test KV-1/KV-2: Fused compress + NVFP4 quantize kernel.
|
|
|
|
Verifies that the single-kernel compress+quantize path produces output
|
|
with cos >= 0.999 vs the BF16 reference path.
|
|
|
|
Production values:
|
|
- hd=512, m=4 (CSA), m=128 (HCA)
|
|
- T=32 (CSA: 8 blocks), T=256 (HCA: 2 blocks)
|
|
- kv_dim=1024 (CSA: 2*hd), kv_dim=512 (HCA: hd)
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
def test_csa_compress_quant():
|
|
"""KV-1: CSA compress + NVFP4 quantize vs BF16 reference."""
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
hd = 512
|
|
m = 4
|
|
T = 32 # 8 compressed blocks
|
|
kv_dim = 2 * hd # CSA uses 2*hd for kv/gate projections
|
|
|
|
kv_proj = torch.randn(T, kv_dim, device=device) * 0.5
|
|
gate_proj = torch.randn(T, kv_dim, device=device) * 0.3
|
|
position_bias = torch.randn(m, kv_dim, device=device) * 0.1
|
|
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
|
|
|
|
# BF16 reference path
|
|
from dsv4.kernels.compressor.production_compress import csa_compress_production
|
|
ref_bf16 = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
|
|
|
# NVFP4 fused path
|
|
from dsv4.kernels.compressor.production_compress import csa_compress_production_nvfp4
|
|
fp4_data, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
|
|
|
# Dequant NVFP4 → BF16
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
|
|
fp4_data.view(torch.uint8),
|
|
sf.view(torch.uint8),
|
|
gsa,
|
|
)
|
|
|
|
# Compare
|
|
ref_f = ref_bf16.float()
|
|
nvfp4_f = nvfp4_bf16.float()
|
|
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
|
|
max_err = (ref_f - nvfp4_f).abs().max().item()
|
|
ref_max = ref_f.abs().max().item()
|
|
|
|
print(f"CSA compress + NVFP4 quantize:")
|
|
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
|
|
print(f" fp4 shape: {tuple(fp4_data.shape)}, sf shape: {tuple(sf.shape)}, gsa shape: {tuple(gsa.shape)}")
|
|
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
|
|
print(f" max_error: {max_err:.6f}")
|
|
print(f" cosine: {cos:.6f}")
|
|
assert cos >= 0.999, f"CSA compress+quant cos={cos:.6f} < 0.999"
|
|
print(f" ✅ PASS (cos={cos:.6f})")
|
|
|
|
|
|
def test_hca_compress_quant():
|
|
"""KV-2: HCA compress + NVFP4 quantize vs BF16 reference."""
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
hd = 512
|
|
m = 128
|
|
T = 256 # 2 compressed blocks
|
|
|
|
kv_proj = torch.randn(T, hd, device=device) * 0.5
|
|
gate_proj = torch.randn(T, hd, device=device) * 0.3
|
|
position_bias = torch.randn(m, hd, device=device) * 0.1
|
|
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
|
|
|
|
# BF16 reference path
|
|
from dsv4.kernels.compressor.production_compress import hca_compress_production
|
|
ref_bf16 = hca_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
|
|
|
# NVFP4 fused path
|
|
from dsv4.kernels.compressor.production_compress import hca_compress_production_nvfp4
|
|
fp4_data, sf, gsa = hca_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
|
|
|
|
# Dequant NVFP4 → BF16
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
|
|
fp4_data.view(torch.uint8),
|
|
sf.view(torch.uint8),
|
|
gsa,
|
|
)
|
|
|
|
# Compare
|
|
ref_f = ref_bf16.float()
|
|
nvfp4_f = nvfp4_bf16.float()
|
|
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
|
|
max_err = (ref_f - nvfp4_f).abs().max().item()
|
|
ref_max = ref_f.abs().max().item()
|
|
|
|
print(f"HCA compress + NVFP4 quantize:")
|
|
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
|
|
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
|
|
print(f" max_error: {max_err:.6f}")
|
|
print(f" cosine: {cos:.6f}")
|
|
assert cos >= 0.999, f"HCA compress+quant cos={cos:.6f} < 0.999"
|
|
print(f" ✅ PASS (cos={cos:.6f})")
|
|
|
|
|
|
def test_dequant_selective():
|
|
"""Test selective dequant: only top-k entries from a larger FP4 buffer."""
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
M = 64 # total entries in cache
|
|
N = 512 # hd
|
|
K = 8 # top-k
|
|
|
|
# Create BF16 data
|
|
bf16_data = torch.randn(M, N, device=device, dtype=torch.bfloat16) * 2.0
|
|
|
|
# Quantize to NVFP4
|
|
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
|
fp4, sf, gsa = quantize_nvfp4_gpu_fused(bf16_data)
|
|
|
|
# Select K random indices
|
|
indices = torch.randperm(M, device=device)[:K].to(torch.int32)
|
|
|
|
# Selective dequant
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
sel_bf16 = dequant_mod.dequant_nvfp4_selective(
|
|
fp4.view(torch.uint8),
|
|
sf.view(torch.uint8),
|
|
gsa,
|
|
indices,
|
|
)
|
|
|
|
# Full dequant for comparison
|
|
full_bf16 = dequant_mod.dequant_nvfp4(
|
|
fp4.view(torch.uint8),
|
|
sf.view(torch.uint8),
|
|
gsa,
|
|
)
|
|
|
|
# Compare selected entries
|
|
ref = full_bf16[indices.cpu().numpy()].to(device)
|
|
cos = torch.nn.functional.cosine_similarity(sel_bf16.float().flatten(), ref.float().flatten(), dim=0).item()
|
|
|
|
print(f"Selective dequant (M={M}, K={K}, N={N}):")
|
|
print(f" sel shape: {tuple(sel_bf16.shape)}")
|
|
print(f" cosine vs full dequant: {cos:.6f}")
|
|
assert cos >= 0.9999, f"Selective dequant cos={cos:.6f} < 0.9999"
|
|
print(f" ✅ PASS (cos={cos:.6f})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=" * 60)
|
|
print("KV-1/KV-2: Compress + NVFP4 Quantize Tests")
|
|
print("Production values: hd=512, m=4 (CSA), m=128 (HCA)")
|
|
print("=" * 60)
|
|
|
|
test_csa_compress_quant()
|
|
print()
|
|
test_hca_compress_quant()
|
|
print()
|
|
test_dequant_selective()
|
|
|
|
print("\n" + "=" * 60)
|
|
print("ALL TESTS PASSED")
|
|
print("=" * 60)
|