Architecture matches paper: 'BF16 for RoPE dims, FP8 for remaining dims' - Non-RoPE dims (448 of 512): FP8_E4M3 storage → dequant to BF16 for FMHA - RoPE dims (64 of 512): BF16 storage (RoPE applied directly, no conversion) - Indexer keys: FP8_E4M3 (ihd=128, no RoPE) - SWA: BF16 (unchanged) Pipeline: Compressor → FP32 → split → [nope: FP32→FP8] + [rope: FP32→BF16→RoPE] Gather: [nope: FP8→BF16] + [rope: BF16] → concat → FMHA No BF16 intermediate for non-RoPE data. No FP32 intermediate after BF16 RoPE. BF16 is the final format consumed by FMHA (no further conversion). KVCache rewritten: - comp_nope_fp8/scale: FP8 storage for non-RoPE - comp_rope_bf16: BF16 storage for RoPE - comp_nope_selective/all: FP8→BF16 dequant - comp_rope_selective/all: BF16 gather - set_compressed_mixed: write mixed format - set_indexer_keys_fp8: write FP8 indexer keys
115 lines
5.3 KiB
Python
115 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
|
"""KV-1/KV-2/KV-3: Mixed FP8+BF16 compressed KV + FP8 indexer keys — production-value unit tests.
|
|
|
|
Tests the kv_quantize.cu kernels at production shapes:
|
|
- FP8_E4M3: nope_dim=448, ihd=128
|
|
- BF16 RoPE: rope_dim=64
|
|
- Mixed storage: FP8 nope + BF16 rope → concat → compare with FP32 reference
|
|
|
|
Falsifiable gates:
|
|
- FP8_E4M3 quantize FP32→FP8→BF16: cos ≥ 0.998 vs FP32 reference
|
|
- Mixed storage round-trip: FP32 → (FP8 nope + BF16 rope) → BF16 concat: cos ≥ 0.998
|
|
- Selective dequant matches full dequant
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
|
|
|
print("=" * 60)
|
|
print("KV-1/KV-2/KV-3: Mixed FP8+BF16 Storage — Unit Tests")
|
|
print("=" * 60)
|
|
|
|
hd = 512; rope_dim = 64; nope_dim = hd - rope_dim # 448
|
|
|
|
# ===========================================================================
|
|
# Test 1: FP8_E4M3 nope round-trip (production nope_dim=448)
|
|
# ===========================================================================
|
|
print("\n--- Test 1: FP8_E4M3 nope FP32→FP8→BF16 (nope_dim=448) ---")
|
|
for M in [1, 4, 8, 32, 128]:
|
|
data = torch.randn(M, nope_dim, device=device, dtype=torch.float32) * 3.0
|
|
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
|
deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
|
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
|
max_err = (data.float() - deq.float()).abs().max().item()
|
|
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f}")
|
|
assert cos >= 0.998, f"FP8 nope round-trip cos={cos:.6f} < 0.998 at M={M}"
|
|
|
|
# ===========================================================================
|
|
# Test 2: FP8_E4M3 indexer keys (production ihd=128)
|
|
# ===========================================================================
|
|
print("\n--- Test 2: FP8_E4M3 indexer keys (ihd=128) ---")
|
|
for M in [1, 4, 32, 128]:
|
|
data = torch.randn(M, 128, device=device, dtype=torch.float32) * 3.0
|
|
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
|
deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
|
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
|
print(f" M={M:3d}: cos={cos:.6f}")
|
|
assert cos >= 0.998, f"FP8 indexer cos={cos:.6f} < 0.998 at M={M}"
|
|
|
|
# ===========================================================================
|
|
# Test 3: Mixed storage round-trip (FP8 nope + BF16 rope → concat)
|
|
# ===========================================================================
|
|
print("\n--- Test 3: Mixed FP8+BF16 full round-trip (hd=512) ---")
|
|
# Build proper RoPE cache
|
|
theta = 10000.0
|
|
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
|
angles = torch.outer(torch.arange(1024, dtype=torch.float32), freqs)
|
|
cos_cache = torch.cos(angles).to(device)
|
|
sin_cache = torch.sin(angles).to(device)
|
|
|
|
from dsv4.ops.rope_cuda import apply_rope
|
|
|
|
for M in [1, 4, 8, 32]:
|
|
# Simulate compressor FP32 output
|
|
data_fp32 = torch.randn(M, hd, device=device, dtype=torch.float32) * 3.0
|
|
positions = torch.arange(M, device=device, dtype=torch.long)
|
|
|
|
# Reference: FP32 → BF16 → RoPE → full BF16
|
|
ref_bf16 = data_fp32.bfloat16()
|
|
ref_3d = ref_bf16.unsqueeze(1) # (M, 1, hd)
|
|
ref_3d = apply_rope(ref_3d, positions, cos_cache, sin_cache, rope_dim)
|
|
ref_full = ref_3d.squeeze(1) # (M, hd) BF16
|
|
|
|
# Our path: FP32 → split → FP8 nope + BF16 rope (with RoPE) → concat
|
|
nope_fp32 = data_fp32[:, :nope_dim].contiguous()
|
|
rope_bf16 = data_fp32[:, nope_dim:].bfloat16().contiguous()
|
|
rope_3d = rope_bf16.unsqueeze(1)
|
|
rope_3d = apply_rope(rope_3d, positions, cos_cache, sin_cache, rope_dim)
|
|
rope_bf16 = rope_3d.squeeze(1)
|
|
|
|
nope_fp8, nope_scale = mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
|
nope_bf16 = mod.dequant_fp8_e4m3(nope_fp8.view(torch.uint8), nope_scale)
|
|
|
|
# Concat nope + rope
|
|
result = torch.cat([nope_bf16, rope_bf16], dim=1) # (M, hd) BF16
|
|
|
|
cos = torch.nn.functional.cosine_similarity(ref_full.float().flatten(), result.float().flatten(), dim=0).item()
|
|
max_err = (ref_full.float() - result.float()).abs().max().item()
|
|
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f}")
|
|
assert cos >= 0.998, f"Mixed storage cos={cos:.6f} < 0.998 at M={M}"
|
|
|
|
# ===========================================================================
|
|
# Test 4: Selective dequant (CSA top-k gather)
|
|
# ===========================================================================
|
|
print("\n--- Test 4: Selective FP8 dequant (CSA top-k gather) ---")
|
|
M = 32; data = torch.randn(M, nope_dim, device=device, dtype=torch.float32) * 3.0
|
|
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
|
full_deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
|
indices = torch.tensor([0, 5, 10, 20, 31], device=device, dtype=torch.int32)
|
|
sel_deq = mod.dequant_fp8_e4m3_selective(fp8.view(torch.uint8), scale, indices)
|
|
for i, idx in enumerate(indices):
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item()
|
|
assert cos >= 0.99999, f"Selective mismatch at idx={idx}: cos={cos:.6f}"
|
|
print(f" All 5 selective dequant entries match full dequant: PASS")
|
|
|
|
print("\n" + "=" * 60)
|
|
print("ALL TESTS PASSED")
|
|
print("=" * 60)
|