84 lines
3.5 KiB
Python
84 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
"""Debug test: compare T=1 prefill vs T=1 decode, step by step.
|
|
|
|
Uses synthetic data. Prints per-step comparisons to identify
|
|
where the prefill kernel diverges from the decode kernel.
|
|
"""
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
HD = 512; NOPE = 448; ROPE = 64; H = 128
|
|
B = 1; T = 1; N = 256
|
|
scale = 1.0 / math.sqrt(HD)
|
|
|
|
def quantize_fp8_e4m3(x_fp32):
|
|
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
s = amax / 448.0
|
|
fp8 = (x_fp32 / s).clamp(-448, 448).to(torch.float8_e4m3fn)
|
|
return fp8.view(torch.uint8), s.squeeze(-1)
|
|
|
|
def cosine(a, b):
|
|
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
|
|
|
def main():
|
|
torch.manual_seed(42)
|
|
q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5
|
|
k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5
|
|
q_bf16 = q_fp32.bfloat16().cuda()
|
|
k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE])
|
|
k_rope_bf16 = k_fp32[:, NOPE:].bfloat16()
|
|
k_nope_fp8 = k_nope_fp8.cuda()
|
|
k_nope_scale = k_nope_scale.cuda()
|
|
k_rope_bf16 = k_rope_bf16.cuda()
|
|
|
|
# Reference SDPA
|
|
nope_dequant = k_nope_fp8.view(torch.float8_e4m3fn).cpu().float() * k_nope_scale.cpu().unsqueeze(-1).float()
|
|
k_full = torch.cat([nope_dequant, k_fp32[:, NOPE:]], dim=-1).bfloat16().cuda()
|
|
k_4d = k_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1)
|
|
v_4d = k_4d.clone()
|
|
o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale)
|
|
print(f"Reference: |o|={o_ref.float().abs().max().item():.6f} mean={o_ref.float().mean().item():.6f}")
|
|
|
|
# Decode kernel
|
|
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
|
o_decode, lse_decode = fmha_mixed_fp8_decode_raw(
|
|
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
|
print(f"Decode: |o|={o_decode.float().abs().max().item():.6f} mean={o_decode.float().mean().item():.6f}")
|
|
print(f"Decode vs Ref: cos={cosine(o_decode, o_ref):.6f}")
|
|
|
|
# Prefill kernel
|
|
from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw
|
|
o_prefill, lse_prefill = fmha_mixed_fp8_prefill_raw(
|
|
q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE)
|
|
print(f"Prefill: |o|={o_prefill.float().abs().max().item():.6f} mean={o_prefill.float().mean().item():.6f}")
|
|
print(f"Prefill vs Ref: cos={cosine(o_prefill, o_ref):.6f}")
|
|
print(f"Prefill vs Decode: cos={cosine(o_prefill, o_decode):.6f}")
|
|
|
|
# Check for NaN
|
|
has_nan = torch.isnan(o_prefill).any().item()
|
|
print(f"Prefill NaN: {has_nan}")
|
|
|
|
# Per-head cosine
|
|
o_d_h = o_decode.float().squeeze(0).squeeze(1) # (H, HD)
|
|
o_p_h = o_prefill.float().squeeze(0).squeeze(1)
|
|
if o_d_h.dim() == 3: o_d_h = o_d_h.squeeze(0)
|
|
if o_p_h.dim() == 3: o_p_h = o_p_h.squeeze(0)
|
|
per_head_cos = F.cosine_similarity(o_d_h, o_p_h, dim=-1)
|
|
print(f"Per-head cos: min={per_head_cos.min().item():.6f} mean={per_head_cos.mean().item():.6f} max={per_head_cos.max().item():.6f}")
|
|
|
|
# Value comparison for head 0
|
|
if not has_nan:
|
|
d0 = o_decode[0, 0, 0, :8].float()
|
|
p0 = o_prefill[0, 0, 0, :8].float()
|
|
r0 = o_ref[0, 0, 0, :8].float()
|
|
print(f"Decode[0,0,0,:8]: {d0.tolist()}")
|
|
print(f"Prefill[0,0,0,:8]: {p0.tolist()}")
|
|
print(f"Ref[0,0,0,:8]: {r0.tolist()}")
|
|
print(f"Ratio decode/ref: {(d0 / (r0 + 1e-10)).tolist()}")
|
|
print(f"Ratio prefill/ref: {(p0 / (r0 + 1e-10)).tolist()}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|