debug: compare FMHA vs SDPA output at layer 0
This commit is contained in:
@@ -395,6 +395,27 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
# Apply per-head correction
|
||||
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd)
|
||||
|
||||
# -- Debug: compare FMHA output with SDPA reference --
|
||||
if li == 0 and positions[0].item() < 2:
|
||||
# SDPA reference
|
||||
k_exp = k_full.expand(n_h, -1, -1).contiguous()
|
||||
v_exp = v_full.expand(n_h, -1, -1).contiguous()
|
||||
q_ref = q_input # (n_h, T, hd)
|
||||
attn_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref, k_exp, v_exp, scale=1.0/math.sqrt(hd), is_causal=False)
|
||||
# Apply inverse RoPE to SDPA output too (since K=V with RoPE)
|
||||
attn_ref = attn_ref.permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_ref = apply_inverse_rope(attn_ref, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
# Compare with FMHA output (before sink correction)
|
||||
o_4d_nosink, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
|
||||
attn_fmha = o_4d_nosink.squeeze(0).permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_fmha = apply_inverse_rope(attn_fmha, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
# Cosine similarity
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
attn_ref.reshape(-1).float(), attn_fmha.reshape(-1).float(), dim=0)
|
||||
max_diff = (attn_ref.float() - attn_fmha.float()).abs().max()
|
||||
print(f" L{li} FMHA vs SDPA: cos={cos_sim:.6f} max_diff={max_diff:.6f}", flush=True)
|
||||
|
||||
attn_out = attn_out.bfloat16()
|
||||
|
||||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||||
|
||||
Reference in New Issue
Block a user