Files
nvfp4-megamoe-kernel/tests/unit/test_kv_compress_quant.py
biondizzle f23320b5b2 KV-1/KV-2: Fused compress+NVFP4 quantize kernels + dequant
- compressor_reduce_quant.cu: Single-kernel CSA/HCA compress + RMSNorm + NVFP4 quantize.
  No intermediate BF16. FP32 → E2M1 + E4M3 + FP32 gsa in one kernel.
  Shared memory: ~2.5KB per CTA (FP32 staging + nibble buffer).

- dequant_nvfp4.cu: NVFP4 → BF16 dequantization kernels.
  Full dequant (HCA dense gather) and selective dequant (CSA top-k gather).
  Single kernel launch per gather operation.

- production_compress.py: Added csa_compress_production_nvfp4() and
  hca_compress_production_nvfp4() — production path for KV-1/KV-2.

- loader.py: Preload dequant_nvfp4 and compressor_reduce_quant modules.

- test_kv_compress_quant.py: Unit tests verifying cos >= 0.999
  between BF16 reference and NVFP4 round-trip path.
2026-06-02 09:37:53 +00:00

171 lines
6.0 KiB
Python

#!/usr/bin/env python3
"""Test KV-1/KV-2: Fused compress + NVFP4 quantize kernel.
Verifies that the single-kernel compress+quantize path produces output
with cos >= 0.999 vs the BF16 reference path.
Production values:
- hd=512, m=4 (CSA), m=128 (HCA)
- T=32 (CSA: 8 blocks), T=256 (HCA: 2 blocks)
- kv_dim=1024 (CSA: 2*hd), kv_dim=512 (HCA: hd)
"""
import torch
import math
def test_csa_compress_quant():
"""KV-1: CSA compress + NVFP4 quantize vs BF16 reference."""
torch.manual_seed(42)
device = 'cuda'
hd = 512
m = 4
T = 32 # 8 compressed blocks
kv_dim = 2 * hd # CSA uses 2*hd for kv/gate projections
kv_proj = torch.randn(T, kv_dim, device=device) * 0.5
gate_proj = torch.randn(T, kv_dim, device=device) * 0.3
position_bias = torch.randn(m, kv_dim, device=device) * 0.1
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
# BF16 reference path
from dsv4.kernels.compressor.production_compress import csa_compress_production
ref_bf16 = csa_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# NVFP4 fused path
from dsv4.kernels.compressor.production_compress import csa_compress_production_nvfp4
fp4_data, sf, gsa = csa_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# Dequant NVFP4 → BF16
from dsv4.kernels.cuda.loader import get_cuda_module
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
fp4_data.view(torch.uint8),
sf.view(torch.uint8),
gsa,
)
# Compare
ref_f = ref_bf16.float()
nvfp4_f = nvfp4_bf16.float()
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
max_err = (ref_f - nvfp4_f).abs().max().item()
ref_max = ref_f.abs().max().item()
print(f"CSA compress + NVFP4 quantize:")
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
print(f" fp4 shape: {tuple(fp4_data.shape)}, sf shape: {tuple(sf.shape)}, gsa shape: {tuple(gsa.shape)}")
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
print(f" max_error: {max_err:.6f}")
print(f" cosine: {cos:.6f}")
assert cos >= 0.999, f"CSA compress+quant cos={cos:.6f} < 0.999"
print(f" ✅ PASS (cos={cos:.6f})")
def test_hca_compress_quant():
"""KV-2: HCA compress + NVFP4 quantize vs BF16 reference."""
torch.manual_seed(42)
device = 'cuda'
hd = 512
m = 128
T = 256 # 2 compressed blocks
kv_proj = torch.randn(T, hd, device=device) * 0.5
gate_proj = torch.randn(T, hd, device=device) * 0.3
position_bias = torch.randn(m, hd, device=device) * 0.1
kv_norm_weight = torch.randn(hd, device=device).abs() + 0.5
# BF16 reference path
from dsv4.kernels.compressor.production_compress import hca_compress_production
ref_bf16 = hca_compress_production(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# NVFP4 fused path
from dsv4.kernels.compressor.production_compress import hca_compress_production_nvfp4
fp4_data, sf, gsa = hca_compress_production_nvfp4(kv_proj.float(), gate_proj.float(), position_bias, kv_norm_weight, m=m)
# Dequant NVFP4 → BF16
from dsv4.kernels.cuda.loader import get_cuda_module
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
nvfp4_bf16 = dequant_mod.dequant_nvfp4(
fp4_data.view(torch.uint8),
sf.view(torch.uint8),
gsa,
)
# Compare
ref_f = ref_bf16.float()
nvfp4_f = nvfp4_bf16.float()
cos = torch.nn.functional.cosine_similarity(ref_f.flatten(), nvfp4_f.flatten(), dim=0).item()
max_err = (ref_f - nvfp4_f).abs().max().item()
ref_max = ref_f.abs().max().item()
print(f"HCA compress + NVFP4 quantize:")
print(f" ref shape: {tuple(ref_bf16.shape)}, nvfp4 shape: {tuple(nvfp4_bf16.shape)}")
print(f" |ref|_max: {ref_max:.4f}, |nvfp4|_max: {nvfp4_f.abs().max().item():.4f}")
print(f" max_error: {max_err:.6f}")
print(f" cosine: {cos:.6f}")
assert cos >= 0.999, f"HCA compress+quant cos={cos:.6f} < 0.999"
print(f" ✅ PASS (cos={cos:.6f})")
def test_dequant_selective():
"""Test selective dequant: only top-k entries from a larger FP4 buffer."""
torch.manual_seed(42)
device = 'cuda'
M = 64 # total entries in cache
N = 512 # hd
K = 8 # top-k
# Create BF16 data
bf16_data = torch.randn(M, N, device=device, dtype=torch.bfloat16) * 2.0
# Quantize to NVFP4
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
fp4, sf, gsa = quantize_nvfp4_gpu_fused(bf16_data)
# Select K random indices
indices = torch.randperm(M, device=device)[:K].to(torch.int32)
# Selective dequant
from dsv4.kernels.cuda.loader import get_cuda_module
dequant_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
sel_bf16 = dequant_mod.dequant_nvfp4_selective(
fp4.view(torch.uint8),
sf.view(torch.uint8),
gsa,
indices,
)
# Full dequant for comparison
full_bf16 = dequant_mod.dequant_nvfp4(
fp4.view(torch.uint8),
sf.view(torch.uint8),
gsa,
)
# Compare selected entries
ref = full_bf16[indices.cpu().numpy()].to(device)
cos = torch.nn.functional.cosine_similarity(sel_bf16.float().flatten(), ref.float().flatten(), dim=0).item()
print(f"Selective dequant (M={M}, K={K}, N={N}):")
print(f" sel shape: {tuple(sel_bf16.shape)}")
print(f" cosine vs full dequant: {cos:.6f}")
assert cos >= 0.9999, f"Selective dequant cos={cos:.6f} < 0.9999"
print(f" ✅ PASS (cos={cos:.6f})")
if __name__ == "__main__":
print("=" * 60)
print("KV-1/KV-2: Compress + NVFP4 Quantize Tests")
print("Production values: hd=512, m=4 (CSA), m=128 (HCA)")
print("=" * 60)
test_csa_compress_quant()
print()
test_hca_compress_quant()
print()
test_dequant_selective()
print("\n" + "=" * 60)
print("ALL TESTS PASSED")
print("=" * 60)