[Core][Hybrid allocator + connector] Support hybrid allocator + kv cache connector (#30166)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Co-authored-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Yifan Qiao
2025-12-26 18:25:46 -08:00
committed by GitHub
parent 5326c89803
commit 52bf066516
8 changed files with 362 additions and 111 deletions

View File

@@ -21,13 +21,23 @@ from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowS
pytestmark = pytest.mark.cpu_test
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0)
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
return SlidingWindowManager(
sliding_window_spec,
block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool):
def get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool, enable_caching=True
):
return ChunkedLocalAttentionManager(
chunked_local_attention_spec, block_pool, kv_cache_group_id=0
chunked_local_attention_spec,
block_pool,
enable_caching=enable_caching,
kv_cache_group_id=0,
)
@@ -332,13 +342,55 @@ def test_get_num_blocks_to_allocate():
]
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
== 15
)
def test_evictable_cached_blocks_not_double_allocated():
block_size = 2
sliding_window_length = 2 * block_size
sliding_window_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window_length,
)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
request_id = "req"
evictable_block = block_pool.blocks[1] # ref_cnt == 0, eviction candidate
num_blocks_to_allocate = manager.get_num_blocks_to_allocate(
request_id=request_id,
num_tokens=2 * block_size,
new_computed_blocks=[evictable_block],
total_computed_tokens=block_size,
)
# Free capacity check should count evictable cached blocks, but allocation
# should only allocate the truly new block.
assert num_blocks_to_allocate == 2
manager.allocate_new_computed_blocks(
request_id,
[evictable_block],
num_local_computed_tokens=block_size,
num_external_computed_tokens=0,
)
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4)
assert len(new_blocks) == 1
assert len(manager.req_to_blocks[request_id]) == 2
def test_chunked_local_attention_get_num_blocks_to_allocate():
block_size = 2
attention_spec = ChunkedLocalAttentionSpec(
@@ -359,8 +411,10 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
]
assert (
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
== 20
)
assert (
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
== 15
)

View File

@@ -254,6 +254,10 @@ class BlockPool:
[] if self.enable_kv_cache_events else None
)
for i, blk in enumerate(new_full_blocks):
# Some blocks may be null blocks when enabling sparse attention like
# sliding window attention. We skip null blocks here.
if blk.is_null:
continue
assert blk.block_hash is None
block_hash = new_block_hashes[i]
@@ -361,7 +365,7 @@ class BlockPool:
)
return True
def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
@@ -369,15 +373,14 @@ class BlockPool:
Args:
blocks: A list of blocks to touch.
"""
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their

View File

@@ -60,6 +60,7 @@ class KVCacheCoordinator(ABC):
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
enable_caching=enable_caching,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
@@ -73,6 +74,7 @@ class KVCacheCoordinator(ABC):
num_tokens: int,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
total_computed_tokens: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@@ -85,9 +87,10 @@ class KVCacheCoordinator(ABC):
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
total_computed_tokens: Include both local and external tokens.
Returns:
The number of blocks.
The number of blocks to allocate.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
@@ -95,30 +98,48 @@ class KVCacheCoordinator(ABC):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, []
request_id, num_encoder_tokens, [], 0
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i]
request_id,
num_tokens,
new_computed_blocks[i],
total_computed_tokens,
)
return num_blocks_to_allocate
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
def allocate_new_computed_blocks(
self,
request_id: str,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None:
"""
Add the new computed blocks to the request.
Add the new computed blocks to the request. Optionally allocate new
blocks for external computed tokens (if any).
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
num_local_computed_tokens: The number of local computed tokens.
num_external_computed_tokens: The number of external computed tokens.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
manager.allocate_new_computed_blocks(
request_id,
new_computed_blocks[i],
num_local_computed_tokens,
num_external_computed_tokens,
)
def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
self,
request_id: str,
num_tokens: int,
num_encoder_tokens: int = 0,
) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
@@ -184,17 +205,20 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers
]
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
manager.remove_skipped_blocks(request_id, total_computed_tokens)
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""

View File

