KV-1/KV-2: Mixed FP8+BF16 compressed KV (DeepSeek V4 paper format)
Architecture matches paper: 'BF16 for RoPE dims, FP8 for remaining dims' - Non-RoPE dims (448 of 512): FP8_E4M3 storage → dequant to BF16 for FMHA - RoPE dims (64 of 512): BF16 storage (RoPE applied directly, no conversion) - Indexer keys: FP8_E4M3 (ihd=128, no RoPE) - SWA: BF16 (unchanged) Pipeline: Compressor → FP32 → split → [nope: FP32→FP8] + [rope: FP32→BF16→RoPE] Gather: [nope: FP8→BF16] + [rope: BF16] → concat → FMHA No BF16 intermediate for non-RoPE data. No FP32 intermediate after BF16 RoPE. BF16 is the final format consumed by FMHA (no further conversion). KVCache rewritten: - comp_nope_fp8/scale: FP8 storage for non-RoPE - comp_rope_bf16: BF16 storage for RoPE - comp_nope_selective/all: FP8→BF16 dequant - comp_rope_selective/all: BF16 gather - set_compressed_mixed: write mixed format - set_indexer_keys_fp8: write FP8 indexer keys
This commit is contained in:
@@ -444,37 +444,40 @@ class Indexer:
|
||||
# KV Cache
|
||||
# =====================================================================
|
||||
class KVCache:
|
||||
"""KV Cache with NVFP4 compressed KV and FP8_E4M3 indexer keys.
|
||||
"""KV Cache with mixed-precision compressed KV (DeepSeek V4 paper format).
|
||||
|
||||
KV-1/KV-2: Compressed KV is stored as NVFP4 (E2M1 + E4M3 + FP32 gsa).
|
||||
KV-3: Indexer keys are stored as FP8_E4M3 (1 byte + per-row scale).
|
||||
SWA: BF16 (only 128 tokens × 512 × 61 layers = 8MB, fits in L2).
|
||||
KV-1/KV-2: Compressed KV uses mixed storage:
|
||||
- Non-RoPE dims (448 of 512): FP8_E4M3 → ~50% size reduction
|
||||
- RoPE dims (64 of 512): BF16 (RoPE applied directly, stored as BF16)
|
||||
KV-3: Indexer keys stored as FP8_E4M3 (ihd=128, no RoPE).
|
||||
SWA: BF16 (128 tokens × 512 × 61 layers = 8MB, fits in L2).
|
||||
|
||||
Storage savings vs BF16:
|
||||
NVFP4: 0.5 bytes/val + 0.125 bytes/val (sf) + 4 bytes/row (gsa)
|
||||
= hd/2 + hd/16 + 1 scalars per entry
|
||||
= 256 + 32 + 1 = 289 bytes/entry at hd=512
|
||||
vs 1024 bytes/entry BF16 → 3.5× savings
|
||||
FP8_E4M3: 1 byte/val + 4 bytes/row (scale)
|
||||
= 128 + 4 = 132 bytes/entry at ihd=128
|
||||
vs 256 bytes/entry BF16 → 1.9× savings
|
||||
This matches the DeepSeek V4 paper: "BF16 for RoPE dims, FP8 for remaining dims.
|
||||
This hybrid representation reduces the KV cache size by nearly half."
|
||||
|
||||
Storage per compressed entry at hd=512:
|
||||
nope (448) × FP8 = 448 bytes + 4 bytes (scale) = 452
|
||||
rope (64) × BF16 = 128 bytes
|
||||
Total = 580 bytes vs 1024 bytes BF16 → 1.76× savings
|
||||
"""
|
||||
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0',
|
||||
indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024):
|
||||
indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024, rope_dim=64):
|
||||
self.hd, self.ws, self.dev = head_dim, window_size, device
|
||||
self.idx_key_dim = indexer_key_dim
|
||||
self.ratio = compress_ratio
|
||||
self.max_comp = max_comp
|
||||
self.rope_dim = rope_dim
|
||||
self.nope_dim = head_dim - rope_dim # 448
|
||||
|
||||
# SWA: BF16 (small, fits in L2)
|
||||
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
|
||||
self.swa_len, self.swa_head = 0, 0
|
||||
|
||||
# Compressed KV: NVFP4 storage
|
||||
self.comp_kv_fp4 = torch.zeros(max_comp, head_dim // 2, dtype=torch.uint8, device=device)
|
||||
self.comp_kv_sf = torch.zeros(max_comp, head_dim // 16, dtype=torch.uint8, device=device)
|
||||
self.comp_kv_gsa = torch.zeros(max_comp, dtype=torch.float32, device=device)
|
||||
# Compressed KV: mixed FP8 (nope) + BF16 (rope)
|
||||
self.comp_nope_fp8 = torch.zeros(max_comp, self.nope_dim, dtype=torch.uint8, device=device)
|
||||
self.comp_nope_scale = torch.zeros(max_comp, dtype=torch.float32, device=device)
|
||||
self.comp_rope_bf16 = torch.zeros(max_comp, rope_dim, dtype=torch.bfloat16, device=device)
|
||||
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
|
||||
|
||||
# Indexer compressed keys: FP8_E4M3
|
||||
@@ -487,21 +490,12 @@ class KVCache:
|
||||
self._has_idx = False
|
||||
|
||||
# Cache dequant modules (loaded once)
|
||||
self._dequant_mod = None
|
||||
self._kv_quant_mod = None
|
||||
|
||||
def _get_dequant_mod(self):
|
||||
if self._dequant_mod is None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
self._dequant_mod = get_cuda_module(
|
||||
"dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
return self._dequant_mod
|
||||
|
||||
def _get_kv_quant_mod(self):
|
||||
if self._kv_quant_mod is None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
self._kv_quant_mod = get_cuda_module(
|
||||
"kv_quantize", ["kv_quantize.cu"])
|
||||
self._kv_quant_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
return self._kv_quant_mod
|
||||
|
||||
def append_swa(self, kv, pos):
|
||||
@@ -513,78 +507,52 @@ class KVCache:
|
||||
self.swa_head = (self.swa_head + T) % self.ws
|
||||
self.swa_len = min(self.swa_len + T, self.ws)
|
||||
|
||||
def add_compressed(self, ckv, cpos, idx_kv=None):
|
||||
"""Add compressed KV entries to NVFP4 cache.
|
||||
|
||||
ckv can be:
|
||||
- BF16 tensor (n_comp, hd) — will be quantized to NVFP4
|
||||
- NVFP4 triple (fp4, sf, gsa) — stored directly
|
||||
idx_kv can be:
|
||||
- BF16 tensor (n_comp, ihd) — will be quantized to FP8_E4M3
|
||||
- FP8 triple (fp8, scale) — stored directly
|
||||
"""
|
||||
if ckv is None: return
|
||||
def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16):
|
||||
"""Add compressed KV entries (mixed FP8 nope + BF16 rope)."""
|
||||
T = nope_fp8.shape[0]
|
||||
end = self.n_comp
|
||||
|
||||
# Handle compressed KV
|
||||
if isinstance(ckv, tuple) and len(ckv) == 3:
|
||||
# NVFP4 triple: (fp4, sf, gsa)
|
||||
fp4, sf, gsa = ckv
|
||||
T = fp4.shape[0]
|
||||
self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8) if fp4.dtype != torch.uint8 else fp4
|
||||
self.comp_kv_sf[end:end+T] = sf.view(torch.uint8) if sf.dtype != torch.uint8 else sf
|
||||
self.comp_kv_gsa[end:end+T] = gsa
|
||||
elif isinstance(ckv, torch.Tensor):
|
||||
# BF16 tensor — quantize to NVFP4 using proven two-kernel path
|
||||
T = ckv.shape[0]
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
fp4, sf, gsa = quantize_nvfp4_gpu_fused(ckv)
|
||||
self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8)
|
||||
self.comp_kv_sf[end:end+T] = sf.view(torch.uint8)
|
||||
self.comp_kv_gsa[end:end+T] = gsa
|
||||
else:
|
||||
raise ValueError(f"Unexpected ckv type: {type(ckv)}")
|
||||
|
||||
self.comp_pos_buf[end:end+ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0]] = cpos
|
||||
T = ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0]
|
||||
|
||||
# Handle indexer keys
|
||||
if idx_kv is not None:
|
||||
if isinstance(idx_kv, tuple) and len(idx_kv) == 2:
|
||||
# FP8 triple: (fp8, scale)
|
||||
fp8, scale = idx_kv
|
||||
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
|
||||
self.comp_idx_scale[end:end+T] = scale
|
||||
elif isinstance(idx_kv, torch.Tensor):
|
||||
# BF16 tensor — quantize to FP8_E4M3
|
||||
mod = self._get_kv_quant_mod()
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous())
|
||||
self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8)
|
||||
self.comp_idx_scale[end:end+T] = scale
|
||||
self._has_idx = True
|
||||
|
||||
self.comp_nope_fp8[end:end+T] = nope_fp8.view(torch.uint8) if nope_fp8.dtype != torch.uint8 else nope_fp8
|
||||
self.comp_nope_scale[end:end+T] = nope_scale
|
||||
self.comp_rope_bf16[end:end+T] = rope_bf16
|
||||
self.n_comp = end + T
|
||||
|
||||
@property
|
||||
def comp_kv(self):
|
||||
"""Dequantize NVFP4 → BF16 for FMHA. Returns (n_comp, hd) BF16."""
|
||||
if self.n_comp == 0: return None
|
||||
mod = self._get_dequant_mod()
|
||||
return mod.dequant_nvfp4(
|
||||
self.comp_kv_fp4[:self.n_comp],
|
||||
self.comp_kv_sf[:self.n_comp],
|
||||
self.comp_kv_gsa[:self.n_comp],
|
||||
)
|
||||
def set_indexer_keys_fp8(self, idx_kv):
|
||||
"""Add indexer compressed keys. idx_kv is BF16 (n_comp, ihd) or FP8 (fp8, scale)."""
|
||||
T = self.n_comp # should match compressed KV count
|
||||
end = T - (idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0])
|
||||
if isinstance(idx_kv, tuple) and len(idx_kv) == 2:
|
||||
fp8, scale = idx_kv
|
||||
self.comp_idx_fp8[end:end+T-end] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8
|
||||
self.comp_idx_scale[end:end+T-end] = scale
|
||||
elif isinstance(idx_kv, torch.Tensor):
|
||||
mod = self._get_kv_quant_mod()
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous())
|
||||
self.comp_idx_fp8[end:self.n_comp] = fp8.view(torch.uint8)
|
||||
self.comp_idx_scale[end:self.n_comp] = scale
|
||||
self._has_idx = True
|
||||
|
||||
def comp_kv_selective(self, indices):
|
||||
"""Dequantize selected NVFP4 entries → BF16 for CSA top-k gather."""
|
||||
mod = self._get_dequant_mod()
|
||||
return mod.dequant_nvfp4_selective(
|
||||
self.comp_kv_fp4,
|
||||
self.comp_kv_sf,
|
||||
self.comp_kv_gsa,
|
||||
indices.int(),
|
||||
)
|
||||
def comp_nope_selective(self, indices):
|
||||
"""Dequant FP8 nope for selected entries → BF16."""
|
||||
mod = self._get_kv_quant_mod()
|
||||
return mod.dequant_fp8_e4m3_selective(
|
||||
self.comp_nope_fp8, self.comp_nope_scale, indices.int())
|
||||
|
||||
def comp_rope_selective(self, indices):
|
||||
"""Gather BF16 rope for selected entries."""
|
||||
return self.comp_rope_bf16[indices.long()]
|
||||
|
||||
@property
|
||||
def comp_nope_all(self):
|
||||
"""Dequant all FP8 nope → BF16."""
|
||||
mod = self._get_kv_quant_mod()
|
||||
return mod.dequant_fp8_e4m3(
|
||||
self.comp_nope_fp8[:self.n_comp],
|
||||
self.comp_nope_scale[:self.n_comp])
|
||||
|
||||
@property
|
||||
def comp_rope_all(self):
|
||||
"""Return all BF16 rope entries."""
|
||||
return self.comp_rope_bf16[:self.n_comp]
|
||||
|
||||
@property
|
||||
def comp_pos(self):
|
||||
@@ -592,21 +560,19 @@ class KVCache:
|
||||
|
||||
@property
|
||||
def comp_idx_kv(self):
|
||||
"""Dequantize FP8 indexer keys → BF16 for scoring."""
|
||||
"""Dequant FP8 indexer keys → BF16 for scoring."""
|
||||
if not self._has_idx or self.n_comp == 0: return None
|
||||
mod = self._get_kv_quant_mod()
|
||||
return mod.dequant_fp8_e4m3(
|
||||
self.comp_idx_fp8[:self.n_comp],
|
||||
self.comp_idx_scale[:self.n_comp],
|
||||
)
|
||||
self.comp_idx_scale[:self.n_comp])
|
||||
|
||||
def get_swa(self):
|
||||
"""Return SWA KV and positions as views (no clone). Caller copies into gather_buf."""
|
||||
"""Return SWA KV and positions as views (no clone)."""
|
||||
if self.swa_len == 0:
|
||||
return self.swa[:0], self.swa_pos[:0]
|
||||
if self.swa_len < self.ws:
|
||||
return self.swa[:self.swa_len], self.swa_pos[:self.swa_len]
|
||||
# Ring buffer wrap — gather non-contiguous rows
|
||||
idx = torch.arange(self.swa_head, self.swa_head + self.ws, device=self.dev) % self.ws
|
||||
return self.swa[idx], self.swa_pos[idx]
|
||||
|
||||
@@ -691,26 +657,31 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
_pt('rope_kv_end')
|
||||
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
|
||||
|
||||
# 3. Compressor → compressed KV → FP32 RoPE → NVFP4
|
||||
# 3. Compressor → compressed KV (mixed storage: FP8 + BF16 RoPE)
|
||||
# DeepSeek V4 paper: "BF16 for RoPE dims, FP8 for remaining dims"
|
||||
_pt('compress_start')
|
||||
comp_nvfp4, comp_pos, block_bias = None, None, None; comp_idx_kv = None
|
||||
comp_pos, block_bias = None, None; comp_idx_kv = None
|
||||
if compressor is not None and compressor.ratio > 0:
|
||||
comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, positions)
|
||||
if comp_kv_fp32 is not None:
|
||||
# Apply RoPE on FP32 (no BF16 intermediate)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
comp_kv_fp32_contig = comp_kv_fp32.contiguous()
|
||||
c = rope_cos[comp_pos].contiguous()
|
||||
s = rope_sin[comp_pos].contiguous()
|
||||
kv_mod.rope_fp32(comp_kv_fp32_contig, comp_pos.contiguous(), c, s, rd, False)
|
||||
# Quantize FP32 → NVFP4 (two-kernel, proven pattern)
|
||||
gsa = kv_mod.compute_amax_gsa_fp32(comp_kv_fp32_contig, 6.0 * 448.0)
|
||||
fp4, sf = kv_mod.quantize_nvfp4_from_fp32(comp_kv_fp32_contig, gsa)
|
||||
comp_nvfp4 = (fp4, sf, gsa)
|
||||
nope_dim = hd - rd # 448
|
||||
# Split into non-RoPE (FP8) and RoPE (BF16) parts
|
||||
nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous() # (n_comp, 448) FP32
|
||||
rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous() # (n_comp, 64) BF16
|
||||
# Apply RoPE on BF16 rope dims (existing BF16 RoPE kernel)
|
||||
rope_3d = rope_bf16.unsqueeze(1) # (n_comp, 1, 64)
|
||||
rope_3d = _apply_rope(rope_3d, comp_pos, rope_cos, rope_sin, rd)
|
||||
rope_bf16 = rope_3d.squeeze(1) # (n_comp, 64) BF16
|
||||
# Quantize non-RoPE part FP32 → FP8_E4M3
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
# Store mixed-format compressed KV
|
||||
kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16)
|
||||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
|
||||
kv_cache.add_compressed(comp_nvfp4, comp_pos, comp_idx_kv)
|
||||
# Indexer keys: FP8_E4M3 (ihd=128, no RoPE)
|
||||
kv_cache.set_indexer_keys_fp8(comp_idx_kv)
|
||||
_pt('compress_end')
|
||||
|
||||
# 4. Indexer top-k (CSA)
|
||||
@@ -718,24 +689,31 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
|
||||
|
||||
# 5. Gather KV — NVFP4 dequant for compressed KV
|
||||
# 5. Gather KV — mixed storage: FP8 nope dequant + BF16 rope concat
|
||||
_pt('gather_start')
|
||||
swa_kv, _swa_pos = kv_cache.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated BF16
|
||||
gbuf = kv_cache.gather_buf # (max_len, hd) pre-allocated BF16
|
||||
if kv_cache.n_comp > 0:
|
||||
if ratio == 4:
|
||||
# CSA: dequant only top-k entries (bandwidth savings)
|
||||
# CSA: dequant only top-k entries
|
||||
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken"
|
||||
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int()
|
||||
n_tk = tk.shape[0]
|
||||
gbuf[:n_tk] = kv_cache.comp_kv_selective(tk) # NVFP4 → BF16
|
||||
# Dequant FP8 nope + gather BF16 rope for top-k
|
||||
nope_bf16 = kv_cache.comp_nope_selective(tk) # FP8→BF16 (n_tk, 448)
|
||||
rope_bf16 = kv_cache.comp_rope_selective(tk) # BF16 gather (n_tk, 64)
|
||||
gbuf[:n_tk, :nope_dim] = nope_bf16
|
||||
gbuf[:n_tk, nope_dim:] = rope_bf16
|
||||
gbuf[n_tk:n_tk + swa_len] = swa_kv
|
||||
all_kv = gbuf[:n_tk + swa_len]
|
||||
elif ratio > 4:
|
||||
# HCA: dequant all entries (dense gather)
|
||||
# HCA: dequant all entries
|
||||
n_comp = kv_cache.n_comp
|
||||
gbuf[:n_comp] = kv_cache.comp_kv # NVFP4 → BF16
|
||||
nope_bf16 = kv_cache.comp_nope_all # FP8→BF16 (n_comp, 448)
|
||||
rope_bf16 = kv_cache.comp_rope_all # BF16 (n_comp, 64)
|
||||
gbuf[:n_comp, :nope_dim] = nope_bf16
|
||||
gbuf[:n_comp, nope_dim:] = rope_bf16
|
||||
gbuf[n_comp:n_comp + swa_len] = swa_kv
|
||||
all_kv = gbuf[:n_comp + swa_len]
|
||||
else:
|
||||
@@ -1223,7 +1201,7 @@ def main():
|
||||
# C1: max_comp derived from target context and compress ratio
|
||||
max_comp = (max_ctx + ratio - 1) // ratio if ratio > 0 else 0
|
||||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp, device=dev,
|
||||
indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk)
|
||||
indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd)
|
||||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
"""KV-1/KV-2/KV-3: NVFP4 compressed KV + FP8 indexer keys — production-value unit tests.
|
||||
"""KV-1/KV-2/KV-3: Mixed FP8+BF16 compressed KV + FP8 indexer keys — production-value unit tests.
|
||||
|
||||
Tests the kv_quantize.cu kernels at production shapes:
|
||||
- NVFP4: hd=512 (not 64/128)
|
||||
- FP8_E4M3: ihd=128
|
||||
- FP32 RoPE: rope_dim=64
|
||||
- Multiple batch sizes (1, 4, 8, 32)
|
||||
- FP8_E4M3: nope_dim=448, ihd=128
|
||||
- BF16 RoPE: rope_dim=64
|
||||
- Mixed storage: FP8 nope + BF16 rope → concat → compare with FP32 reference
|
||||
|
||||
Falsifiable gates:
|
||||
- NVFP4 quantize FP32→NVFP4→BF16: cos ≥ 0.995 vs FP32 reference
|
||||
- FP8_E4M3 quantize FP32→FP8→BF16: cos ≥ 0.999 vs FP32 reference
|
||||
- FP32 RoPE: cos = 1.000000 vs PyTorch FP32 reference
|
||||
- FP8_E4M3 quantize FP32→FP8→BF16: cos ≥ 0.998 vs FP32 reference
|
||||
- Mixed storage round-trip: FP32 → (FP8 nope + BF16 rope) → BF16 concat: cos ≥ 0.998
|
||||
- Selective dequant matches full dequant
|
||||
"""
|
||||
|
||||
@@ -24,118 +22,92 @@ from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
|
||||
print("=" * 60)
|
||||
print("KV-1/KV-2/KV-3: Production-Value Unit Tests")
|
||||
print("KV-1/KV-2/KV-3: Mixed FP8+BF16 Storage — Unit Tests")
|
||||
print("=" * 60)
|
||||
|
||||
# ===========================================================================
|
||||
# Test 1: NVFP4 quantize FP32 → NVFP4 → BF16
|
||||
# ===========================================================================
|
||||
print("\n--- Test 1: NVFP4 FP32→NVFP4 round-trip (production hd=512) ---")
|
||||
for M in [1, 4, 8, 32]:
|
||||
data = torch.randn(M, 512, device=device, dtype=torch.float32) * 5.0
|
||||
gsa = mod.compute_amax_gsa_fp32(data.contiguous(), 6.0 * 448.0)
|
||||
fp4, sf = mod.quantize_nvfp4_from_fp32(data.contiguous(), gsa)
|
||||
|
||||
# Dequant using the proven dequant_nvfp4 kernel
|
||||
deq_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
deq = deq_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa)
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
||||
max_err = (data.float() - deq.float()).abs().max().item()
|
||||
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f} |data|_max={data.abs().max().item():.2f}")
|
||||
assert cos >= 0.990, f"NVFP4 round-trip cos={cos:.6f} < 0.990 at M={M}"
|
||||
hd = 512; rope_dim = 64; nope_dim = hd - rope_dim # 448
|
||||
|
||||
# ===========================================================================
|
||||
# Test 2: FP8_E4M3 quantize FP32 → FP8 → BF16
|
||||
# Test 1: FP8_E4M3 nope round-trip (production nope_dim=448)
|
||||
# ===========================================================================
|
||||
print("\n--- Test 2: FP8_E4M3 FP32→FP8 round-trip (production ihd=128) ---")
|
||||
print("\n--- Test 1: FP8_E4M3 nope FP32→FP8→BF16 (nope_dim=448) ---")
|
||||
for M in [1, 4, 8, 32, 128]:
|
||||
data = torch.randn(M, nope_dim, device=device, dtype=torch.float32) * 3.0
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
||||
deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
||||
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
||||
max_err = (data.float() - deq.float()).abs().max().item()
|
||||
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f}")
|
||||
assert cos >= 0.998, f"FP8 nope round-trip cos={cos:.6f} < 0.998 at M={M}"
|
||||
|
||||
# ===========================================================================
|
||||
# Test 2: FP8_E4M3 indexer keys (production ihd=128)
|
||||
# ===========================================================================
|
||||
print("\n--- Test 2: FP8_E4M3 indexer keys (ihd=128) ---")
|
||||
for M in [1, 4, 32, 128]:
|
||||
data = torch.randn(M, 128, device=device, dtype=torch.float32) * 3.0
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
||||
deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item()
|
||||
max_err = (data.float() - deq.float()).abs().max().item()
|
||||
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f} |data|_max={data.abs().max().item():.2f}")
|
||||
assert cos >= 0.998, f"FP8 round-trip cos={cos:.6f} < 0.998 at M={M}"
|
||||
print(f" M={M:3d}: cos={cos:.6f}")
|
||||
assert cos >= 0.998, f"FP8 indexer cos={cos:.6f} < 0.998 at M={M}"
|
||||
|
||||
# ===========================================================================
|
||||
# Test 3: FP32 RoPE
|
||||
# Test 3: Mixed storage round-trip (FP8 nope + BF16 rope → concat)
|
||||
# ===========================================================================
|
||||
print("\n--- Test 3: FP32 RoPE (production rope_dim=64, hd=512) ---")
|
||||
hd = 512; rope_dim = 64
|
||||
# Build proper RoPE cache (same as single_shot build_rope_cache)
|
||||
print("\n--- Test 3: Mixed FP8+BF16 full round-trip (hd=512) ---")
|
||||
# Build proper RoPE cache
|
||||
theta = 10000.0
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||||
angles = torch.outer(torch.arange(1024, dtype=torch.float32), freqs)
|
||||
cos_cache = torch.cos(angles).to(device) # (1024, rope_dim/2) FP32
|
||||
sin_cache = torch.sin(angles).to(device) # (1024, rope_dim/2) FP32
|
||||
# cos²+sin²=1 by construction
|
||||
cos_cache = torch.cos(angles).to(device)
|
||||
sin_cache = torch.sin(angles).to(device)
|
||||
|
||||
for M in [1, 4, 8]:
|
||||
data = torch.randn(M, hd, device=device, dtype=torch.float32) * 2.0
|
||||
from dsv4.ops.rope_cuda import apply_rope
|
||||
|
||||
for M in [1, 4, 8, 32]:
|
||||
# Simulate compressor FP32 output
|
||||
data_fp32 = torch.randn(M, hd, device=device, dtype=torch.float32) * 3.0
|
||||
positions = torch.arange(M, device=device, dtype=torch.long)
|
||||
|
||||
# FP32 RoPE via kv_quantize
|
||||
data_kv = data.clone()
|
||||
mod.rope_fp32(data_kv, positions, cos_cache, sin_cache, rope_dim, False)
|
||||
|
||||
# PyTorch FP32 reference
|
||||
data_ref = data.clone()
|
||||
nope = hd - rope_dim
|
||||
for m in range(M):
|
||||
p = positions[m].item()
|
||||
c = cos_cache[p] # (rope_dim/2,)
|
||||
s = sin_cache[p]
|
||||
for i in range(rope_dim // 2):
|
||||
ev = data_ref[m, nope + 2 * i]
|
||||
od = data_ref[m, nope + 2 * i + 1]
|
||||
data_ref[m, nope + 2 * i] = ev * c[i] - od * s[i]
|
||||
data_ref[m, nope + 2 * i + 1] = ev * s[i] + od * c[i]
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(data_kv.flatten(), data_ref.flatten(), dim=0).item()
|
||||
max_err = (data_kv - data_ref).abs().max().item()
|
||||
print(f" M={M}: cos={cos_sim:.6f} max_err={max_err:.8f}")
|
||||
assert cos_sim >= 0.99999, f"FP32 RoPE cos={cos_sim:.6f} < 0.99999 at M={M}"
|
||||
|
||||
# Reference: FP32 → BF16 → RoPE → full BF16
|
||||
ref_bf16 = data_fp32.bfloat16()
|
||||
ref_3d = ref_bf16.unsqueeze(1) # (M, 1, hd)
|
||||
ref_3d = apply_rope(ref_3d, positions, cos_cache, sin_cache, rope_dim)
|
||||
ref_full = ref_3d.squeeze(1) # (M, hd) BF16
|
||||
|
||||
# Our path: FP32 → split → FP8 nope + BF16 rope (with RoPE) → concat
|
||||
nope_fp32 = data_fp32[:, :nope_dim].contiguous()
|
||||
rope_bf16 = data_fp32[:, nope_dim:].bfloat16().contiguous()
|
||||
rope_3d = rope_bf16.unsqueeze(1)
|
||||
rope_3d = apply_rope(rope_3d, positions, cos_cache, sin_cache, rope_dim)
|
||||
rope_bf16 = rope_3d.squeeze(1)
|
||||
|
||||
nope_fp8, nope_scale = mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
nope_bf16 = mod.dequant_fp8_e4m3(nope_fp8.view(torch.uint8), nope_scale)
|
||||
|
||||
# Concat nope + rope
|
||||
result = torch.cat([nope_bf16, rope_bf16], dim=1) # (M, hd) BF16
|
||||
|
||||
cos = torch.nn.functional.cosine_similarity(ref_full.float().flatten(), result.float().flatten(), dim=0).item()
|
||||
max_err = (ref_full.float() - result.float()).abs().max().item()
|
||||
print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f}")
|
||||
assert cos >= 0.998, f"Mixed storage cos={cos:.6f} < 0.998 at M={M}"
|
||||
|
||||
# ===========================================================================
|
||||
# Test 4: Selective dequant matches full dequant (NVFP4)
|
||||
# Test 4: Selective dequant (CSA top-k gather)
|
||||
# ===========================================================================
|
||||
print("\n--- Test 4: Selective dequant NVFP4 (CSA top-k gather) ---")
|
||||
M = 32; hd = 512
|
||||
data = torch.randn(M, hd, device=device, dtype=torch.float32) * 5.0
|
||||
gsa = mod.compute_amax_gsa_fp32(data.contiguous(), 6.0 * 448.0)
|
||||
fp4, sf = mod.quantize_nvfp4_from_fp32(data.contiguous(), gsa)
|
||||
|
||||
deq_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
||||
# Full dequant
|
||||
full_deq = deq_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa)
|
||||
# Selective dequant — pick 5 entries
|
||||
indices = torch.tensor([0, 5, 10, 20, 31], device=device, dtype=torch.int32)
|
||||
sel_deq = deq_mod.dequant_nvfp4_selective(fp4.view(torch.uint8), sf.view(torch.uint8), gsa, indices)
|
||||
|
||||
# Compare
|
||||
for i, idx in enumerate(indices):
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item()
|
||||
assert cos >= 0.99999, f"Selective dequant mismatch at idx={idx}: cos={cos:.6f}"
|
||||
print(f" All 5 selective dequant entries match full dequant: PASS")
|
||||
|
||||
# ===========================================================================
|
||||
# Test 5: FP8 selective dequant
|
||||
# ===========================================================================
|
||||
print("\n--- Test 5: Selective dequant FP8 (indexer key gather) ---")
|
||||
M = 64; ihd = 128
|
||||
data = torch.randn(M, ihd, device=device, dtype=torch.float32) * 3.0
|
||||
print("\n--- Test 4: Selective FP8 dequant (CSA top-k gather) ---")
|
||||
M = 32; data = torch.randn(M, nope_dim, device=device, dtype=torch.float32) * 3.0
|
||||
fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous())
|
||||
full_deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale)
|
||||
indices = torch.tensor([0, 15, 30, 45, 63], device=device, dtype=torch.int32)
|
||||
indices = torch.tensor([0, 5, 10, 20, 31], device=device, dtype=torch.int32)
|
||||
sel_deq = mod.dequant_fp8_e4m3_selective(fp8.view(torch.uint8), scale, indices)
|
||||
for i, idx in enumerate(indices):
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item()
|
||||
assert cos >= 0.99999, f"FP8 selective mismatch at idx={idx}: cos={cos:.6f}"
|
||||
print(f" All 5 selective dequant entries match: PASS")
|
||||
assert cos >= 0.99999, f"Selective mismatch at idx={idx}: cos={cos:.6f}"
|
||||
print(f" All 5 selective dequant entries match full dequant: PASS")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("ALL TESTS PASSED")
|
||||
|
||||
Reference in New Issue
Block a user