From d2cf5ccc324714551c7bfd709758b75c5d69fd56 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 06:39:23 +0000 Subject: [PATCH] CRITICAL FIX: use SDPA for short sequences (FMHA padding bug) FMHA pads N to next multiple of 128. For N<<128 (like 5 tokens), the 123 padded zero-K entries contribute exp(0)=1 to the softmax denominator, diluting real attention weights by ~128/5 = 25.6x. This caused the model to produce incoherent output for short prompts. Fix: use SDPA for seq_len < 120 (no padding), FMHA for longer sequences where the padding effect is negligible. Also: SDPA path includes attention sinks (paper D5c), FMHA path uses analytic sink correction via LSE. --- single_shot_inference.py | 90 +++++++++++++++++++--------------------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 650e11be..f141b898 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -366,57 +366,53 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, k_full, v_full = kv_cache.get() # (1, seq_len, hd) each โ€” RoPE'd, K=V seq_len = k_full.shape[1] - # -- FMHA with sink bias correction (paper D5c) -- + # -- Attention: SDPA for short seqs (avoids FMHA padding bug), FMHA for long -- q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) - from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw - scale = 1.0 / math.sqrt(hd) - q_4d = q_input.unsqueeze(0).contiguous() # (1, n_h, T, hd) - k_4d = k_full.unsqueeze(0).contiguous() # (1, 1, seq_len, hd) - # CRITICAL: V must be (B, n_h, hd, N) โ€” transposed! - v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous() # (1, 1, hd, seq_len) - o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) - attn_out = o_4d.squeeze(0).permute(1, 0, 2) # (T, n_h, hd) + # FMHA pads N to next multiple of 128. For N<<128, padded zero-K entries + # contribute exp(0)=1 to softmax, diluting real attention weights by ~128/N. + # Use SDPA for short sequences where padding dominates. + if seq_len < 120: + k_expanded = k_full.expand(n_h, -1, -1).contiguous() + v_expanded = v_full.expand(n_h, -1, -1).contiguous() + # Add attention sink (paper D5c) + sink_key = f"{pre}.sinks" + if sink_key in w and seq_len > 0: + sinks = w[sink_key].to(device=device) # (n_h,) BF16 + sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device) + sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device) + k_with_sink = torch.cat([k_expanded, sink_k], dim=1) + v_with_sink = torch.cat([v_expanded, sink_v], dim=1) + sink_bias_mask = torch.zeros(n_h, T, seq_len + 1, dtype=torch.bfloat16, device=device) + for h in range(n_h): + sink_bias_mask[h, :, -1] = sinks[h] + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_with_sink, v_with_sink, + attn_mask=sink_bias_mask, scale=scale) + else: + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_expanded, v_expanded, scale=scale, is_causal=False) + attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) + else: + # Use FMHA kernel for longer sequences (padding effect is negligible) + from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw + q_4d = q_input.unsqueeze(0).contiguous() + k_4d = k_full.unsqueeze(0).contiguous() + v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous() + o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) + attn_out = o_4d.squeeze(0).permute(1, 0, 2) + # Sink correction + sink_key = f"{pre}.sinks" + if sink_key in w and seq_len > 0: + sinks = w[sink_key].to(device=device) + lse_2d = lse.squeeze(0).t() + sink_exp = torch.exp(sinks.float()) + attn_exp = torch.exp(lse_2d.float()) + correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) + attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() + attn_out = attn_out.bfloat16() - # Apply sink bias correction: scale the output by softmax_normalizer / (normalizer + exp(sink)) - # This simulates adding a virtual sink position with V=0 and logit=sink - # O_sink = O_raw * exp(lse) / (exp(lse) + exp(sink)) - sink_key = f"{pre}.sinks" - if sink_key in w and seq_len > 0: - sinks = w[sink_key].to(device=device) # (n_h,) BF16 - # lse: (1, n_h, T) โ€” log-sum-exp of attention scores per head - lse_2d = lse.squeeze(0).t() # (T, n_h) - # For each head, compute the correction factor - sink_exp = torch.exp(sinks.float()) # (n_h,) - attn_exp = torch.exp(lse_2d.float()) # (T, n_h) - # Correction: attn_exp / (attn_exp + sink_exp) - correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h) - # Apply per-head correction - attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd) - - # -- Debug: compare FMHA output with SDPA reference -- - if li == 0 and positions[0].item() < 2: - # SDPA reference - k_exp = k_full.expand(n_h, -1, -1).contiguous() - v_exp = v_full.expand(n_h, -1, -1).contiguous() - q_ref = q_input # (n_h, T, hd) - attn_ref = torch.nn.functional.scaled_dot_product_attention( - q_ref, k_exp, v_exp, scale=1.0/math.sqrt(hd), is_causal=False) - # Apply inverse RoPE to SDPA output too (since K=V with RoPE) - attn_ref = attn_ref.permute(1, 0, 2) # (T, n_h, hd) - attn_ref = apply_inverse_rope(attn_ref, positions_dev, rope_cos, rope_sin, hd, rd) - # Compare with FMHA output (before sink correction) - o_4d_nosink, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) - attn_fmha = o_4d_nosink.squeeze(0).permute(1, 0, 2) # (T, n_h, hd) - attn_fmha = apply_inverse_rope(attn_fmha, positions_dev, rope_cos, rope_sin, hd, rd) - # Cosine similarity - cos_sim = torch.nn.functional.cosine_similarity( - attn_ref.reshape(-1).float(), attn_fmha.reshape(-1).float(), dim=0) - max_diff = (attn_ref.float() - attn_fmha.float()).abs().max() - print(f" L{li} FMHA vs SDPA: cos={cos_sim:.6f} max_diff={max_diff:.6f}", flush=True) - - attn_out = attn_out.bfloat16() # -- Inverse RoPE on attention output (paper ยง2.3.3) -- attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)