debug: compare FMHA vs SDPA output at layer 0

This commit is contained in:
2026-05-31 06:16:58 +00:00
parent 59c75ca4e9
commit 152af7295a

View File

@@ -395,6 +395,27 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
# Apply per-head correction
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd)
# -- Debug: compare FMHA output with SDPA reference --
if li == 0 and positions[0].item() < 2:
# SDPA reference
k_exp = k_full.expand(n_h, -1, -1).contiguous()
v_exp = v_full.expand(n_h, -1, -1).contiguous()
q_ref = q_input # (n_h, T, hd)
attn_ref = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_exp, v_exp, scale=1.0/math.sqrt(hd), is_causal=False)
# Apply inverse RoPE to SDPA output too (since K=V with RoPE)
attn_ref = attn_ref.permute(1, 0, 2) # (T, n_h, hd)
attn_ref = apply_inverse_rope(attn_ref, positions_dev, rope_cos, rope_sin, hd, rd)
# Compare with FMHA output (before sink correction)
o_4d_nosink, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
attn_fmha = o_4d_nosink.squeeze(0).permute(1, 0, 2) # (T, n_h, hd)
attn_fmha = apply_inverse_rope(attn_fmha, positions_dev, rope_cos, rope_sin, hd, rd)
# Cosine similarity
cos_sim = torch.nn.functional.cosine_similarity(
attn_ref.reshape(-1).float(), attn_fmha.reshape(-1).float(), dim=0)
max_diff = (attn_ref.float() - attn_fmha.float()).abs().max()
print(f" L{li} FMHA vs SDPA: cos={cos_sim:.6f} max_diff={max_diff:.6f}", flush=True)
attn_out = attn_out.bfloat16()
# -- Inverse RoPE on attention output (paper §2.3.3) --