diff --git a/single_shot_inference.py b/single_shot_inference.py index 2e31296f..c8474ecb 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -444,37 +444,40 @@ class Indexer: # KV Cache # ===================================================================== class KVCache: - """KV Cache with NVFP4 compressed KV and FP8_E4M3 indexer keys. + """KV Cache with mixed-precision compressed KV (DeepSeek V4 paper format). - KV-1/KV-2: Compressed KV is stored as NVFP4 (E2M1 + E4M3 + FP32 gsa). - KV-3: Indexer keys are stored as FP8_E4M3 (1 byte + per-row scale). - SWA: BF16 (only 128 tokens × 512 × 61 layers = 8MB, fits in L2). + KV-1/KV-2: Compressed KV uses mixed storage: + - Non-RoPE dims (448 of 512): FP8_E4M3 → ~50% size reduction + - RoPE dims (64 of 512): BF16 (RoPE applied directly, stored as BF16) + KV-3: Indexer keys stored as FP8_E4M3 (ihd=128, no RoPE). + SWA: BF16 (128 tokens × 512 × 61 layers = 8MB, fits in L2). - Storage savings vs BF16: - NVFP4: 0.5 bytes/val + 0.125 bytes/val (sf) + 4 bytes/row (gsa) - = hd/2 + hd/16 + 1 scalars per entry - = 256 + 32 + 1 = 289 bytes/entry at hd=512 - vs 1024 bytes/entry BF16 → 3.5× savings - FP8_E4M3: 1 byte/val + 4 bytes/row (scale) - = 128 + 4 = 132 bytes/entry at ihd=128 - vs 256 bytes/entry BF16 → 1.9× savings + This matches the DeepSeek V4 paper: "BF16 for RoPE dims, FP8 for remaining dims. + This hybrid representation reduces the KV cache size by nearly half." + + Storage per compressed entry at hd=512: + nope (448) × FP8 = 448 bytes + 4 bytes (scale) = 452 + rope (64) × BF16 = 128 bytes + Total = 580 bytes vs 1024 bytes BF16 → 1.76× savings """ def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0', - indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024): + indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024, rope_dim=64): self.hd, self.ws, self.dev = head_dim, window_size, device self.idx_key_dim = indexer_key_dim self.ratio = compress_ratio self.max_comp = max_comp + self.rope_dim = rope_dim + self.nope_dim = head_dim - rope_dim # 448 # SWA: BF16 (small, fits in L2) self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) self.swa_len, self.swa_head = 0, 0 - # Compressed KV: NVFP4 storage - self.comp_kv_fp4 = torch.zeros(max_comp, head_dim // 2, dtype=torch.uint8, device=device) - self.comp_kv_sf = torch.zeros(max_comp, head_dim // 16, dtype=torch.uint8, device=device) - self.comp_kv_gsa = torch.zeros(max_comp, dtype=torch.float32, device=device) + # Compressed KV: mixed FP8 (nope) + BF16 (rope) + self.comp_nope_fp8 = torch.zeros(max_comp, self.nope_dim, dtype=torch.uint8, device=device) + self.comp_nope_scale = torch.zeros(max_comp, dtype=torch.float32, device=device) + self.comp_rope_bf16 = torch.zeros(max_comp, rope_dim, dtype=torch.bfloat16, device=device) self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device) # Indexer compressed keys: FP8_E4M3 @@ -487,21 +490,12 @@ class KVCache: self._has_idx = False # Cache dequant modules (loaded once) - self._dequant_mod = None self._kv_quant_mod = None - def _get_dequant_mod(self): - if self._dequant_mod is None: - from dsv4.kernels.cuda.loader import get_cuda_module - self._dequant_mod = get_cuda_module( - "dequant_nvfp4", ["dequant_nvfp4.cu"]) - return self._dequant_mod - def _get_kv_quant_mod(self): if self._kv_quant_mod is None: from dsv4.kernels.cuda.loader import get_cuda_module - self._kv_quant_mod = get_cuda_module( - "kv_quantize", ["kv_quantize.cu"]) + self._kv_quant_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) return self._kv_quant_mod def append_swa(self, kv, pos): @@ -513,78 +507,52 @@ class KVCache: self.swa_head = (self.swa_head + T) % self.ws self.swa_len = min(self.swa_len + T, self.ws) - def add_compressed(self, ckv, cpos, idx_kv=None): - """Add compressed KV entries to NVFP4 cache. - - ckv can be: - - BF16 tensor (n_comp, hd) — will be quantized to NVFP4 - - NVFP4 triple (fp4, sf, gsa) — stored directly - idx_kv can be: - - BF16 tensor (n_comp, ihd) — will be quantized to FP8_E4M3 - - FP8 triple (fp8, scale) — stored directly - """ - if ckv is None: return + def set_compressed_mixed(self, nope_fp8, nope_scale, rope_bf16): + """Add compressed KV entries (mixed FP8 nope + BF16 rope).""" + T = nope_fp8.shape[0] end = self.n_comp - - # Handle compressed KV - if isinstance(ckv, tuple) and len(ckv) == 3: - # NVFP4 triple: (fp4, sf, gsa) - fp4, sf, gsa = ckv - T = fp4.shape[0] - self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8) if fp4.dtype != torch.uint8 else fp4 - self.comp_kv_sf[end:end+T] = sf.view(torch.uint8) if sf.dtype != torch.uint8 else sf - self.comp_kv_gsa[end:end+T] = gsa - elif isinstance(ckv, torch.Tensor): - # BF16 tensor — quantize to NVFP4 using proven two-kernel path - T = ckv.shape[0] - from dsv4.ops.quantize import quantize_nvfp4_gpu_fused - fp4, sf, gsa = quantize_nvfp4_gpu_fused(ckv) - self.comp_kv_fp4[end:end+T] = fp4.view(torch.uint8) - self.comp_kv_sf[end:end+T] = sf.view(torch.uint8) - self.comp_kv_gsa[end:end+T] = gsa - else: - raise ValueError(f"Unexpected ckv type: {type(ckv)}") - - self.comp_pos_buf[end:end+ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0]] = cpos - T = ckv.shape[0] if isinstance(ckv, torch.Tensor) else ckv[0].shape[0] - - # Handle indexer keys - if idx_kv is not None: - if isinstance(idx_kv, tuple) and len(idx_kv) == 2: - # FP8 triple: (fp8, scale) - fp8, scale = idx_kv - self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8 - self.comp_idx_scale[end:end+T] = scale - elif isinstance(idx_kv, torch.Tensor): - # BF16 tensor — quantize to FP8_E4M3 - mod = self._get_kv_quant_mod() - fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous()) - self.comp_idx_fp8[end:end+T] = fp8.view(torch.uint8) - self.comp_idx_scale[end:end+T] = scale - self._has_idx = True - + self.comp_nope_fp8[end:end+T] = nope_fp8.view(torch.uint8) if nope_fp8.dtype != torch.uint8 else nope_fp8 + self.comp_nope_scale[end:end+T] = nope_scale + self.comp_rope_bf16[end:end+T] = rope_bf16 self.n_comp = end + T - @property - def comp_kv(self): - """Dequantize NVFP4 → BF16 for FMHA. Returns (n_comp, hd) BF16.""" - if self.n_comp == 0: return None - mod = self._get_dequant_mod() - return mod.dequant_nvfp4( - self.comp_kv_fp4[:self.n_comp], - self.comp_kv_sf[:self.n_comp], - self.comp_kv_gsa[:self.n_comp], - ) + def set_indexer_keys_fp8(self, idx_kv): + """Add indexer compressed keys. idx_kv is BF16 (n_comp, ihd) or FP8 (fp8, scale).""" + T = self.n_comp # should match compressed KV count + end = T - (idx_kv[0].shape[0] if isinstance(idx_kv, tuple) else idx_kv.shape[0]) + if isinstance(idx_kv, tuple) and len(idx_kv) == 2: + fp8, scale = idx_kv + self.comp_idx_fp8[end:end+T-end] = fp8.view(torch.uint8) if fp8.dtype != torch.uint8 else fp8 + self.comp_idx_scale[end:end+T-end] = scale + elif isinstance(idx_kv, torch.Tensor): + mod = self._get_kv_quant_mod() + fp8, scale = mod.quantize_fp8_e4m3_from_fp32(idx_kv.float().contiguous()) + self.comp_idx_fp8[end:self.n_comp] = fp8.view(torch.uint8) + self.comp_idx_scale[end:self.n_comp] = scale + self._has_idx = True - def comp_kv_selective(self, indices): - """Dequantize selected NVFP4 entries → BF16 for CSA top-k gather.""" - mod = self._get_dequant_mod() - return mod.dequant_nvfp4_selective( - self.comp_kv_fp4, - self.comp_kv_sf, - self.comp_kv_gsa, - indices.int(), - ) + def comp_nope_selective(self, indices): + """Dequant FP8 nope for selected entries → BF16.""" + mod = self._get_kv_quant_mod() + return mod.dequant_fp8_e4m3_selective( + self.comp_nope_fp8, self.comp_nope_scale, indices.int()) + + def comp_rope_selective(self, indices): + """Gather BF16 rope for selected entries.""" + return self.comp_rope_bf16[indices.long()] + + @property + def comp_nope_all(self): + """Dequant all FP8 nope → BF16.""" + mod = self._get_kv_quant_mod() + return mod.dequant_fp8_e4m3( + self.comp_nope_fp8[:self.n_comp], + self.comp_nope_scale[:self.n_comp]) + + @property + def comp_rope_all(self): + """Return all BF16 rope entries.""" + return self.comp_rope_bf16[:self.n_comp] @property def comp_pos(self): @@ -592,21 +560,19 @@ class KVCache: @property def comp_idx_kv(self): - """Dequantize FP8 indexer keys → BF16 for scoring.""" + """Dequant FP8 indexer keys → BF16 for scoring.""" if not self._has_idx or self.n_comp == 0: return None mod = self._get_kv_quant_mod() return mod.dequant_fp8_e4m3( self.comp_idx_fp8[:self.n_comp], - self.comp_idx_scale[:self.n_comp], - ) + self.comp_idx_scale[:self.n_comp]) def get_swa(self): - """Return SWA KV and positions as views (no clone). Caller copies into gather_buf.""" + """Return SWA KV and positions as views (no clone).""" if self.swa_len == 0: return self.swa[:0], self.swa_pos[:0] if self.swa_len < self.ws: return self.swa[:self.swa_len], self.swa_pos[:self.swa_len] - # Ring buffer wrap — gather non-contiguous rows idx = torch.arange(self.swa_head, self.swa_head + self.ws, device=self.dev) % self.ws return self.swa[idx], self.swa_pos[idx] @@ -691,26 +657,31 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, _pt('rope_kv_end') kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions) - # 3. Compressor → compressed KV → FP32 RoPE → NVFP4 + # 3. Compressor → compressed KV (mixed storage: FP8 + BF16 RoPE) + # DeepSeek V4 paper: "BF16 for RoPE dims, FP8 for remaining dims" _pt('compress_start') - comp_nvfp4, comp_pos, block_bias = None, None, None; comp_idx_kv = None + comp_pos, block_bias = None, None; comp_idx_kv = None if compressor is not None and compressor.ratio > 0: comp_kv_fp32, comp_pos, block_bias = compressor.forward(x_normed, positions) if comp_kv_fp32 is not None: - # Apply RoPE on FP32 (no BF16 intermediate) from dsv4.kernels.cuda.loader import get_cuda_module kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) - comp_kv_fp32_contig = comp_kv_fp32.contiguous() - c = rope_cos[comp_pos].contiguous() - s = rope_sin[comp_pos].contiguous() - kv_mod.rope_fp32(comp_kv_fp32_contig, comp_pos.contiguous(), c, s, rd, False) - # Quantize FP32 → NVFP4 (two-kernel, proven pattern) - gsa = kv_mod.compute_amax_gsa_fp32(comp_kv_fp32_contig, 6.0 * 448.0) - fp4, sf = kv_mod.quantize_nvfp4_from_fp32(comp_kv_fp32_contig, gsa) - comp_nvfp4 = (fp4, sf, gsa) + nope_dim = hd - rd # 448 + # Split into non-RoPE (FP8) and RoPE (BF16) parts + nope_fp32 = comp_kv_fp32[:, :nope_dim].contiguous() # (n_comp, 448) FP32 + rope_bf16 = comp_kv_fp32[:, nope_dim:].bfloat16().contiguous() # (n_comp, 64) BF16 + # Apply RoPE on BF16 rope dims (existing BF16 RoPE kernel) + rope_3d = rope_bf16.unsqueeze(1) # (n_comp, 1, 64) + rope_3d = _apply_rope(rope_3d, comp_pos, rope_cos, rope_sin, rd) + rope_bf16 = rope_3d.squeeze(1) # (n_comp, 64) BF16 + # Quantize non-RoPE part FP32 → FP8_E4M3 + nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32) + # Store mixed-format compressed KV + kv_cache.set_compressed_mixed(nope_fp8, nope_scale, rope_bf16) if compressor.is_csa and indexer is not None and indexer.compressor is not None: comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions) - kv_cache.add_compressed(comp_nvfp4, comp_pos, comp_idx_kv) + # Indexer keys: FP8_E4M3 (ihd=128, no RoPE) + kv_cache.set_indexer_keys_fp8(comp_idx_kv) _pt('compress_end') # 4. Indexer top-k (CSA) @@ -718,24 +689,31 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if indexer is not None and ratio == 4: topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li) - # 5. Gather KV — NVFP4 dequant for compressed KV + # 5. Gather KV — mixed storage: FP8 nope dequant + BF16 rope concat _pt('gather_start') swa_kv, _swa_pos = kv_cache.get_swa() swa_len = swa_kv.shape[0] - gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated BF16 + gbuf = kv_cache.gather_buf # (max_len, hd) pre-allocated BF16 if kv_cache.n_comp > 0: if ratio == 4: - # CSA: dequant only top-k entries (bandwidth savings) + # CSA: dequant only top-k entries assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken" tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int() n_tk = tk.shape[0] - gbuf[:n_tk] = kv_cache.comp_kv_selective(tk) # NVFP4 → BF16 + # Dequant FP8 nope + gather BF16 rope for top-k + nope_bf16 = kv_cache.comp_nope_selective(tk) # FP8→BF16 (n_tk, 448) + rope_bf16 = kv_cache.comp_rope_selective(tk) # BF16 gather (n_tk, 64) + gbuf[:n_tk, :nope_dim] = nope_bf16 + gbuf[:n_tk, nope_dim:] = rope_bf16 gbuf[n_tk:n_tk + swa_len] = swa_kv all_kv = gbuf[:n_tk + swa_len] elif ratio > 4: - # HCA: dequant all entries (dense gather) + # HCA: dequant all entries n_comp = kv_cache.n_comp - gbuf[:n_comp] = kv_cache.comp_kv # NVFP4 → BF16 + nope_bf16 = kv_cache.comp_nope_all # FP8→BF16 (n_comp, 448) + rope_bf16 = kv_cache.comp_rope_all # BF16 (n_comp, 64) + gbuf[:n_comp, :nope_dim] = nope_bf16 + gbuf[:n_comp, nope_dim:] = rope_bf16 gbuf[n_comp:n_comp + swa_len] = swa_kv all_kv = gbuf[:n_comp + swa_len] else: @@ -1223,7 +1201,7 @@ def main(): # C1: max_comp derived from target context and compress ratio max_comp = (max_ctx + ratio - 1) // ratio if ratio > 0 else 0 kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp, device=dev, - indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk) + indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk, rope_dim=rd) if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) diff --git a/tests/unit/test_kv_quantize.py b/tests/unit/test_kv_quantize.py index 443b0304..62927991 100644 --- a/tests/unit/test_kv_quantize.py +++ b/tests/unit/test_kv_quantize.py @@ -1,16 +1,14 @@ #!/usr/bin/env python3 -"""KV-1/KV-2/KV-3: NVFP4 compressed KV + FP8 indexer keys — production-value unit tests. +"""KV-1/KV-2/KV-3: Mixed FP8+BF16 compressed KV + FP8 indexer keys — production-value unit tests. Tests the kv_quantize.cu kernels at production shapes: - - NVFP4: hd=512 (not 64/128) - - FP8_E4M3: ihd=128 - - FP32 RoPE: rope_dim=64 - - Multiple batch sizes (1, 4, 8, 32) + - FP8_E4M3: nope_dim=448, ihd=128 + - BF16 RoPE: rope_dim=64 + - Mixed storage: FP8 nope + BF16 rope → concat → compare with FP32 reference Falsifiable gates: - - NVFP4 quantize FP32→NVFP4→BF16: cos ≥ 0.995 vs FP32 reference - - FP8_E4M3 quantize FP32→FP8→BF16: cos ≥ 0.999 vs FP32 reference - - FP32 RoPE: cos = 1.000000 vs PyTorch FP32 reference + - FP8_E4M3 quantize FP32→FP8→BF16: cos ≥ 0.998 vs FP32 reference + - Mixed storage round-trip: FP32 → (FP8 nope + BF16 rope) → BF16 concat: cos ≥ 0.998 - Selective dequant matches full dequant """ @@ -24,118 +22,92 @@ from dsv4.kernels.cuda.loader import get_cuda_module mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) print("=" * 60) -print("KV-1/KV-2/KV-3: Production-Value Unit Tests") +print("KV-1/KV-2/KV-3: Mixed FP8+BF16 Storage — Unit Tests") print("=" * 60) -# =========================================================================== -# Test 1: NVFP4 quantize FP32 → NVFP4 → BF16 -# =========================================================================== -print("\n--- Test 1: NVFP4 FP32→NVFP4 round-trip (production hd=512) ---") -for M in [1, 4, 8, 32]: - data = torch.randn(M, 512, device=device, dtype=torch.float32) * 5.0 - gsa = mod.compute_amax_gsa_fp32(data.contiguous(), 6.0 * 448.0) - fp4, sf = mod.quantize_nvfp4_from_fp32(data.contiguous(), gsa) - - # Dequant using the proven dequant_nvfp4 kernel - deq_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) - deq = deq_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa) - - cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item() - max_err = (data.float() - deq.float()).abs().max().item() - print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f} |data|_max={data.abs().max().item():.2f}") - assert cos >= 0.990, f"NVFP4 round-trip cos={cos:.6f} < 0.990 at M={M}" +hd = 512; rope_dim = 64; nope_dim = hd - rope_dim # 448 # =========================================================================== -# Test 2: FP8_E4M3 quantize FP32 → FP8 → BF16 +# Test 1: FP8_E4M3 nope round-trip (production nope_dim=448) # =========================================================================== -print("\n--- Test 2: FP8_E4M3 FP32→FP8 round-trip (production ihd=128) ---") +print("\n--- Test 1: FP8_E4M3 nope FP32→FP8→BF16 (nope_dim=448) ---") +for M in [1, 4, 8, 32, 128]: + data = torch.randn(M, nope_dim, device=device, dtype=torch.float32) * 3.0 + fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous()) + deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale) + cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item() + max_err = (data.float() - deq.float()).abs().max().item() + print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f}") + assert cos >= 0.998, f"FP8 nope round-trip cos={cos:.6f} < 0.998 at M={M}" + +# =========================================================================== +# Test 2: FP8_E4M3 indexer keys (production ihd=128) +# =========================================================================== +print("\n--- Test 2: FP8_E4M3 indexer keys (ihd=128) ---") for M in [1, 4, 32, 128]: data = torch.randn(M, 128, device=device, dtype=torch.float32) * 3.0 fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous()) deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale) - cos = torch.nn.functional.cosine_similarity(data.float().flatten(), deq.float().flatten(), dim=0).item() - max_err = (data.float() - deq.float()).abs().max().item() - print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f} |data|_max={data.abs().max().item():.2f}") - assert cos >= 0.998, f"FP8 round-trip cos={cos:.6f} < 0.998 at M={M}" + print(f" M={M:3d}: cos={cos:.6f}") + assert cos >= 0.998, f"FP8 indexer cos={cos:.6f} < 0.998 at M={M}" # =========================================================================== -# Test 3: FP32 RoPE +# Test 3: Mixed storage round-trip (FP8 nope + BF16 rope → concat) # =========================================================================== -print("\n--- Test 3: FP32 RoPE (production rope_dim=64, hd=512) ---") -hd = 512; rope_dim = 64 -# Build proper RoPE cache (same as single_shot build_rope_cache) +print("\n--- Test 3: Mixed FP8+BF16 full round-trip (hd=512) ---") +# Build proper RoPE cache theta = 10000.0 freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) angles = torch.outer(torch.arange(1024, dtype=torch.float32), freqs) -cos_cache = torch.cos(angles).to(device) # (1024, rope_dim/2) FP32 -sin_cache = torch.sin(angles).to(device) # (1024, rope_dim/2) FP32 -# cos²+sin²=1 by construction +cos_cache = torch.cos(angles).to(device) +sin_cache = torch.sin(angles).to(device) -for M in [1, 4, 8]: - data = torch.randn(M, hd, device=device, dtype=torch.float32) * 2.0 +from dsv4.ops.rope_cuda import apply_rope + +for M in [1, 4, 8, 32]: + # Simulate compressor FP32 output + data_fp32 = torch.randn(M, hd, device=device, dtype=torch.float32) * 3.0 positions = torch.arange(M, device=device, dtype=torch.long) - - # FP32 RoPE via kv_quantize - data_kv = data.clone() - mod.rope_fp32(data_kv, positions, cos_cache, sin_cache, rope_dim, False) - - # PyTorch FP32 reference - data_ref = data.clone() - nope = hd - rope_dim - for m in range(M): - p = positions[m].item() - c = cos_cache[p] # (rope_dim/2,) - s = sin_cache[p] - for i in range(rope_dim // 2): - ev = data_ref[m, nope + 2 * i] - od = data_ref[m, nope + 2 * i + 1] - data_ref[m, nope + 2 * i] = ev * c[i] - od * s[i] - data_ref[m, nope + 2 * i + 1] = ev * s[i] + od * c[i] - - cos_sim = torch.nn.functional.cosine_similarity(data_kv.flatten(), data_ref.flatten(), dim=0).item() - max_err = (data_kv - data_ref).abs().max().item() - print(f" M={M}: cos={cos_sim:.6f} max_err={max_err:.8f}") - assert cos_sim >= 0.99999, f"FP32 RoPE cos={cos_sim:.6f} < 0.99999 at M={M}" + + # Reference: FP32 → BF16 → RoPE → full BF16 + ref_bf16 = data_fp32.bfloat16() + ref_3d = ref_bf16.unsqueeze(1) # (M, 1, hd) + ref_3d = apply_rope(ref_3d, positions, cos_cache, sin_cache, rope_dim) + ref_full = ref_3d.squeeze(1) # (M, hd) BF16 + + # Our path: FP32 → split → FP8 nope + BF16 rope (with RoPE) → concat + nope_fp32 = data_fp32[:, :nope_dim].contiguous() + rope_bf16 = data_fp32[:, nope_dim:].bfloat16().contiguous() + rope_3d = rope_bf16.unsqueeze(1) + rope_3d = apply_rope(rope_3d, positions, cos_cache, sin_cache, rope_dim) + rope_bf16 = rope_3d.squeeze(1) + + nope_fp8, nope_scale = mod.quantize_fp8_e4m3_from_fp32(nope_fp32) + nope_bf16 = mod.dequant_fp8_e4m3(nope_fp8.view(torch.uint8), nope_scale) + + # Concat nope + rope + result = torch.cat([nope_bf16, rope_bf16], dim=1) # (M, hd) BF16 + + cos = torch.nn.functional.cosine_similarity(ref_full.float().flatten(), result.float().flatten(), dim=0).item() + max_err = (ref_full.float() - result.float()).abs().max().item() + print(f" M={M:3d}: cos={cos:.6f} max_err={max_err:.4f}") + assert cos >= 0.998, f"Mixed storage cos={cos:.6f} < 0.998 at M={M}" # =========================================================================== -# Test 4: Selective dequant matches full dequant (NVFP4) +# Test 4: Selective dequant (CSA top-k gather) # =========================================================================== -print("\n--- Test 4: Selective dequant NVFP4 (CSA top-k gather) ---") -M = 32; hd = 512 -data = torch.randn(M, hd, device=device, dtype=torch.float32) * 5.0 -gsa = mod.compute_amax_gsa_fp32(data.contiguous(), 6.0 * 448.0) -fp4, sf = mod.quantize_nvfp4_from_fp32(data.contiguous(), gsa) - -deq_mod = get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"]) -# Full dequant -full_deq = deq_mod.dequant_nvfp4(fp4.view(torch.uint8), sf.view(torch.uint8), gsa) -# Selective dequant — pick 5 entries -indices = torch.tensor([0, 5, 10, 20, 31], device=device, dtype=torch.int32) -sel_deq = deq_mod.dequant_nvfp4_selective(fp4.view(torch.uint8), sf.view(torch.uint8), gsa, indices) - -# Compare -for i, idx in enumerate(indices): - cos = torch.nn.functional.cosine_similarity( - full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item() - assert cos >= 0.99999, f"Selective dequant mismatch at idx={idx}: cos={cos:.6f}" -print(f" All 5 selective dequant entries match full dequant: PASS") - -# =========================================================================== -# Test 5: FP8 selective dequant -# =========================================================================== -print("\n--- Test 5: Selective dequant FP8 (indexer key gather) ---") -M = 64; ihd = 128 -data = torch.randn(M, ihd, device=device, dtype=torch.float32) * 3.0 +print("\n--- Test 4: Selective FP8 dequant (CSA top-k gather) ---") +M = 32; data = torch.randn(M, nope_dim, device=device, dtype=torch.float32) * 3.0 fp8, scale = mod.quantize_fp8_e4m3_from_fp32(data.contiguous()) full_deq = mod.dequant_fp8_e4m3(fp8.view(torch.uint8), scale) -indices = torch.tensor([0, 15, 30, 45, 63], device=device, dtype=torch.int32) +indices = torch.tensor([0, 5, 10, 20, 31], device=device, dtype=torch.int32) sel_deq = mod.dequant_fp8_e4m3_selective(fp8.view(torch.uint8), scale, indices) for i, idx in enumerate(indices): cos = torch.nn.functional.cosine_similarity( full_deq[idx].float().flatten(), sel_deq[i].float().flatten(), dim=0).item() - assert cos >= 0.99999, f"FP8 selective mismatch at idx={idx}: cos={cos:.6f}" -print(f" All 5 selective dequant entries match: PASS") + assert cos >= 0.99999, f"Selective mismatch at idx={idx}: cos={cos:.6f}" +print(f" All 5 selective dequant entries match full dequant: PASS") print("\n" + "=" * 60) print("ALL TESTS PASSED")