[Model] Allow users to control skip reading cache per request. (#28194)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -57,6 +57,7 @@ class PoolingParams(
|
||||
## Internal use only
|
||||
task: PoolingTask | None = None
|
||||
requires_token_ids: bool = False
|
||||
skip_reading_prefix_cache: bool = None
|
||||
extra_kwargs: dict[str, Any] | None = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
@@ -93,6 +94,8 @@ class PoolingParams(
|
||||
# plugin task uses io_processor.parse_request to verify inputs,
|
||||
# skipping PoolingParams verify
|
||||
if self.task == "plugin":
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
self.skip_reading_prefix_cache = True
|
||||
return
|
||||
|
||||
# NOTE: Task validation needs to done against the model instance,
|
||||
@@ -122,6 +125,15 @@ class PoolingParams(
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of all pooling may less than n_prompt_tokens,
|
||||
# we need to skip reading cache at this request.
|
||||
if self.task in ["token_embed", "token_classify"]:
|
||||
self.skip_reading_prefix_cache = True
|
||||
else:
|
||||
self.skip_reading_prefix_cache = False
|
||||
|
||||
self._verify_step_pooling(pooler_config, valid_parameters)
|
||||
|
||||
def _verify_step_pooling(
|
||||
|
||||
@@ -254,6 +254,8 @@ class SamplingParams(
|
||||
generated token can complete the sequence."""
|
||||
_bad_words_token_ids: list[list[int]] | None = None
|
||||
|
||||
skip_reading_prefix_cache: bool = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: int | None = 1,
|
||||
@@ -414,6 +416,12 @@ class SamplingParams(
|
||||
self.structured_outputs = self.guided_decoding
|
||||
self.guided_decoding = None
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of prompt logprobs may less than n_prompt_tokens,
|
||||
# we need to skip reading cache at this request.
|
||||
self.skip_reading_prefix_cache = self.prompt_logprobs is not None
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
|
||||
@@ -185,12 +185,11 @@ class KVCacheManager:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if not self.enable_caching or (
|
||||
request.sampling_params is not None
|
||||
and request.sampling_params.prompt_logprobs is not None
|
||||
):
|
||||
# We skip finding the prefix cache hit when prefix caching is
|
||||
# disabled or the request is marked as skipping kv cache read
|
||||
# (which happens when the request requires prompt logprobs
|
||||
# or calls a pooling model with all pooling).
|
||||
if not self.enable_caching or request.skip_reading_prefix_cache:
|
||||
return self.empty_kv_cache_blocks, 0
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
|
||||
@@ -127,6 +127,8 @@ class Request:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(
|
||||
cls,
|
||||
@@ -180,6 +182,19 @@ class Request:
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
def get_skip_reading_prefix_cache(self) -> bool:
|
||||
if (
|
||||
self.sampling_params is not None
|
||||
and self.sampling_params.skip_reading_prefix_cache is not None
|
||||
):
|
||||
return self.sampling_params.skip_reading_prefix_cache
|
||||
elif (
|
||||
self.pooling_params is not None
|
||||
and self.pooling_params.skip_reading_prefix_cache is not None
|
||||
):
|
||||
return self.pooling_params.skip_reading_prefix_cache
|
||||
return False
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return RequestStatus.is_finished(self.status)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user