diff --git a/tests/unit/test_b1_isolate_qk_pv.py b/tests/unit/test_b1_isolate_qk_pv.py new file mode 100644 index 00000000..866f6c07 --- /dev/null +++ b/tests/unit/test_b1_isolate_qk_pv.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""B1 FMHA isolate: test QK scores separately from PV. + +Strategy: +1. Compute QK scores with the kernel and with reference +2. If QK is wrong, the bug is in QK. If QK is right, bug is in PV. +3. Also test: single-head, single N=128 to minimize moving parts. +""" +import sys +import math +import torch +import torch.nn.functional as F + + +def quantize_fp8_e4m3(x_fp32): + amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = amax / 448.0 + fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn) + return fp8.view(torch.uint8), scale.squeeze(-1) + + +def dequantize_fp8_e4m3(fp8_uint8, scale): + fp8 = fp8_uint8.view(torch.float8_e4m3fn) + return fp8.float() * scale.unsqueeze(-1).float() + + +def main(): + torch.manual_seed(42) + HD = 512; NOPE = 448; ROPE = 64 + H = 1; B = 1; T = 1; N = 128 + scale = 1.0 / math.sqrt(HD) + + print(f"=== B1 FMHA Isolate: QK vs PV (N={N} H={H}) ===\n") + + # Generate Q and KV + q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5 + k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5 + q_bf16 = q_fp32.bfloat16().cuda() + + k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE]) + k_rope_bf16 = k_fp32[:, NOPE:].bfloat16() + k_nope_fp8 = k_nope_fp8.cuda(); k_nope_scale = k_nope_scale.cuda() + k_rope_bf16 = k_rope_bf16.cuda() + + # --- Reference QK scores (head 0) --- + k_nope_dequant = dequantize_fp8_e4m3(k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu()).cuda() + k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:].cuda()], dim=-1) # (N, HD) + + q_h0 = q_fp32[0, 0, 0, :].cuda().float() + scores_ref = torch.matmul(q_h0, k_full.T) * scale # (N,) + + # Separate noPE and RoPE scores + scores_nope_ref = torch.matmul(q_h0[:NOPE], k_full[:, :NOPE].T) * scale + scores_rope_ref = torch.matmul(q_h0[NOPE:], k_full[:, NOPE:].T) * scale + + print(f"Reference scores (head 0):") + print(f" Total: [{scores_ref.min():.4f}, {scores_ref.max():.4f}]") + print(f" noPE: [{scores_nope_ref.min():.4f}, {scores_nope_ref.max():.4f}]") + print(f" RoPE: [{scores_rope_ref.min():.4f}, {scores_rope_ref.max():.4f}]") + print(f" First 8 total scores: {scores_ref[:8].tolist()}") + + # --- Run kernel and extract LSE --- + from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw + o_mixed, lse = fmha_mixed_fp8_decode_raw( + q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE) + + # LSE should equal logsumexp of scores + ref_lse = torch.logsumexp(scores_ref, dim=0) + print(f"\nLSE comparison (head 0):") + print(f" Kernel LSE: {lse[0,0,0].item():.4f}") + print(f" Reference LSE: {ref_lse.item():.4f}") + print(f" Diff: {abs(lse[0,0,0].item() - ref_lse.item()):.4f}") + + # If LSE is close but output is wrong, bug is in PV. + # If LSE is far off, bug is in QK. + lse_close = abs(lse[0,0,0].item() - ref_lse.item()) < 0.1 + + # --- Check softmax probabilities --- + # From reference scores + probs_ref = F.softmax(scores_ref, dim=0) + print(f"\nReference softmax probs: [{probs_ref.min():.6f}, {probs_ref.max():.6f}]") + print(f" First 8 probs: {probs_ref[:8].tolist()}") + + # --- Check output = P @ V --- + # Reference: o = probs @ K (since V = K) + o_ref = torch.matmul(probs_ref.unsqueeze(0), k_full).squeeze(0) # (HD,) + + o_mixed_h0 = o_mixed[0, 0, 0, :].float() + cos = F.cosine_similarity(o_mixed_h0.unsqueeze(0), o_ref.unsqueeze(0)).item() + + print(f"\nOutput comparison (head 0):") + print(f" cos(mixed, ref_P@V) = {cos:.6f}") + print(f" |mixed| = {o_mixed_h0.norm():.6f}") + print(f" |ref_P@V| = {o_ref.norm():.6f}") + + # --- Check if PV is computing P @ V correctly --- + # Compute P @ V step by step + # The kernel does PV by splitting V into (SK_TILE, 16) sub-tiles + # For N=128, HD=512: 32 sub-tiles of 16 dims each + # P is (1, 128), V is (128, 512) + # Expected: (1, 512) + + # Verify that ref P@V matches the simple attention output + o_ref_full = F.scaled_dot_product_attention( + q_fp32.cuda()[:, :1, :, :], # (1, 1, 1, HD) + k_full.unsqueeze(0).unsqueeze(0), # (1, 1, N, HD) + k_full.unsqueeze(0).unsqueeze(0), # V=K + scale=scale + ) + cos_ref = F.cosine_similarity(o_ref.unsqueeze(0), o_ref_full[0,0,0,:].unsqueeze(0)).item() + print(f" cos(ref_P@V, ref_SDPA) = {cos_ref:.6f}") + + # --- Analyze the PV sub-tile structure --- + # The kernel computes PV as: + # For n_sub = 0..31: + # MMA(P[128x128], V[128x16]) → TMEM[128x16] at offset n_sub*16 + # Then reads TMEM and accumulates + # + # The V matrix construction: V[row, d_base+dd] where d_base = n_sub*16 + # For noPE V: dequantized from FP8 + # For RoPE V: directly from k_rope_bf16 + + # Check that the V matrix indexing matches + # V should be K, so V[row, d] = K[row, d] + print(f"\nV matrix sanity (row 0):") + # noPE part + v_nope_ref = k_nope_dequant[0, :8] # first 8 noPE dims + print(f" K_nope[0,:8] (dequant): {v_nope_ref.tolist()}") + print(f" K_orig[0,:8]: {k_fp32[0, :8].tolist()}") + cos_v = F.cosine_similarity(v_nope_ref.unsqueeze(0), k_fp32[0, :8].unsqueeze(0)).item() + print(f" cos(dequant, orig) = {cos_v:.6f}") + + # Diagnose + if lse_close: + print(f"\n*** DIAGNOSIS: LSE is close ({abs(lse[0,0,0].item() - ref_lse.item()):.4f}) but output cos is {cos:.6f}") + print(f"*** BUG IS IN PV (probability-value multiply), NOT IN QK") + else: + print(f"\n*** DIAGNOSIS: LSE is far off ({abs(lse[0,0,0].item() - ref_lse.item()):.4f})") + print(f"*** BUG IS IN QK (query-key scoring)") + + # --- Extra: compare noPE-only output --- + # Zero out RoPE dims in Q and K, run kernel, compare + print(f"\n--- noPE-only test (RoPE zeroed) ---") + q_nope_only = q_bf16.clone() + q_nope_only[:, :, :, NOPE:] = 0 # zero RoPE in Q + k_rope_zero = torch.zeros(N, ROPE, dtype=torch.bfloat16, device='cuda') + + try: + o_nope, lse_nope = fmha_mixed_fp8_decode_raw( + q_nope_only, k_nope_fp8, k_nope_scale, k_rope_zero, scale, rope_dim=ROPE) + + # Reference with zeroed RoPE + q_nz = q_fp32.clone().cuda() + q_nz[:, :, :, NOPE:] = 0 + k_nz = k_full.clone() + k_nz[:, NOPE:] = 0 + o_ref_nope = F.scaled_dot_product_attention( + q_nz, k_nz.unsqueeze(0).unsqueeze(0), k_nz.unsqueeze(0).unsqueeze(0), scale=scale) + + cos_nope = F.cosine_similarity( + o_nope[0,0,0,:].float().unsqueeze(0), + o_ref_nope[0,0,0,:].float().unsqueeze(0)).item() + print(f" noPE-only cos = {cos_nope:.6f}") + except Exception as e: + print(f" noPE-only test failed: {e}") + + sys.exit(0) + + +if __name__ == "__main__": + main()