CRITICAL FIX: use SDPA for short sequences (FMHA padding bug)
FMHA pads N to next multiple of 128. For N<<128 (like 5 tokens), the 123 padded zero-K entries contribute exp(0)=1 to the softmax denominator, diluting real attention weights by ~128/5 = 25.6x. This caused the model to produce incoherent output for short prompts. Fix: use SDPA for seq_len < 120 (no padding), FMHA for longer sequences where the padding effect is negligible. Also: SDPA path includes attention sinks (paper D5c), FMHA path uses analytic sink correction via LSE.
This commit is contained in:
@@ -366,57 +366,53 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||||
seq_len = k_full.shape[1]
|
||||
|
||||
# -- FMHA with sink bias correction (paper D5c) --
|
||||
# -- Attention: SDPA for short seqs (avoids FMHA padding bug), FMHA for long --
|
||||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||||
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
||||
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
q_4d = q_input.unsqueeze(0).contiguous() # (1, n_h, T, hd)
|
||||
k_4d = k_full.unsqueeze(0).contiguous() # (1, 1, seq_len, hd)
|
||||
# CRITICAL: V must be (B, n_h, hd, N) — transposed!
|
||||
v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous() # (1, 1, hd, seq_len)
|
||||
|
||||
o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
|
||||
attn_out = o_4d.squeeze(0).permute(1, 0, 2) # (T, n_h, hd)
|
||||
# FMHA pads N to next multiple of 128. For N<<128, padded zero-K entries
|
||||
# contribute exp(0)=1 to softmax, diluting real attention weights by ~128/N.
|
||||
# Use SDPA for short sequences where padding dominates.
|
||||
if seq_len < 120:
|
||||
k_expanded = k_full.expand(n_h, -1, -1).contiguous()
|
||||
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
|
||||
# Add attention sink (paper D5c)
|
||||
sink_key = f"{pre}.sinks"
|
||||
if sink_key in w and seq_len > 0:
|
||||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||||
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||||
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||||
k_with_sink = torch.cat([k_expanded, sink_k], dim=1)
|
||||
v_with_sink = torch.cat([v_expanded, sink_v], dim=1)
|
||||
sink_bias_mask = torch.zeros(n_h, T, seq_len + 1, dtype=torch.bfloat16, device=device)
|
||||
for h in range(n_h):
|
||||
sink_bias_mask[h, :, -1] = sinks[h]
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_with_sink, v_with_sink,
|
||||
attn_mask=sink_bias_mask, scale=scale)
|
||||
else:
|
||||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_input, k_expanded, v_expanded, scale=scale, is_causal=False)
|
||||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
else:
|
||||
# Use FMHA kernel for longer sequences (padding effect is negligible)
|
||||
from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw
|
||||
q_4d = q_input.unsqueeze(0).contiguous()
|
||||
k_4d = k_full.unsqueeze(0).contiguous()
|
||||
v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous()
|
||||
o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
|
||||
attn_out = o_4d.squeeze(0).permute(1, 0, 2)
|
||||
# Sink correction
|
||||
sink_key = f"{pre}.sinks"
|
||||
if sink_key in w and seq_len > 0:
|
||||
sinks = w[sink_key].to(device=device)
|
||||
lse_2d = lse.squeeze(0).t()
|
||||
sink_exp = torch.exp(sinks.float())
|
||||
attn_exp = torch.exp(lse_2d.float())
|
||||
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10)
|
||||
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16()
|
||||
attn_out = attn_out.bfloat16()
|
||||
|
||||
# Apply sink bias correction: scale the output by softmax_normalizer / (normalizer + exp(sink))
|
||||
# This simulates adding a virtual sink position with V=0 and logit=sink
|
||||
# O_sink = O_raw * exp(lse) / (exp(lse) + exp(sink))
|
||||
sink_key = f"{pre}.sinks"
|
||||
if sink_key in w and seq_len > 0:
|
||||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||||
# lse: (1, n_h, T) — log-sum-exp of attention scores per head
|
||||
lse_2d = lse.squeeze(0).t() # (T, n_h)
|
||||
# For each head, compute the correction factor
|
||||
sink_exp = torch.exp(sinks.float()) # (n_h,)
|
||||
attn_exp = torch.exp(lse_2d.float()) # (T, n_h)
|
||||
# Correction: attn_exp / (attn_exp + sink_exp)
|
||||
correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) # (T, n_h)
|
||||
# Apply per-head correction
|
||||
attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() # (T, n_h, hd)
|
||||
|
||||
# -- Debug: compare FMHA output with SDPA reference --
|
||||
if li == 0 and positions[0].item() < 2:
|
||||
# SDPA reference
|
||||
k_exp = k_full.expand(n_h, -1, -1).contiguous()
|
||||
v_exp = v_full.expand(n_h, -1, -1).contiguous()
|
||||
q_ref = q_input # (n_h, T, hd)
|
||||
attn_ref = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_ref, k_exp, v_exp, scale=1.0/math.sqrt(hd), is_causal=False)
|
||||
# Apply inverse RoPE to SDPA output too (since K=V with RoPE)
|
||||
attn_ref = attn_ref.permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_ref = apply_inverse_rope(attn_ref, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
# Compare with FMHA output (before sink correction)
|
||||
o_4d_nosink, _ = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
|
||||
attn_fmha = o_4d_nosink.squeeze(0).permute(1, 0, 2) # (T, n_h, hd)
|
||||
attn_fmha = apply_inverse_rope(attn_fmha, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
# Cosine similarity
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
attn_ref.reshape(-1).float(), attn_fmha.reshape(-1).float(), dim=0)
|
||||
max_diff = (attn_ref.float() - attn_fmha.float()).abs().max()
|
||||
print(f" L{li} FMHA vs SDPA: cos={cos_sim:.6f} max_diff={max_diff:.6f}", flush=True)
|
||||
|
||||
attn_out = attn_out.bfloat16()
|
||||
|
||||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||||
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
||||
|
||||
Reference in New Issue
Block a user