[Core][Hybrid allocator + connector] Support hybrid allocator + kv cache connector (#30166)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
@@ -21,13 +21,23 @@ from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, SlidingWindowS
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
return SlidingWindowManager(sliding_window_spec, block_pool, kv_cache_group_id=0)
|
||||
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
|
||||
return SlidingWindowManager(
|
||||
sliding_window_spec,
|
||||
block_pool,
|
||||
enable_caching=enable_caching,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
|
||||
def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool):
|
||||
def get_chunked_local_attention_manager(
|
||||
chunked_local_attention_spec, block_pool, enable_caching=True
|
||||
):
|
||||
return ChunkedLocalAttentionManager(
|
||||
chunked_local_attention_spec, block_pool, kv_cache_group_id=0
|
||||
chunked_local_attention_spec,
|
||||
block_pool,
|
||||
enable_caching=enable_caching,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -332,13 +342,55 @@ def test_get_num_blocks_to_allocate():
|
||||
]
|
||||
|
||||
assert (
|
||||
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
|
||||
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
|
||||
== 20
|
||||
)
|
||||
assert (
|
||||
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
|
||||
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
|
||||
== 15
|
||||
)
|
||||
|
||||
|
||||
def test_evictable_cached_blocks_not_double_allocated():
|
||||
block_size = 2
|
||||
sliding_window_length = 2 * block_size
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=sliding_window_length,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
request_id = "req"
|
||||
evictable_block = block_pool.blocks[1] # ref_cnt == 0, eviction candidate
|
||||
|
||||
num_blocks_to_allocate = manager.get_num_blocks_to_allocate(
|
||||
request_id=request_id,
|
||||
num_tokens=2 * block_size,
|
||||
new_computed_blocks=[evictable_block],
|
||||
total_computed_tokens=block_size,
|
||||
)
|
||||
# Free capacity check should count evictable cached blocks, but allocation
|
||||
# should only allocate the truly new block.
|
||||
assert num_blocks_to_allocate == 2
|
||||
|
||||
manager.allocate_new_computed_blocks(
|
||||
request_id,
|
||||
[evictable_block],
|
||||
num_local_computed_tokens=block_size,
|
||||
num_external_computed_tokens=0,
|
||||
)
|
||||
new_blocks = manager.allocate_new_blocks(request_id, num_tokens=4)
|
||||
assert len(new_blocks) == 1
|
||||
assert len(manager.req_to_blocks[request_id]) == 2
|
||||
|
||||
|
||||
def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
block_size = 2
|
||||
attention_spec = ChunkedLocalAttentionSpec(
|
||||
@@ -359,8 +411,10 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
]
|
||||
|
||||
assert (
|
||||
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1) == 20
|
||||
manager.get_num_blocks_to_allocate("1", 20 * block_size, cached_blocks_1, 0)
|
||||
== 20
|
||||
)
|
||||
assert (
|
||||
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15
|
||||
manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2, 0)
|
||||
== 15
|
||||
)
|
||||
|
||||
@@ -254,6 +254,10 @@ class BlockPool:
|
||||
[] if self.enable_kv_cache_events else None
|
||||
)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
# Some blocks may be null blocks when enabling sparse attention like
|
||||
# sliding window attention. We skip null blocks here.
|
||||
if blk.is_null:
|
||||
continue
|
||||
assert blk.block_hash is None
|
||||
block_hash = new_block_hashes[i]
|
||||
|
||||
@@ -361,7 +365,7 @@ class BlockPool:
|
||||
)
|
||||
return True
|
||||
|
||||
def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
|
||||
def touch(self, blocks: Sequence[KVCacheBlock]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
@@ -369,15 +373,14 @@ class BlockPool:
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for blocks_per_group in blocks:
|
||||
for block in blocks_per_group:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.ref_cnt += 1
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.on_block_accessed(block)
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.ref_cnt += 1
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.on_block_accessed(block)
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
|
||||
@@ -60,6 +60,7 @@ class KVCacheCoordinator(ABC):
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
enable_caching=enable_caching,
|
||||
kv_cache_group_id=i,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
@@ -73,6 +74,7 @@ class KVCacheCoordinator(ABC):
|
||||
num_tokens: int,
|
||||
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
|
||||
num_encoder_tokens: int,
|
||||
total_computed_tokens: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
@@ -85,9 +87,10 @@ class KVCacheCoordinator(ABC):
|
||||
prefix caching.
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
total_computed_tokens: Include both local and external tokens.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
The number of blocks to allocate.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
@@ -95,30 +98,48 @@ class KVCacheCoordinator(ABC):
|
||||
# For cross-attention, we issue a single static allocation
|
||||
# of blocks based on the number of encoder input tokens.
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_encoder_tokens, []
|
||||
request_id, num_encoder_tokens, [], 0
|
||||
)
|
||||
else:
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i]
|
||||
request_id,
|
||||
num_tokens,
|
||||
new_computed_blocks[i],
|
||||
total_computed_tokens,
|
||||
)
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
|
||||
def allocate_new_computed_blocks(
|
||||
self,
|
||||
request_id: str,
|
||||
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
|
||||
num_local_computed_tokens: int,
|
||||
num_external_computed_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
Add the new computed blocks to the request. Optionally allocate new
|
||||
blocks for external computed tokens (if any).
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
num_local_computed_tokens: The number of local computed tokens.
|
||||
num_external_computed_tokens: The number of external computed tokens.
|
||||
"""
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
|
||||
manager.allocate_new_computed_blocks(
|
||||
request_id,
|
||||
new_computed_blocks[i],
|
||||
num_local_computed_tokens,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
num_encoder_tokens: int = 0,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
@@ -184,17 +205,20 @@ class KVCacheCoordinator(ABC):
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
def remove_skipped_blocks(
|
||||
self, request_id: str, total_computed_tokens: int
|
||||
) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
total_computed_tokens: The total number of computed tokens, including
|
||||
local computed tokens and external computed tokens.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
manager.remove_skipped_blocks(request_id, total_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
|
||||
@@ -210,6 +210,7 @@ class KVCacheManager:
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: KVCacheBlocks | None = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
num_external_computed_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
num_encoder_tokens: int = 0,
|
||||
) -> KVCacheBlocks | None:
|
||||
@@ -217,16 +218,16 @@ class KVCacheManager:
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
num_new_tokens: The number of new tokens to be allocated and computed.
|
||||
num_new_computed_tokens: The number of new computed tokens just
|
||||
hitting the prefix caching, excluding external tokens.
|
||||
new_computed_blocks: The cached blocks for the above new computed
|
||||
tokens.
|
||||
tokens, grouped as a tuple by kv cache groups.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
num_external_computed_tokens: The number of tokens that their
|
||||
KV caches are not cached by vLLM but cached by the connector.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
@@ -236,29 +237,81 @@ class KVCacheManager:
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
-----------------------------------------------------------------------
|
||||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||
-----------------------------------------------------------------------
|
||||
| < required > |
|
||||
--------------------------------------------------
|
||||
| < full > |
|
||||
------------------------------------------------
|
||||
| <new full> |
|
||||
--------------
|
||||
----------------------------------------------------------------------
|
||||
| < comp > | < new_comp > | < ext_comp > | < new > | < lookahead > |
|
||||
----------------------------------------------------------------------
|
||||
| < to be computed > |
|
||||
----------------------------------------------------------------------
|
||||
| < to be allocated > |
|
||||
----------------------------------------------------------------------
|
||||
| < to be cached (roughly, |
|
||||
| details below)> |
|
||||
----------------------------------------------------------------------
|
||||
| Prefix-cached tokens from either vLLM |
|
||||
| or connector. Can be safely removed if |
|
||||
| they are outside sliding window. |
|
||||
----------------------------------------------------------------------
|
||||
| < cached by vLLM > | not cached by |
|
||||
| vLLM, but |
|
||||
| ref_cnt | ref_cnt not | cached by |
|
||||
| increased| increased yet| connector |
|
||||
----------------------------------------------------------------------
|
||||
```
|
||||
The following *_blocks are illustrated in this layout.
|
||||
|
||||
Abbrivations:
|
||||
|
||||
```
|
||||
comp = request.num_computed_tokens
|
||||
new_comp = num_new_computed_tokens
|
||||
= len(new_computed_blocks) * block_size
|
||||
ext_comp = num_external_computed_tokens, cached by the connector
|
||||
new = num_new_tokens, including unverified draft tokens
|
||||
lookahead = num_lookahead_tokens
|
||||
```
|
||||
|
||||
NOTE: for new tokens which include both verified and unverified draft
|
||||
tokens, we only cache the verified tokens (by capping the number at
|
||||
`request.num_tokens`).
|
||||
|
||||
The allocation has three stages:
|
||||
- Free unnecessary blocks in `comp` and check
|
||||
if we have sufficient free blocks (return None if not).
|
||||
- Handle prefix tokens (`comp + new_comp + ext_comp`):
|
||||
- Free unnecessary blocks (e.g. outside sliding window)
|
||||
- Allocate new blocks for `ext_comp` tokens inside
|
||||
sliding window
|
||||
- Allocate new blocks for tokens to be computed (`new + lookahead`)
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
# When loading KV data asynchronously, we may have zero new tokens to
|
||||
# compute while still allocating slots for externally computed tokens.
|
||||
if num_new_tokens == 0 and num_external_computed_tokens == 0:
|
||||
raise ValueError(
|
||||
"num_new_tokens must be greater than 0 when there are no "
|
||||
"external computed tokens"
|
||||
)
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = self.empty_kv_cache_blocks.blocks
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_local_computed_tokens = (
|
||||
request.num_computed_tokens + num_new_computed_tokens
|
||||
)
|
||||
total_computed_tokens = min(
|
||||
num_local_computed_tokens + num_external_computed_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
num_tokens_need_slot = min(
|
||||
total_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
@@ -266,15 +319,7 @@ class KVCacheManager:
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.coordinator.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens
|
||||
)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
request.request_id, total_computed_tokens
|
||||
)
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
@@ -282,25 +327,25 @@ class KVCacheManager:
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
total_computed_tokens=num_local_computed_tokens
|
||||
+ num_external_computed_tokens,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when prefix caching is disabled"
|
||||
)
|
||||
|
||||
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
|
||||
if (
|
||||
new_computed_block_list is not self.empty_kv_cache_blocks.blocks
|
||||
or num_external_computed_tokens > 0
|
||||
):
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list
|
||||
self.coordinator.allocate_new_computed_blocks(
|
||||
request_id=request.request_id,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_local_computed_tokens=num_local_computed_tokens,
|
||||
num_external_computed_tokens=num_external_computed_tokens,
|
||||
)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
@@ -312,12 +357,14 @@ class KVCacheManager:
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
|
||||
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
|
||||
# draft tokens that could be rejected). Therefore, we cap the number
|
||||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||||
# NOTE(woosuk): We want to commit (cache) up to num_local_computed_tokens
|
||||
# + num_external_computed_tokens + num_new_tokens, but must exclude
|
||||
# "non-committable" tokens (e.g., draft tokens that could be rejected).
|
||||
# Therefore, we cap the number at `request.num_tokens`, ensuring only
|
||||
# "finalized" tokens are cached.
|
||||
num_tokens_to_cache = min(
|
||||
num_computed_tokens + num_new_tokens, request.num_tokens
|
||||
total_computed_tokens + num_new_tokens,
|
||||
request.num_tokens,
|
||||
)
|
||||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||||
|
||||
@@ -333,6 +380,19 @@ class KVCacheManager:
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def remove_skipped_blocks(
|
||||
self, request_id: str, total_computed_tokens: int
|
||||
) -> None:
|
||||
"""Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
total_computed_tokens: The total number of computed tokens, including
|
||||
local computed tokens and external computed tokens.
|
||||
"""
|
||||
self.coordinator.remove_skipped_blocks(request_id, total_computed_tokens)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
@@ -408,7 +468,13 @@ class KVCacheManager:
|
||||
return self.get_blocks(request_id).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
"""Cache the blocks for the request, if enabled.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
num_computed_tokens: The number of computed tokens, including tokens
|
||||
that are already cached and tokens to be cached.
|
||||
"""
|
||||
if self.enable_caching:
|
||||
self.coordinator.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
|
||||
@@ -587,10 +587,11 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens + num_external_computed_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
new_computed_blocks,
|
||||
num_new_tokens,
|
||||
num_new_computed_tokens=num_new_local_computed_tokens,
|
||||
new_computed_blocks=new_computed_blocks,
|
||||
num_lookahead_tokens=effective_lookahead_tokens,
|
||||
num_external_computed_tokens=num_external_computed_tokens,
|
||||
delay_cache_blocks=load_kv_async,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
@@ -606,7 +607,7 @@ class Scheduler(SchedulerInterface):
|
||||
if self.connector is not None:
|
||||
self.connector.update_state_after_alloc(
|
||||
request,
|
||||
new_computed_blocks + new_blocks,
|
||||
self.kv_cache_manager.get_blocks(request.request_id),
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
@@ -1580,6 +1581,13 @@ class Scheduler(SchedulerInterface):
|
||||
if self.connector is None:
|
||||
return False, None
|
||||
|
||||
# Free any out-of-window prefix blocks before we hand the block table to
|
||||
# the connector.
|
||||
self.kv_cache_manager.remove_skipped_blocks(
|
||||
request_id=request.request_id,
|
||||
total_computed_tokens=request.num_tokens,
|
||||
)
|
||||
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
|
||||
if not isinstance(self.connector, SupportsHMA):
|
||||
|
||||
@@ -30,6 +30,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
enable_caching: bool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
@@ -48,6 +49,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
@@ -68,6 +70,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
total_computed_tokens: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
@@ -78,46 +81,121 @@ class SingleTypeKVCacheManager(ABC):
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
total_computed_tokens: Include both local and external computed
|
||||
tokens.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
The number of blocks to allocate.
|
||||
"""
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (
|
||||
num_required_blocks
|
||||
- len(new_computed_blocks)
|
||||
- len(self.req_to_blocks[request_id])
|
||||
)
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
|
||||
)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
num_req_blocks = len(self.req_to_blocks.get(request_id, ()))
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
if request_id in self.num_cached_block:
|
||||
# Fast-path: a running request won't have any new prefix-cache hits.
|
||||
assert len(new_computed_blocks) == 0
|
||||
# NOTE: With speculative decoding, request's blocks may be allocated
|
||||
# for draft tokens which are later rejected. In this case,
|
||||
# num_required_blocks may be smaller than num_req_blocks.
|
||||
return max(num_required_blocks - num_req_blocks, 0)
|
||||
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
|
||||
num_local_computed_blocks = len(new_computed_blocks) + num_req_blocks
|
||||
# Number of whole blocks that are skipped by the attention window.
|
||||
# If nothing is skipped, this is 0.
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
# We need blocks for the non-skipped suffix. If there are still
|
||||
# local-computed blocks inside the window, they contribute to the
|
||||
# required capacity; otherwise, skipped blocks dominate.
|
||||
num_new_blocks = max(
|
||||
num_required_blocks - max(num_skipped_blocks, num_local_computed_blocks),
|
||||
0,
|
||||
)
|
||||
|
||||
# Among the `new_computed_blocks`, the first `num_skipped_blocks` worth
|
||||
# of blocks are skipped; `num_req_blocks` of those may already be in
|
||||
# `req_to_blocks`, so only skip the remainder from `new_computed_blocks`.
|
||||
num_skipped_new_computed_blocks = max(0, num_skipped_blocks - num_req_blocks)
|
||||
|
||||
# If a computed block is an eviction candidate (in the free queue and
|
||||
# ref_cnt == 0), it will be removed from the free queue when touched by
|
||||
# the allocated request, so we must count it in the free-capacity check.
|
||||
num_evictable_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks[num_skipped_new_computed_blocks:]
|
||||
)
|
||||
return num_new_blocks + num_evictable_blocks
|
||||
|
||||
def allocate_new_computed_blocks(
|
||||
self,
|
||||
request_id: str,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
num_local_computed_tokens: int,
|
||||
num_external_computed_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
Add the new computed blocks to the request. This involves three steps:
|
||||
1. Touch the computed blocks to make sure they won't be evicted.
|
||||
1.5. (Optional) For sliding window, skip blocks are padded with null blocks.
|
||||
2. Add the remaining computed blocks.
|
||||
3. (Optional) For KV connectors, allocate new blocks for external computed
|
||||
tokens (if any).
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
num_local_computed_tokens: The number of local computed tokens.
|
||||
num_external_computed_tokens: The number of external computed tokens.
|
||||
"""
|
||||
if request_id not in self.num_cached_block:
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||
else:
|
||||
# A running request. Should not have new computed blocks.
|
||||
|
||||
if request_id in self.num_cached_block:
|
||||
# Fast-path: a running request won't have any new prefix-cache hits.
|
||||
# It should not have any new computed blocks.
|
||||
assert len(new_computed_blocks) == 0
|
||||
return
|
||||
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
num_total_computed_tokens = (
|
||||
num_local_computed_tokens + num_external_computed_tokens
|
||||
)
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(num_total_computed_tokens)
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
if num_skipped_blocks > 0:
|
||||
# It is possible that all new computed blocks are skipped when
|
||||
# num_skipped_blocks > len(new_computed_blocks).
|
||||
new_computed_blocks = new_computed_blocks[num_skipped_blocks:]
|
||||
# Some external computed tokens may be skipped too.
|
||||
num_external_computed_tokens = min(
|
||||
num_total_computed_tokens - num_skipped_tokens,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_blocks)
|
||||
else:
|
||||
assert not any(new_computed_blocks), (
|
||||
"Computed blocks should be empty when prefix caching is disabled"
|
||||
)
|
||||
|
||||
# Skip blocks are padded with null blocks.
|
||||
req_blocks.extend([self._null_block] * num_skipped_blocks)
|
||||
# Add the remaining computed blocks.
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
# All cached hits (including skipped nulls) are already cached; mark
|
||||
# them so cache_blocks() will not try to re-cache blocks that already
|
||||
# have a block_hash set.
|
||||
self.num_cached_block[request_id] = len(req_blocks)
|
||||
|
||||
if num_external_computed_tokens > 0:
|
||||
# Allocate new blocks for external computed tokens.
|
||||
allocated_blocks = self.block_pool.get_new_blocks(
|
||||
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
|
||||
)
|
||||
req_blocks.extend(allocated_blocks)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
@@ -252,7 +330,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
def remove_skipped_blocks(
|
||||
self, request_id: str, total_computed_tokens: int
|
||||
) -> None:
|
||||
"""
|
||||
Remove and free the blocks that are no longer needed for attention computation.
|
||||
The removed blocks should be replaced by null_block.
|
||||
@@ -262,18 +342,24 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
total_computed_tokens: The total number of computed tokens, including
|
||||
local computed tokens and external computed tokens.
|
||||
"""
|
||||
# Remove the blocks that will be skipped during attention computation.
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(total_computed_tokens)
|
||||
if num_skipped_tokens <= 0:
|
||||
# This indicates that ALL tokens are inside attention window.
|
||||
# Thus we do not need to free any blocks outside attention window.
|
||||
# A typical case is full attention that we never free any token
|
||||
# before the request is finished.
|
||||
return
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
# `num_skipped_tokens` may include tokens that haven't been allocated yet
|
||||
# (e.g., when the attention window moves into the external computed tokens
|
||||
# range), so we must cap to the number of blocks that currently exist for
|
||||
# this request.
|
||||
num_skipped_blocks = min(num_skipped_blocks, len(blocks))
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# Because the block starts from index 0, the num_skipped_block-th block
|
||||
# corresponds to index num_skipped_blocks - 1.
|
||||
@@ -486,7 +572,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
return num_computed_tokens - self.sliding_window + 1
|
||||
return max(0, num_computed_tokens - self.sliding_window + 1)
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
@@ -711,6 +797,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
total_computed_tokens: int,
|
||||
) -> int:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
@@ -721,7 +808,7 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
)
|
||||
return super().get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks
|
||||
request_id, num_tokens, new_computed_blocks, total_computed_tokens
|
||||
)
|
||||
|
||||
def allocate_new_blocks(
|
||||
@@ -749,8 +836,12 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
"""Manager for cross-attention KV cache in encoder-decoder models."""
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
def allocate_new_computed_blocks(
|
||||
self,
|
||||
request_id: str,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
num_local_computed_tokens: int,
|
||||
num_external_computed_tokens: int,
|
||||
) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so `new_computed_blocks` should always be empty.
|
||||
|
||||
@@ -624,7 +624,7 @@ class Worker(WorkerBase):
|
||||
output = self.model_runner.execute_model(
|
||||
scheduler_output, intermediate_tensors
|
||||
)
|
||||
if isinstance(output, (ModelRunnerOutput, NoneType)):
|
||||
if isinstance(output, ModelRunnerOutput | NoneType):
|
||||
return output
|
||||
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
|
||||
@@ -304,6 +304,13 @@ class TPUWorker:
|
||||
|
||||
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
||||
# Init kv cache connector here, because it requires
|
||||
# `kv_cache_config`.
|
||||
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
|
||||
# because `initialize_kv_cache` will inject kv cache groups not
|
||||
# related to kv cache connector (e.g. kv cache sharing layers).
|
||||
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
|
||||
|
||||
self.model_runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
def check_health(self) -> None:
|
||||
@@ -336,8 +343,6 @@ class TPUWorker:
|
||||
parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size
|
||||
)
|
||||
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self.model_runner.ensure_kv_transfer_shutdown()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user