152 lines
6.4 KiB
Python
152 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
"""B1 FMHA debug: isolate the cosine failure to noPE vs RoPE path.
|
|
|
|
Strategy:
|
|
1. Run mixed FP8 kernel with RoPE=0 (all noPE) → compare vs BF16
|
|
2. Run mixed FP8 kernel with noPE=0 (all RoPE) → compare vs BF16
|
|
3. Run with full split → see which part is broken
|
|
4. Print per-dimension residual to find where the error lives
|
|
"""
|
|
import sys
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def quantize_fp8_e4m3(x_fp32):
|
|
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
scale = amax / 448.0
|
|
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
|
return fp8.view(torch.uint8), scale.squeeze(-1)
|
|
|
|
|
|
def dequantize_fp8_e4m3(fp8_uint8, scale):
|
|
fp8 = fp8_uint8.view(torch.float8_e4m3fn)
|
|
return fp8.float() * scale.unsqueeze(-1).float()
|
|
|
|
|
|
def cosine(a, b):
|
|
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
|
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
HD = 512; NOPE = 448; ROPE = 64
|
|
H = 4; B = 1; T = 1; N = 128 # small for debugging
|
|
scale = 1.0 / math.sqrt(HD)
|
|
|
|
print(f"=== B1 FMHA Debug: N={N} H={H} HD={HD} NOPE={NOPE} ROPE={ROPE} ===\n")
|
|
|
|
# Generate Q and KV
|
|
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
|
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
|
q_bf16 = q_fp32.bfloat16().cuda()
|
|
|
|
# Split KV
|
|
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
|
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
|
k_nope_fp8 = k_nope_fp8.cuda()
|
|
k_nope_scale = k_nope_scale.cuda()
|
|
k_rope_bf16 = k_rope_bf16.cuda()
|
|
|
|
# --- FP32 Reference ---
|
|
k_nope_dequant = dequantize_fp8_e4m3(k_nope_fp8.view(torch.uint8).cpu(), k_nope_scale.cpu())
|
|
k_full = torch.cat([k_nope_dequant, k_fp32[:, NOPE:]], dim=-1) # (N, HD) FP32
|
|
v_full = k_full.clone()
|
|
|
|
q_f = q_fp32.cuda() # (B, H, 1, HD)
|
|
k_f = k_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda()
|
|
v_f = v_full.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1).cuda()
|
|
o_ref = F.scaled_dot_product_attention(q_f, k_f, v_f, scale=scale) # (B, H, 1, HD)
|
|
|
|
print(f"Reference output: |o|={o_ref.abs().max().item():.6f}")
|
|
print(f" head 0: {o_ref[0,0,0,:8].tolist()}")
|
|
print(f" head 0 noPE part: {o_ref[0,0,0,:8].tolist()}")
|
|
print(f" head 0 RoPE part: {o_ref[0,0,0,448:456].tolist()}")
|
|
|
|
# --- Mixed FP8 kernel ---
|
|
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
|
o_mixed, lse = fmha_mixed_fp8_decode_raw(
|
|
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
|
|
|
print(f"\nMixed FP8 output: |o|={o_mixed.abs().max().item():.6f}")
|
|
print(f" head 0: {o_mixed[0,0,0,:8].tolist()}")
|
|
print(f" head 0 RoPE part: {o_mixed[0,0,0,448:456].tolist()}")
|
|
|
|
# Global cosine
|
|
cos = cosine(o_mixed, o_ref.bfloat16())
|
|
print(f"\nGlobal cosine: {cos:.6f}")
|
|
|
|
# Per-head cosine
|
|
o_mixed_h = o_mixed.float().squeeze(2) # (B, H, HD)
|
|
o_ref_h = o_ref.bfloat16().float().squeeze(2)
|
|
per_head = F.cosine_similarity(o_mixed_h, o_ref_h, dim=-1) # (B, H)
|
|
print(f"Per-head cosine: {per_head[0].tolist()}")
|
|
print(f" min={per_head.min().item():.6f} mean={per_head.mean().item():.6f}")
|
|
|
|
# --- Per-dimension analysis ---
|
|
# Compare noPE vs RoPE portions separately
|
|
o_mixed_nope = o_mixed[0, 0, 0, :NOPE].float()
|
|
o_ref_nope = o_ref[0, 0, 0, :NOPE].float()
|
|
o_mixed_rope = o_mixed[0, 0, 0, NOPE:].float()
|
|
o_ref_rope = o_ref[0, 0, 0, NOPE:].float()
|
|
|
|
cos_nope = F.cosine_similarity(o_mixed_nope.unsqueeze(0), o_ref_nope.unsqueeze(0), dim=1).item()
|
|
cos_rope = F.cosine_similarity(o_mixed_rope.unsqueeze(0), o_ref_rope.unsqueeze(0), dim=1).item()
|
|
|
|
print(f"\nPer-dim cosine (head 0):")
|
|
print(f" noPE (0..447): cos={cos_nope:.6f} |mixed|={o_mixed_nope.abs().max():.6f} |ref|={o_ref_nope.abs().max():.6f}")
|
|
print(f" RoPE (448..511): cos={cos_rope:.6f} |mixed|={o_mixed_rope.abs().max():.6f} |ref|={o_ref_rope.abs().max():.6f}")
|
|
|
|
# Residual
|
|
residual = (o_mixed[0,0,0,:] - o_ref[0,0,0,:].bfloat16()).float()
|
|
print(f"\nResidual: |res|={residual.abs().max().item():.6f} mean={residual.mean().item():.6f}")
|
|
print(f" noPE residual: |res|={residual[:NOPE].abs().max().item():.6f}")
|
|
print(f" RoPE residual: |res|={residual[NOPE:].abs().max().item():.6f}")
|
|
|
|
# --- Per-head breakdown ---
|
|
print(f"\nPer-head noPE/RoPE cosines:")
|
|
for h in range(H):
|
|
mn = o_mixed[0,h,0,:NOPE].float()
|
|
rn = o_ref[0,h,0,:NOPE].float()
|
|
mr = o_mixed[0,h,0,NOPE:].float()
|
|
rr = o_ref[0,h,0,NOPE:].float()
|
|
cn = F.cosine_similarity(mn.unsqueeze(0), rn.unsqueeze(0)).item()
|
|
cr = F.cosine_similarity(mr.unsqueeze(0), rr.unsqueeze(0)).item()
|
|
print(f" H{h}: noPE_cos={cn:.4f} rope_cos={cr:.4f} total_cos={per_head[0,h].item():.4f}")
|
|
|
|
# --- Score comparison ---
|
|
# Compute reference scores manually
|
|
q_h0 = q_fp32[0, 0, 0, :].cuda().float() # (HD,)
|
|
k_all = k_full.cuda().float() # (N, HD)
|
|
scores_ref = torch.matmul(q_h0, k_all.T) * scale # (N,)
|
|
|
|
# Compute noPE and RoPE scores separately
|
|
scores_nope_ref = torch.matmul(q_h0[:NOPE], k_all[:, :NOPE].T) * scale
|
|
scores_rope_ref = torch.matmul(q_h0[NOPE:], k_all[:, NOPE:].T) * scale
|
|
|
|
print(f"\nScore analysis (head 0):")
|
|
print(f" noPE: [{scores_nope_ref.min().item():.4f}, {scores_nope_ref.max().item():.4f}]")
|
|
print(f" RoPE: [{scores_rope_ref.min().item():.4f}, {scores_rope_ref.max().item():.4f}]")
|
|
print(f" Total: [{scores_ref.min().item():.4f}, {scores_ref.max().item():.4f}]")
|
|
|
|
# Check: are the noPE and RoPE scores the right order of magnitude?
|
|
nope_range = scores_nope_ref.max().item() - scores_nope_ref.min().item()
|
|
rope_range = scores_rope_ref.max().item() - scores_rope_ref.min().item()
|
|
print(f" noPE range: {nope_range:.4f} RoPE range: {rope_range:.4f}")
|
|
print(f" noPE/RoPE ratio: {nope_range/rope_range:.2f}" if rope_range > 0 else " RoPE range is zero!")
|
|
|
|
# --- Check if the kernel is producing V=K correctly ---
|
|
# In MQA self-attention, V=K. Check if the output magnitude matches
|
|
# the expected: o = softmax(QK^T/sqrt(d)) @ K
|
|
# With K=V and N=128, the output should be a weighted average of K rows
|
|
print(f"\n K (row 0): {k_full[0,:8].tolist()}")
|
|
print(f" o (head 0): {o_mixed[0,0,0,:8].float().tolist()}")
|
|
print(f" o_ref (head 0): {o_ref[0,0,0,:8].tolist()}")
|
|
|
|
sys.exit(0 if cos >= 0.999 else 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|