Files
nvfp4-megamoe-kernel/tests/unit/test_fmha_mixed_fp8_debug.py

106 lines
4.1 KiB
Python

#!/usr/bin/env python3
"""Minimal debug test for B1 mixed FP8 FMHA — compare per-step with BF16 reference.
Tests a single head with small N to isolate the precision issue.
"""
import sys
import math
import torch
import torch.nn.functional as F
def main():
torch.manual_seed(42)
HD = 512; NOPE = 448; ROPE = 64
H = 1; B = 1; T = 1
N = 128 # small
scale = 1.0 / math.sqrt(HD)
print(f"=== B1 Minimal Debug: N={N} H={H} HD={HD} ===")
# Generate synthetic Q and KV
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
q_bf16 = q_fp32.bfloat16().cuda()
# Split KV
k_nope_fp32 = k_fp32[:, :NOPE].contiguous()
k_rope_fp32 = k_fp32[:, NOPE:].contiguous()
# Quantize noPE to FP8 (same method as the production path)
amax = k_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
k_nope_scale = (amax / 448.0).squeeze(-1) # (N,) FP32
k_nope_fp8 = (k_nope_fp32 / k_nope_scale.unsqueeze(-1)).clamp(-448, 448).to(torch.float8_e4m3fn).view(torch.uint8)
k_nope_fp8 = k_nope_fp8.cuda()
k_nope_scale = k_nope_scale.cuda()
k_rope_bf16 = k_rope_fp32.bfloat16().cuda()
# Reference: BF16 SDPA
k_nope_dequant = k_nope_fp8.cpu().view(torch.float8_e4m3fn).bfloat16() * k_nope_scale.cpu().unsqueeze(-1).bfloat16()
k_full_bf16 = torch.cat([k_nope_dequant, k_rope_fp32.bfloat16()], dim=-1).cuda()
v_full_bf16 = k_full_bf16.clone()
q_3d = q_bf16.squeeze(0) # (H, 1, HD)
k_3d = k_full_bf16.unsqueeze(0) # (1, N, HD)
v_3d = v_full_bf16.unsqueeze(0) # (1, N, HD)
o_ref = F.scaled_dot_product_attention(
q_3d.float(), k_3d.unsqueeze(0).float(), v_3d.unsqueeze(0).float(), scale=scale
).bfloat16() # (1, H, 1, HD)
o_ref = o_ref.squeeze(0) # (H, 1, HD)
print(f"Reference: |o|={o_ref.abs().max().item():.6f} mean={o_ref.float().mean().item():.6f}")
print(f" o[0,0,:8]={o_ref[0,0,:8].float().tolist()}")
print(f" o[0,0,440:448]={o_ref[0,0,440:448].float().tolist()}")
# Mixed FP8 kernel
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
q_4d = q_bf16 # (B, H, T, HD)
o_mixed, lse = fmha_mixed_fp8_decode_raw(
q_4d, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
# o_mixed: (B, H, T, HD)
o_mixed_3d = o_mixed.squeeze(0) # (H, 1, HD)
print(f"Mixed FP8: |o|={o_mixed.abs().max().item():.6f} mean={o_mixed.float().mean().item():.6f}")
print(f" o[0,0,:8]={o_mixed_3d[0,0,:8].float().tolist()}")
print(f" o[0,0,440:448]={o_mixed_3d[0,0,440:448].float().tolist()}")
# Cosine
cos = F.cosine_similarity(o_ref.flatten().float(), o_mixed.flatten().float(), dim=0).item()
print(f"\nCosine: {cos:.6f}")
# LSE comparison
# Reference LSE: log(sum(exp(scores)))
q_f = q_3d.float() # (H, 1, HD)
k_f = k_3d.unsqueeze(0).float() # (1, 1, N, HD)
scores = torch.matmul(q_f, k_f.transpose(-2, -1)) * scale # (H, 1, 1, N)
ref_lse = torch.logsumexp(scores, dim=-1) # (H, 1, 1)
print(f"Ref LSE: {ref_lse[0,0,0].item():.6f}")
print(f"Mixed LSE: {lse[0,0,0].item():.6f}")
# Score distribution
print(f"\nScores: min={scores.min().item():.4f} max={scores.max().item():.4f} mean={scores.mean().item():.4f}")
# Check if the noPE vs RoPE contributions are correct
q_nope_f = q_f[:, :, :NOPE] # (H, 1, NOPE)
q_rope_f = q_f[:, :, NOPE:] # (H, 1, ROPE)
k_nope_f = k_3d.unsqueeze(0).float()[:, :, :, :NOPE] # (1, 1, N, NOPE)
k_rope_f = k_3d.unsqueeze(0).float()[:, :, :, NOPE:] # (1, 1, N, ROPE)
scores_nope = torch.matmul(q_nope_f, k_nope_f.transpose(-2, -1)) * scale
scores_rope = torch.matmul(q_rope_f, k_rope_f.transpose(-2, -1)) * scale
print(f"noPE scores: [{scores_nope.min().item():.4f}, {scores_nope.max().item():.4f}]")
print(f"RoPE scores: [{scores_rope.min().item():.4f}, {scores_rope.max().item():.4f}]")
if cos < 0.999:
print(f"\n!!! COSINE TOO LOW ({cos:.6f}) — B1 KERNEL IS BROKEN !!!")
sys.exit(1)
else:
print(f"\nPASS: cosine {cos:.6f}")
sys.exit(0)
if __name__ == "__main__":
main()