From 85c74e593279e86e28610ad07ecf20580ef21157 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 15:28:52 +0000 Subject: [PATCH] Fix attention for decode (1 query vs N cached KVs) --- tests/test_decode_attention_b200.py | 38 ++++++++++++++++++----------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/tests/test_decode_attention_b200.py b/tests/test_decode_attention_b200.py index c62be70d..c5d4da1f 100644 --- a/tests/test_decode_attention_b200.py +++ b/tests/test_decode_attention_b200.py @@ -155,20 +155,30 @@ def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim): # ── Attention ──────────────────────────────────────────────────────── def full_causal_attention(q, kv, scale): - """Full causal self-attention. q: (T, NH, HD), kv: (T, HD).""" - T, NH, HD = q.shape - q_2d = q.reshape(T * NH, HD) - kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() - k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD) - v_2d = k_2d.clone() - scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale - query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH) - kv_pos = torch.arange(T, device=q.device).unsqueeze(0) - causal = kv_pos <= query_pos.unsqueeze(1) - scores = scores.squeeze(1).masked_fill(~causal, float('-inf')) - weights = F.softmax(scores.float(), dim=-1).to(q.dtype) - out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1) - return out.reshape(T, NH, HD) + """Full causal self-attention. q: (T_q, NH, HD), kv: (T_kv, HD). + + Works for prefill (T_q == T_kv) and decode (T_q == 1, T_kv > 1). + Uses SDPA for efficiency. + """ + T_q, NH, HD = q.shape + T_kv = kv.shape[0] + + # q: (NH, T_q, HD), k/v: (NH, T_kv, HD) — shared KV across heads + q_t = q.permute(1, 0, 2) # (NH, T_q, HD) + kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T_kv, HD) + v_exp = kv_exp.clone() + + # Causal mask: query at position i can attend to positions <= i + # For decode (T_q=1), all T_kv positions are valid (position T_kv-1 attends to 0..T_kv-1) + if T_q == T_kv: + # Prefill: standard causal + attn_mask = torch.tril(torch.ones(T_q, T_kv, device=q.device, dtype=torch.bool)).unsqueeze(0).expand(NH, -1, -1) + out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, attn_mask=attn_mask, scale=scale) + else: + # Decode or mixed: no masking needed (all positions are in the past) + out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, is_causal=False, scale=scale) + + return out.permute(1, 0, 2) # (T_q, NH, HD) def swa_decode_attention(q_new, kv_cache_bf16, positions_new, scale, window_size=WINDOW):