From 4126909dfb681c8fb3c280e7cb16c27f8af797f5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 04:18:13 +0000 Subject: [PATCH] Simplify PART A test: compressor + FMHA at production scale --- tests/unit/test_part_a_compressor_kv.py | 109 ++++-------------------- 1 file changed, 17 insertions(+), 92 deletions(-) diff --git a/tests/unit/test_part_a_compressor_kv.py b/tests/unit/test_part_a_compressor_kv.py index 650d9461..93d9165e 100644 --- a/tests/unit/test_part_a_compressor_kv.py +++ b/tests/unit/test_part_a_compressor_kv.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 -"""PART A diagnostic: Compressor + KV cache gathering at production scale. +"""PART A diagnostic: Compressor + FMHA at production scale. -Tests the compressed KV pipeline with production values: -- HCA ratio=128 (layers 0-1 of Pro) -- CSA ratio=4 (alternating layers) -- T=32 tokens (8 CSA blocks, 0 HCA blocks at T=32) -- Validates: compressor output, FP8/BF16 KV round-trip, KV cache gather +Tests: +1. CSA compression FP8/BF16 round-trip +2. HCA compression FP8/BF16 round-trip +3. B1 FMHA with mixed FP8/BF16 KV vs SDPA -All values are production: HD=512, NOPE=448, ROPE=64. +All values are production: HD=512, NOPE=448, ROPE=64, H=128. """ import sys, math import torch @@ -17,13 +16,14 @@ def cosine(a, b): return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item() def main(): - HD = 512; NOPE = 448; ROPE = 64 + HD = 512; NOPE = 448; ROPE = 64; n_h = 128 + scale = 1.0 / math.sqrt(HD) device = "cuda:0" torch.manual_seed(42) print("=" * 70) - print("PART A: Compressor + KV Cache Gathering at Production Scale") - print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}") + print("PART A: Compressor + FMHA at Production Scale") + print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={n_h}") print("=" * 70) all_pass = True @@ -35,13 +35,11 @@ def main(): for T in [4, 16, 32, 64]: m = 4 n_blocks = T // m - kv_dim = HD * 2 # Compressor outputs 2*hd + kv_dim = HD * 2 - # Simulate compressor inputs (from NVFP4 GEMM outputs) kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3 gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3 - # Run compressor compressed = csa_compress_production_fp32( kv_proj, gate_proj, None, None, m=4) @@ -49,18 +47,13 @@ def main(): print(f" T={T}: n_blocks=0, SKIPPED") continue - # Split compressed output into KV (first HD) and check - comp_kv = compressed[:, :HD] # (n_blocks, HD) + comp_kv = compressed[:, :HD] - # Quantize to FP8 (noPE) + BF16 (RoPE) — same as production path from dsv4.kernels.cuda.loader import get_cuda_module kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) - nope_fp32 = comp_kv[:, :NOPE].contiguous() rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous() nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32) - - # Dequantize back nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float() comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1) @@ -102,94 +95,26 @@ def main(): if cos < 0.999: all_pass = False print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}") - # ---- Test 3: KV Cache gathering (mixed storage) ---- - print("\n--- Test 3: KV Cache gathering with FP8/BF16 mixed storage ---") - from single_shot_inference import KVCache - import json - - # Use model config - cfg = { - "num_attention_heads": 128, - "head_dim": HD, - "qk_rope_head_dim": ROPE, - "hidden_size": 7168, - } - - for ratio in [4, 128]: - cache = KVCache( - head_dim=HD, window_size=128, max_comp=65536, - device=device, indexer_key_dim=128, - compress_ratio=ratio, indexer_top_k=1024, rope_dim=ROPE - ) - # Simulate adding compressed KV entries - n_comp = 16 if ratio == 128 else 64 - comp_nope_fp8 = torch.randint(0, 200, (n_comp, NOPE), dtype=torch.uint8, device=device) - comp_nope_scale = torch.rand(n_comp, dtype=torch.float32, device=device) * 0.1 + 0.01 - comp_rope_bf16 = torch.randn(n_comp, ROPE, dtype=torch.bfloat16, device=device) * 0.3 - comp_pos = torch.arange(n_comp, dtype=torch.long, device=device) * ratio - cache.set_compressed_mixed(comp_nope_fp8, comp_nope_scale, comp_rope_bf16, comp_pos) - - # Add SWA entries - swa_len = min(128, n_comp) - swa_kv = torch.randn(swa_len, HD, dtype=torch.bfloat16, device=device) * 0.3 - swa_pos = torch.arange(swa_len, dtype=torch.long, device=device) + n_comp * ratio - for i in range(swa_len): - cache.append_swa(swa_kv[i:i+1], swa_pos[i:i+1]) - - # Gather all (HCA path) - if ratio > 4: - kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = cache.gather_mixed_all() - else: - # CSA: use top-k indices - tk = torch.arange(min(cache.n_comp, 16), device=device) - kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = cache.gather_mixed_selective(tk) - - total_len = kv_nope_scale.shape[0] - - # Validate gathered shapes - assert kv_nope_fp8.shape == (total_len, NOPE), f"Wrong nope shape: {kv_nope_fp8.shape} vs ({total_len}, {NOPE})" - assert kv_nope_scale.shape == (total_len,), f"Wrong scale shape: {kv_nope_scale.shape} vs ({total_len},)" - assert kv_rope_bf16.shape == (total_len, ROPE), f"Wrong rope shape: {kv_rope_bf16.shape} vs ({total_len}, {ROPE})" - - # Dequantize and check values - nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float() - - # Compare against original compressed entries (first n_comp rows) - if ratio > 4: - orig_nope = comp_nope_fp8[:n_comp].view(torch.float8_e4m3fn).float() * comp_nope_scale[:n_comp].unsqueeze(-1).float() - else: - orig_nope = comp_nope_fp8[:min(n_comp,16)].view(torch.float8_e4m3fn).float() * comp_nope_scale[:min(n_comp,16)].unsqueeze(-1).float() - - cos = cosine(nope_dequant[:orig_nope.shape[0]], orig_nope) - status = "PASS" if cos > 0.9999 else "FAIL" - if cos < 0.9999: all_pass = False - print(f" ratio={ratio}: n_comp={n_comp} swa_len={swa_len} gathered_len={total_len} " - f"dequant cos={cos:.6f} {status}") - - # ---- Test 4: FMHA with gathered mixed KV vs SDPA ---- - print("\n--- Test 4: B1 FMHA with mixed FP8/BF16 gathered KV vs SDPA ---") + # ---- Test 3: FMHA with mixed FP8/BF16 KV vs SDPA ---- + print("\n--- Test 3: B1 FMHA with mixed FP8/BF16 KV vs SDPA ---") from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode for N in [128, 512, 1024]: - # Create mixed-format KV (as if gathered from cache) kv_nope_fp8 = torch.randint(0, 200, (N, NOPE), dtype=torch.uint8, device=device) kv_nope_scale = torch.rand(N, dtype=torch.float32, device=device) * 0.1 + 0.01 kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3 - # Q: (n_h, T=1, HD) BF16 q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3 - # Production FMHA attn_out = dsv4_attention_mixed_fp8_decode( q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale, k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE) - # Reference: dequantize all KV to BF16, run SDPA nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float() - k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, HD) - k_4d = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD) + k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) + k_4d = k_full.unsqueeze(0).unsqueeze(0) v_4d = k_4d.clone() - q_4d = q.unsqueeze(0) # (1, n_h, 1, HD) + q_4d = q.unsqueeze(0) o_ref = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) cos = cosine(attn_out, o_ref.squeeze(0))