Fix attention for decode (1 query vs N cached KVs)
This commit is contained in:
@@ -155,20 +155,30 @@ def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
|
||||
# ── Attention ────────────────────────────────────────────────────────
|
||||
|
||||
def full_causal_attention(q, kv, scale):
|
||||
"""Full causal self-attention. q: (T, NH, HD), kv: (T, HD)."""
|
||||
T, NH, HD = q.shape
|
||||
q_2d = q.reshape(T * NH, HD)
|
||||
kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
|
||||
k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
|
||||
v_2d = k_2d.clone()
|
||||
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
|
||||
query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
|
||||
kv_pos = torch.arange(T, device=q.device).unsqueeze(0)
|
||||
causal = kv_pos <= query_pos.unsqueeze(1)
|
||||
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
|
||||
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
|
||||
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
|
||||
return out.reshape(T, NH, HD)
|
||||
"""Full causal self-attention. q: (T_q, NH, HD), kv: (T_kv, HD).
|
||||
|
||||
Works for prefill (T_q == T_kv) and decode (T_q == 1, T_kv > 1).
|
||||
Uses SDPA for efficiency.
|
||||
"""
|
||||
T_q, NH, HD = q.shape
|
||||
T_kv = kv.shape[0]
|
||||
|
||||
# q: (NH, T_q, HD), k/v: (NH, T_kv, HD) — shared KV across heads
|
||||
q_t = q.permute(1, 0, 2) # (NH, T_q, HD)
|
||||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T_kv, HD)
|
||||
v_exp = kv_exp.clone()
|
||||
|
||||
# Causal mask: query at position i can attend to positions <= i
|
||||
# For decode (T_q=1), all T_kv positions are valid (position T_kv-1 attends to 0..T_kv-1)
|
||||
if T_q == T_kv:
|
||||
# Prefill: standard causal
|
||||
attn_mask = torch.tril(torch.ones(T_q, T_kv, device=q.device, dtype=torch.bool)).unsqueeze(0).expand(NH, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, attn_mask=attn_mask, scale=scale)
|
||||
else:
|
||||
# Decode or mixed: no masking needed (all positions are in the past)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, is_causal=False, scale=scale)
|
||||
|
||||
return out.permute(1, 0, 2) # (T_q, NH, HD)
|
||||
|
||||
|
||||
def swa_decode_attention(q_new, kv_cache_bf16, positions_new, scale, window_size=WINDOW):
|
||||
|
||||
Reference in New Issue
Block a user