[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -2,15 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -97,84 +95,39 @@ class BlockPool:
|
||||
self,
|
||||
request: Request,
|
||||
blocks: list[KVCacheBlock],
|
||||
block_hashes: list[BlockHash],
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
hash_fn: Callable,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it computes the
|
||||
block hashes for the blocks starting from `num_cached_blocks` to
|
||||
`num_full_blocks`, updating the metadata for each block
|
||||
and caching them in the `cached_block_hash_to_block`.
|
||||
metadata to be updated and cached. Given a request, it updates the
|
||||
metadata for each block and caching it in the
|
||||
`cached_block_hash_to_block`.
|
||||
The block hashes values are computed by the Request object immediately
|
||||
when it is created and when new tokens are appended.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blocks: All blocks in the request.
|
||||
block_hashes: Block hashes of the blocks in the request. Note that
|
||||
this list may be shorter than the blocks list. In this case the
|
||||
missed block hash will be computed in this function.
|
||||
num_cached_blocks: The number of blocks that are already cached.
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
hash_fn: The hash function to use for block hashes.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(block_hashes) >= num_cached_blocks
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
if num_cached_blocks == 0:
|
||||
prev_block_hash_value = None
|
||||
else:
|
||||
prev_block = blocks[num_cached_blocks - 1]
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.get_hash_value()
|
||||
|
||||
parent_block_hash = prev_block_hash_value
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
else None)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
|
||||
if i < len(new_block_hashes):
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption), or by other
|
||||
# single_type_managers with the same block_size.
|
||||
# In this case we simply reuse the block hash.
|
||||
block_hash = new_block_hashes[i]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
# in case it will be preempted in the future.
|
||||
blk_idx = num_cached_blocks + i
|
||||
start_token_idx = blk_idx * block_size
|
||||
end_token_idx = (blk_idx + 1) * block_size
|
||||
block_tokens = request.all_token_ids[
|
||||
start_token_idx:end_token_idx]
|
||||
assert len(block_tokens) == block_size, (
|
||||
f"Expected {block_size} tokens, got "
|
||||
f"{len(block_tokens)} at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Generate extra keys for multi-modal inputs. Note that since
|
||||
# we reach to this branch only when the block is completed with
|
||||
# generated tokens, we only need to consider the last mm input.
|
||||
extra_keys, _ = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
block_hashes.append(block_hash)
|
||||
block_hash = new_block_hashes[i]
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
block_hash_with_group_id = BlockHashWithGroupId(
|
||||
@@ -184,9 +137,15 @@ class BlockPool:
|
||||
blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
if num_cached_blocks == 0:
|
||||
parent_block_hash = None
|
||||
else:
|
||||
parent_block = blocks[num_cached_blocks - 1]
|
||||
assert parent_block.block_hash is not None
|
||||
parent_block_hash = parent_block.block_hash.get_hash_value()
|
||||
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
) for i, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups))
|
||||
|
||||
@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
|
||||
manager.allocate_new_blocks(request_id, num_tokens)
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, block_hashes, num_computed_tokens)
|
||||
manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool):
|
||||
use_eagle: bool, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
|
||||
caching_hash_fn, enable_kv_cache_events)
|
||||
enable_kv_cache_events)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
enable_caching, enable_kv_cache_events)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
enable_caching, enable_kv_cache_events)
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool, caching_hash_fn: Callable,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
|
||||
use_eagle, caching_hash_fn,
|
||||
use_eagle,
|
||||
enable_kv_cache_events)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
enable_caching, enable_kv_cache_events)
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens, init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@@ -71,23 +68,13 @@ class KVCacheManager:
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
||||
# Attention free models don't have kv cache,
|
||||
# thus don't need prefix caching.
|
||||
enable_caching = False
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.caching_hash_fn = (
|
||||
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
|
||||
sha256 if caching_hash_algo == "sha256" else hash)
|
||||
init_none_hash(self.caching_hash_fn)
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
@@ -107,19 +94,12 @@ class KVCacheManager:
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=self.enable_caching,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
# `get_computed_blocks` or `allocate_slots`.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
@@ -161,15 +141,6 @@ class KVCacheManager:
|
||||
and request.sampling_params.prompt_logprobs is not None)):
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
assert self.block_size is not None
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
@@ -178,7 +149,7 @@ class KVCacheManager:
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(block_hashes,
|
||||
self.coordinator.find_longest_cache_hit(request.block_hashes,
|
||||
max_cache_hit_length))
|
||||
|
||||
if self.log_stats:
|
||||
@@ -296,11 +267,7 @@ class KVCacheManager:
|
||||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||||
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
|
||||
request.num_tokens)
|
||||
self.coordinator.cache_blocks(
|
||||
request,
|
||||
self.req_to_block_hashes[request.request_id],
|
||||
num_tokens_to_cache,
|
||||
)
|
||||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||||
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
@@ -373,14 +340,6 @@ class KVCacheManager:
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
|
||||
NOTE: Unlike `free`, this method should be called only when the request
|
||||
is finished, not when it is preempted.
|
||||
"""
|
||||
self.req_to_block_hashes.pop(request.request_id, None)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
@@ -397,9 +356,7 @@ class KVCacheManager:
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
if self.enable_caching:
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
self.coordinator.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
|
||||
@@ -547,41 +547,61 @@ def hash_block_tokens(
|
||||
curr_block_token_ids_tuple, extra_keys)
|
||||
|
||||
|
||||
def hash_request_tokens(hash_function: Any, block_size: int,
|
||||
request: Request) -> list[BlockHash]:
|
||||
"""Computes hash values of a chain of blocks given a sequence of
|
||||
token IDs. The hash value is used for prefix caching.
|
||||
|
||||
Args:
|
||||
block_size: The size of each block.
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
The list of computed hash values.
|
||||
def get_request_block_hasher(
|
||||
block_size: int,
|
||||
caching_hash_fn: Callable[[Any],
|
||||
int]) -> Callable[[Request], list[BlockHash]]:
|
||||
"""
|
||||
token_ids = request.all_token_ids
|
||||
Returns a function which computes the list of un-computed block hashes
|
||||
of a request.
|
||||
|
||||
req_need_extra_keys = need_extra_keys(request)
|
||||
req_extra_keys = None
|
||||
curr_mm_idx = 0
|
||||
Each request holds a list of its block hashes (request.block_hashes).
|
||||
When a request is created, it calls the below function to compute
|
||||
the hashes of all full blocks of the request's initial tokens.
|
||||
The hashes are then stored in request.block_hashes.
|
||||
Later, whenever new tokens are appended to the request, it calls
|
||||
the below function again to compute any new full blocks of tokens.
|
||||
The returned new hashes are appended to request.block_hashes.
|
||||
"""
|
||||
|
||||
ret = []
|
||||
parent_block_hash_value = None
|
||||
# Only full blocks will be hashed
|
||||
for start in range(0, len(token_ids) - block_size + 1, block_size):
|
||||
end = start + block_size
|
||||
block_token_ids = token_ids[start:end]
|
||||
def request_block_hasher(request: Request) -> list[BlockHash]:
|
||||
start_token_idx = len(request.block_hashes) * block_size
|
||||
num_tokens = request.num_tokens
|
||||
|
||||
curr_mm_idx = 0
|
||||
if start_token_idx > 0:
|
||||
# Set curr_mm_idx = -1 to indicate the last mm input.
|
||||
# Note that since we reach to this branch only when the block is
|
||||
# completed with generated tokens, we only need to consider the
|
||||
# last mm input.
|
||||
curr_mm_idx = -1
|
||||
|
||||
prev_block_hash_value = request.block_hashes[-1].hash_value \
|
||||
if request.block_hashes else None
|
||||
new_block_hashes: list[BlockHash] = []
|
||||
while True:
|
||||
end_token_idx = start_token_idx + block_size
|
||||
if end_token_idx > num_tokens:
|
||||
# We only hash full blocks
|
||||
break
|
||||
|
||||
if req_need_extra_keys:
|
||||
# MM and LoRA requests need extra keys for block-hash computation.
|
||||
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start, end, curr_mm_idx)
|
||||
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, curr_mm_idx)
|
||||
|
||||
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
|
||||
block_token_ids, req_extra_keys)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
# Compute the hash of the current block
|
||||
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
|
||||
block_hash = hash_block_tokens(caching_hash_fn,
|
||||
prev_block_hash_value, block_tokens,
|
||||
extra_keys)
|
||||
|
||||
new_block_hashes.append(block_hash)
|
||||
start_token_idx += block_size
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
return new_block_hashes
|
||||
|
||||
return request_block_hasher
|
||||
|
||||
|
||||
def max_memory_usage_bytes(vllm_config: VllmConfig,
|
||||
|
||||
@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
use_eagle=self.use_eagle,
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
|
||||
def _free_blocks(self, request: Request):
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
del self.requests[request.request_id]
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
caching_hash_fn: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
caching_hash_fn: The caching hash function.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_tokens: int) -> None:
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import (decorate_logs, make_zmq_socket,
|
||||
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
|
||||
resolve_obj_by_qualname, set_process_title)
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
|
||||
get_request_block_hasher,
|
||||
init_none_hash,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -140,6 +142,19 @@ class EngineCore:
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
self.request_block_hasher: Optional[Callable[[Request],
|
||||
list[BlockHash]]] = None
|
||||
if (self.vllm_config.cache_config.enable_prefix_caching
|
||||
or self.scheduler.get_kv_connector() is not None):
|
||||
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
caching_hash_fn = get_hash_fn_by_name(
|
||||
vllm_config.cache_config.prefix_caching_hash_algo)
|
||||
init_none_hash(caching_hash_fn)
|
||||
|
||||
self.request_block_hasher = get_request_block_hasher(
|
||||
block_size, caching_hash_fn)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
@@ -417,7 +432,8 @@ class EngineCore:
|
||||
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
|
||||
request.mm_kwargs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
req = Request.from_engine_core_request(request,
|
||||
self.request_block_hasher)
|
||||
if req.use_structured_output:
|
||||
# Note on thread safety: no race condition.
|
||||
# `grammar_init` is only invoked in input processing thread. For
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
class Request:
|
||||
@@ -36,6 +38,8 @@ class Request:
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
@@ -108,8 +112,18 @@ class Request:
|
||||
# indicates that the output is corrupted
|
||||
self.num_nans_in_logits = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Optional[Callable[
|
||||
[], list[BlockHash]]] = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
def from_engine_core_request(
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
@@ -131,6 +145,7 @@ class Request:
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
||||
def append_output_token_ids(
|
||||
@@ -144,6 +159,9 @@ class Request:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
if self.get_hash_new_full_blocks is not None:
|
||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||
|
||||
@property
|
||||
def is_output_corrupted(self) -> bool:
|
||||
return self.num_nans_in_logits > 0
|
||||
|
||||
Reference in New Issue
Block a user