FIX: Use vLLM's decode_swa_indices for correct paged KV cache access during decode
This commit is contained in:
@@ -710,6 +710,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
swa_metadata.block_size,
|
||||
self.scale,
|
||||
self.window_size,
|
||||
swa_indices=swa_metadata.decode_swa_indices,
|
||||
swa_lens=swa_metadata.decode_swa_lens,
|
||||
decode_token_idx=t,
|
||||
).squeeze(0)
|
||||
else:
|
||||
# CSA/HCA layers: sparse attention + SWA + sink merge
|
||||
|
||||
@@ -219,29 +219,51 @@ def blackwell_attention_kv_write(
|
||||
def blackwell_attention_decode(
|
||||
q, # (1, NH, HD) single decode query with RoPE
|
||||
positions, # (1,) absolute position
|
||||
swa_kv_cache, # (num_blocks, block_size, HD) fp8 paged cache
|
||||
swa_kv_cache, # (num_blocks, block_size, HD) fp8 SWA cache (uint8)
|
||||
swa_inv_scale, # (max_slots, 1) per-token inv scale
|
||||
slot_mapping, # (1,) slot for the new token (already written)
|
||||
block_size, # tokens per block
|
||||
scale, # 1/sqrt(HD)
|
||||
window_size, # 128
|
||||
swa_indices=None, # (num_decode_tokens, window_size) pre-computed paged indices
|
||||
swa_lens=None, # (num_decode_tokens,) number of valid indices per token
|
||||
decode_token_idx=0, # which decode token this is
|
||||
) -> torch.Tensor:
|
||||
"""Decode attention: read all cached KV, attend.
|
||||
"""Decode attention: read cached KV using paged indices, attend.
|
||||
|
||||
Uses pre-computed swa_indices from vLLM's metadata for correct paged access.
|
||||
Returns: (1, NH, HD) attention output.
|
||||
"""
|
||||
pos = positions[0].item()
|
||||
# Read all KV from position 0 to pos (inclusive)
|
||||
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=q.device)
|
||||
kv_cached_fp8 = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, q.shape[2])
|
||||
kv_inv_scales = swa_inv_scale[all_slots]
|
||||
kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_inv_scales)
|
||||
NH = q.shape[1]
|
||||
HD = q.shape[2]
|
||||
device = q.device
|
||||
|
||||
# Apply SWA window
|
||||
window_start = max(0, pos - window_size + 1)
|
||||
kv_window = kv_cached[window_start:]
|
||||
if swa_indices is not None and swa_lens is not None:
|
||||
# Use pre-computed paged indices from vLLM
|
||||
num_valid = swa_lens[decode_token_idx].item()
|
||||
indices = swa_indices[decode_token_idx, :num_valid]
|
||||
block_indices = indices // block_size
|
||||
offsets = indices % block_size
|
||||
kv_cached_raw = swa_kv_cache[block_indices, offsets]
|
||||
if swa_kv_cache.dtype == torch.uint8:
|
||||
kv_cached_raw = kv_cached_raw.view(torch.float8_e4m3fn)
|
||||
# Dequantize: for now use bf16 cast (fp8 → bf16 without per-token scale)
|
||||
# TODO: store and read per-token inv_scale in paged cache
|
||||
kv_cached = kv_cached_raw.to(torch.bfloat16)
|
||||
else:
|
||||
# Fallback: sequential slot access
|
||||
pos = positions[0].item()
|
||||
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=device)
|
||||
kv_cached_raw = paged_kv_read(all_slots, swa_kv_cache, block_size, pos + 1, HD)
|
||||
kv_inv_scales = swa_inv_scale[all_slots]
|
||||
kv_cached = kv_dequantize_fp8(kv_cached_raw, kv_inv_scales)
|
||||
window_start = max(0, pos - window_size + 1)
|
||||
kv_cached = kv_cached[window_start:]
|
||||
|
||||
return decode_attention(q, kv_window, scale)
|
||||
q_t = q.permute(1, 0, 2)
|
||||
kv_exp = kv_cached.unsqueeze(0).expand(NH, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
|
||||
return out.permute(1, 0, 2)
|
||||
|
||||
|
||||
def full_sdpa_attention(
|
||||
|
||||
Reference in New Issue
Block a user