diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index 3165490d..c8182531 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -710,6 +710,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): swa_metadata.block_size, self.scale, self.window_size, + swa_indices=swa_metadata.decode_swa_indices, + swa_lens=swa_metadata.decode_swa_lens, + decode_token_idx=t, ).squeeze(0) else: # CSA/HCA layers: sparse attention + SWA + sink merge diff --git a/vllm/patches/layers/csa_attention.py b/vllm/patches/layers/csa_attention.py index 7137b486..c315fed5 100644 --- a/vllm/patches/layers/csa_attention.py +++ b/vllm/patches/layers/csa_attention.py @@ -219,29 +219,51 @@ def blackwell_attention_kv_write( def blackwell_attention_decode( q, # (1, NH, HD) single decode query with RoPE positions, # (1,) absolute position - swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache + swa_kv_cache, # (num_blocks, block_size, HD) fp8 SWA cache (uint8) swa_inv_scale, # (max_slots, 1) per-token inv scale slot_mapping, # (1,) slot for the new token (already written) block_size, # tokens per block scale, # 1/sqrt(HD) window_size, # 128 + swa_indices=None, # (num_decode_tokens, window_size) pre-computed paged indices + swa_lens=None, # (num_decode_tokens,) number of valid indices per token + decode_token_idx=0, # which decode token this is ) -> torch.Tensor: - """Decode attention: read all cached KV, attend. + """Decode attention: read cached KV using paged indices, attend. + Uses pre-computed swa_indices from vLLM's metadata for correct paged access. Returns: (1, NH, HD) attention output. """ - pos = positions[0].item() - # Read all KV from position 0 to pos (inclusive) - all_slots = torch.arange(pos + 1, dtype=torch.int64, device=q.device) - kv_cached_fp8 = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, q.shape[2]) - kv_inv_scales = swa_inv_scale[all_slots] - kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_inv_scales) + NH = q.shape[1] + HD = q.shape[2] + device = q.device - # Apply SWA window - window_start = max(0, pos - window_size + 1) - kv_window = kv_cached[window_start:] + if swa_indices is not None and swa_lens is not None: + # Use pre-computed paged indices from vLLM + num_valid = swa_lens[decode_token_idx].item() + indices = swa_indices[decode_token_idx, :num_valid] + block_indices = indices // block_size + offsets = indices % block_size + kv_cached_raw = swa_kv_cache[block_indices, offsets] + if swa_kv_cache.dtype == torch.uint8: + kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn) + # Dequantize: for now use bf16 cast (fp8 → bf16 without per-token scale) + # TODO: store and read per-token inv_scale in paged cache + kv_cached = kv_cached_raw.to(torch.bfloat16) + else: + # Fallback: sequential slot access + pos = positions[0].item() + all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device) + kv_cached_raw = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD) + kv_inv_scales = swa_inv_scale[all_slots] + kv_cached = kv_dequantize_fp8(kv_cached_raw, kv_inv_scales) + window_start = max(0, pos - window_size + 1) + kv_cached = kv_cached[window_start:] - return decode_attention(q, kv_window, scale) + q_t = q.permute(1, 0, 2) + kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1) + out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale) + return out.permute(1, 0, 2) def full_sdpa_attention(