Add B1 QK vs PV isolation test

This commit is contained in:
2026-06-03 00:23:35 +00:00
parent c322e3f301
commit 29a95a3db6

View File

@@ -0,0 +1,171 @@
#!/usr/bin/env python3
"""B1 FMHA isolate: test QK scores separately from PV.
Strategy:
1. Compute QK scores with the kernel and with reference
2. If QK is wrong, the bug is in QK. If QK is right, bug is in PV.
3. Also test: single-head, single N=128 to minimize moving parts.
"""
import sys
import math
import torch
import torch.nn.functional as F
def quantize_fp8_e4m3(x_fp32):
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
scale = amax / 448.0
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
return fp8.view(torch.uint8), scale.squeeze(-1)
def dequantize_fp8_e4m3(fp8_uint8, scale):
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
return fp8.float() * scale.unsqueeze(-1).float()
def main():
torch.manual_seed(42)
HD = 512; NOPE = 448; ROPE = 64
H = 1; B = 1; T = 1; N = 128
scale = 1.0 / math.sqrt(HD)
print(f"=== B1 FMHA Isolate: QK vs PV (N={N} H={H}) ===\n")
# Generate Q and KV
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
q_bf16 = q_fp32.bfloat16().cuda()
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
k_nope_fp8 = k_nope_fp8.cuda(); k_nope_scale = k_nope_scale.cuda()
k_rope_bf16 = k_rope_bf16.cuda()
# --- Reference QK scores (head 0) ---
k_nope_dequant = dequantize_fp8_e4m3(k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu()).cuda()
k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:].cuda()], dim=-1) # (N, HD)
q_h0 = q_fp32[0, 0, 0, :].cuda().float()
scores_ref = torch.matmul(q_h0, k_full.T) * scale # (N,)
# Separate noPE and RoPE scores
scores_nope_ref = torch.matmul(q_h0[:NOPE], k_full[:, :NOPE].T) * scale
scores_rope_ref = torch.matmul(q_h0[NOPE:], k_full[:, NOPE:].T) * scale
print(f"Reference scores (head 0):")
print(f" Total: [{scores_ref.min():.4f}, {scores_ref.max():.4f}]")
print(f" noPE: [{scores_nope_ref.min():.4f}, {scores_nope_ref.max():.4f}]")
print(f" RoPE: [{scores_rope_ref.min():.4f}, {scores_rope_ref.max():.4f}]")
print(f" First 8 total scores: {scores_ref[:8].tolist()}")
# --- Run kernel and extract LSE ---
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
o_mixed, lse = fmha_mixed_fp8_decode_raw(
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
# LSE should equal logsumexp of scores
ref_lse = torch.logsumexp(scores_ref, dim=0)
print(f"\nLSE comparison (head 0):")
print(f" Kernel LSE: {lse[0,0,0].item():.4f}")
print(f" Reference LSE: {ref_lse.item():.4f}")
print(f" Diff: {abs(lse[0,0,0].item() - ref_lse.item()):.4f}")
# If LSE is close but output is wrong, bug is in PV.
# If LSE is far off, bug is in QK.
lse_close = abs(lse[0,0,0].item() - ref_lse.item()) < 0.1
# --- Check softmax probabilities ---
# From reference scores
probs_ref = F.softmax(scores_ref, dim=0)
print(f"\nReference softmax probs: [{probs_ref.min():.6f}, {probs_ref.max():.6f}]")
print(f" First 8 probs: {probs_ref[:8].tolist()}")
# --- Check output = P @ V ---
# Reference: o = probs @ K (since V = K)
o_ref = torch.matmul(probs_ref.unsqueeze(0), k_full).squeeze(0) # (HD,)
o_mixed_h0 = o_mixed[0, 0, 0, :].float()
cos = F.cosine_similarity(o_mixed_h0.unsqueeze(0), o_ref.unsqueeze(0)).item()
print(f"\nOutput comparison (head 0):")
print(f" cos(mixed, ref_P@V) = {cos:.6f}")
print(f" |mixed| = {o_mixed_h0.norm():.6f}")
print(f" |ref_P@V| = {o_ref.norm():.6f}")
# --- Check if PV is computing P @ V correctly ---
# Compute P @ V step by step
# The kernel does PV by splitting V into (SK_TILE, 16) sub-tiles
# For N=128, HD=512: 32 sub-tiles of 16 dims each
# P is (1, 128), V is (128, 512)
# Expected: (1, 512)
# Verify that ref P@V matches the simple attention output
o_ref_full = F.scaled_dot_product_attention(
q_fp32.cuda()[:, :1, :, :], # (1, 1, 1, HD)
k_full.unsqueeze(0).unsqueeze(0), # (1, 1, N, HD)
k_full.unsqueeze(0).unsqueeze(0), # V=K
scale=scale
)
cos_ref = F.cosine_similarity(o_ref.unsqueeze(0), o_ref_full[0,0,0,:].unsqueeze(0)).item()
print(f" cos(ref_P@V, ref_SDPA) = {cos_ref:.6f}")
# --- Analyze the PV sub-tile structure ---
# The kernel computes PV as:
# For n_sub = 0..31:
# MMA(P[128x128], V[128x16]) → TMEM[128x16] at offset n_sub*16
# Then reads TMEM and accumulates
#
# The V matrix construction: V[row, d_base+dd] where d_base = n_sub*16
# For noPE V: dequantized from FP8
# For RoPE V: directly from k_rope_bf16
# Check that the V matrix indexing matches
# V should be K, so V[row, d] = K[row, d]
print(f"\nV matrix sanity (row 0):")
# noPE part
v_nope_ref = k_nope_dequant[0, :8] # first 8 noPE dims
print(f" K_nope[0,:8] (dequant): {v_nope_ref.tolist()}")
print(f" K_orig[0,:8]: {k_fp32[0, :8].tolist()}")
cos_v = F.cosine_similarity(v_nope_ref.unsqueeze(0), k_fp32[0, :8].unsqueeze(0)).item()
print(f" cos(dequant, orig) = {cos_v:.6f}")
# Diagnose
if lse_close:
print(f"\n*** DIAGNOSIS: LSE is close ({abs(lse[0,0,0].item() - ref_lse.item()):.4f}) but output cos is {cos:.6f}")
print(f"*** BUG IS IN PV (probability-value multiply), NOT IN QK")
else:
print(f"\n*** DIAGNOSIS: LSE is far off ({abs(lse[0,0,0].item() - ref_lse.item()):.4f})")
print(f"*** BUG IS IN QK (query-key scoring)")
# --- Extra: compare noPE-only output ---
# Zero out RoPE dims in Q and K, run kernel, compare
print(f"\n--- noPE-only test (RoPE zeroed) ---")
q_nope_only = q_bf16.clone()
q_nope_only[:, :, :, NOPE:] = 0 # zero RoPE in Q
k_rope_zero = torch.zeros(N, ROPE, dtype=torch.bfloat16, device='cuda')
try:
o_nope, lse_nope = fmha_mixed_fp8_decode_raw(
q_nope_only, k_nope_fp8, k_nope_scale, k_rope_zero, scale, rope_dim=ROPE)
# Reference with zeroed RoPE
q_nz = q_fp32.clone().cuda()
q_nz[:, :, :, NOPE:] = 0
k_nz = k_full.clone()
k_nz[:, NOPE:] = 0
o_ref_nope = F.scaled_dot_product_attention(
q_nz, k_nz.unsqueeze(0).unsqueeze(0), k_nz.unsqueeze(0).unsqueeze(0), scale=scale)
cos_nope = F.cosine_similarity(
o_nope[0,0,0,:].float().unsqueeze(0),
o_ref_nope[0,0,0,:].float().unsqueeze(0)).item()
print(f" noPE-only cos = {cos_nope:.6f}")
except Exception as e:
print(f" noPE-only test failed: {e}")
sys.exit(0)
if __name__ == "__main__":
main()