Add Automatic Prefix Caching (#2762)

Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Sage Moore
2024-03-02 03:50:01 -05:00
committed by GitHub
parent baee28c46c
commit ce4f5a29fb
18 changed files with 615 additions and 289 deletions

View File

@@ -225,7 +225,6 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@@ -245,7 +244,6 @@ class _AsyncLLMEngine(LLMEngine):
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async def _run_workers_async(
@@ -422,7 +420,6 @@ class AsyncLLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
@@ -435,7 +432,6 @@ class AsyncLLMEngine:
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
@@ -472,8 +468,7 @@ class AsyncLLMEngine:
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
lora_request=lora_request)
return stream
@@ -484,7 +479,6 @@ class AsyncLLMEngine:
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@@ -500,11 +494,6 @@ class AsyncLLMEngine:
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
@@ -565,7 +554,6 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async for request_output in stream: