#!/usr/bin/env python3 """Debug test: compare T=1 prefill vs T=1 decode, step by step. Uses synthetic data. Prints per-step comparisons to identify where the prefill kernel diverges from the decode kernel. """ import math import torch import torch.nn.functional as F HD = 512; NOPE = 448; ROPE = 64; H = 128 B = 1; T = 1; N = 256 scale = 1.0 / math.sqrt(HD) def quantize_fp8_e4m3(x_fp32): amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) s = amax / 448.0 fp8 = (x_fp32 / s).clamp(-448, 448).to(torch.float8_e4m3fn) return fp8.view(torch.uint8), s.squeeze(-1) def cosine(a, b): return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item() def main(): torch.manual_seed(42) q_fp32 = torch.randn(B, H, T, HD, dtype=torch.float32) * 0.5 k_fp32 = torch.randn(N, HD, dtype=torch.float32) * 0.5 q_bf16 = q_fp32.bfloat16().cuda() k_nope_fp8, k_nope_scale = quantize_fp8_e4m3(k_fp32[:, :NOPE]) k_rope_bf16 = k_fp32[:, NOPE:].bfloat16() k_nope_fp8 = k_nope_fp8.cuda() k_nope_scale = k_nope_scale.cuda() k_rope_bf16 = k_rope_bf16.cuda() # Reference SDPA nope_dequant = k_nope_fp8.view(torch.float8_e4m3fn).cpu().float() * k_nope_scale.cpu().unsqueeze(-1).float() k_full = torch.cat([nope_dequant, k_fp32[:, NOPE:]], dim=-1).bfloat16().cuda() k_4d = k_full.unsqueeze(0).unsqueeze(0).expand(1, 1, -1, -1) v_4d = k_4d.clone() o_ref = F.scaled_dot_product_attention(q_bf16, k_4d, v_4d, scale=scale) print(f"Reference: |o|={o_ref.float().abs().max().item():.6f} mean={o_ref.float().mean().item():.6f}") # Decode kernel from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw o_decode, lse_decode = fmha_mixed_fp8_decode_raw( q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE) print(f"Decode: |o|={o_decode.float().abs().max().item():.6f} mean={o_decode.float().mean().item():.6f}") print(f"Decode vs Ref: cos={cosine(o_decode, o_ref):.6f}") # Prefill kernel from dsv4.kernels.attention.fmha_mixed_fp8_prefill_op import fmha_mixed_fp8_prefill_raw o_prefill, lse_prefill = fmha_mixed_fp8_prefill_raw( q_bf16, k_nope_fp8, k_nope_scale, k_rope_bf16, scale, rope_dim=ROPE) print(f"Prefill: |o|={o_prefill.float().abs().max().item():.6f} mean={o_prefill.float().mean().item():.6f}") print(f"Prefill vs Ref: cos={cosine(o_prefill, o_ref):.6f}") print(f"Prefill vs Decode: cos={cosine(o_prefill, o_decode):.6f}") # Check for NaN has_nan = torch.isnan(o_prefill).any().item() print(f"Prefill NaN: {has_nan}") # Per-head cosine o_d_h = o_decode.float().squeeze(0).squeeze(1) # (H, HD) o_p_h = o_prefill.float().squeeze(0).squeeze(1) if o_d_h.dim() == 3: o_d_h = o_d_h.squeeze(0) if o_p_h.dim() == 3: o_p_h = o_p_h.squeeze(0) per_head_cos = F.cosine_similarity(o_d_h, o_p_h, dim=-1) print(f"Per-head cos: min={per_head_cos.min().item():.6f} mean={per_head_cos.mean().item():.6f} max={per_head_cos.max().item():.6f}") # Value comparison for head 0 if not has_nan: d0 = o_decode[0, 0, 0, :8].float() p0 = o_prefill[0, 0, 0, :8].float() r0 = o_ref[0, 0, 0, :8].float() print(f"Decode[0,0,0,:8]: {d0.tolist()}") print(f"Prefill[0,0,0,:8]: {p0.tolist()}") print(f"Ref[0,0,0,:8]: {r0.tolist()}") print(f"Ratio decode/ref: {(d0 / (r0 + 1e-10)).tolist()}") print(f"Ratio prefill/ref: {(p0 / (r0 + 1e-10)).tolist()}") if __name__ == "__main__": main()