diff --git a/tests/unit/test_part_a_compressor_kv.py b/tests/unit/test_part_a_compressor_kv.py index 93d9165e..37ebe386 100644 --- a/tests/unit/test_part_a_compressor_kv.py +++ b/tests/unit/test_part_a_compressor_kv.py @@ -1,13 +1,5 @@ #!/usr/bin/env python3 -"""PART A diagnostic: Compressor + FMHA at production scale. - -Tests: -1. CSA compression FP8/BF16 round-trip -2. HCA compression FP8/BF16 round-trip -3. B1 FMHA with mixed FP8/BF16 KV vs SDPA - -All values are production: HD=512, NOPE=448, ROPE=64, H=128. -""" +"""PART A diagnostic: Compressor + FMHA at production scale.""" import sys, math import torch import torch.nn.functional as F @@ -23,104 +15,95 @@ def main(): print("=" * 70) print("PART A: Compressor + FMHA at Production Scale") - print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={n_h}") print("=" * 70) all_pass = True # ---- Test 1: CSA compression round-trip ---- - print("\n--- Test 1: CSA compression (ratio=4) with FP8/BF16 KV ---") + print("\n--- Test 1: CSA compression (ratio=4) ---") from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32 + from dsv4.kernels.cuda.loader import get_cuda_module + kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) for T in [4, 16, 32, 64]: - m = 4 - n_blocks = T // m - kv_dim = HD * 2 - + m = 4; n_blocks = T // m; kv_dim = HD * 2 kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3 gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3 - - compressed = csa_compress_production_fp32( - kv_proj, gate_proj, None, None, m=4) - - if compressed.shape[0] == 0: - print(f" T={T}: n_blocks=0, SKIPPED") - continue - + compressed = csa_compress_production_fp32(kv_proj, gate_proj, None, None, m=4) + if compressed.shape[0] == 0: print(f" T={T}: SKIP"); continue comp_kv = compressed[:, :HD] - - from dsv4.kernels.cuda.loader import get_cuda_module - kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) nope_fp32 = comp_kv[:, :NOPE].contiguous() rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous() nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32) nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float() comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1) - cos = cosine(comp_kv, comp_kv_rt) - status = "PASS" if cos > 0.999 else "FAIL" - if cos < 0.999: all_pass = False - print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}") + ok = cos > 0.999 + if not ok: all_pass = False + print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}") - # ---- Test 2: HCA compression (ratio=128) ---- - print("\n--- Test 2: HCA compression (ratio=128) with FP8/BF16 KV ---") + # ---- Test 2: HCA compression round-trip ---- + print("\n--- Test 2: HCA compression (ratio=128) ---") from dsv4.kernels.compressor.production_compress import hca_compress_production_fp32 for T in [128, 256]: - m = 128 - n_blocks = T // m - if n_blocks == 0: - print(f" T={T}: n_blocks=0, SKIPPED") - continue - + m = 128; n_blocks = T // m + if n_blocks == 0: print(f" T={T}: SKIP"); continue kv_dim = HD * 2 kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3 gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3 - - compressed = hca_compress_production_fp32( - kv_proj, gate_proj, None, None, m=128) - + compressed = hca_compress_production_fp32(kv_proj, gate_proj, None, None, m=128) comp_kv = compressed[:, :HD] - - from dsv4.kernels.cuda.loader import get_cuda_module - kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) nope_fp32 = comp_kv[:, :NOPE].contiguous() rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous() nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32) nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float() comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1) - cos = cosine(comp_kv, comp_kv_rt) - status = "PASS" if cos > 0.999 else "FAIL" - if cos < 0.999: all_pass = False - print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}") + ok = cos > 0.999 + if not ok: all_pass = False + print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}") - # ---- Test 3: FMHA with mixed FP8/BF16 KV vs SDPA ---- - print("\n--- Test 3: B1 FMHA with mixed FP8/BF16 KV vs SDPA ---") + # ---- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ---- + print("\n--- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ---") from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode for N in [128, 512, 1024]: - kv_nope_fp8 = torch.randint(0, 200, (N, NOPE), dtype=torch.uint8, device=device) - kv_nope_scale = torch.rand(N, dtype=torch.float32, device=device) * 0.1 + 0.01 + # Realistic FP8 quantized KV + kv_nope_fp32 = torch.randn(N, NOPE, dtype=torch.float32, device=device) * 0.3 kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3 + amax = kv_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + nope_scale = (amax / 448.0).squeeze(-1) + nope_clamped = (kv_nope_fp32 / nope_scale.unsqueeze(-1)).clamp(-448, 448) + kv_nope_fp8 = nope_clamped.to(torch.float8_e4m3fn).view(torch.uint8).contiguous() + kv_nope_scale = nope_scale.contiguous() q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3 + # Production FMHA (128 heads, each attends to the same KV) attn_out = dsv4_attention_mixed_fp8_decode( q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale, k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE) + # Reference: dequantize, run SDPA per-head (MQA: all Q heads share 1 KV head) nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float() k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) - k_4d = k_full.unsqueeze(0).unsqueeze(0) - v_4d = k_4d.clone() - q_4d = q.unsqueeze(0) - o_ref = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale) + # MQA reference: expand K/V for all Q heads + k_expanded = k_full.unsqueeze(0).expand(n_h, -1, -1) # (n_h, N, HD) + # SDPA per head + o_ref = torch.zeros_like(attn_out) + for h in range(n_h): + q_h = q[h:h+1] # (1, 1, HD) + k_h = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD) + v_h = k_h.clone() + q_4d = q_h.unsqueeze(0) # (1, 1, 1, HD) + o_h = F.scaled_dot_product_attention(q_4d, k_h, v_h, scale=scale) + o_ref[h] = o_h.squeeze() - cos = cosine(attn_out, o_ref.squeeze(0)) - status = "PASS" if cos > 0.999 else "FAIL" - if cos < 0.999: all_pass = False - print(f" N={N}: FMHA cos vs SDPA = {cos:.6f} {status}") + cos = cosine(attn_out, o_ref) + ok = cos > 0.999 + if not ok: all_pass = False + print(f" N={N}: cos={cos:.6f} {'PASS' if ok else 'FAIL'}") # ---- Summary ---- print("\n" + "=" * 70)