Add comprehensive B1 mixed FP8 FMHA unit test
This commit is contained in:
641
tests/unit/test_b1_mixed_fp8_fmha.py
Normal file
641
tests/unit/test_b1_mixed_fp8_fmha.py
Normal file
@@ -0,0 +1,641 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Comprehensive unit test for B1 mixed FP8/BF16 decode FMHA.
|
||||
|
||||
Tests ALL components of the B1 pipeline at production values:
|
||||
1. quantize_q_fp8_split — Q BF16 → FP8 noPE + BF16 RoPE
|
||||
2. gather_mixed_selective/all/swa_only — KV gather preserving FP8
|
||||
3. fmha_mixed_fp8_decode_kernel — the actual FMHA at HD=512, H=128
|
||||
4. End-to-end: synthetic Q + KV → mixed FP8 FMHA → cosine vs BF16 reference
|
||||
|
||||
Production sizes: HD=512, NOPE=448, ROPE=64, H=128, N=128..2048.
|
||||
No shortcuts. No fallbacks. No toy values.
|
||||
"""
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def quantize_fp8_e4m3(x_fp32):
|
||||
"""Quantize FP32 tensor to FP8_E4M3 with per-row scale."""
|
||||
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
scale = amax / 448.0
|
||||
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
||||
return fp8.view(torch.uint8), scale.squeeze(-1)
|
||||
|
||||
|
||||
def dequantize_fp8_e4m3(fp8_uint8, scale):
|
||||
"""Dequantize FP8_E4M3 + per-row scale → FP32."""
|
||||
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
|
||||
return fp8.float() * scale.unsqueeze(-1).float()
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: quantize_q_fp8_split
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_quantize_q_fp8_split():
|
||||
"""Test Q quantization: BF16 → FP8 noPE + BF16 RoPE + FP32 scale."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: quantize_q_fp8_split")
|
||||
print("=" * 70)
|
||||
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import _quantize_q_split
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64
|
||||
B, H, T = 1, 128, 1 # production values
|
||||
|
||||
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
|
||||
q_nope_fp8, q_nope_scale, q_rope = _quantize_q_split(q_bf16, ROPE)
|
||||
|
||||
# Verify shapes
|
||||
assert q_nope_fp8.shape == (B, H, T, NOPE), \
|
||||
f"q_nope_fp8 shape {q_nope_fp8.shape} != expected {(B, H, T, NOPE)}"
|
||||
assert q_nope_scale.shape == (B, H, T), \
|
||||
f"q_nope_scale shape {q_nope_scale.shape} != expected {(B, H, T)}"
|
||||
assert q_rope.shape == (B, H, T, ROPE), \
|
||||
f"q_rope shape {q_rope.shape} != expected {(B, H, T, ROPE)}"
|
||||
|
||||
# Verify dtypes
|
||||
assert q_nope_fp8.dtype == torch.float8_e4m3fn, \
|
||||
f"q_nope_fp8 dtype {q_nope_fp8.dtype} != float8_e4m3fn"
|
||||
assert q_nope_scale.dtype == torch.float32, \
|
||||
f"q_nope_scale dtype {q_nope_scale.dtype} != float32"
|
||||
assert q_rope.dtype == torch.bfloat16, \
|
||||
f"q_rope dtype {q_rope.dtype} != bfloat16"
|
||||
|
||||
# Verify noPE quantization round-trip accuracy
|
||||
q_nope_dequant = dequantize_fp8_e4m3(
|
||||
q_nope_fp8.view(torch.uint8).cpu(), q_nope_scale.cpu())
|
||||
q_nope_ref = q_fp32[:, :, :, :NOPE]
|
||||
cos_nope = cosine(q_nope_dequant, q_nope_ref)
|
||||
print(f" Q noPE dequant cosine: {cos_nope:.6f}")
|
||||
assert cos_nope >= 0.999, f"Q noPE dequant cosine {cos_nope:.6f} < 0.999"
|
||||
|
||||
# Verify RoPE passthrough (should be exact)
|
||||
q_rope_ref = q_fp32[:, :, :, NOPE:]
|
||||
cos_rope = cosine(q_rope.cpu().float(), q_rope_ref)
|
||||
print(f" Q RoPE passthrough cosine: {cos_rope:.6f}")
|
||||
assert cos_rope >= 0.9999, f"Q RoPE passthrough cosine {cos_rope:.6f} < 0.9999"
|
||||
|
||||
# Per-head noPE cosine check
|
||||
q_nope_dequant_h = q_nope_dequant.reshape(B * H, NOPE)
|
||||
q_nope_ref_h = q_nope_ref.reshape(B * H, NOPE)
|
||||
per_head_cos = F.cosine_similarity(q_nope_dequant_h, q_nope_ref_h, dim=-1)
|
||||
min_head = per_head_cos.min().item()
|
||||
mean_head = per_head_cos.mean().item()
|
||||
print(f" Q noPE per-head cosine: min={min_head:.6f} mean={mean_head:.6f}")
|
||||
assert min_head >= 0.998, f"Q noPE min per-head cosine {min_head:.6f} < 0.998"
|
||||
|
||||
print(" PASS")
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: gather_mixed_selective / gather_mixed_all / gather_mixed_swa_only
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_gather_mixed_kernels():
|
||||
"""Test KV gather kernels: selective, all, swa_only."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: gather_mixed kernels")
|
||||
print("=" * 70)
|
||||
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("fp8_attention_io", ["fp8_attention_io.cu"],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
])
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64
|
||||
MAX_COMP = 128 # test with 128 compressed entries
|
||||
|
||||
# Generate compressed KV in storage format
|
||||
comp_fp32 = torch.randn(MAX_COMP, HD, dtype=torch.float32) * 0.5
|
||||
comp_nope_fp8, comp_nope_scale = quantize_fp8_e4m3(comp_fp32[:, :NOPE])
|
||||
comp_rope_bf16 = comp_fp32[:, NOPE:].bfloat16()
|
||||
|
||||
comp_nope_fp8 = comp_nope_fp8.cuda()
|
||||
comp_nope_scale = comp_nope_scale.cuda()
|
||||
comp_rope_bf16 = comp_rope_bf16.cuda()
|
||||
|
||||
# --- Test 2a: gather_mixed_all ---
|
||||
print("\n 2a: gather_mixed_all")
|
||||
swa_fp32 = torch.randn(32, HD, dtype=torch.float32) * 0.5
|
||||
swa_bf16 = swa_fp32.bfloat16().cuda()
|
||||
N_COMP = 64 # use first 64 compressed entries
|
||||
total = N_COMP + 32
|
||||
|
||||
out_nope_fp8 = torch.zeros(total, NOPE, dtype=torch.uint8, device='cuda')
|
||||
out_nope_scale = torch.zeros(total, dtype=torch.float32, device='cuda')
|
||||
out_rope_bf16 = torch.zeros(total, ROPE, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mod.gather_mixed_all_(
|
||||
comp_nope_fp8[:N_COMP], comp_nope_scale[:N_COMP], comp_rope_bf16[:N_COMP],
|
||||
swa_bf16, out_nope_fp8, out_nope_scale, out_rope_bf16)
|
||||
|
||||
# Verify compressed part (should be exact copy)
|
||||
assert torch.equal(out_nope_fp8[:N_COMP].cpu(), comp_nope_fp8[:N_COMP].cpu()), \
|
||||
"gather_mixed_all: noPE FP8 bytes mismatch for compressed rows"
|
||||
assert torch.allclose(out_nope_scale[:N_COMP].cpu(), comp_nope_scale[:N_COMP].cpu()), \
|
||||
"gather_mixed_all: noPE scale mismatch for compressed rows"
|
||||
assert torch.equal(out_rope_bf16[:N_COMP].cpu(), comp_rope_bf16[:N_COMP].cpu()), \
|
||||
"gather_mixed_all: RoPE BF16 mismatch for compressed rows"
|
||||
|
||||
# Verify SWA part (was BF16 → quantized to FP8, so round-trip loss expected)
|
||||
swa_nope_dequant = dequantize_fp8_e4m3(
|
||||
out_nope_fp8[N_COMP:].cpu(), out_nope_scale[N_COMP:].cpu())
|
||||
swa_nope_ref = swa_fp32[:, :NOPE]
|
||||
cos_swa_nope = cosine(swa_nope_dequant, swa_nope_ref)
|
||||
print(f" SWA noPE dequant cosine: {cos_swa_nope:.6f}")
|
||||
assert cos_swa_nope >= 0.999, f"SWA noPE dequant cosine {cos_swa_nope:.6f} < 0.999"
|
||||
|
||||
swa_rope_ref = swa_fp32[:, NOPE:]
|
||||
cos_swa_rope = cosine(out_rope_bf16[N_COMP:].cpu().float(), swa_rope_ref)
|
||||
print(f" SWA RoPE cosine: {cos_swa_rope:.6f}")
|
||||
assert cos_swa_rope >= 0.9999, f"SWA RoPE cosine {cos_swa_rope:.6f} < 0.9999"
|
||||
|
||||
print(" PASS")
|
||||
|
||||
# --- Test 2b: gather_mixed_selective ---
|
||||
print("\n 2b: gather_mixed_selective")
|
||||
indices = torch.tensor([5, 10, 20, 30, 50], dtype=torch.int32, device='cuda')
|
||||
K = indices.shape[0]
|
||||
total2 = K + 32 # 5 compressed + 32 SWA
|
||||
|
||||
out2_nope_fp8 = torch.zeros(total2, NOPE, dtype=torch.uint8, device='cuda')
|
||||
out2_nope_scale = torch.zeros(total2, dtype=torch.float32, device='cuda')
|
||||
out2_rope_bf16 = torch.zeros(total2, ROPE, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mod.gather_mixed_selective_(
|
||||
comp_nope_fp8, comp_nope_scale, comp_rope_bf16,
|
||||
swa_bf16, indices,
|
||||
out2_nope_fp8, out2_nope_scale, out2_rope_bf16)
|
||||
|
||||
# Verify selected compressed rows match original
|
||||
for i, idx in enumerate([5, 10, 20, 30, 50]):
|
||||
assert torch.equal(out2_nope_fp8[i].cpu(), comp_nope_fp8[idx].cpu()), \
|
||||
f"selective: noPE FP8 mismatch at index {idx}"
|
||||
assert torch.allclose(out2_nope_scale[i].cpu(), comp_nope_scale[idx].cpu()), \
|
||||
f"selective: noPE scale mismatch at index {idx}"
|
||||
assert torch.equal(out2_rope_bf16[i].cpu(), comp_rope_bf16[idx].cpu()), \
|
||||
f"selective: RoPE mismatch at index {idx}"
|
||||
|
||||
print(" PASS")
|
||||
|
||||
# --- Test 2c: gather_mixed_swa_only ---
|
||||
print("\n 2c: gather_mixed_swa_only")
|
||||
total3 = 32
|
||||
out3_nope_fp8 = torch.zeros(total3, NOPE, dtype=torch.uint8, device='cuda')
|
||||
out3_nope_scale = torch.zeros(total3, dtype=torch.float32, device='cuda')
|
||||
out3_rope_bf16 = torch.zeros(total3, ROPE, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mod.gather_mixed_swa_only_(
|
||||
swa_bf16, out3_nope_fp8, out3_nope_scale, out3_rope_bf16, ROPE)
|
||||
|
||||
swa3_nope_dequant = dequantize_fp8_e4m3(
|
||||
out3_nope_fp8.cpu(), out3_nope_scale.cpu())
|
||||
cos3 = cosine(swa3_nope_dequant, swa_fp32[:, :NOPE])
|
||||
print(f" SWA-only noPE dequant cosine: {cos3:.6f}")
|
||||
assert cos3 >= 0.999, f"SWA-only noPE cosine {cos3:.6f} < 0.999"
|
||||
|
||||
cos3_rope = cosine(out3_rope_bf16.cpu().float(), swa_fp32[:, NOPE:])
|
||||
print(f" SWA-only RoPE cosine: {cos3_rope:.6f}")
|
||||
assert cos3_rope >= 0.9999, f"SWA-only RoPE cosine {cos3_rope:.6f} < 0.9999"
|
||||
|
||||
print(" PASS")
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: Mixed FP8 FMHA decode kernel — cosine vs BF16 reference
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_fmha_mixed_fp8_decode():
|
||||
"""Test the B1 mixed FP8 decode FMHA at production values.
|
||||
|
||||
Production: HD=512, NOPE=448, ROPE=64, H=128, N=128..2048.
|
||||
Compares kernel output vs FP32 SDPA reference.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: fmha_mixed_fp8_decode — production values")
|
||||
print("=" * 70)
|
||||
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64; H = 128; B = 1
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
|
||||
N_values = [128, 256, 512, 1024, 2048]
|
||||
all_pass = True
|
||||
|
||||
for N in N_values:
|
||||
print(f"\n N={N} H={H} HD={HD}")
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Generate synthetic Q and KV
|
||||
q_fp32 = torch.randn(B, H, 1, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
|
||||
# Split KV into noPE (FP8) + RoPE (BF16)
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
# Run mixed FP8 decode
|
||||
try:
|
||||
o_mixed, lse = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
except Exception as e:
|
||||
print(f" MIXED FP8 FAILED: {e}")
|
||||
all_pass = False
|
||||
continue
|
||||
|
||||
# BF16 reference: dequantize noPE, concat, run FP32 SDPA
|
||||
k_nope_dequant = dequantize_fp8_e4m3(
|
||||
k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu())
|
||||
k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:]], dim=-1) # (N, HD) FP32
|
||||
k_full_bf16 = k_full.bfloat16().cuda()
|
||||
v_full_bf16 = k_full_bf16.clone()
|
||||
|
||||
# SDPA reference — FP32 math
|
||||
q_f = q_fp32.cuda() # (B, H, 1, HD) FP32
|
||||
k_f = k_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda() # (B, 1, N, HD)
|
||||
v_f = k_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda()
|
||||
o_ref = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD)
|
||||
o_ref_bf16 = o_ref.bfloat16()
|
||||
|
||||
# Global cosine
|
||||
cos_global = cosine(o_mixed, o_ref_bf16)
|
||||
|
||||
# Per-head cosine
|
||||
o_mixed_h = o_mixed.float().squeeze(2) # (B, H, HD)
|
||||
o_ref_h = o_ref_bf16.float().squeeze(2)
|
||||
per_head_cos = F.cosine_similarity(o_mixed_h, o_ref_h, dim=-1) # (B, H)
|
||||
min_cos = per_head_cos.min().item()
|
||||
mean_cos = per_head_cos.mean().item()
|
||||
|
||||
# Magnitude comparison
|
||||
mixed_max = o_mixed.float().abs().max().item()
|
||||
ref_max = o_ref_bf16.float().abs().max().item()
|
||||
mag_ratio = mixed_max / ref_max if ref_max > 0 else 0.0
|
||||
|
||||
# LSE comparison
|
||||
ref_scores = torch.matmul(q_f.squeeze(2), k_f.squeeze(1).transpose(-2, -1)) * scale
|
||||
ref_lse = torch.logsumexp(ref_scores, dim=-1) # (B, H, 1)
|
||||
|
||||
passed = cos_global >= 0.999
|
||||
status = "PASS" if passed else "FAIL"
|
||||
print(f" {status}: cos_global={cos_global:.6f} min_head={min_cos:.6f} "
|
||||
f"mean_head={mean_cos:.6f}")
|
||||
print(f" |mixed|={mixed_max:.4f} |ref|={ref_max:.4f} ratio={mag_ratio:.4f}")
|
||||
print(f" LSE: mixed={lse[0,0,0].item():.4f} ref={ref_lse[0,0,0].item():.4f} "
|
||||
f"diff={abs(lse[0,0,0].item() - ref_lse[0,0,0].item()):.4f}")
|
||||
|
||||
if not passed:
|
||||
all_pass = False
|
||||
# Print worst heads
|
||||
worst = per_head_cos[0].argsort()[:5]
|
||||
print(f" Worst heads: {worst.tolist()} cos={per_head_cos[0][worst].tolist()}")
|
||||
|
||||
return all_pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 4: Mixed FP8 FMHA with attention sinks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_fmha_mixed_fp8_with_sinks():
|
||||
"""Test B1 mixed FP8 FMHA with attention sink bias.
|
||||
|
||||
Production: same as test 3 but with non-zero sink bias.
|
||||
The sink bias adds a denominator-only logit to the softmax.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 4: fmha_mixed_fp8_decode with attention sinks")
|
||||
print("=" * 70)
|
||||
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64; H = 128; B = 1; N = 512
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
torch.manual_seed(42)
|
||||
|
||||
q_fp32 = torch.randn(B, H, 1, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
# Generate sink bias (production: per-head FP32)
|
||||
sink_bias = torch.randn(H, dtype=torch.float32) * 2.0
|
||||
|
||||
# Run with sink bias
|
||||
o_with_sink, lse_with = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale,
|
||||
attn_sink=sink_bias, rope_dim=ROPE)
|
||||
|
||||
# Run without sink bias
|
||||
o_no_sink, lse_no = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale,
|
||||
rope_dim=ROPE)
|
||||
|
||||
# With non-trivial sink bias, output SHOULD differ from no-sink
|
||||
diff = (o_with_sink - o_no_sink).float().abs().max().item()
|
||||
print(f" Max diff with/without sink: {diff:.6f}")
|
||||
assert diff > 1e-4, "Sink bias has no effect on output — kernel is ignoring it"
|
||||
|
||||
# Sanity: output magnitudes should be in same ballpark
|
||||
with_max = o_with_sink.float().abs().max().item()
|
||||
no_max = o_no_sink.float().abs().max().item()
|
||||
print(f" |with_sink|={with_max:.4f} |no_sink|={no_max:.4f}")
|
||||
assert 0.1 < with_max / no_max < 10.0, \
|
||||
f"Sink bias causing extreme magnitude shift: {with_max / no_max:.4f}"
|
||||
|
||||
print(" PASS")
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 5: Mixed FP8 FMHA — multi-head GQA (multiple Q per KV)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_fmha_mixed_fp8_gqa():
|
||||
"""Test B1 with GQA: 128 Q heads, 1 KV head (MQA, which is DSV4).
|
||||
|
||||
This tests that the kernel correctly handles 128 Q heads sharing one
|
||||
KV head, which is the actual production configuration.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 5: fmha_mixed_fp8_decode — GQA/MQA (H=128 Q heads, 1 KV head)")
|
||||
print("=" * 70)
|
||||
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64; H = 128; B = 1; N = 256
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
torch.manual_seed(42)
|
||||
|
||||
q_fp32 = torch.randn(B, H, 1, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
o_mixed, lse = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
|
||||
assert o_mixed.shape == (B, H, 1, HD), f"Output shape {o_mixed.shape} != {(B, H, 1, HD)}"
|
||||
assert lse.shape == (B, H, 1), f"LSE shape {lse.shape} != {(B, H, 1)}"
|
||||
assert not torch.isnan(o_mixed).any(), "NaN in output"
|
||||
assert not torch.isinf(o_mixed).any(), "Inf in output"
|
||||
|
||||
# Per-head variance check: all 128 heads should produce reasonable output
|
||||
o_max_per_head = o_mixed.float().abs().amax(dim=-1).squeeze(2) # (B, H)
|
||||
mean_max = o_max_per_head.mean().item()
|
||||
std_max = o_max_per_head.std().item()
|
||||
print(f" Per-head |o|_max: mean={mean_max:.4f} std={std_max:.4f}")
|
||||
print(f" |o| range: [{o_max_per_head.min().item():.4f}, {o_max_per_head.max().item():.4f}]")
|
||||
|
||||
# No head should produce zero output
|
||||
assert o_max_per_head.min().item() > 0.0, "A head produced zero output"
|
||||
|
||||
# LSE variance: shouldn't be degenerate
|
||||
lse_vals = lse.squeeze(2) # (B, H)
|
||||
print(f" LSE range: [{lse_vals.min().item():.4f}, {lse_vals.max().item():.4f}]")
|
||||
|
||||
print(" PASS")
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 6: Weight loading verification — print actual shapes and dtypes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_weight_loading():
|
||||
"""Verify that KV cache weights are loaded in the correct format.
|
||||
|
||||
This test checks that the production path uses FP8 for noPE and BF16 for RoPE.
|
||||
It does NOT run inference — it only inspects the data formats.
|
||||
Must be run on B200 with checkpoint access.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 6: Weight loading verification (requires checkpoint)")
|
||||
print("=" * 70)
|
||||
|
||||
# This test is designed to be run on the B200 where the checkpoint exists.
|
||||
# It prints the actual shapes and dtypes of the KV cache entries after
|
||||
# the first prefill step to verify B1 mixed format is correct.
|
||||
#
|
||||
# What we verify:
|
||||
# - comp_nope_fp8 is uint8 (storage for float8_e4m3fn)
|
||||
# - comp_nope_scale is float32
|
||||
# - comp_rope_bf16 is bfloat16
|
||||
# - comp_idx_fp8 is uint8 (indexer keys in FP8)
|
||||
# - comp_idx_scale is float32
|
||||
# - gather_nope_fp8 is uint8
|
||||
# - gather_rope_bf16 is bfloat16
|
||||
#
|
||||
# These are all checked via the KVCache constructor which allocates them,
|
||||
# so we can verify without loading the actual model.
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64
|
||||
MAX_COMP = 1024; INDEXER_TOP_K = 512; SWA = 4096
|
||||
|
||||
# Simulate KVCache allocations (mirrors single_shot_inference.py)
|
||||
comp_nope_fp8 = torch.zeros(MAX_COMP, NOPE, dtype=torch.uint8, device='cpu')
|
||||
comp_nope_scale = torch.zeros(MAX_COMP, dtype=torch.float32, device='cpu')
|
||||
comp_rope_bf16 = torch.zeros(MAX_COMP, ROPE, dtype=torch.bfloat16, device='cpu')
|
||||
comp_idx_fp8 = torch.zeros(MAX_COMP, 128, dtype=torch.uint8, device='cpu') # ihd=128
|
||||
comp_idx_scale = torch.zeros(MAX_COMP, dtype=torch.float32, device='cpu')
|
||||
gather_nope_fp8 = torch.zeros(MAX_COMP + SWA, NOPE, dtype=torch.uint8, device='cpu')
|
||||
gather_nope_scale = torch.zeros(MAX_COMP + SWA, dtype=torch.float32, device='cpu')
|
||||
gather_rope_bf16 = torch.zeros(MAX_COMP + SWA, ROPE, dtype=torch.bfloat16, device='cpu')
|
||||
|
||||
# Verify dtypes
|
||||
checks = [
|
||||
("comp_nope_fp8", comp_nope_fp8.dtype, torch.uint8),
|
||||
("comp_nope_scale", comp_nope_scale.dtype, torch.float32),
|
||||
("comp_rope_bf16", comp_rope_bf16.dtype, torch.bfloat16),
|
||||
("comp_idx_fp8", comp_idx_fp8.dtype, torch.uint8),
|
||||
("comp_idx_scale", comp_idx_scale.dtype, torch.float32),
|
||||
("gather_nope_fp8", gather_nope_fp8.dtype, torch.uint8),
|
||||
("gather_nope_scale", gather_nope_scale.dtype, torch.float32),
|
||||
("gather_rope_bf16", gather_rope_bf16.dtype, torch.bfloat16),
|
||||
]
|
||||
|
||||
all_ok = True
|
||||
for name, actual, expected in checks:
|
||||
ok = actual == expected
|
||||
status = "OK" if ok else "WRONG"
|
||||
if not ok: all_ok = False
|
||||
print(f" {name}: {actual} (expected {expected}) — {status}")
|
||||
|
||||
# Verify shapes
|
||||
shape_checks = [
|
||||
("comp_nope_fp8", comp_nope_fp8.shape, (MAX_COMP, NOPE)),
|
||||
("comp_rope_bf16", comp_rope_bf16.shape, (MAX_COMP, ROPE)),
|
||||
("comp_idx_fp8", comp_idx_fp8.shape, (MAX_COMP, 128)),
|
||||
("gather_nope_fp8", gather_nope_fp8.shape, (MAX_COMP + SWA, NOPE)),
|
||||
("gather_rope_bf16", gather_rope_bf16.shape, (MAX_COMP + SWA, ROPE)),
|
||||
]
|
||||
|
||||
for name, actual, expected in shape_checks:
|
||||
ok = actual == expected
|
||||
status = "OK" if ok else "WRONG"
|
||||
if not ok: all_ok = False
|
||||
print(f" {name} shape: {actual} (expected {expected}) — {status}")
|
||||
|
||||
# Verify the NOPE dimension matches the DSV4 architecture
|
||||
assert NOPE == HD - ROPE, f"NOPE ({NOPE}) != HD - ROPE ({HD} - {ROPE} = {HD - ROPE})"
|
||||
print(f" NOPE={NOPE} = HD({HD}) - ROPE({ROPE}) — OK")
|
||||
|
||||
if all_ok:
|
||||
print(" PASS")
|
||||
else:
|
||||
print(" FAIL: dtype/shape mismatches detected")
|
||||
return all_ok
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 7: Batch test — multiple batch sizes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_fmha_mixed_fp8_batch():
|
||||
"""Test B1 with different batch sizes (B=1,2,4)."""
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 7: fmha_mixed_fp8_decode — batch sizes")
|
||||
print("=" * 70)
|
||||
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
|
||||
HD = 512; NOPE = 448; ROPE = 64; H = 128; N = 256
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
|
||||
all_pass = True
|
||||
for B in [1, 2, 4]:
|
||||
print(f"\n B={B}")
|
||||
torch.manual_seed(42)
|
||||
q_fp32 = torch.randn(B, H, 1, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
try:
|
||||
o, lse = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
except Exception as e:
|
||||
print(f" FAILED: {e}")
|
||||
all_pass = False
|
||||
continue
|
||||
|
||||
assert o.shape == (B, H, 1, HD), f"Shape {o.shape} != {(B, H, 1, HD)}"
|
||||
assert not torch.isnan(o).any(), "NaN in output"
|
||||
cos = cosine(o, q_fp32.cuda().bfloat16()) # sanity: not trivially zero
|
||||
print(f" OK: shape={tuple(o.shape)} |o|={o.float().abs().max().item():.4f}")
|
||||
|
||||
return all_pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 70)
|
||||
print("B1 Mixed FP8/BF16 FMHA — Comprehensive Unit Test")
|
||||
print("Production values: HD=512, NOPE=448, ROPE=64, H=128")
|
||||
print("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
# Test 1: Q quantization
|
||||
try:
|
||||
results["1_quantize_q"] = test_quantize_q_fp8_split()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["1_quantize_q"] = False
|
||||
|
||||
# Test 2: Gather kernels
|
||||
try:
|
||||
results["2_gather_mixed"] = test_gather_mixed_kernels()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["2_gather_mixed"] = False
|
||||
|
||||
# Test 3: FMHA decode cosine
|
||||
try:
|
||||
results["3_fmha_cosine"] = test_fmha_mixed_fp8_decode()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["3_fmha_cosine"] = False
|
||||
|
||||
# Test 4: Attention sinks
|
||||
try:
|
||||
results["4_sinks"] = test_fmha_mixed_fp8_with_sinks()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["4_sinks"] = False
|
||||
|
||||
# Test 5: GQA/MQA
|
||||
try:
|
||||
results["5_gqa"] = test_fmha_mixed_fp8_gqa()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["5_gqa"] = False
|
||||
|
||||
# Test 6: Weight loading verification
|
||||
try:
|
||||
results["6_weight_loading"] = test_weight_loading()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["6_weight_loading"] = False
|
||||
|
||||
# Test 7: Batch sizes
|
||||
try:
|
||||
results["7_batch"] = test_fmha_mixed_fp8_batch()
|
||||
except Exception as e:
|
||||
print(f" EXCEPTION: {e}")
|
||||
results["7_batch"] = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("SUMMARY")
|
||||
print("=" * 70)
|
||||
all_pass = True
|
||||
for name, passed in results.items():
|
||||
status = "PASS" if passed else "FAIL"
|
||||
if not passed: all_pass = False
|
||||
print(f" {name}: {status}")
|
||||
|
||||
print()
|
||||
if all_pass:
|
||||
print("ALL TESTS PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("SOME TESTS FAILED")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user