[Experimental] Prefix Caching Support (#1669)

Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
shiyi.c_98
2024-01-17 16:32:10 -08:00
committed by GitHub
parent 14cc317ba4
commit d10f8e1d43
20 changed files with 1356 additions and 71 deletions

View File

@@ -371,6 +371,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
@@ -383,6 +384,7 @@ class AsyncLLMEngine:
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}.")
@@ -401,7 +403,8 @@ class AsyncLLMEngine:
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prefix_pos=prefix_pos)
return stream
@@ -410,7 +413,8 @@ class AsyncLLMEngine:
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
prompt_token_ids: Optional[List[int]] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@@ -425,6 +429,11 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
@@ -482,7 +491,8 @@ class AsyncLLMEngine:
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
arrival_time=arrival_time,
prefix_pos=prefix_pos)
async for request_output in stream:
yield request_output