[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-08-16 02:52:52 +03:00
committed by GitHub
parent b9dc9d2607
commit c280066f9d
19 changed files with 381 additions and 335 deletions

View File

@@ -2,15 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
FreeKVCacheBlockQueue, KVCacheBlock)
from vllm.v1.request import Request
logger = init_logger(__name__)
@@ -97,84 +95,39 @@ class BlockPool:
self,
request: Request,
blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
metadata to be updated and cached. Given a request, it updates the
metadata for each block and caching it in the
`cached_block_hash_to_block`.
The block hashes values are computed by the Request object immediately
when it is created and when new tokens are appended.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
block_hash = new_block_hashes[i]
# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
@@ -184,9 +137,15 @@ class BlockPool:
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = parent_block.block_hash.get_hash_value()
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
manager.cache_blocks(request, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool):
use_eagle: bool, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
caching_hash_fn, enable_kv_cache_events)
enable_kv_cache_events)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str,
@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
use_eagle, caching_hash_fn,
use_eagle,
enable_kv_cache_events)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)

View File

@@ -1,16 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens, init_none_hash)
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@@ -71,23 +68,13 @@ class KVCacheManager:
kv_cache_config: KVCacheConfig,
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.max_model_len = max_model_len
if len(kv_cache_config.kv_cache_groups) == 0:
# Attention free models don't have kv cache,
# thus don't need prefix caching.
enable_caching = False
self.enable_caching = enable_caching
self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
init_none_hash(self.caching_hash_fn)
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
@@ -107,19 +94,12 @@ class KVCacheManager:
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
@property
def usage(self) -> float:
"""Get the KV cache usage.
@@ -161,15 +141,6 @@ class KVCacheManager:
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
assert self.block_size is not None
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
@@ -178,7 +149,7 @@ class KVCacheManager:
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
self.coordinator.find_longest_cache_hit(request.block_hashes,
max_cache_hit_length))
if self.log_stats:
@@ -296,11 +267,7 @@ class KVCacheManager:
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
request.num_tokens)
self.coordinator.cache_blocks(
request,
self.req_to_block_hashes[request.request_id],
num_tokens_to_cache,
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
return KVCacheBlocks(new_blocks)
@@ -373,14 +340,6 @@ class KVCacheManager:
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
@@ -397,9 +356,7 @@ class KVCacheManager:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
if self.enable_caching:
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
self.coordinator.cache_blocks(request, num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""

View File

@@ -547,41 +547,61 @@ def hash_block_tokens(
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHash]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
request: The request object.
Returns:
The list of computed hash values.
def get_request_block_hasher(
block_size: int,
caching_hash_fn: Callable[[Any],
int]) -> Callable[[Request], list[BlockHash]]:
"""
token_ids = request.all_token_ids
Returns a function which computes the list of un-computed block hashes
of a request.
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
ret = []
parent_block_hash_value = None
# Only full blocks will be hashed
for start in range(0, len(token_ids) - block_size + 1, block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
def request_block_hasher(request: Request) -> list[BlockHash]:
start_token_idx = len(request.block_hashes) * block_size
num_tokens = request.num_tokens
curr_mm_idx = 0
if start_token_idx > 0:
# Set curr_mm_idx = -1 to indicate the last mm input.
# Note that since we reach to this branch only when the block is
# completed with generated tokens, we only need to consider the
# last mm input.
curr_mm_idx = -1
prev_block_hash_value = request.block_hashes[-1].hash_value \
if request.block_hashes else None
new_block_hashes: list[BlockHash] = []
while True:
end_token_idx = start_token_idx + block_size
if end_token_idx > num_tokens:
# We only hash full blocks
break
if req_need_extra_keys:
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, curr_mm_idx)
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
# Compute the hash of the current block
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
block_hash = hash_block_tokens(caching_hash_fn,
prev_block_hash_value, block_tokens,
extra_keys)
new_block_hashes.append(block_hash)
start_token_idx += block_size
prev_block_hash_value = block_hash.hash_value
return new_block_hashes
return request_block_hasher
def max_memory_usage_bytes(vllm_config: VllmConfig,

View File

@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request):
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:

View File

@@ -3,7 +3,6 @@
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
self.block_size = kv_cache_spec.block_size
@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_tokens: int) -> None:
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks

View File

@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (decorate_logs, make_zmq_socket,
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
resolve_obj_by_qualname, set_process_title)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
get_request_block_hasher,
init_none_hash,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
@@ -140,6 +142,19 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None
if (self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None):
block_size = vllm_config.cache_config.block_size
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo)
init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
@@ -417,7 +432,8 @@ class EngineCore:
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request)
req = Request.from_engine_core_request(request,
self.request_block_hasher)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For

View File

@@ -3,7 +3,8 @@
import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.v1.core.kv_cache_utils import BlockHash
class Request:
@@ -36,6 +38,8 @@ class Request:
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
@@ -108,8 +112,18 @@ class Request:
# indicates that the output is corrupted
self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
@@ -131,6 +145,7 @@ class Request:
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
block_hasher=block_hasher,
)
def append_output_token_ids(
@@ -144,6 +159,9 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0