Add B1 QK vs PV isolation test
This commit is contained in:
171
tests/unit/test_b1_isolate_qk_pv.py
Normal file
171
tests/unit/test_b1_isolate_qk_pv.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
"""B1 FMHA isolate: test QK scores separately from PV.
|
||||
|
||||
Strategy:
|
||||
1. Compute QK scores with the kernel and with reference
|
||||
2. If QK is wrong, the bug is in QK. If QK is right, bug is in PV.
|
||||
3. Also test: single-head, single N=128 to minimize moving parts.
|
||||
"""
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def quantize_fp8_e4m3(x_fp32):
|
||||
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
scale = amax / 448.0
|
||||
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
||||
return fp8.view(torch.uint8), scale.squeeze(-1)
|
||||
|
||||
|
||||
def dequantize_fp8_e4m3(fp8_uint8, scale):
|
||||
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
|
||||
return fp8.float() * scale.unsqueeze(-1).float()
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
HD = 512; NOPE = 448; ROPE = 64
|
||||
H = 1; B = 1; T = 1; N = 128
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
|
||||
print(f"=== B1 FMHA Isolate: QK vs PV (N={N} H={H}) ===\n")
|
||||
|
||||
# Generate Q and KV
|
||||
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda(); k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
# --- Reference QK scores (head 0) ---
|
||||
k_nope_dequant = dequantize_fp8_e4m3(k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu()).cuda()
|
||||
k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:].cuda()], dim=-1) # (N, HD)
|
||||
|
||||
q_h0 = q_fp32[0, 0, 0, :].cuda().float()
|
||||
scores_ref = torch.matmul(q_h0, k_full.T) * scale # (N,)
|
||||
|
||||
# Separate noPE and RoPE scores
|
||||
scores_nope_ref = torch.matmul(q_h0[:NOPE], k_full[:, :NOPE].T) * scale
|
||||
scores_rope_ref = torch.matmul(q_h0[NOPE:], k_full[:, NOPE:].T) * scale
|
||||
|
||||
print(f"Reference scores (head 0):")
|
||||
print(f" Total: [{scores_ref.min():.4f}, {scores_ref.max():.4f}]")
|
||||
print(f" noPE: [{scores_nope_ref.min():.4f}, {scores_nope_ref.max():.4f}]")
|
||||
print(f" RoPE: [{scores_rope_ref.min():.4f}, {scores_rope_ref.max():.4f}]")
|
||||
print(f" First 8 total scores: {scores_ref[:8].tolist()}")
|
||||
|
||||
# --- Run kernel and extract LSE ---
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
o_mixed, lse = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
|
||||
# LSE should equal logsumexp of scores
|
||||
ref_lse = torch.logsumexp(scores_ref, dim=0)
|
||||
print(f"\nLSE comparison (head 0):")
|
||||
print(f" Kernel LSE: {lse[0,0,0].item():.4f}")
|
||||
print(f" Reference LSE: {ref_lse.item():.4f}")
|
||||
print(f" Diff: {abs(lse[0,0,0].item() - ref_lse.item()):.4f}")
|
||||
|
||||
# If LSE is close but output is wrong, bug is in PV.
|
||||
# If LSE is far off, bug is in QK.
|
||||
lse_close = abs(lse[0,0,0].item() - ref_lse.item()) < 0.1
|
||||
|
||||
# --- Check softmax probabilities ---
|
||||
# From reference scores
|
||||
probs_ref = F.softmax(scores_ref, dim=0)
|
||||
print(f"\nReference softmax probs: [{probs_ref.min():.6f}, {probs_ref.max():.6f}]")
|
||||
print(f" First 8 probs: {probs_ref[:8].tolist()}")
|
||||
|
||||
# --- Check output = P @ V ---
|
||||
# Reference: o = probs @ K (since V = K)
|
||||
o_ref = torch.matmul(probs_ref.unsqueeze(0), k_full).squeeze(0) # (HD,)
|
||||
|
||||
o_mixed_h0 = o_mixed[0, 0, 0, :].float()
|
||||
cos = F.cosine_similarity(o_mixed_h0.unsqueeze(0), o_ref.unsqueeze(0)).item()
|
||||
|
||||
print(f"\nOutput comparison (head 0):")
|
||||
print(f" cos(mixed, ref_P@V) = {cos:.6f}")
|
||||
print(f" |mixed| = {o_mixed_h0.norm():.6f}")
|
||||
print(f" |ref_P@V| = {o_ref.norm():.6f}")
|
||||
|
||||
# --- Check if PV is computing P @ V correctly ---
|
||||
# Compute P @ V step by step
|
||||
# The kernel does PV by splitting V into (SK_TILE, 16) sub-tiles
|
||||
# For N=128, HD=512: 32 sub-tiles of 16 dims each
|
||||
# P is (1, 128), V is (128, 512)
|
||||
# Expected: (1, 512)
|
||||
|
||||
# Verify that ref P@V matches the simple attention output
|
||||
o_ref_full = F.scaled_dot_product_attention(
|
||||
q_fp32.cuda()[:, :1, :, :], # (1, 1, 1, HD)
|
||||
k_full.unsqueeze(0).unsqueeze(0), # (1, 1, N, HD)
|
||||
k_full.unsqueeze(0).unsqueeze(0), # V=K
|
||||
scale=scale
|
||||
)
|
||||
cos_ref = F.cosine_similarity(o_ref.unsqueeze(0), o_ref_full[0,0,0,:].unsqueeze(0)).item()
|
||||
print(f" cos(ref_P@V, ref_SDPA) = {cos_ref:.6f}")
|
||||
|
||||
# --- Analyze the PV sub-tile structure ---
|
||||
# The kernel computes PV as:
|
||||
# For n_sub = 0..31:
|
||||
# MMA(P[128x128], V[128x16]) → TMEM[128x16] at offset n_sub*16
|
||||
# Then reads TMEM and accumulates
|
||||
#
|
||||
# The V matrix construction: V[row, d_base+dd] where d_base = n_sub*16
|
||||
# For noPE V: dequantized from FP8
|
||||
# For RoPE V: directly from k_rope_bf16
|
||||
|
||||
# Check that the V matrix indexing matches
|
||||
# V should be K, so V[row, d] = K[row, d]
|
||||
print(f"\nV matrix sanity (row 0):")
|
||||
# noPE part
|
||||
v_nope_ref = k_nope_dequant[0, :8] # first 8 noPE dims
|
||||
print(f" K_nope[0,:8] (dequant): {v_nope_ref.tolist()}")
|
||||
print(f" K_orig[0,:8]: {k_fp32[0, :8].tolist()}")
|
||||
cos_v = F.cosine_similarity(v_nope_ref.unsqueeze(0), k_fp32[0, :8].unsqueeze(0)).item()
|
||||
print(f" cos(dequant, orig) = {cos_v:.6f}")
|
||||
|
||||
# Diagnose
|
||||
if lse_close:
|
||||
print(f"\n*** DIAGNOSIS: LSE is close ({abs(lse[0,0,0].item() - ref_lse.item()):.4f}) but output cos is {cos:.6f}")
|
||||
print(f"*** BUG IS IN PV (probability-value multiply), NOT IN QK")
|
||||
else:
|
||||
print(f"\n*** DIAGNOSIS: LSE is far off ({abs(lse[0,0,0].item() - ref_lse.item()):.4f})")
|
||||
print(f"*** BUG IS IN QK (query-key scoring)")
|
||||
|
||||
# --- Extra: compare noPE-only output ---
|
||||
# Zero out RoPE dims in Q and K, run kernel, compare
|
||||
print(f"\n--- noPE-only test (RoPE zeroed) ---")
|
||||
q_nope_only = q_bf16.clone()
|
||||
q_nope_only[:, :, :, NOPE:] = 0 # zero RoPE in Q
|
||||
k_rope_zero = torch.zeros(N, ROPE, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
try:
|
||||
o_nope, lse_nope = fmha_mixed_fp8_decode_raw(
|
||||
q_nope_only, k_nope_fp8, k_nope_scale, k_rope_zero, scale, rope_dim=ROPE)
|
||||
|
||||
# Reference with zeroed RoPE
|
||||
q_nz = q_fp32.clone().cuda()
|
||||
q_nz[:, :, :, NOPE:] = 0
|
||||
k_nz = k_full.clone()
|
||||
k_nz[:, NOPE:] = 0
|
||||
o_ref_nope = F.scaled_dot_product_attention(
|
||||
q_nz, k_nz.unsqueeze(0).unsqueeze(0), k_nz.unsqueeze(0).unsqueeze(0), scale=scale)
|
||||
|
||||
cos_nope = F.cosine_similarity(
|
||||
o_nope[0,0,0,:].float().unsqueeze(0),
|
||||
o_ref_nope[0,0,0,:].float().unsqueeze(0)).item()
|
||||
print(f" noPE-only cos = {cos_nope:.6f}")
|
||||
except Exception as e:
|
||||
print(f" noPE-only test failed: {e}")
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user