[Core] Support Local Chunked Attention for Hybrid KV Cache (#19351)
Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lu Fang <fanglu@meta.com>
This commit is contained in:
@@ -11,7 +11,8 @@ from typing import Any, Callable, NamedTuple, Optional
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
@@ -976,7 +977,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
|
||||
has_sliding_window = any(
|
||||
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
|
||||
if has_full_attention and has_sliding_window:
|
||||
has_chunked_local_attention = any(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
for spec in kv_cache_spec.values())
|
||||
if has_full_attention and (has_sliding_window
|
||||
or has_chunked_local_attention):
|
||||
for layer_name, spec in kv_cache_spec.items():
|
||||
if isinstance(spec, SlidingWindowSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@@ -987,6 +992,15 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
use_mla=spec.use_mla,
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
elif isinstance(spec, ChunkedLocalAttentionSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
@@ -1010,7 +1024,6 @@ def get_kv_cache_config(
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
|
||||
@@ -394,6 +394,129 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
return 0
|
||||
|
||||
|
||||
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec,
|
||||
block_pool: BlockPool, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
prefix of the blocks that is not longer than `max_length`. The prefix
|
||||
should be a common prefix hit for all the kv cache groups in
|
||||
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
|
||||
note we mark as computed if the whole block is outside of the local
|
||||
window, and set the block as null. Examples:
|
||||
|
||||
1. Attention chunk size of 8, block size of 4, max length of 15
|
||||
for next token at 15th (zero-indexed), 8th - 14th tokens are in
|
||||
the window(needs lookup), 0th - 7th are not in the window,
|
||||
so they are already marked as computed. We check the complete
|
||||
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
|
||||
[null, null, block 3], otherwise, we return [null, null]
|
||||
|
||||
2. Attention chunk size of 8, block size of 4, max length of 16
|
||||
for next token at 16th (zero-indexed), 0th - 15th tokens are not
|
||||
in the window, so they are already marked as computed.
|
||||
we return 4 blocks[null, null, null, null]
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks
|
||||
"""
|
||||
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
|
||||
"ChunkedLocalAttentionManager can only be used for " +
|
||||
"chunked local attention groups")
|
||||
assert use_eagle is False, ("Hybrid KV cache is not supported for " +
|
||||
"eagle + chunked local attention.")
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (max_length //
|
||||
kv_cache_spec.attention_chunk_size *
|
||||
kv_cache_spec.attention_chunk_size)
|
||||
else:
|
||||
local_attention_start_idx = 0
|
||||
# we marked blocks out of window as computed
|
||||
# with null blocks, and blocks inside window based on cache lookup
|
||||
# result [null] [null] ... [null] [hit block 1 (1st block contain
|
||||
# last window)] [hit block 2] ... [hit block x]
|
||||
local_attention_start_block_idx = (local_attention_start_idx //
|
||||
kv_cache_spec.block_size)
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[block_pool.null_block] * local_attention_start_block_idx
|
||||
for _ in range(len(kv_cache_group_ids)))
|
||||
for i in range(local_attention_start_block_idx, max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Remove the blocks that are no longer be in the chunked attention
|
||||
# window and skipped during the attention computation.
|
||||
|
||||
# [chunk 0][chunk 1]local_attention_start_idx ... current
|
||||
# we computed previous number of chunks to get the idx of
|
||||
# current chunk window starting offset,
|
||||
# e.g. for computed 1024 tokens, the 1024th token (0 indexed)
|
||||
# is in the second chunk, there are 1 prev chunk, the start idx
|
||||
# is 1024. for 1023, it will be 0.
|
||||
num_cached_block = self.num_cached_block.get(request_id, 0)
|
||||
local_attention_start_idx = (
|
||||
num_computed_tokens
|
||||
) // self.attention_chunk_size * self.attention_chunk_size
|
||||
first_useful_block_idx = local_attention_start_idx // self.block_size
|
||||
if num_cached_block > 0:
|
||||
# Make sure we don't delete the last cached block
|
||||
first_useful_block_idx = min(first_useful_block_idx,
|
||||
num_cached_block - 1)
|
||||
# if block size = 128, 0 -> block 0, 1024 (= 128 * 8) ->
|
||||
# block 8, 372 (= 128 * 2 + 116) -> block 2
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# we need to keep the last block to get the previous hash key
|
||||
for i in range(first_useful_block_idx - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
cascade attention is not supported by chunked local attention.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
@classmethod
|
||||
@@ -435,8 +558,8 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
ChunkedLocalAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user