[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-08-16 02:52:52 +03:00
committed by GitHub
parent b9dc9d2607
commit c280066f9d
19 changed files with 381 additions and 335 deletions

View File

@@ -3,7 +3,8 @@
import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.v1.core.kv_cache_utils import BlockHash
class Request:
@@ -36,6 +38,8 @@ class Request:
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
@@ -108,8 +112,18 @@ class Request:
# indicates that the output is corrupted
self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
@@ -131,6 +145,7 @@ class Request:
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
block_hasher=block_hasher,
)
def append_output_token_ids(
@@ -144,6 +159,9 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0