Fix PART A test: proper FP8 quantization and MQA reference
This commit is contained in:
@@ -1,13 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""PART A diagnostic: Compressor + FMHA at production scale.
|
||||
|
||||
Tests:
|
||||
1. CSA compression FP8/BF16 round-trip
|
||||
2. HCA compression FP8/BF16 round-trip
|
||||
3. B1 FMHA with mixed FP8/BF16 KV vs SDPA
|
||||
|
||||
All values are production: HD=512, NOPE=448, ROPE=64, H=128.
|
||||
"""
|
||||
"""PART A diagnostic: Compressor + FMHA at production scale."""
|
||||
import sys, math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -23,104 +15,95 @@ def main():
|
||||
|
||||
print("=" * 70)
|
||||
print("PART A: Compressor + FMHA at Production Scale")
|
||||
print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={n_h}")
|
||||
print("=" * 70)
|
||||
|
||||
all_pass = True
|
||||
|
||||
# ---- Test 1: CSA compression round-trip ----
|
||||
print("\n--- Test 1: CSA compression (ratio=4) with FP8/BF16 KV ---")
|
||||
print("\n--- Test 1: CSA compression (ratio=4) ---")
|
||||
from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
|
||||
for T in [4, 16, 32, 64]:
|
||||
m = 4
|
||||
n_blocks = T // m
|
||||
kv_dim = HD * 2
|
||||
|
||||
m = 4; n_blocks = T // m; kv_dim = HD * 2
|
||||
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
|
||||
compressed = csa_compress_production_fp32(
|
||||
kv_proj, gate_proj, None, None, m=4)
|
||||
|
||||
if compressed.shape[0] == 0:
|
||||
print(f" T={T}: n_blocks=0, SKIPPED")
|
||||
continue
|
||||
|
||||
compressed = csa_compress_production_fp32(kv_proj, gate_proj, None, None, m=4)
|
||||
if compressed.shape[0] == 0: print(f" T={T}: SKIP"); continue
|
||||
comp_kv = compressed[:, :HD]
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
||||
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
||||
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
||||
|
||||
cos = cosine(comp_kv, comp_kv_rt)
|
||||
status = "PASS" if cos > 0.999 else "FAIL"
|
||||
if cos < 0.999: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}")
|
||||
ok = cos > 0.999
|
||||
if not ok: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# ---- Test 2: HCA compression (ratio=128) ----
|
||||
print("\n--- Test 2: HCA compression (ratio=128) with FP8/BF16 KV ---")
|
||||
# ---- Test 2: HCA compression round-trip ----
|
||||
print("\n--- Test 2: HCA compression (ratio=128) ---")
|
||||
from dsv4.kernels.compressor.production_compress import hca_compress_production_fp32
|
||||
|
||||
for T in [128, 256]:
|
||||
m = 128
|
||||
n_blocks = T // m
|
||||
if n_blocks == 0:
|
||||
print(f" T={T}: n_blocks=0, SKIPPED")
|
||||
continue
|
||||
|
||||
m = 128; n_blocks = T // m
|
||||
if n_blocks == 0: print(f" T={T}: SKIP"); continue
|
||||
kv_dim = HD * 2
|
||||
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
|
||||
compressed = hca_compress_production_fp32(
|
||||
kv_proj, gate_proj, None, None, m=128)
|
||||
|
||||
compressed = hca_compress_production_fp32(kv_proj, gate_proj, None, None, m=128)
|
||||
comp_kv = compressed[:, :HD]
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
||||
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
||||
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
||||
|
||||
cos = cosine(comp_kv, comp_kv_rt)
|
||||
status = "PASS" if cos > 0.999 else "FAIL"
|
||||
if cos < 0.999: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}")
|
||||
ok = cos > 0.999
|
||||
if not ok: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# ---- Test 3: FMHA with mixed FP8/BF16 KV vs SDPA ----
|
||||
print("\n--- Test 3: B1 FMHA with mixed FP8/BF16 KV vs SDPA ---")
|
||||
# ---- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ----
|
||||
print("\n--- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ---")
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
||||
|
||||
for N in [128, 512, 1024]:
|
||||
kv_nope_fp8 = torch.randint(0, 200, (N, NOPE), dtype=torch.uint8, device=device)
|
||||
kv_nope_scale = torch.rand(N, dtype=torch.float32, device=device) * 0.1 + 0.01
|
||||
# Realistic FP8 quantized KV
|
||||
kv_nope_fp32 = torch.randn(N, NOPE, dtype=torch.float32, device=device) * 0.3
|
||||
kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3
|
||||
amax = kv_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
nope_scale = (amax / 448.0).squeeze(-1)
|
||||
nope_clamped = (kv_nope_fp32 / nope_scale.unsqueeze(-1)).clamp(-448, 448)
|
||||
kv_nope_fp8 = nope_clamped.to(torch.float8_e4m3fn).view(torch.uint8).contiguous()
|
||||
kv_nope_scale = nope_scale.contiguous()
|
||||
|
||||
q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3
|
||||
|
||||
# Production FMHA (128 heads, each attends to the same KV)
|
||||
attn_out = dsv4_attention_mixed_fp8_decode(
|
||||
q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE)
|
||||
|
||||
# Reference: dequantize, run SDPA per-head (MQA: all Q heads share 1 KV head)
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1)
|
||||
k_4d = k_full.unsqueeze(0).unsqueeze(0)
|
||||
v_4d = k_4d.clone()
|
||||
q_4d = q.unsqueeze(0)
|
||||
o_ref = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale)
|
||||
# MQA reference: expand K/V for all Q heads
|
||||
k_expanded = k_full.unsqueeze(0).expand(n_h, -1, -1) # (n_h, N, HD)
|
||||
# SDPA per head
|
||||
o_ref = torch.zeros_like(attn_out)
|
||||
for h in range(n_h):
|
||||
q_h = q[h:h+1] # (1, 1, HD)
|
||||
k_h = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
|
||||
v_h = k_h.clone()
|
||||
q_4d = q_h.unsqueeze(0) # (1, 1, 1, HD)
|
||||
o_h = F.scaled_dot_product_attention(q_4d, k_h, v_h, scale=scale)
|
||||
o_ref[h] = o_h.squeeze()
|
||||
|
||||
cos = cosine(attn_out, o_ref.squeeze(0))
|
||||
status = "PASS" if cos > 0.999 else "FAIL"
|
||||
if cos < 0.999: all_pass = False
|
||||
print(f" N={N}: FMHA cos vs SDPA = {cos:.6f} {status}")
|
||||
cos = cosine(attn_out, o_ref)
|
||||
ok = cos > 0.999
|
||||
if not ok: all_pass = False
|
||||
print(f" N={N}: cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
||||
|
||||
# ---- Summary ----
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
Reference in New Issue
Block a user