From c322e3f3010749709ef35cd91be897258d148a75 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 00:22:00 +0000 Subject: [PATCH] Add B1 FMHA debug test for cosine failure investigation --- tests/unit/test_b1_debug_cosine.py | 151 +++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 tests/unit/test_b1_debug_cosine.py diff --git a/tests/unit/test_b1_debug_cosine.py b/tests/unit/test_b1_debug_cosine.py new file mode 100644 index 00000000..7821e74c --- /dev/null +++ b/tests/unit/test_b1_debug_cosine.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +"""B1 FMHA debug: isolate the cosine failure to noPE vs RoPE path. + +Strategy: +1. Run mixed FP8 kernel with RoPE=0 (all noPE) → compare vs BF16 +2. Run mixed FP8 kernel with noPE=0 (all RoPE) → compare vs BF16 +3. Run with full split → see which part is broken +4. Print per-dimension residual to find where the error lives +""" +import sys +import math +import torch +import torch.nn.functional as F + + +def quantize_fp8_e4m3(x_fp32): + amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + scale = amax / 448.0 + fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn) + return fp8.view(torch.uint8), scale.squeeze(-1) + + +def dequantize_fp8_e4m3(fp8_uint8, scale): + fp8 = fp8_uint8.view(torch.float8_e4m3fn) + return fp8.float() * scale.unsqueeze(-1).float() + + +def cosine(a, b): + return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item() + + +def main(): + torch.manual_seed(42) + HD = 512; NOPE = 448; ROPE = 64 + H = 4; B = 1; T = 1; N = 128 # small for debugging + scale = 1.0 / math.sqrt(HD) + + print(f"=== B1 FMHA Debug: N={N} H={H} HD={HD} NOPE={NOPE} ROPE={ROPE} ===\n") + + # Generate Q and KV + q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5 + k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5 + q_bf16 = q_fp32.bfloat16().cuda() + + # Split KV + k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE]) + k_rope_bf16 = k_fp32[:, NOPE:].bfloat16() + k_nope_fp8 = k_nope_fp8.cuda() + k_nope_scale = k_nope_scale.cuda() + k_rope_bf16 = k_rope_bf16.cuda() + + # --- FP32 Reference --- + k_nope_dequant = dequantize_fp8_e4m3(k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu()) + k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:]], dim=-1) # (N, HD) FP32 + v_full = k_full.clone() + + q_f = q_fp32.cuda() # (B, H, 1, HD) + k_f = k_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda() + v_f = v_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda() + o_ref = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD) + + print(f"Reference output: |o|={o_ref.abs().max().item():.6f}") + print(f" head 0: {o_ref[0,0,0,:8].tolist()}") + print(f" head 0 noPE part: {o_ref[0,0,0,:8].tolist()}") + print(f" head 0 RoPE part: {o_ref[0,0,0,448:456].tolist()}") + + # --- Mixed FP8 kernel --- + from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw + o_mixed, lse = fmha_mixed_fp8_decode_raw( + q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE) + + print(f"\nMixed FP8 output: |o|={o_mixed.abs().max().item():.6f}") + print(f" head 0: {o_mixed[0,0,0,:8].tolist()}") + print(f" head 0 RoPE part: {o_mixed[0,0,0,448:456].tolist()}") + + # Global cosine + cos = cosine(o_mixed, o_ref.bfloat16()) + print(f"\nGlobal cosine: {cos:.6f}") + + # Per-head cosine + o_mixed_h = o_mixed.float().squeeze(2) # (B, H, HD) + o_ref_h = o_ref.bfloat16().float().squeeze(2) + per_head = F.cosine_similarity(o_mixed_h, o_ref_h, dim=-1) # (B, H) + print(f"Per-head cosine: {per_head[0].tolist()}") + print(f" min={per_head.min().item():.6f} mean={per_head.mean().item():.6f}") + + # --- Per-dimension analysis --- + # Compare noPE vs RoPE portions separately + o_mixed_nope = o_mixed[0, 0, 0, :NOPE].float() + o_ref_nope = o_ref[0, 0, 0, :NOPE].float() + o_mixed_rope = o_mixed[0, 0, 0, NOPE:].float() + o_ref_rope = o_ref[0, 0, 0, NOPE:].float() + + cos_nope = F.cosine_similarity(o_mixed_nope.unsqueeze(0), o_ref_nope.unsqueeze(0), dim=1).item() + cos_rope = F.cosine_similarity(o_mixed_rope.unsqueeze(0), o_ref_rope.unsqueeze(0), dim=1).item() + + print(f"\nPer-dim cosine (head 0):") + print(f" noPE (0..447): cos={cos_nope:.6f} |mixed|={o_mixed_nope.abs().max():.6f} |ref|={o_ref_nope.abs().max():.6f}") + print(f" RoPE (448..511): cos={cos_rope:.6f} |mixed|={o_mixed_rope.abs().max():.6f} |ref|={o_ref_rope.abs().max():.6f}") + + # Residual + residual = (o_mixed[0,0,0,:] - o_ref[0,0,0,:].bfloat16()).float() + print(f"\nResidual: |res|={residual.abs().max().item():.6f} mean={residual.mean().item():.6f}") + print(f" noPE residual: |res|={residual[:NOPE].abs().max().item():.6f}") + print(f" RoPE residual: |res|={residual[NOPE:].abs().max().item():.6f}") + + # --- Per-head breakdown --- + print(f"\nPer-head noPE/RoPE cosines:") + for h in range(H): + mn = o_mixed[0,h,0,:NOPE].float() + rn = o_ref[0,h,0,:NOPE].float() + mr = o_mixed[0,h,0,NOPE:].float() + rr = o_ref[0,h,0,NOPE:].float() + cn = F.cosine_similarity(mn.unsqueeze(0), rn.unsqueeze(0)).item() + cr = F.cosine_similarity(mr.unsqueeze(0), rr.unsqueeze(0)).item() + print(f" H{h}: noPE_cos={cn:.4f} rope_cos={cr:.4f} total_cos={per_head[0,h].item():.4f}") + + # --- Score comparison --- + # Compute reference scores manually + q_h0 = q_fp32[0, 0, 0, :].cuda().float() # (HD,) + k_all = k_full.cuda().float() # (N, HD) + scores_ref = torch.matmul(q_h0, k_all.T) * scale # (N,) + + # Compute noPE and RoPE scores separately + scores_nope_ref = torch.matmul(q_h0[:NOPE], k_all[:, :NOPE].T) * scale + scores_rope_ref = torch.matmul(q_h0[NOPE:], k_all[:, NOPE:].T) * scale + + print(f"\nScore analysis (head 0):") + print(f" noPE: [{scores_nope_ref.min().item():.4f}, {scores_nope_ref.max().item():.4f}]") + print(f" RoPE: [{scores_rope_ref.min().item():.4f}, {scores_rope_ref.max().item():.4f}]") + print(f" Total: [{scores_ref.min().item():.4f}, {scores_ref.max().item():.4f}]") + + # Check: are the noPE and RoPE scores the right order of magnitude? + nope_range = scores_nope_ref.max().item() - scores_nope_ref.min().item() + rope_range = scores_rope_ref.max().item() - scores_rope_ref.min().item() + print(f" noPE range: {nope_range:.4f} RoPE range: {rope_range:.4f}") + print(f" noPE/RoPE ratio: {nope_range/rope_range:.2f}" if rope_range > 0 else " RoPE range is zero!") + + # --- Check if the kernel is producing V=K correctly --- + # In MQA self-attention, V=K. Check if the output magnitude matches + # the expected: o = softmax(QK^T/sqrt(d)) @ K + # With K=V and N=128, the output should be a weighted average of K rows + print(f"\n K (row 0): {k_full[0,:8].tolist()}") + print(f" o (head 0): {o_mixed[0,0,0,:8].float().tolist()}") + print(f" o_ref (head 0): {o_ref[0,0,0,:8].tolist()}") + + sys.exit(0 if cos >= 0.999 else 1) + + +if __name__ == "__main__": + main()