Add B1 FMHA debug test for cosine failure investigation
This commit is contained in:
151
tests/unit/test_b1_debug_cosine.py
Normal file
151
tests/unit/test_b1_debug_cosine.py
Normal file
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
"""B1 FMHA debug: isolate the cosine failure to noPE vs RoPE path.
|
||||
|
||||
Strategy:
|
||||
1. Run mixed FP8 kernel with RoPE=0 (all noPE) → compare vs BF16
|
||||
2. Run mixed FP8 kernel with noPE=0 (all RoPE) → compare vs BF16
|
||||
3. Run with full split → see which part is broken
|
||||
4. Print per-dimension residual to find where the error lives
|
||||
"""
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def quantize_fp8_e4m3(x_fp32):
|
||||
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
scale = amax / 448.0
|
||||
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
||||
return fp8.view(torch.uint8), scale.squeeze(-1)
|
||||
|
||||
|
||||
def dequantize_fp8_e4m3(fp8_uint8, scale):
|
||||
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
|
||||
return fp8.float() * scale.unsqueeze(-1).float()
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
|
||||
def main():
|
||||
torch.manual_seed(42)
|
||||
HD = 512; NOPE = 448; ROPE = 64
|
||||
H = 4; B = 1; T = 1; N = 128 # small for debugging
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
|
||||
print(f"=== B1 FMHA Debug: N={N} H={H} HD={HD} NOPE={NOPE} ROPE={ROPE} ===\n")
|
||||
|
||||
# Generate Q and KV
|
||||
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
||||
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
||||
q_bf16 = q_fp32.bfloat16().cuda()
|
||||
|
||||
# Split KV
|
||||
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
||||
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
||||
k_nope_fp8 = k_nope_fp8.cuda()
|
||||
k_nope_scale = k_nope_scale.cuda()
|
||||
k_rope_bf16 = k_rope_bf16.cuda()
|
||||
|
||||
# --- FP32 Reference ---
|
||||
k_nope_dequant = dequantize_fp8_e4m3(k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu())
|
||||
k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:]], dim=-1) # (N, HD) FP32
|
||||
v_full = k_full.clone()
|
||||
|
||||
q_f = q_fp32.cuda() # (B, H, 1, HD)
|
||||
k_f = k_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda()
|
||||
v_f = v_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda()
|
||||
o_ref = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD)
|
||||
|
||||
print(f"Reference output: |o|={o_ref.abs().max().item():.6f}")
|
||||
print(f" head 0: {o_ref[0,0,0,:8].tolist()}")
|
||||
print(f" head 0 noPE part: {o_ref[0,0,0,:8].tolist()}")
|
||||
print(f" head 0 RoPE part: {o_ref[0,0,0,448:456].tolist()}")
|
||||
|
||||
# --- Mixed FP8 kernel ---
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
o_mixed, lse = fmha_mixed_fp8_decode_raw(
|
||||
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
||||
|
||||
print(f"\nMixed FP8 output: |o|={o_mixed.abs().max().item():.6f}")
|
||||
print(f" head 0: {o_mixed[0,0,0,:8].tolist()}")
|
||||
print(f" head 0 RoPE part: {o_mixed[0,0,0,448:456].tolist()}")
|
||||
|
||||
# Global cosine
|
||||
cos = cosine(o_mixed, o_ref.bfloat16())
|
||||
print(f"\nGlobal cosine: {cos:.6f}")
|
||||
|
||||
# Per-head cosine
|
||||
o_mixed_h = o_mixed.float().squeeze(2) # (B, H, HD)
|
||||
o_ref_h = o_ref.bfloat16().float().squeeze(2)
|
||||
per_head = F.cosine_similarity(o_mixed_h, o_ref_h, dim=-1) # (B, H)
|
||||
print(f"Per-head cosine: {per_head[0].tolist()}")
|
||||
print(f" min={per_head.min().item():.6f} mean={per_head.mean().item():.6f}")
|
||||
|
||||
# --- Per-dimension analysis ---
|
||||
# Compare noPE vs RoPE portions separately
|
||||
o_mixed_nope = o_mixed[0, 0, 0, :NOPE].float()
|
||||
o_ref_nope = o_ref[0, 0, 0, :NOPE].float()
|
||||
o_mixed_rope = o_mixed[0, 0, 0, NOPE:].float()
|
||||
o_ref_rope = o_ref[0, 0, 0, NOPE:].float()
|
||||
|
||||
cos_nope = F.cosine_similarity(o_mixed_nope.unsqueeze(0), o_ref_nope.unsqueeze(0), dim=1).item()
|
||||
cos_rope = F.cosine_similarity(o_mixed_rope.unsqueeze(0), o_ref_rope.unsqueeze(0), dim=1).item()
|
||||
|
||||
print(f"\nPer-dim cosine (head 0):")
|
||||
print(f" noPE (0..447): cos={cos_nope:.6f} |mixed|={o_mixed_nope.abs().max():.6f} |ref|={o_ref_nope.abs().max():.6f}")
|
||||
print(f" RoPE (448..511): cos={cos_rope:.6f} |mixed|={o_mixed_rope.abs().max():.6f} |ref|={o_ref_rope.abs().max():.6f}")
|
||||
|
||||
# Residual
|
||||
residual = (o_mixed[0,0,0,:] - o_ref[0,0,0,:].bfloat16()).float()
|
||||
print(f"\nResidual: |res|={residual.abs().max().item():.6f} mean={residual.mean().item():.6f}")
|
||||
print(f" noPE residual: |res|={residual[:NOPE].abs().max().item():.6f}")
|
||||
print(f" RoPE residual: |res|={residual[NOPE:].abs().max().item():.6f}")
|
||||
|
||||
# --- Per-head breakdown ---
|
||||
print(f"\nPer-head noPE/RoPE cosines:")
|
||||
for h in range(H):
|
||||
mn = o_mixed[0,h,0,:NOPE].float()
|
||||
rn = o_ref[0,h,0,:NOPE].float()
|
||||
mr = o_mixed[0,h,0,NOPE:].float()
|
||||
rr = o_ref[0,h,0,NOPE:].float()
|
||||
cn = F.cosine_similarity(mn.unsqueeze(0), rn.unsqueeze(0)).item()
|
||||
cr = F.cosine_similarity(mr.unsqueeze(0), rr.unsqueeze(0)).item()
|
||||
print(f" H{h}: noPE_cos={cn:.4f} rope_cos={cr:.4f} total_cos={per_head[0,h].item():.4f}")
|
||||
|
||||
# --- Score comparison ---
|
||||
# Compute reference scores manually
|
||||
q_h0 = q_fp32[0, 0, 0, :].cuda().float() # (HD,)
|
||||
k_all = k_full.cuda().float() # (N, HD)
|
||||
scores_ref = torch.matmul(q_h0, k_all.T) * scale # (N,)
|
||||
|
||||
# Compute noPE and RoPE scores separately
|
||||
scores_nope_ref = torch.matmul(q_h0[:NOPE], k_all[:, :NOPE].T) * scale
|
||||
scores_rope_ref = torch.matmul(q_h0[NOPE:], k_all[:, NOPE:].T) * scale
|
||||
|
||||
print(f"\nScore analysis (head 0):")
|
||||
print(f" noPE: [{scores_nope_ref.min().item():.4f}, {scores_nope_ref.max().item():.4f}]")
|
||||
print(f" RoPE: [{scores_rope_ref.min().item():.4f}, {scores_rope_ref.max().item():.4f}]")
|
||||
print(f" Total: [{scores_ref.min().item():.4f}, {scores_ref.max().item():.4f}]")
|
||||
|
||||
# Check: are the noPE and RoPE scores the right order of magnitude?
|
||||
nope_range = scores_nope_ref.max().item() - scores_nope_ref.min().item()
|
||||
rope_range = scores_rope_ref.max().item() - scores_rope_ref.min().item()
|
||||
print(f" noPE range: {nope_range:.4f} RoPE range: {rope_range:.4f}")
|
||||
print(f" noPE/RoPE ratio: {nope_range/rope_range:.2f}" if rope_range > 0 else " RoPE range is zero!")
|
||||
|
||||
# --- Check if the kernel is producing V=K correctly ---
|
||||
# In MQA self-attention, V=K. Check if the output magnitude matches
|
||||
# the expected: o = softmax(QK^T/sqrt(d)) @ K
|
||||
# With K=V and N=128, the output should be a weighted average of K rows
|
||||
print(f"\n K (row 0): {k_full[0,:8].tolist()}")
|
||||
print(f" o (head 0): {o_mixed[0,0,0,:8].float().tolist()}")
|
||||
print(f" o_ref (head 0): {o_ref[0,0,0,:8].tolist()}")
|
||||
|
||||
sys.exit(0 if cos >= 0.999 else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user