[Model] Allow users to control skip reading cache per request. (#28194)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-11-16 16:04:50 +08:00
committed by GitHub
parent d231876ce3
commit a55b64635c
5 changed files with 67 additions and 8 deletions

View File

@@ -57,6 +57,7 @@ class PoolingParams(
## Internal use only
task: PoolingTask | None = None
requires_token_ids: bool = False
skip_reading_prefix_cache: bool = None
extra_kwargs: dict[str, Any] | None = None
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
@@ -93,6 +94,8 @@ class PoolingParams(
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
if self.skip_reading_prefix_cache is None:
self.skip_reading_prefix_cache = True
return
# NOTE: Task validation needs to done against the model instance,
@@ -122,6 +125,15 @@ class PoolingParams(
if getattr(self, k, None) is None:
setattr(self, k, getattr(pooler_config, k))
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of all pooling may less than n_prompt_tokens,
# we need to skip reading cache at this request.
if self.task in ["token_embed", "token_classify"]:
self.skip_reading_prefix_cache = True
else:
self.skip_reading_prefix_cache = False
self._verify_step_pooling(pooler_config, valid_parameters)
def _verify_step_pooling(

View File

@@ -254,6 +254,8 @@ class SamplingParams(
generated token can complete the sequence."""
_bad_words_token_ids: list[list[int]] | None = None
skip_reading_prefix_cache: bool = None
@staticmethod
def from_optional(
n: int | None = 1,
@@ -414,6 +416,12 @@ class SamplingParams(
self.structured_outputs = self.guided_decoding
self.guided_decoding = None
if self.skip_reading_prefix_cache is None:
# If prefix caching is enabled,
# the output of prompt logprobs may less than n_prompt_tokens,
# we need to skip reading cache at this request.
self.skip_reading_prefix_cache = self.prompt_logprobs is not None
def _verify_args(self) -> None:
if not isinstance(self.n, int):
raise ValueError(f"n must be an int, but is of type {type(self.n)}")

View File

@@ -185,12 +185,11 @@ class KVCacheManager:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if not self.enable_caching or (
request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None
):
# We skip finding the prefix cache hit when prefix caching is
# disabled or the request is marked as skipping kv cache read
# (which happens when the request requires prompt logprobs
# or calls a pooling model with all pooling).
if not self.enable_caching or request.skip_reading_prefix_cache:
return self.empty_kv_cache_blocks, 0
# NOTE: When all tokens hit the cache, we must recompute the last token

View File

@@ -127,6 +127,8 @@ class Request:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache()
@classmethod
def from_engine_core_request(
cls,
@@ -180,6 +182,19 @@ class Request:
def num_output_tokens(self) -> int:
return len(self._output_token_ids)
def get_skip_reading_prefix_cache(self) -> bool:
if (
self.sampling_params is not None
and self.sampling_params.skip_reading_prefix_cache is not None
):
return self.sampling_params.skip_reading_prefix_cache
elif (
self.pooling_params is not None
and self.pooling_params.skip_reading_prefix_cache is not None
):
return self.pooling_params.skip_reading_prefix_cache
return False
def is_finished(self) -> bool:
return RequestStatus.is_finished(self.status)