@@ -210,6 +210,7 @@ class KVCacheManager:
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_lookahead_tokens: int = 0,
num_external_computed_tokens: int = 0,
delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0,
) -> KVCacheBlocks | None:
@@ -217,16 +218,16 @@ class KVCacheManager:
Args:
request: The request to allocate slots.
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
num_new_tokens: The number of new tokens to be allocated and computed.
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
tokens, grouped as a tuple by kv cache groups.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
num_external_computed_tokens: The number of tokens that their
KV caches are not cached by vLLM but cached by the connector.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
@@ -236,29 +237,81 @@ class KVCacheManager:
Blocks layout:
```
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
----------------------------------------------------------------------
| < comp > | < new_comp > | < ext_comp > | < new > | < lookahead > |
----------------------------------------------------------------------
| < to be computed > |
----------------------------------------------------------------------
| < to be allocated > |
----------------------------------------------------------------------
| < to be cached (roughly, |
| details below)> |
----------------------------------------------------------------------
| Prefix-cached tokens from either vLLM |
| or connector. Can be safely removed if |
| they are outside sliding window. |
----------------------------------------------------------------------
| < cached by vLLM > | not cached by |
| vLLM, but |
| ref_cnt | ref_cnt not | cached by |
| increased| increased yet| connector |
----------------------------------------------------------------------
```
The following *_blocks are illustrated in this layout.
Abbrivations:
```
comp = request.num_computed_tokens
new_comp = num_new_computed_tokens
= len(new_computed_blocks) * block_size
ext_comp = num_external_computed_tokens, cached by the connector
new = num_new_tokens, including unverified draft tokens
lookahead = num_lookahead_tokens
```
NOTE: for new tokens which include both verified and unverified draft
tokens, we only cache the verified tokens (by capping the number at
`request.num_tokens`).
The allocation has three stages:
- Free unnecessary blocks in `comp` and check
if we have sufficient free blocks (return None if not).
- Handle prefix tokens (`comp + new_comp + ext_comp`):
- Free unnecessary blocks (e.g. outside sliding window)
- Allocate new blocks for `ext_comp` tokens inside
sliding window
- Allocate new blocks for tokens to be computed (`new + lookahead`)
Returns:
A list of new allocated blocks.
"""
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
# When loading KV data asynchronously, we may have zero new tokens to
# compute while still allocating slots for externally computed tokens.
if num_new_tokens == 0 and num_external_computed_tokens == 0:
raise ValueError(
"num_new_tokens must be greater than 0 when there are no "
"external computed tokens"
)
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = self.empty_kv_cache_blocks.blocks
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_local_computed_tokens = (
request.num_computed_tokens + num_new_computed_tokens
)
total_computed_tokens = min(
num_local_computed_tokens + num_external_computed_tokens,
self.max_model_len,
)
num_tokens_need_slot = min(
total_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len,
)
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
@@ -266,15 +319,7 @@ class KVCacheManager:
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.coordinator.remove_skipped_blocks(
request.request_id, request.num_computed_tokens
)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len,
request.request_id, total_computed_tokens
)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
@@ -282,25 +327,25 @@ class KVCacheManager:
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
total_computed_tokens=num_local_computed_tokens
+ num_external_computed_tokens,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not any(new_computed_block_list), (
"Computed blocks should be empty when prefix caching is disabled"
)
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
if (
new_computed_block_list is not self.empty_kv_cache_blocks.blocks
or num_external_computed_tokens > 0
):
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.coordinator.save_new_computed_blocks(
request.request_id, new_computed_block_list
self.coordinator.allocate_new_computed_blocks(
request_id=request.request_id,
new_computed_blocks=new_computed_block_list,
num_local_computed_tokens=num_local_computed_tokens,
num_external_computed_tokens=num_external_computed_tokens,
)
new_blocks = self.coordinator.allocate_new_blocks(
@@ -312,12 +357,14 @@ class KVCacheManager:
if not self.enable_caching or delay_cache_blocks:
return self.create_kv_cache_blocks(new_blocks)
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
# draft tokens that could be rejected). Therefore, we cap the number
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
# NOTE(woosuk): We want to commit (cache) up to num_local_computed_tokens
# + num_external_computed_tokens + num_new_tokens, but must exclude
# "non-committable" tokens (e.g., draft tokens that could be rejected).
# Therefore, we cap the number at `request.num_tokens`, ensuring only
# "finalized" tokens are cached.
num_tokens_to_cache = min(
num_computed_tokens + num_new_tokens, request.num_tokens
total_computed_tokens + num_new_tokens,
request.num_tokens,
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
@@ -333,6 +380,19 @@ class KVCacheManager:
"""
self.coordinator.free(request.request_id)
def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
"""Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
"""
self.coordinator.remove_skipped_blocks(request_id, total_computed_tokens)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
@@ -408,7 +468,13 @@ class KVCacheManager:
return self.get_blocks(request_id).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
"""Cache the blocks for the request, if enabled.
Args:
request: The request to cache the blocks.
num_computed_tokens: The number of computed tokens, including tokens
that are already cached and tokens to be cached.
"""
if self.enable_caching:
self.coordinator.cache_blocks(request, num_computed_tokens)

