[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
class Request:
|
||||
@@ -36,6 +38,8 @@ class Request:
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
@@ -108,8 +112,18 @@ class Request:
|
||||
# indicates that the output is corrupted
|
||||
self.num_nans_in_logits = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Optional[Callable[
|
||||
[], list[BlockHash]]] = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
def from_engine_core_request(
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
@@ -131,6 +145,7 @@ class Request:
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
||||
def append_output_token_ids(
|
||||
@@ -144,6 +159,9 @@ class Request:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
if self.get_hash_new_full_blocks is not None:
|
||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||
|
||||
@property
|
||||
def is_output_corrupted(self) -> bool:
|
||||
return self.num_nans_in_logits > 0
|
||||
|
||||
Reference in New Issue
Block a user