[Core] Support Local Chunked Attention for Hybrid KV Cache (#19351)

Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Lu Fang <fanglu@meta.com>
This commit is contained in:
Lucia Fang
2025-07-19 11:48:38 +08:00
committed by GitHub
parent 466e878f2a
commit 9a9fda1423
9 changed files with 351 additions and 19 deletions

View File

@@ -11,7 +11,8 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
@@ -976,7 +977,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
has_chunked_local_attention = any(
isinstance(spec, ChunkedLocalAttentionSpec)
for spec in kv_cache_spec.values())
if has_full_attention and (has_sliding_window
or has_chunked_local_attention):
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
@@ -987,6 +992,15 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
elif isinstance(spec, ChunkedLocalAttentionSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
attention_chunk_size=spec.attention_chunk_size,
)
if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
@@ -1010,7 +1024,6 @@ def get_kv_cache_config(
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)

View File

@@ -394,6 +394,129 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return 0
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
block_pool: BlockPool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
prefix of the blocks that is not longer than `max_length`. The prefix
should be a common prefix hit for all the kv cache groups in
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
note we mark as computed if the whole block is outside of the local
window, and set the block as null. Examples:
1. Attention chunk size of 8, block size of 4, max length of 15
for next token at 15th (zero-indexed), 8th - 14th tokens are in
the window(needs lookup), 0th - 7th are not in the window,
so they are already marked as computed. We check the complete
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
[null, null, block 3], otherwise, we return [null, null]
2. Attention chunk size of 8, block size of 4, max length of 16
for next token at 16th (zero-indexed), 0th - 15th tokens are not
in the window, so they are already marked as computed.
we return 4 blocks[null, null, null, null]
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks
"""
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
"ChunkedLocalAttentionManager can only be used for " +
"chunked local attention groups")
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
"eagle + chunked local attention.")
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (max_length //
kv_cache_spec.attention_chunk_size *
kv_cache_spec.attention_chunk_size)
else:
local_attention_start_idx = 0
# we marked blocks out of window as computed
# with null blocks, and blocks inside window based on cache lookup
# result [null] [null] ... [null] [hit block 1 (1st block contain
# last window)] [hit block 2] ... [hit block x]
local_attention_start_block_idx = (local_attention_start_idx //
kv_cache_spec.block_size)
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[block_pool.null_block] * local_attention_start_block_idx
for _ in range(len(kv_cache_group_ids)))
for i in range(local_attention_start_block_idx, max_num_blocks):
block_hash = block_hashes[i]
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the chunked attention
# window and skipped during the attention computation.
# [chunk 0][chunk 1]local_attention_start_idx ... current
# we computed previous number of chunks to get the idx of
# current chunk window starting offset,
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
# is in the second chunk, there are 1 prev chunk, the start idx
# is 1024. for 1023, it will be 0.
num_cached_block = self.num_cached_block.get(request_id, 0)
local_attention_start_idx = (
num_computed_tokens
) // self.attention_chunk_size * self.attention_chunk_size
first_useful_block_idx = local_attention_start_idx // self.block_size
if num_cached_block > 0:
# Make sure we don't delete the last cached block
first_useful_block_idx = min(first_useful_block_idx,
num_cached_block - 1)
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
# block 8, 372 (= 128 * 2 + 116) -> block 2
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# we need to keep the last block to get the previous hash key
for i in range(first_useful_block_idx - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
cascade attention is not supported by chunked local attention.
"""
return 0
class MambaManager(SingleTypeKVCacheManager):
@classmethod
@@ -435,8 +558,8 @@ class MambaManager(SingleTypeKVCacheManager):
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
ChunkedLocalAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
}