View File

@@ -587,10 +587,11 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_new_tokens,
num_new_computed_tokens=num_new_local_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
num_external_computed_tokens=num_external_computed_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
@@ -606,7 +607,7 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
self.kv_cache_manager.get_blocks(request.request_id),
num_external_computed_tokens,
)
@@ -1580,6 +1581,13 @@ class Scheduler(SchedulerInterface):
if self.connector is None:
return False, None
# Free any out-of-window prefix blocks before we hand the block table to
# the connector.
self.kv_cache_manager.remove_skipped_blocks(
request_id=request.request_id,
total_computed_tokens=request.num_tokens,
)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if not isinstance(self.connector, SupportsHMA):

View File

@@ -30,6 +30,7 @@ class SingleTypeKVCacheManager(ABC):
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
enable_caching: bool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
@@ -48,6 +49,7 @@ class SingleTypeKVCacheManager(ABC):
self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
self.enable_caching = enable_caching
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
@@ -68,6 +70,7 @@ class SingleTypeKVCacheManager(ABC):
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@@ -78,46 +81,121 @@ class SingleTypeKVCacheManager(ABC):
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
total_computed_tokens: Include both local and external computed
tokens.
Returns:
The number of blocks.
The number of blocks to allocate.
"""
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (
num_required_blocks
- len(new_computed_blocks)
- len(self.req_to_blocks[request_id])
)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
)
return num_new_blocks + num_evictable_computed_blocks
num_req_blocks = len(self.req_to_blocks.get(request_id, ()))
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
if request_id in self.num_cached_block:
# Fast-path: a running request won't have any new prefix-cache hits.
assert len(new_computed_blocks) == 0
# NOTE: With speculative decoding, request's blocks may be allocated
# for draft tokens which are later rejected. In this case,
# num_required_blocks may be smaller than num_req_blocks.
return max(num_required_blocks - num_req_blocks, 0)
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
num_local_computed_blocks = len(new_computed_blocks) + num_req_blocks
# Number of whole blocks that are skipped by the attention window.
# If nothing is skipped, this is 0.
num_skipped_blocks = num_skipped_tokens // self.block_size
# We need blocks for the non-skipped suffix. If there are still
# local-computed blocks inside the window, they contribute to the
# required capacity; otherwise, skipped blocks dominate.
num_new_blocks = max(
num_required_blocks - max(num_skipped_blocks, num_local_computed_blocks),
0,
)
# Among the `new_computed_blocks`, the first `num_skipped_blocks` worth
# of blocks are skipped; `num_req_blocks` of those may already be in
# `req_to_blocks`, so only skip the remainder from `new_computed_blocks`.
num_skipped_new_computed_blocks = max(0, num_skipped_blocks - num_req_blocks)
# If a computed block is an eviction candidate (in the free queue and
# ref_cnt == 0), it will be removed from the free queue when touched by
# the allocated request, so we must count it in the free-capacity check.
num_evictable_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks[num_skipped_new_computed_blocks:]
)
return num_new_blocks + num_evictable_blocks
def allocate_new_computed_blocks(
self,
request_id: str,
new_computed_blocks: Sequence[KVCacheBlock],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None:
"""
Add the new computed blocks to the request.
Add the new computed blocks to the request. This involves three steps:
1. Touch the computed blocks to make sure they won't be evicted.
1.5. (Optional) For sliding window, skip blocks are padded with null blocks.
2. Add the remaining computed blocks.
3. (Optional) For KV connectors, allocate new blocks for external computed
tokens (if any).
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
num_local_computed_tokens: The number of local computed tokens.
num_external_computed_tokens: The number of external computed tokens.
"""
if request_id not in self.num_cached_block:
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
if request_id in self.num_cached_block:
# Fast-path: a running request won't have any new prefix-cache hits.
# It should not have any new computed blocks.
assert len(new_computed_blocks) == 0
return
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
num_total_computed_tokens = (
num_local_computed_tokens + num_external_computed_tokens
)
num_skipped_tokens = self.get_num_skipped_tokens(num_total_computed_tokens)
num_skipped_blocks = num_skipped_tokens // self.block_size
if num_skipped_blocks > 0:
# It is possible that all new computed blocks are skipped when
# num_skipped_blocks > len(new_computed_blocks).
new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
# Some external computed tokens may be skipped too.
num_external_computed_tokens = min(
num_total_computed_tokens - num_skipped_tokens,
num_external_computed_tokens,
)
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_blocks)
else:
assert not any(new_computed_blocks), (
"Computed blocks should be empty when prefix caching is disabled"
)
# Skip blocks are padded with null blocks.
req_blocks.extend([self._null_block] * num_skipped_blocks)
# Add the remaining computed blocks.
req_blocks.extend(new_computed_blocks)
# All cached hits (including skipped nulls) are already cached; mark
# them so cache_blocks() will not try to re-cache blocks that already
# have a block_hash set.
self.num_cached_block[request_id] = len(req_blocks)
if num_external_computed_tokens > 0:
# Allocate new blocks for external computed tokens.
allocated_blocks = self.block_pool.get_new_blocks(
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
def allocate_new_blocks(
self, request_id: str, num_tokens: int
@@ -252,7 +330,9 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
def remove_skipped_blocks(
self, request_id: str, total_computed_tokens: int
) -> None:
"""
Remove and free the blocks that are no longer needed for attention computation.
The removed blocks should be replaced by null_block.
@@ -262,18 +342,24 @@ class SingleTypeKVCacheManager(ABC):
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
total_computed_tokens: The total number of computed tokens, including
local computed tokens and external computed tokens.
"""
# Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token
# before the request is finished.
return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id]
num_skipped_blocks = num_skipped_tokens // self.block_size
# `num_skipped_tokens` may include tokens that haven't been allocated yet
# (e.g., when the attention window moves into the external computed tokens
# range), so we must cap to the number of blocks that currently exist for
# this request.
num_skipped_blocks = min(num_skipped_blocks, len(blocks))
removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1.
@@ -486,7 +572,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
Returns:
The number of tokens that will be skipped for attention computation.
"""
return num_computed_tokens - self.sliding_window + 1
return max(0, num_computed_tokens - self.sliding_window + 1)
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
@@ -711,6 +797,7 @@ class MambaManager(SingleTypeKVCacheManager):
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
total_computed_tokens: int,
) -> int:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
@@ -721,7 +808,7 @@ class MambaManager(SingleTypeKVCacheManager):
* self.kv_cache_spec.num_speculative_blocks
)
return super().get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks
request_id, num_tokens, new_computed_blocks, total_computed_tokens
)
def allocate_new_blocks(
@@ -749,8 +836,12 @@ class MambaManager(SingleTypeKVCacheManager):
class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models."""
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
def allocate_new_computed_blocks(
self,
request_id: str,
new_computed_blocks: Sequence[KVCacheBlock],
num_local_computed_tokens: int,
num_external_computed_tokens: int,
) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.

View File

@@ -624,7 +624,7 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, (ModelRunnerOutput, NoneType)):
if isinstance(output, ModelRunnerOutput | NoneType):
return output
assert isinstance(output, IntermediateTensors)

View File

@@ -304,6 +304,13 @@ class TPUWorker:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None:
@@ -336,8 +343,6 @@ class TPUWorker:
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
)
ensure_kv_transfer_initialized(vllm_config)
def shutdown(self) -> None:
self.model_runner.ensure_kv_transfer_shutdown()