CRITICAL FIX: use SDPA for short sequences (FMHA padding bug)

FMHA pads N to next multiple of 128. For N<<128 (like 5 tokens),
the 123 padded zero-K entries contribute exp(0)=1 to the softmax
denominator, diluting real attention weights by ~128/5 = 25.6x.

This caused the model to produce incoherent output for short prompts.

Fix: use SDPA for seq_len < 120 (no padding), FMHA for longer
sequences where the padding effect is negligible.

Also: SDPA path includes attention sinks (paper D5c), FMHA path
uses analytic sink correction via LSE.
This commit is contained in:
2026-05-31 06:39:23 +00:00
parent 5f98855141
commit d2cf5ccc32

View File

@@ -366,57 +366,53 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
seq_len = k_full.shape[1]
# -- FMHA with sink bias correction (paper D5c) --
# -- Attention: SDPA for short seqs (avoids FMHA padding bug), FMHA for long --
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
scale = 1.0 / math.sqrt(hd)
q_4d = q_input.unsqueeze(0).contiguous() # (1, n_h, T, hd)
k_4d = k_full.unsqueeze(0).contiguous() # (1, 1, seq_len, hd)
# CRITICAL: V must be (B, n_h, hd, N) — transposed!
v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous() # (1, 1, hd, seq_len)
o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
attn_out = o_4d.squeeze(0).permute(1, 0, 2) # (T, n_h, hd)
# FMHA pads N to next multiple of 128. For N<<128, padded zero-K entries
# contribute exp(0)=1 to softmax, diluting real attention weights by ~128/N.
# Use SDPA for short sequences where padding dominates.
if seq_len < 120:
k_expanded = k_full.expand(n_h, -1, -1).contiguous()
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
# Add attention sink (paper D5c)
sink_key = f"{pre}.sinks"
if sink_key in w and seq_len > 0:
sinks = w[sink_key].to(device=device) # (n_h,) BF16
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
k_with_sink = torch.cat([k_expanded, sink_k], dim=1)
v_with_sink = torch.cat([v_expanded, sink_v], dim=1)
sink_bias_mask = torch.zeros(n_h, T, seq_len + 1, dtype=torch.bfloat16, device=device)
for h in range(n_h):
sink_bias_mask[h, :, -1] = sinks[h]
attn_out = torch.nn.functional.scaled_dot_product_attention(
q_input, k_with_sink, v_with_sink,
attn_mask=sink_bias_mask, scale=scale)
else:
attn_out = torch.nn.functional.scaled_dot_product_attention(
q_input, k_expanded, v_expanded, scale=scale, is_causal=False)
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
else:
# Use FMHA kernel for longer sequences (padding effect is negligible)
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
q_4d = q_input.unsqueeze(0).contiguous()
k_4d = k_full.unsqueeze(0).contiguous()
v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous()
o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
attn_out = o_4d.squeeze(0).permute(1, 0, 2)
# Sink correction
sink_key = f"{pre}.sinks"
if sink_key in w and seq_len > 0:
sinks = w[sink_key].to(device=device)
lse_2d = lse.squeeze(0).t()
sink_exp = torch.exp(sinks.float())
attn_exp = torch.exp(lse_2d.float())
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10)
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16()
attn_out = attn_out.bfloat16()
# Apply sink bias correction: scale the output by softmax_normalizer / (normalizer + exp(sink))
# This simulates adding a virtual sink position with V=0 and logit=sink
# O_sink = O_raw * exp(lse) / (exp(lse) + exp(sink))
sink_key = f"{pre}.sinks"
if sink_key in w and seq_len > 0:
sinks = w[sink_key].to(device=device) # (n_h,) BF16
# lse: (1, n_h, T) — log-sum-exp of attention scores per head
lse_2d = lse.squeeze(0).t() # (T, n_h)
# For each head, compute the correction factor
sink_exp = torch.exp(sinks.float()) # (n_h,)
attn_exp = torch.exp(lse_2d.float()) # (T, n_h)
# Correction: attn_exp / (attn_exp + sink_exp)
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h)
# Apply per-head correction
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd)
# -- Debug: compare FMHA output with SDPA reference --
if li == 0 and positions[0].item() < 2:
# SDPA reference
k_exp = k_full.expand(n_h, -1, -1).contiguous()
v_exp = v_full.expand(n_h, -1, -1).contiguous()
q_ref = q_input # (n_h, T, hd)
attn_ref = torch.nn.functional.scaled_dot_product_attention(
q_ref, k_exp, v_exp, scale=1.0/math.sqrt(hd), is_causal=False)
# Apply inverse RoPE to SDPA output too (since K=V with RoPE)
attn_ref = attn_ref.permute(1, 0, 2) # (T, n_h, hd)
attn_ref = apply_inverse_rope(attn_ref, positions_dev, rope_cos, rope_sin, hd, rd)
# Compare with FMHA output (before sink correction)
o_4d_nosink, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
attn_fmha = o_4d_nosink.squeeze(0).permute(1, 0, 2) # (T, n_h, hd)
attn_fmha = apply_inverse_rope(attn_fmha, positions_dev, rope_cos, rope_sin, hd, rd)
# Cosine similarity
cos_sim = torch.nn.functional.cosine_similarity(
attn_ref.reshape(-1).float(), attn_fmha.reshape(-1).float(), dim=0)
max_diff = (attn_ref.float() - attn_fmha.float()).abs().max()
print(f" L{li} FMHA vs SDPA: cos={cos_sim:.6f} max_diff={max_diff:.6f}", flush=True)
attn_out = attn_out.bfloat16()
# -- Inverse RoPE on attention output (paper §2.3.3) --
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)