diff --git a/single_shot_inference.py b/single_shot_inference.py index 57a30d0b..fcc1be3b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -395,6 +395,27 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, # Apply per-head correction attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd) + # -- Debug: compare FMHA output with SDPA reference -- + if li == 0 and positions[0].item() < 2: + # SDPA reference + k_exp = k_full.expand(n_h, -1, -1).contiguous() + v_exp = v_full.expand(n_h, -1, -1).contiguous() + q_ref = q_input # (n_h, T, hd) + attn_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_exp, v_exp, scale=1.0/math.sqrt(hd), is_causal=False) + # Apply inverse RoPE to SDPA output too (since K=V with RoPE) + attn_ref = attn_ref.permute(1, 0, 2) # (T, n_h, hd) + attn_ref = apply_inverse_rope(attn_ref, positions_dev, rope_cos, rope_sin, hd, rd) + # Compare with FMHA output (before sink correction) + o_4d_nosink, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) + attn_fmha = o_4d_nosink.squeeze(0).permute(1, 0, 2) # (T, n_h, hd) + attn_fmha = apply_inverse_rope(attn_fmha, positions_dev, rope_cos, rope_sin, hd, rd) + # Cosine similarity + cos_sim = torch.nn.functional.cosine_similarity( + attn_ref.reshape(-1).float(), attn_fmha.reshape(-1).float(), dim=0) + max_diff = (attn_ref.float() - attn_fmha.float()).abs().max() + print(f" L{li} FMHA vs SDPA: cos={cos_sim:.6f} max_diff={max_diff:.6f}", flush=True) + attn_out = attn_out.bfloat16() # -- Inverse RoPE on attention output (paper ยง2.3.3) --