[v1] Hybrid Memory Allocator (#17996)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -7,8 +7,8 @@ from typing import Callable, Optional
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.request import Request
|
||||
@@ -27,6 +27,7 @@ class BlockPool:
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -56,7 +57,7 @@ class BlockPool:
|
||||
# if there is already an identical block in the cache. This is because
|
||||
# we want to make sure the allocated block IDs won't change so that
|
||||
# block tables are append-only.
|
||||
self.cached_block_hash_to_block: dict[BlockHash, dict[
|
||||
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
@@ -68,22 +69,29 @@ class BlockPool:
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
def get_cached_block(self,
|
||||
block_hash: BlockHash) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
def get_cached_block(
|
||||
self, block_hash: BlockHash,
|
||||
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
|
||||
"""Get the cached block by the block hash for each group in
|
||||
`kv_cache_group_ids`, or None if cache miss for any group.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
kv_cache_group_ids: The ids of the KV cache groups.
|
||||
|
||||
Returns:
|
||||
The cached block if it exists, or None.
|
||||
The cached blocks if exists, or None.
|
||||
"""
|
||||
cached_blocks = self.cached_block_hash_to_block.get(block_hash)
|
||||
if not cached_blocks:
|
||||
return None
|
||||
first_block_id = next(iter(cached_blocks))
|
||||
return cached_blocks[first_block_id]
|
||||
cached_blocks = []
|
||||
for group_id in kv_cache_group_ids:
|
||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
||||
BlockHashWithGroupId(block_hash, group_id))
|
||||
if not cached_blocks_one_group:
|
||||
return None
|
||||
first_block_id = next(iter(cached_blocks_one_group))
|
||||
cached_blocks.append(cached_blocks_one_group[first_block_id])
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
self,
|
||||
@@ -93,6 +101,7 @@ class BlockPool:
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
hash_fn: Callable,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
@@ -112,6 +121,7 @@ class BlockPool:
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
hash_fn: The hash function to use for block hashes.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
@@ -126,7 +136,7 @@ class BlockPool:
|
||||
else:
|
||||
prev_block = blocks[num_cached_blocks - 1]
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.hash_value
|
||||
prev_block_hash_value = prev_block.block_hash.get_hash_value()
|
||||
|
||||
parent_block_hash = prev_block_hash_value
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
@@ -138,8 +148,9 @@ class BlockPool:
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption). In this case we simply
|
||||
# reuse the block hash.
|
||||
# generated tokens with preemption), or by other
|
||||
# single_type_managers with the same block_size.
|
||||
# In this case we simply reuse the block hash.
|
||||
block_hash = new_block_hashes[i]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
@@ -166,8 +177,11 @@ class BlockPool:
|
||||
block_hashes.append(block_hash)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
block_hash_with_group_id = BlockHashWithGroupId(
|
||||
block_hash, kv_cache_group_id)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
||||
blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
@@ -237,12 +251,16 @@ class BlockPool:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
# or `(hash_value, group_id)` here. But it's fine now because
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.hash_value]))
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
|
||||
return True
|
||||
return False
|
||||
|
||||
def touch(self, blocks: list[KVCacheBlock]) -> None:
|
||||
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
@@ -250,12 +268,13 @@ class BlockPool:
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
for blocks_per_group in blocks:
|
||||
for block in blocks_per_group:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
|
||||
358
vllm/v1/core/kv_cache_coordinator.py
Normal file
358
vllm/v1/core/kv_cache_coordinator.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, SingleTypeKVCacheManager,
|
||||
get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class KVCacheCoordinator(ABC):
|
||||
"""
|
||||
Coordinate the KV cache of different KV cache groups.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_managers: list[SingleTypeKVCacheManager] = []
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
||||
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
i].kv_cache_spec
|
||||
self.single_type_managers.append(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
))
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i])
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
manager.save_new_computed_blocks(request_id,
|
||||
new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[list[KVCacheBlock]]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
new_blocks = []
|
||||
for manager in self.single_type_managers:
|
||||
new_blocks.append(
|
||||
manager.allocate_new_blocks(request_id, num_tokens))
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, block_hashes, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.free(request_id)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
block_hashes: The block hashes of the request.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks.
|
||||
"""
|
||||
num_blocks_per_group = [
|
||||
manager.get_num_common_prefix_blocks(request_id,
|
||||
num_running_requests)
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
return num_blocks_per_group
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
|
||||
"""
|
||||
Get the blocks for the request.
|
||||
"""
|
||||
return [
|
||||
manager.req_to_blocks[request_id]
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
pass
|
||||
|
||||
|
||||
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for models with only one KV cache group. This is the
|
||||
case for models with only one KV cache type, e.g., all attention layers use
|
||||
full attention or all attention layers use sliding window attention.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
|
||||
class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for hybrid models with multiple KV cache types, and
|
||||
thus multiple kv cache groups.
|
||||
To simplify `find_longest_cache_hit`, it only supports the combination of
|
||||
two types of KV cache groups, and one of them must be full attention.
|
||||
May extend to more general cases in the future.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
"""
|
||||
Verifies that the model has exactly two types of KV cache groups, and
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_type_id: Optional[str] = None
|
||||
other_type_id: Optional[str] = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_type_id is None:
|
||||
full_attention_type_id = g.kv_cache_spec.type_id
|
||||
else:
|
||||
assert full_attention_type_id == g.kv_cache_spec.type_id, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now.")
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_type_id is None:
|
||||
other_type_id = g.kv_cache_spec.type_id
|
||||
else:
|
||||
assert other_type_id == g.kv_cache_spec.type_id, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now.")
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_type_id is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now.")
|
||||
assert other_type_id is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other "
|
||||
"groups now.")
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]].__class__
|
||||
|
||||
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.full_attention_group_ids[0]].kv_cache_spec
|
||||
self.other_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.other_group_ids[0]].kv_cache_spec
|
||||
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
assert self.other_block_size % self.full_attention_block_size == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full attention "
|
||||
"layers is divisible by other layers now.")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
"""
|
||||
Find the longest cache hit for the request.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_cache_hit_length: The maximum length of the cache hit.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of the cache hit blocks for each single type manager.
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
hit_blocks_full_attn = (
|
||||
self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
))
|
||||
hit_length = len(
|
||||
hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
hit_blocks_other_attn = (
|
||||
self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
))
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
# NOTE: the prefix cache hit length must be a multiply of block_size as
|
||||
# we don't support partial block cache hit yet. The cache hit length
|
||||
# of other attention is ensured to be a multiply of the block size of
|
||||
# full attention layers in current implementation, because hit_length is
|
||||
# a multiply of other attention's block size, and other attention's
|
||||
# block size is a multiply of full attention's block size (verified in
|
||||
# `verify_and_split_kv_cache_groups`).
|
||||
assert hit_length % self.full_attention_block_size == 0
|
||||
|
||||
# Truncate the full attention cache hit to the length of the
|
||||
# cache hit of the other attention.
|
||||
for i in range(len(hit_blocks_full_attn)):
|
||||
del hit_blocks_full_attn[i][hit_length //
|
||||
self.full_attention_block_size:]
|
||||
|
||||
# Merge the hit blocks of full attention and other attention.
|
||||
hit_blocks = hit_blocks_other_attn
|
||||
for group_id, blocks in enumerate(hit_blocks_full_attn):
|
||||
# NOTE: there is only one full attention group in most cases. So
|
||||
# the time complexity of insert is fine.
|
||||
hit_blocks.insert(group_id, blocks)
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool, caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
else:
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
@@ -8,11 +8,9 @@ from typing import Optional
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@@ -22,16 +20,24 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
blocks: list[KVCacheBlock]
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
blocks: list[list[KVCacheBlock]]
|
||||
"""
|
||||
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
||||
We don't use block of tokens as the outer dimension because it assumes all
|
||||
kv_cache_groups have the same number of blocks, which is true for now but
|
||||
will be broken if we want to give different block_size to different
|
||||
kv_cache_groups in the future.
|
||||
"""
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(self.blocks + other.blocks)
|
||||
|
||||
@classmethod
|
||||
def create_empty(cls) -> "KVCacheBlocks":
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return cls([])
|
||||
return KVCacheBlocks(
|
||||
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
|
||||
|
||||
def get_block_ids(self) -> list[list[int]]:
|
||||
"""
|
||||
@@ -39,15 +45,20 @@ class KVCacheBlocks:
|
||||
|
||||
Returns:
|
||||
list[list[int]]: A two-level list where
|
||||
* the outer list corresponds to KV cache groups (only 1 group now)
|
||||
* the outer list corresponds to KV cache groups
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
return [[block.block_id for block in self.blocks]]
|
||||
block_ids = []
|
||||
for group in self.blocks:
|
||||
block_ids.append([blk.block_id for blk in group])
|
||||
return block_ids
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [
|
||||
block.block_id for block in self.blocks if block.block_hash is None
|
||||
block.block_id for block in self.blocks[0]
|
||||
if block.block_hash is None
|
||||
]
|
||||
|
||||
|
||||
@@ -63,12 +74,6 @@ class KVCacheManager:
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
||||
"KVCacheManager does not support hybrid models with more than 1 "
|
||||
"kv cache group")
|
||||
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_gpu_blocks = kv_cache_config.num_blocks
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
@@ -77,17 +82,24 @@ class KVCacheManager:
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
num_kv_cache_groups=1,
|
||||
enable_caching=enable_caching,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
@@ -133,7 +145,7 @@ class KVCacheManager:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
or request.sampling_params.prompt_logprobs is not None):
|
||||
return KVCacheBlocks.create_empty(), 0
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
@@ -154,20 +166,16 @@ class KVCacheManager:
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes, max_cache_hit_length)
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(block_hashes,
|
||||
max_cache_hit_length))
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.queries += request.num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||
|
||||
return KVCacheBlocks(computed_blocks), num_computed_tokens
|
||||
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
@@ -220,7 +228,9 @@ class KVCacheManager:
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = []
|
||||
new_computed_block_list = [
|
||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
|
||||
]
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
@@ -228,8 +238,8 @@ class KVCacheManager:
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.single_type_manager.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens)
|
||||
self.coordinator.remove_skipped_blocks(request.request_id,
|
||||
request.num_computed_tokens)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
@@ -238,12 +248,12 @@ class KVCacheManager:
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len)
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
))
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
@@ -253,16 +263,16 @@ class KVCacheManager:
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not new_computed_block_list, (
|
||||
assert all(not blocks for blocks in new_computed_block_list), (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list)
|
||||
self.coordinator.save_new_computed_blocks(request.request_id,
|
||||
new_computed_block_list)
|
||||
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
@@ -273,7 +283,7 @@ class KVCacheManager:
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
self.single_type_manager.cache_blocks(
|
||||
self.coordinator.cache_blocks(
|
||||
request, self.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens + num_new_tokens - num_draft_tokens)
|
||||
|
||||
@@ -287,7 +297,7 @@ class KVCacheManager:
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
self.single_type_manager.free(request.request_id)
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
@@ -345,10 +355,8 @@ class KVCacheManager:
|
||||
group.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
return [
|
||||
self.single_type_manager.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
]
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
@@ -368,6 +376,15 @@ class KVCacheManager:
|
||||
|
||||
def get_block_ids(self, request_id: str) -> list[list[int]]:
|
||||
"""Get the block ids of a request."""
|
||||
assert request_id in self.single_type_manager.req_to_blocks
|
||||
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
|
||||
).get_block_ids()
|
||||
return KVCacheBlocks(
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request."""
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""KV-Cache Utilities."""
|
||||
|
||||
import os
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
@@ -33,6 +34,18 @@ class BlockHash(NamedTuple):
|
||||
extra_keys: Optional[Any] = None
|
||||
|
||||
|
||||
class BlockHashWithGroupId(NamedTuple):
|
||||
# The hash value for the contents (e.g., token_ids) of a block without group
|
||||
# ID. The value is the same for blocks representing the same tokens but for
|
||||
# different groups.
|
||||
block_hash: BlockHash
|
||||
# The KV cache group ID.
|
||||
group_id: int
|
||||
|
||||
def get_hash_value(self) -> int:
|
||||
return self.block_hash.hash_value
|
||||
|
||||
|
||||
# The hash seed for the first block of the prefix block sequence.
|
||||
#
|
||||
# Even if the hash function is the builtin hash(), we use sha256 to generate
|
||||
@@ -44,7 +57,7 @@ class BlockHash(NamedTuple):
|
||||
# This aligns with the behavior of Python's hash() function, which also uses
|
||||
# a random seed if PYTHONHASHSEED is not set.
|
||||
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
|
||||
'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED'))
|
||||
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
@@ -118,7 +131,7 @@ class KVCacheBlock:
|
||||
ref_cnt: int = 0
|
||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
||||
# It is only available when the block is full.
|
||||
_block_hash: Optional[BlockHash] = None
|
||||
_block_hash: Optional[BlockHashWithGroupId] = None
|
||||
|
||||
# Used to construct a doubly linked list for free blocks.
|
||||
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
||||
@@ -135,11 +148,11 @@ class KVCacheBlock:
|
||||
self.ref_cnt -= 1
|
||||
|
||||
@property
|
||||
def block_hash(self) -> Optional[BlockHash]:
|
||||
def block_hash(self) -> Optional[BlockHashWithGroupId]:
|
||||
return self._block_hash
|
||||
|
||||
@block_hash.setter
|
||||
def block_hash(self, block_hash: BlockHash):
|
||||
def block_hash(self, block_hash: BlockHashWithGroupId):
|
||||
assert self.block_hash is None, (
|
||||
"The block already has a hash. This should not happen.")
|
||||
self._block_hash = block_hash
|
||||
@@ -151,10 +164,10 @@ class KVCacheBlock:
|
||||
def __repr__(self) -> str:
|
||||
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
|
||||
# on KVCacheBlock object recursively.
|
||||
prev_block_id = self.prev_free_block.block_id \
|
||||
if self.prev_free_block else None
|
||||
next_block_id = self.next_free_block.block_id \
|
||||
if self.next_free_block else None
|
||||
prev_block_id = (self.prev_free_block.block_id
|
||||
if self.prev_free_block else None)
|
||||
next_block_id = (self.next_free_block.block_id
|
||||
if self.next_free_block else None)
|
||||
return (f"KVCacheBlock(block_id={self.block_id}, "
|
||||
f"ref_cnt={self.ref_cnt}, "
|
||||
f"_block_hash={self._block_hash}, "
|
||||
@@ -570,20 +583,20 @@ def create_kv_cache_group_specs(
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Create KVCacheGroupSpec object for each kv cache group layer.
|
||||
The layers in the same group should share the same
|
||||
KVCacheSpec.
|
||||
Create KVCacheGroupSpec object for each kv cache group layer.
|
||||
The layers in the same group should share the same
|
||||
KVCacheSpec.
|
||||
|
||||
Args:
|
||||
kv_cache_spec:
|
||||
A mapping from each layer name to its corresponding KVCacheSpec.
|
||||
grouped_layer_names:
|
||||
A list of kv cache groups, where each element is a list of layer
|
||||
names that belong to the same group and should share the same
|
||||
KVCacheSpec.
|
||||
Returns:
|
||||
A list of KVCacheGroupSpec objects, one for each group.
|
||||
"""
|
||||
Args:
|
||||
kv_cache_spec:
|
||||
A mapping from each layer name to its corresponding KVCacheSpec.
|
||||
grouped_layer_names:
|
||||
A list of kv cache groups, where each element is a list of layer
|
||||
names that belong to the same group and should share the same
|
||||
KVCacheSpec.
|
||||
Returns:
|
||||
A list of KVCacheGroupSpec objects, one for each group.
|
||||
"""
|
||||
kv_cache_groups = []
|
||||
for layer_names_one_group in grouped_layer_names:
|
||||
layer_specs = [
|
||||
@@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config(
|
||||
return max_concurrency
|
||||
|
||||
|
||||
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
available_memory: int, page_size: int) -> int:
|
||||
"""
|
||||
Get the number of kv cache blocks.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
num_layers: The number of layers
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
page_size: The page size of the KV cache.
|
||||
"""
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
return num_blocks
|
||||
|
||||
|
||||
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
"""
|
||||
Get the page size of the KV cache.
|
||||
"""
|
||||
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
|
||||
assert len(page_sizes) == 1
|
||||
return page_sizes.pop()
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
@@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
assert len(page_sizes) == 1
|
||||
page_size = page_sizes.pop()
|
||||
|
||||
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
|
||||
num_blocks = max(num_blocks, 0)
|
||||
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
|
||||
available_memory, page_size)
|
||||
|
||||
per_layer_size = page_size * num_blocks
|
||||
# All layers have the same KV cache spec, so we create one kv cache group
|
||||
# for all layers.
|
||||
grouped_layer_names = [list(kv_cache_spec.keys())]
|
||||
|
||||
# Each layer uses a separate Tensor to store its KV cache.
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
|
||||
for layer_name in kv_cache_spec
|
||||
]
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={
|
||||
layer_name: KVCacheTensor(size=per_layer_size)
|
||||
for layer_name in kv_cache_spec
|
||||
},
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layer_names),
|
||||
)
|
||||
@@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def is_kv_cache_page_size_uniform(
|
||||
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same page size.
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
True if all layers have the same page size, False otherwise.
|
||||
"""
|
||||
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
return len(page_sizes) == 1
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_page_size(
|
||||
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for hybrid models with multiple
|
||||
attention types but still with a uniform page size (physical memory per
|
||||
block per layer) for all layers.
|
||||
|
||||
Detailed explanation about kv cache management of hybrid models:
|
||||
The layers in the models are repeated with some patterns, e.g., a model
|
||||
with 10 full attention layers and 20 sliding window attention layers can be
|
||||
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
|
||||
The KVCacheManager allocates different block tables for each of the 3 layers
|
||||
in the pattern, and repeats each of them 10 times to generate the
|
||||
block_table for the 30 layers in the model.
|
||||
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
|
||||
of which contains 10 layers in the model.
|
||||
The KVCacheManager allocates the block_table for each group based on its
|
||||
kv_cache spec, and the model runner applies the block table to each layer
|
||||
in the group.
|
||||
For example:
|
||||
1. A model only uses full attention. The pattern is
|
||||
(num_hidden_layers * full), so there is only one group and the block table
|
||||
is shared by all layers. It is already handled by
|
||||
`_get_kv_cache_config_uniform_type`.
|
||||
2. A model with 10 full attention layers and 20 sliding window
|
||||
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
|
||||
there are 3 kv_cache_groups, each of which represents 10 layers.
|
||||
|
||||
To simplify the implementation, we make the following assumptions:
|
||||
1. Physical memory per block: Must be the same across all KV cache groups.
|
||||
Breaking this assumption is non-trivial due to memory fragmentation concerns
|
||||
when allocating blocks of different sizes.
|
||||
2. Tokens per block (block_size): Currently, we directly use
|
||||
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
|
||||
cache group, but within each KV cache group, all layers must share the same
|
||||
block size.
|
||||
3. Physical memory per token per layer: This property is decided by model
|
||||
config. Currently we only support models that have the same physical memory
|
||||
per token per layer for all layers. Can be relaxed with a simple extension,
|
||||
but still need to keep physical memory per block the same for all groups.
|
||||
4. Number of layers per group: Currently assumed the same for all layers.
|
||||
Can be relaxed with a simple extension, but still need to keep physical
|
||||
memory per block the same for all groups.
|
||||
5. Attention type within groups: All layers in a group must share the same
|
||||
attention type. One exception is that, when
|
||||
`--disable-hybrid-kv-cache-manager` is true, the single group for full
|
||||
attention layers may also include attention layers using sliding window or
|
||||
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
|
||||
6. Support for multiple attention types: The design for most components is
|
||||
general to an arbitrary number of attention types. But
|
||||
`find_longest_cache_hit` only supports one attention type or two
|
||||
types of full-attention plus exactly one another type. The general
|
||||
implementation of this function is feasible but we don't know how to
|
||||
implement it cleanly yet.
|
||||
|
||||
As we assume tokens per block, physical memory per token per layer, and
|
||||
number of layers per group are the same now, we can ensure that physical
|
||||
memory per block is the same for all groups.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
# Group all layers by type_id.
|
||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
||||
same_type_layers: dict[str, list[str]] = defaultdict(list)
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
same_type_layers[layer_spec.type_id].append(layer_name)
|
||||
|
||||
# Split each group into smaller groups, to make the number of layers in each
|
||||
# group identical. Add padding to the last group of each type if necessary.
|
||||
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
|
||||
# split to 3 groups with 2 layers each:
|
||||
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
|
||||
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
|
||||
# open-source hybrid model follows a n:1 pattern between different attention
|
||||
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
|
||||
# full), so we can use the "1" in the n:1 pattern as the group size, which
|
||||
# is the minimum number of layers among all attention types. Need a better
|
||||
# strategy if we want to support more complex patterns (e.g., 20 full + 30
|
||||
# sw, where the group size should be 10).
|
||||
group_size = min([len(layers) for layers in same_type_layers.values()])
|
||||
grouped_layers = []
|
||||
for layers in same_type_layers.values():
|
||||
num_padding_layers = group_size - len(layers) % group_size
|
||||
if num_padding_layers != group_size:
|
||||
logger.warning(
|
||||
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
|
||||
num_padding_layers,
|
||||
num_padding_layers / len(layers) * 100,
|
||||
)
|
||||
for i in range(0, len(layers), group_size):
|
||||
grouped_layers.append(layers[i:i + group_size])
|
||||
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layers)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout in the example will be:
|
||||
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.1: share another Tensor with size=available_memory//2
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
per_memory_pool_size = page_size * num_blocks
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(grouped_layers[j]):
|
||||
shared_by.append(grouped_layers[j][i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(
|
||||
grouped_layers) * vllm_config.cache_config.block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"""
|
||||
Only models with one type of KV cache are supported yet. This function tries
|
||||
to convert the KV cache specs to one type if the model is a hybrid model
|
||||
with multiple type of KV cache. It will convert all SlidingWindowSpec to
|
||||
FullAttentionSpec if both types are present.
|
||||
This function tries to convert the KV cache specs to one type if the model
|
||||
is a hybrid model with multiple type of KV cache. It will convert all
|
||||
SlidingWindowSpec to FullAttentionSpec if both types are present.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
type_ids = set(layer_spec.type_id
|
||||
for layer_spec in kv_cache_spec.values())
|
||||
return len(type_ids) > 1
|
||||
|
||||
if not is_hybrid(kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Hybrid KV cache manager is disabled for this hybrid model, "
|
||||
"This means we do not enable any optimizations for saving KV cache "
|
||||
"memory (e.g., dropping the KV cache outside the sliding window). "
|
||||
"The compute of layers like sliding window is still saved.")
|
||||
|
||||
has_full_attention = any(
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
|
||||
has_sliding_window = any(
|
||||
@@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
def get_kv_cache_config(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
|
||||
def get_kv_cache_config(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model
|
||||
TODO: support hybrid models with more than one type of KV cache.
|
||||
Generates the KV cache configuration for a model.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
@@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig,
|
||||
Returns:
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_config_uniform_page_size(vllm_config,
|
||||
kv_cache_spec,
|
||||
available_memory)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
@@ -377,7 +377,8 @@ class Scheduler(SchedulerInterface):
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
else:
|
||||
new_computed_blocks = KVCacheBlocks.create_empty()
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
@@ -1010,7 +1011,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
self.kv_cache_manager.single_type_manager.cache_blocks(
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request,
|
||||
self.kv_cache_manager.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens,
|
||||
|
||||
@@ -22,8 +22,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
use_eagle: bool,
|
||||
num_kv_cache_groups: int,
|
||||
kv_cache_group_id: int,
|
||||
caching_hash_fn: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -31,9 +30,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
use_eagle: Whether to use eagle.
|
||||
num_kv_cache_groups: The number of kv cache groups managed by this
|
||||
manager.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
caching_hash_fn: The caching hash function.
|
||||
"""
|
||||
|
||||
@@ -41,9 +38,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
@@ -56,8 +50,8 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.num_kv_cache_groups = num_kv_cache_groups
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
@@ -86,8 +80,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks)
|
||||
return ((num_new_blocks + num_evictable_computed_blocks) *
|
||||
self.num_kv_cache_groups)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
@@ -130,8 +123,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(
|
||||
num_new_blocks * self.num_kv_cache_groups)
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
@@ -156,12 +148,19 @@ class SingleTypeKVCacheManager(ABC):
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
@@ -188,12 +187,22 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. If no cache hit is found, return an empty list.
|
||||
`max_length`. The prefix should be a common prefix hit for all the
|
||||
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
|
||||
return an empty list.
|
||||
If eagle is enabled, drop the last matched block to force recompute the
|
||||
last block to get the required hidden states for eagle drafting head.
|
||||
Need to be customized for each attention type.
|
||||
@@ -201,12 +210,20 @@ class SingleTypeKVCacheManager(ABC):
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block.
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
for each kv cache group in `kv_cache_group_ids`.
|
||||
Return a list of length `len(kv_cache_group_ids)`, where the i-th
|
||||
element is a list of cached blocks for the i-th kv cache group
|
||||
in `kv_cache_group_ids`.
|
||||
For example, sliding window manager should return a list like
|
||||
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
|
||||
sliding window 8.
|
||||
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
|
||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@@ -215,11 +232,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks`. The removed
|
||||
blocks should be replaced by null_block. Return the removed blocks in
|
||||
eviction order, where the first returned block should be evicted first.
|
||||
Don't free the removed blocks in this function. Need to be customized
|
||||
for each attention type.
|
||||
Remove the blocks that are no longer needed from `blocks` and free the
|
||||
blocks. The removed blocks should be replaced by null_block.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
@@ -230,21 +245,36 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
computed_blocks: list[KVCacheBlock] = []
|
||||
max_num_blocks = max_length // self.block_size
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
||||
"FullAttentionManager can only be used for full attention groups")
|
||||
computed_blocks: list[list[KVCacheBlock]] = [
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
]
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
for i in range(max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids):
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].append(cached_block[j])
|
||||
else:
|
||||
break
|
||||
if self.use_eagle and len(computed_blocks) > 0:
|
||||
computed_blocks.pop()
|
||||
if use_eagle and len(computed_blocks[0]) > 0:
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
@@ -267,45 +297,58 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
|
||||
use_eagle: bool, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs)
|
||||
**kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups")
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
self.sliding_window_contiguous_blocks = cdiv(
|
||||
(kv_cache_spec.sliding_window - 1), self.block_size)
|
||||
if self.use_eagle:
|
||||
sliding_window_contiguous_blocks = cdiv(
|
||||
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
self.sliding_window_contiguous_blocks += 1
|
||||
self._null_block = block_pool.null_block
|
||||
sliding_window_contiguous_blocks += 1
|
||||
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
max_num_blocks = max_length // self.block_size
|
||||
computed_blocks = [self._null_block] * max_num_blocks
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
computed_blocks = [[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))]
|
||||
num_contiguous_blocks = 0
|
||||
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hashes[i]):
|
||||
computed_blocks[i] = cached_block
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids):
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j][i] = cached_block[j]
|
||||
num_contiguous_blocks += 1
|
||||
if (num_contiguous_blocks
|
||||
>= self.sliding_window_contiguous_blocks):
|
||||
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
del computed_blocks[i + num_contiguous_blocks:]
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
del computed_blocks[j][i + num_contiguous_blocks:]
|
||||
match_found = True
|
||||
break
|
||||
else:
|
||||
@@ -313,9 +356,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
if not match_found:
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
del computed_blocks[num_contiguous_blocks:]
|
||||
if self.use_eagle and len(computed_blocks) > 0:
|
||||
computed_blocks.pop()
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
del computed_blocks[j][num_contiguous_blocks:]
|
||||
if use_eagle and len(computed_blocks[0]) > 0:
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
|
||||
Reference in New Issue
Block a user