[Experimental] Prefix Caching Support (#1669)

Co-authored-by: DouHappy <2278958187@qq.com>
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
shiyi.c_98
2024-01-17 16:32:10 -08:00
committed by GitHub
parent 14cc317ba4
commit d10f8e1d43
20 changed files with 1356 additions and 71 deletions

View File

@@ -337,6 +337,7 @@ class LLMEngine:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
prefix_pos: Optional[int] = None,
) -> None:
"""Add a request to the engine's request pool.
@@ -353,6 +354,11 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details:
- Set arrival_time to the current time if it is None.
@@ -389,9 +395,13 @@ class LLMEngine:
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time)
arrival_time, prefix)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
@@ -662,6 +672,12 @@ class LLMEngine:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
# Update prefix state, now all the uncomputed prefixes are computed.
for seq_group in scheduled_seq_groups:
if (seq_group.prefix is not None and seq_group.prefix.allocated
and not seq_group.prefix.computed):
seq_group.prefix.computed = True
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,