106 lines
4.1 KiB
Python
106 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Minimal debug test for B1 mixed FP8 FMHA — compare per-step with BF16 reference.
|
|
|
|
Tests a single head with small N to isolate the precision issue.
|
|
"""
|
|
import sys
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
HD = 512; NOPE = 448; ROPE = 64
|
|
H = 1; B = 1; T = 1
|
|
N = 128 # small
|
|
scale = 1.0 / math.sqrt(HD)
|
|
|
|
print(f"=== B1 Minimal Debug: N={N} H={H} HD={HD} ===")
|
|
|
|
# Generate synthetic Q and KV
|
|
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
|
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
|
|
|
q_bf16 = q_fp32.bfloat16().cuda()
|
|
|
|
# Split KV
|
|
k_nope_fp32 = k_fp32[:, :NOPE].contiguous()
|
|
k_rope_fp32 = k_fp32[:, NOPE:].contiguous()
|
|
|
|
# Quantize noPE to FP8 (same method as the production path)
|
|
amax = k_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
k_nope_scale = (amax / 448.0).squeeze(-1) # (N,) FP32
|
|
k_nope_fp8 = (k_nope_fp32 / k_nope_scale.unsqueeze(-1)).clamp(-448, 448).to(torch.float8_e4m3fn).view(torch.uint8)
|
|
|
|
k_nope_fp8 = k_nope_fp8.cuda()
|
|
k_nope_scale = k_nope_scale.cuda()
|
|
k_rope_bf16 = k_rope_fp32.bfloat16().cuda()
|
|
|
|
# Reference: BF16 SDPA
|
|
k_nope_dequant = k_nope_fp8.cpu().view(torch.float8_e4m3fn).bfloat16() * k_nope_scale.cpu().unsqueeze(-1).bfloat16()
|
|
k_full_bf16 = torch.cat([k_nope_dequant, k_rope_fp32.bfloat16()], dim=-1).cuda()
|
|
v_full_bf16 = k_full_bf16.clone()
|
|
|
|
q_3d = q_bf16.squeeze(0) # (H, 1, HD)
|
|
k_3d = k_full_bf16.unsqueeze(0) # (1, N, HD)
|
|
v_3d = v_full_bf16.unsqueeze(0) # (1, N, HD)
|
|
|
|
o_ref = F.scaled_dot_product_attention(
|
|
q_3d.float(), k_3d.unsqueeze(0).float(), v_3d.unsqueeze(0).float(), scale=scale
|
|
).bfloat16() # (1, H, 1, HD)
|
|
o_ref = o_ref.squeeze(0) # (H, 1, HD)
|
|
|
|
print(f"Reference: |o|={o_ref.abs().max().item():.6f} mean={o_ref.float().mean().item():.6f}")
|
|
print(f" o[0,0,:8]={o_ref[0,0,:8].float().tolist()}")
|
|
print(f" o[0,0,440:448]={o_ref[0,0,440:448].float().tolist()}")
|
|
|
|
# Mixed FP8 kernel
|
|
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
|
q_4d = q_bf16 # (B, H, T, HD)
|
|
o_mixed, lse = fmha_mixed_fp8_decode_raw(
|
|
q_4d, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
|
# o_mixed: (B, H, T, HD)
|
|
o_mixed_3d = o_mixed.squeeze(0) # (H, 1, HD)
|
|
|
|
print(f"Mixed FP8: |o|={o_mixed.abs().max().item():.6f} mean={o_mixed.float().mean().item():.6f}")
|
|
print(f" o[0,0,:8]={o_mixed_3d[0,0,:8].float().tolist()}")
|
|
print(f" o[0,0,440:448]={o_mixed_3d[0,0,440:448].float().tolist()}")
|
|
|
|
# Cosine
|
|
cos = F.cosine_similarity(o_ref.flatten().float(), o_mixed.flatten().float(), dim=0).item()
|
|
print(f"\nCosine: {cos:.6f}")
|
|
|
|
# LSE comparison
|
|
# Reference LSE: log(sum(exp(scores)))
|
|
q_f = q_3d.float() # (H, 1, HD)
|
|
k_f = k_3d.unsqueeze(0).float() # (1, 1, N, HD)
|
|
scores = torch.matmul(q_f, k_f.transpose(-2, -1)) * scale # (H, 1, 1, N)
|
|
ref_lse = torch.logsumexp(scores, dim=-1) # (H, 1, 1)
|
|
print(f"Ref LSE: {ref_lse[0,0,0].item():.6f}")
|
|
print(f"Mixed LSE: {lse[0,0,0].item():.6f}")
|
|
|
|
# Score distribution
|
|
print(f"\nScores: min={scores.min().item():.4f} max={scores.max().item():.4f} mean={scores.mean().item():.4f}")
|
|
|
|
# Check if the noPE vs RoPE contributions are correct
|
|
q_nope_f = q_f[:, :, :NOPE] # (H, 1, NOPE)
|
|
q_rope_f = q_f[:, :, NOPE:] # (H, 1, ROPE)
|
|
k_nope_f = k_3d.unsqueeze(0).float()[:, :, :, :NOPE] # (1, 1, N, NOPE)
|
|
k_rope_f = k_3d.unsqueeze(0).float()[:, :, :, NOPE:] # (1, 1, N, ROPE)
|
|
|
|
scores_nope = torch.matmul(q_nope_f, k_nope_f.transpose(-2, -1)) * scale
|
|
scores_rope = torch.matmul(q_rope_f, k_rope_f.transpose(-2, -1)) * scale
|
|
print(f"noPE scores: [{scores_nope.min().item():.4f}, {scores_nope.max().item():.4f}]")
|
|
print(f"RoPE scores: [{scores_rope.min().item():.4f}, {scores_rope.max().item():.4f}]")
|
|
|
|
if cos < 0.999:
|
|
print(f"\n!!! COSINE TOO LOW ({cos:.6f}) — B1 KERNEL IS BROKEN !!!")
|
|
sys.exit(1)
|
|
else:
|
|
print(f"\nPASS: cosine {cos:.6f}")
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|