Fix PART A test: proper FP8 quantization and MQA reference

This commit is contained in:
2026-06-03 04:20:36 +00:00
parent 4126909dfb
commit d8306be3f2

View File

@@ -1,13 +1,5 @@
#!/usr/bin/env python3
"""PART A diagnostic: Compressor + FMHA at production scale.
Tests:
1. CSA compression FP8/BF16 round-trip
2. HCA compression FP8/BF16 round-trip
3. B1 FMHA with mixed FP8/BF16 KV vs SDPA
All values are production: HD=512, NOPE=448, ROPE=64, H=128.
"""
"""PART A diagnostic: Compressor + FMHA at production scale."""
import sys, math
import torch
import torch.nn.functional as F
@@ -23,104 +15,95 @@ def main():
print("=" * 70)
print("PART A: Compressor + FMHA at Production Scale")
print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={n_h}")
print("=" * 70)
all_pass = True
# ---- Test 1: CSA compression round-trip ----
print("\n--- Test 1: CSA compression (ratio=4) with FP8/BF16 KV ---")
print("\n--- Test 1: CSA compression (ratio=4) ---")
from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
for T in [4, 16, 32, 64]:
m = 4
n_blocks = T // m
kv_dim = HD * 2
m = 4; n_blocks = T // m; kv_dim = HD * 2
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
compressed = csa_compress_production_fp32(
kv_proj, gate_proj, None, None, m=4)
if compressed.shape[0] == 0:
print(f" T={T}: n_blocks=0, SKIPPED")
continue
compressed = csa_compress_production_fp32(kv_proj, gate_proj, None, None, m=4)
if compressed.shape[0] == 0: print(f" T={T}: SKIP"); continue
comp_kv = compressed[:, :HD]
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
nope_fp32 = comp_kv[:, :NOPE].contiguous()
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
cos = cosine(comp_kv, comp_kv_rt)
status = "PASS" if cos > 0.999 else "FAIL"
if cos < 0.999: all_pass = False
print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}")
ok = cos > 0.999
if not ok: all_pass = False
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
# ---- Test 2: HCA compression (ratio=128) ----
print("\n--- Test 2: HCA compression (ratio=128) with FP8/BF16 KV ---")
# ---- Test 2: HCA compression round-trip ----
print("\n--- Test 2: HCA compression (ratio=128) ---")
from dsv4.kernels.compressor.production_compress import hca_compress_production_fp32
for T in [128, 256]:
m = 128
n_blocks = T // m
if n_blocks == 0:
print(f" T={T}: n_blocks=0, SKIPPED")
continue
m = 128; n_blocks = T // m
if n_blocks == 0: print(f" T={T}: SKIP"); continue
kv_dim = HD * 2
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
compressed = hca_compress_production_fp32(
kv_proj, gate_proj, None, None, m=128)
compressed = hca_compress_production_fp32(kv_proj, gate_proj, None, None, m=128)
comp_kv = compressed[:, :HD]
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
nope_fp32 = comp_kv[:, :NOPE].contiguous()
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
cos = cosine(comp_kv, comp_kv_rt)
status = "PASS" if cos > 0.999 else "FAIL"
if cos < 0.999: all_pass = False
print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}")
ok = cos > 0.999
if not ok: all_pass = False
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
# ---- Test 3: FMHA with mixed FP8/BF16 KV vs SDPA ----
print("\n--- Test 3: B1 FMHA with mixed FP8/BF16 KV vs SDPA ---")
# ---- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ----
print("\n--- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ---")
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
for N in [128, 512, 1024]:
kv_nope_fp8 = torch.randint(0, 200, (N, NOPE), dtype=torch.uint8, device=device)
kv_nope_scale = torch.rand(N, dtype=torch.float32, device=device) * 0.1 + 0.01
# Realistic FP8 quantized KV
kv_nope_fp32 = torch.randn(N, NOPE, dtype=torch.float32, device=device) * 0.3
kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3
amax = kv_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
nope_scale = (amax / 448.0).squeeze(-1)
nope_clamped = (kv_nope_fp32 / nope_scale.unsqueeze(-1)).clamp(-448, 448)
kv_nope_fp8 = nope_clamped.to(torch.float8_e4m3fn).view(torch.uint8).contiguous()
kv_nope_scale = nope_scale.contiguous()
q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3
# Production FMHA (128 heads, each attends to the same KV)
attn_out = dsv4_attention_mixed_fp8_decode(
q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale,
k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE)
# Reference: dequantize, run SDPA per-head (MQA: all Q heads share 1 KV head)
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1)
k_4d = k_full.unsqueeze(0).unsqueeze(0)
v_4d = k_4d.clone()
q_4d = q.unsqueeze(0)
o_ref = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale)
# MQA reference: expand K/V for all Q heads
k_expanded = k_full.unsqueeze(0).expand(n_h, -1, -1) # (n_h, N, HD)
# SDPA per head
o_ref = torch.zeros_like(attn_out)
for h in range(n_h):
q_h = q[h:h+1] # (1, 1, HD)
k_h = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
v_h = k_h.clone()
q_4d = q_h.unsqueeze(0) # (1, 1, 1, HD)
o_h = F.scaled_dot_product_attention(q_4d, k_h, v_h, scale=scale)
o_ref[h] = o_h.squeeze()
cos = cosine(attn_out, o_ref.squeeze(0))
status = "PASS" if cos > 0.999 else "FAIL"
if cos < 0.999: all_pass = False
print(f" N={N}: FMHA cos vs SDPA = {cos:.6f} {status}")
cos = cosine(attn_out, o_ref)
ok = cos > 0.999
if not ok: all_pass = False
print(f" N={N}: cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
# ---- Summary ----
print("\n" + "=" * 70)