[v1] Hybrid Memory Allocator (#17996)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-06-06 11:47:09 +08:00
committed by GitHub
parent 3465b87ef8
commit f8a1a2d108
21 changed files with 1605 additions and 440 deletions

View File

@@ -7,8 +7,8 @@ from typing import Callable, Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue,
KVCacheBlock,
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
@@ -27,6 +27,7 @@ class BlockPool:
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
enable_kv_cache_events: Whether to enable kv cache events.
"""
def __init__(
@@ -56,7 +57,7 @@ class BlockPool:
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: dict[BlockHash, dict[
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
int, KVCacheBlock]] = defaultdict(dict)
# To represent a placeholder block with block_id=0.
@@ -68,22 +69,29 @@ class BlockPool:
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
def get_cached_block(self,
block_hash: BlockHash) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
def get_cached_block(
self, block_hash: BlockHash,
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
kv_cache_group_ids: The ids of the KV cache groups.
Returns:
The cached block if it exists, or None.
The cached blocks if exists, or None.
"""
cached_blocks = self.cached_block_hash_to_block.get(block_hash)
if not cached_blocks:
return None
first_block_id = next(iter(cached_blocks))
return cached_blocks[first_block_id]
cached_blocks = []
for group_id in kv_cache_group_ids:
cached_blocks_one_group = self.cached_block_hash_to_block.get(
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block_id = next(iter(cached_blocks_one_group))
cached_blocks.append(cached_blocks_one_group[first_block_id])
return cached_blocks
def cache_full_blocks(
self,
@@ -93,6 +101,7 @@ class BlockPool:
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
@@ -112,6 +121,7 @@ class BlockPool:
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
@@ -126,7 +136,7 @@ class BlockPool:
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
@@ -138,8 +148,9 @@ class BlockPool:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
@@ -166,8 +177,11 @@ class BlockPool:
block_hashes.append(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
block_hash_with_group_id = BlockHashWithGroupId(
block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
@@ -237,12 +251,16 @@ class BlockPool:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.hash_value]))
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
return True
return False
def touch(self, blocks: list[KVCacheBlock]) -> None:
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
@@ -250,12 +268,13 @@ class BlockPool:
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their

View File

@@ -0,0 +1,358 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, SingleTypeKVCacheManager,
get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request
class KVCacheCoordinator(ABC):
"""
Coordinate the KV cache of different KV cache groups.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
self.single_type_managers: list[SingleTypeKVCacheManager] = []
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
for i in range(len(self.kv_cache_config.kv_cache_groups)):
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
i].kv_cache_spec
self.single_type_managers.append(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id,
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[list[KVCacheBlock]]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
new_blocks = []
for manager in self.single_type_managers:
new_blocks.append(
manager.allocate_new_blocks(request_id, num_tokens))
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
for manager in self.single_type_managers:
manager.free(request_id)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id,
num_running_requests)
for manager in self.single_type_managers
]
return num_blocks_per_group
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
"""
Get the blocks for the request.
"""
return [
manager.req_to_blocks[request_id]
for manager in self.single_type_managers
]
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
pass
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for models with only one KV cache group. This is the
case for models with only one KV cache type, e.g., all attention layers use
full attention or all attention layers use sliding window attention.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id: Optional[str] = None
other_type_id: Optional[str] = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_type_id is None:
full_attention_type_id = g.kv_cache_spec.type_id
else:
assert full_attention_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now.")
self.full_attention_group_ids.append(i)
else:
if other_type_id is None:
other_type_id = g.kv_cache_spec.type_id
else:
assert other_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now.")
self.other_group_ids.append(i)
assert full_attention_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now.")
assert other_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now.")
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]].__class__
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
self.full_attention_group_ids[0]].kv_cache_spec
self.other_spec = self.kv_cache_config.kv_cache_groups[
self.other_group_ids[0]].kv_cache_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[list[list[KVCacheBlock]], int]:
"""
Find the longest cache hit for the request.
Args:
block_hashes: The block hashes of the request.
max_cache_hit_length: The maximum length of the cache hit.
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
hit_blocks_full_attn = (
self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
))
hit_length = len(
hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
hit_blocks_other_attn = (
self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiply of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiply of the block size of
# full attention layers in current implementation, because hit_length is
# a multiply of other attention's block size, and other attention's
# block size is a multiply of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for i in range(len(hit_blocks_full_attn)):
del hit_blocks_full_attn[i][hit_length //
self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention.
hit_blocks = hit_blocks_other_attn
for group_id, blocks in enumerate(hit_blocks_full_attn):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks.insert(group_id, blocks)
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
else:
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)

View File

@@ -8,11 +8,9 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens)
from vllm.v1.core.single_type_kv_cache_manager import (
get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@@ -22,16 +20,24 @@ logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
blocks: list[KVCacheBlock]
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: list[list[KVCacheBlock]]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
"""
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(self.blocks + other.blocks)
@classmethod
def create_empty(cls) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
return KVCacheBlocks(
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
def get_block_ids(self) -> list[list[int]]:
"""
@@ -39,15 +45,20 @@ class KVCacheBlocks:
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* the outer list corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
return [[block.block_id for block in self.blocks]]
block_ids = []
for group in self.blocks:
block_ids.append([blk.block_id for blk in group])
return block_ids
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
assert len(self.blocks) == 1, "Only one group is supported"
return [
block.block_id for block in self.blocks if block.block_hash is None
block.block_id for block in self.blocks[0]
if block.block_hash is None
]
@@ -63,12 +74,6 @@ class KVCacheManager:
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
assert len(kv_cache_config.kv_cache_groups) == 1, (
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group")
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = kv_cache_spec.block_size
self.num_gpu_blocks = kv_cache_config.num_blocks
self.max_model_len = max_model_len
self.enable_caching = enable_caching
@@ -77,17 +82,24 @@ class KVCacheManager:
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
num_kv_cache_groups=1,
enable_caching=enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
@@ -133,7 +145,7 @@ class KVCacheManager:
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
return KVCacheBlocks.create_empty(), 0
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
@@ -154,20 +166,16 @@ class KVCacheManager:
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes, max_cache_hit_length)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
max_cache_hit_length))
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
return KVCacheBlocks(computed_blocks), num_computed_tokens
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
def allocate_slots(
self,
@@ -220,7 +228,9 @@ class KVCacheManager:
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
new_computed_block_list = [
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
@@ -228,8 +238,8 @@ class KVCacheManager:
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.single_type_manager.remove_skipped_blocks(
request.request_id, request.num_computed_tokens)
self.coordinator.remove_skipped_blocks(request.request_id,
request.num_computed_tokens)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
@@ -238,12 +248,12 @@ class KVCacheManager:
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
))
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
@@ -253,16 +263,16 @@ class KVCacheManager:
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not new_computed_block_list, (
assert all(not blocks for blocks in new_computed_block_list), (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request.request_id, new_computed_block_list)
self.coordinator.save_new_computed_blocks(request.request_id,
new_computed_block_list)
new_blocks = self.single_type_manager.allocate_new_blocks(
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
# P/D: delay caching blocks if we have to recv from
@@ -273,7 +283,7 @@ class KVCacheManager:
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
self.single_type_manager.cache_blocks(
self.coordinator.cache_blocks(
request, self.req_to_block_hashes[request.request_id],
num_computed_tokens + num_new_tokens - num_draft_tokens)
@@ -287,7 +297,7 @@ class KVCacheManager:
Args:
request: The request to free the blocks.
"""
self.single_type_manager.free(request.request_id)
self.coordinator.free(request.request_id)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
@@ -345,10 +355,8 @@ class KVCacheManager:
group.
"""
assert request.status == RequestStatus.RUNNING
return [
self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
]
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
@@ -368,6 +376,15 @@ class KVCacheManager:
def get_block_ids(self, request_id: str) -> list[list[int]]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
).get_block_ids()
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
"""Cache the blocks for the request."""
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""
import os
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional
@@ -33,6 +34,18 @@ class BlockHash(NamedTuple):
extra_keys: Optional[Any] = None
class BlockHashWithGroupId(NamedTuple):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int
def get_hash_value(self) -> int:
return self.block_hash.hash_value
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
@@ -44,7 +57,7 @@ class BlockHash(NamedTuple):
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED'))
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
class PrefixCachingMetrics:
@@ -118,7 +131,7 @@ class KVCacheBlock:
ref_cnt: int = 0
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
_block_hash: Optional[BlockHash] = None
_block_hash: Optional[BlockHashWithGroupId] = None
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
@@ -135,11 +148,11 @@ class KVCacheBlock:
self.ref_cnt -= 1
@property
def block_hash(self) -> Optional[BlockHash]:
def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash
@block_hash.setter
def block_hash(self, block_hash: BlockHash):
def block_hash(self, block_hash: BlockHashWithGroupId):
assert self.block_hash is None, (
"The block already has a hash. This should not happen.")
self._block_hash = block_hash
@@ -151,10 +164,10 @@ class KVCacheBlock:
def __repr__(self) -> str:
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
# on KVCacheBlock object recursively.
prev_block_id = self.prev_free_block.block_id \
if self.prev_free_block else None
next_block_id = self.next_free_block.block_id \
if self.next_free_block else None
prev_block_id = (self.prev_free_block.block_id
if self.prev_free_block else None)
next_block_id = (self.next_free_block.block_id
if self.next_free_block else None)
return (f"KVCacheBlock(block_id={self.block_id}, "
f"ref_cnt={self.ref_cnt}, "
f"_block_hash={self._block_hash}, "
@@ -570,20 +583,20 @@ def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_specs = [
@@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config(
return max_concurrency
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
assert len(page_sizes) == 1
return page_sizes.pop()
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
@@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
The generated KVCacheConfig
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
assert len(page_sizes) == 1
page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors = [
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
tensors={
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
@@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config
def is_kv_cache_page_size_uniform(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
True if all layers have the same page size, False otherwise.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
return len(page_sizes) == 1
def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for hybrid models with multiple
attention types but still with a uniform page size (physical memory per
block per layer) for all layers.
Detailed explanation about kv cache management of hybrid models:
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
of which contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers. It is already handled by
`_get_kv_cache_config_uniform_type`.
2. A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 kv_cache_groups, each of which represents 10 layers.
To simplify the implementation, we make the following assumptions:
1. Physical memory per block: Must be the same across all KV cache groups.
Breaking this assumption is non-trivial due to memory fragmentation concerns
when allocating blocks of different sizes.
2. Tokens per block (block_size): Currently, we directly use
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
cache group, but within each KV cache group, all layers must share the same
block size.
3. Physical memory per token per layer: This property is decided by model
config. Currently we only support models that have the same physical memory
per token per layer for all layers. Can be relaxed with a simple extension,
but still need to keep physical memory per block the same for all groups.
4. Number of layers per group: Currently assumed the same for all layers.
Can be relaxed with a simple extension, but still need to keep physical
memory per block the same for all groups.
5. Attention type within groups: All layers in a group must share the same
attention type. One exception is that, when
`--disable-hybrid-kv-cache-manager` is true, the single group for full
attention layers may also include attention layers using sliding window or
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
6. Support for multiple attention types: The design for most components is
general to an arbitrary number of attention types. But
`find_longest_cache_hit` only supports one attention type or two
types of full-attention plus exactly one another type. The general
implementation of this function is feasible but we don't know how to
implement it cleanly yet.
As we assume tokens per block, physical memory per token per layer, and
number of layers per group are the same now, we can ensure that physical
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[str, list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec.type_id].append(layer_name)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
# split to 3 groups with 2 layers each:
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
# open-source hybrid model follows a n:1 pattern between different attention
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
# full), so we can use the "1" in the n:1 pattern as the group size, which
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
group_size = min([len(layers) for layers in same_type_layers.values()])
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
if num_padding_layers != group_size:
logger.warning(
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
num_padding_layers,
num_padding_layers / len(layers) * 100,
)
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i:i + group_size])
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
grouped_layers)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout in the example will be:
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(grouped_layers[j]):
shared_by.append(grouped_layers[j][i])
kv_cache_tensors.append(
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=kv_cache_groups,
)
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(
grouped_layers) * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
This function tries to convert the KV cache specs to one type if the model
is a hybrid model with multiple type of KV cache. It will convert all
SlidingWindowSpec to FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
type_ids = set(layer_spec.type_id
for layer_spec in kv_cache_spec.values())
return len(type_ids) > 1
if not is_hybrid(kv_cache_spec):
return
logger.warning(
"Hybrid KV cache manager is disabled for this hybrid model, "
"This means we do not enable any optimizations for saving KV cache "
"memory (e.g., dropping the KV cache outside the sliding window). "
"The compute of layers like sliding window is still saved.")
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
@@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
sliding_window=spec.sliding_window,
)
if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
def get_kv_cache_config(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Generates the KV cache configuration for a model.
Args:
vllm_config: The global VllmConfig
@@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig,
Returns:
The generated KVCacheConfigs
"""
unify_hybrid_kv_cache_specs(kv_cache_spec)
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
raise NotImplementedError

View File

@@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
@@ -377,7 +377,8 @@ class Scheduler(SchedulerInterface):
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = KVCacheBlocks.create_empty()
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
@@ -1010,7 +1011,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.single_type_manager.cache_blocks(
self.kv_cache_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,

View File

@@ -22,8 +22,7 @@ class SingleTypeKVCacheManager(ABC):
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
use_eagle: bool,
num_kv_cache_groups: int,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
@@ -31,9 +30,7 @@ class SingleTypeKVCacheManager(ABC):
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
use_eagle: Whether to use eagle.
num_kv_cache_groups: The number of kv cache groups managed by this
manager.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
@@ -41,9 +38,6 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
@@ -56,8 +50,8 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.num_kv_cache_groups = num_kv_cache_groups
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
@@ -86,8 +80,7 @@ class SingleTypeKVCacheManager(ABC):
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return ((num_new_blocks + num_evictable_computed_blocks) *
self.num_kv_cache_groups)
return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks(
self, request_id: str,
@@ -130,8 +123,7 @@ class SingleTypeKVCacheManager(ABC):
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(
num_new_blocks * self.num_kv_cache_groups)
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
return new_blocks
@@ -156,12 +148,19 @@ class SingleTypeKVCacheManager(ABC):
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
@@ -188,12 +187,22 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError
@classmethod
@abstractmethod
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
max_length: int) -> list[KVCacheBlock]:
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. If no cache hit is found, return an empty list.
`max_length`. The prefix should be a common prefix hit for all the
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
@@ -201,12 +210,20 @@ class SingleTypeKVCacheManager(ABC):
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
A list of cached blocks with skipped blocks replaced by null block
for each kv cache group in `kv_cache_group_ids`.
Return a list of length `len(kv_cache_group_ids)`, where the i-th
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
raise NotImplementedError
@@ -215,11 +232,9 @@ class SingleTypeKVCacheManager(ABC):
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function. Need to be customized
for each attention type.
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Args:
request_id: The request ID.
@@ -230,21 +245,36 @@ class SingleTypeKVCacheManager(ABC):
class FullAttentionManager(SingleTypeKVCacheManager):
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
max_num_blocks = max_length // self.block_size
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
computed_blocks: list[list[KVCacheBlock]] = [
[] for _ in range(len(kv_cache_group_ids))
]
max_num_blocks = max_length // kv_cache_spec.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].append(cached_block[j])
else:
break
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
@@ -267,45 +297,58 @@ class FullAttentionManager(SingleTypeKVCacheManager):
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
use_eagle: bool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs)
**kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
if self.use_eagle:
sliding_window_contiguous_blocks = cdiv(
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block
sliding_window_contiguous_blocks += 1
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
max_length: int) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // self.block_size
computed_blocks = [self._null_block] * max_num_blocks
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = [[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))]
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j][i] = cached_block[j]
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][i + num_contiguous_blocks:]
match_found = True
break
else:
@@ -313,9 +356,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][num_contiguous_blocks:]
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,