Fix attention for decode (1 query vs N cached KVs)

This commit is contained in:
2026-05-19 15:28:52 +00:00
parent 85099c7e75
commit 85c74e5932

View File

@@ -155,20 +155,30 @@ def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
# ── Attention ────────────────────────────────────────────────────────
def full_causal_attention(q, kv, scale):
"""Full causal self-attention. q: (T, NH, HD), kv: (T, HD)."""
T, NH, HD = q.shape
q_2d = q.reshape(T * NH, HD)
kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
v_2d = k_2d.clone()
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
kv_pos = torch.arange(T, device=q.device).unsqueeze(0)
causal = kv_pos <= query_pos.unsqueeze(1)
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
return out.reshape(T, NH, HD)
"""Full causal self-attention. q: (T_q, NH, HD), kv: (T_kv, HD).
Works for prefill (T_q == T_kv) and decode (T_q == 1, T_kv > 1).
Uses SDPA for efficiency.
"""
T_q, NH, HD = q.shape
T_kv = kv.shape[0]
# q: (NH, T_q, HD), k/v: (NH, T_kv, HD) — shared KV across heads
q_t = q.permute(1, 0, 2) # (NH, T_q, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T_kv, HD)
v_exp = kv_exp.clone()
# Causal mask: query at position i can attend to positions <= i
# For decode (T_q=1), all T_kv positions are valid (position T_kv-1 attends to 0..T_kv-1)
if T_q == T_kv:
# Prefill: standard causal
attn_mask = torch.tril(torch.ones(T_q, T_kv, device=q.device, dtype=torch.bool)).unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, attn_mask=attn_mask, scale=scale)
else:
# Decode or mixed: no masking needed (all positions are in the past)
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2) # (T_q, NH, HD)
def swa_decode_attention(q_new, kv_cache_bf16, positions_new, scale, window_size=WINDOW):