[Experimental] Prefix Caching Support (#1669)
Co-authored-by: DouHappy <2278958187@qq.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -337,6 +337,7 @@ class LLMEngine:
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
@@ -353,6 +354,11 @@ class LLMEngine:
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
@@ -389,9 +395,13 @@ class LLMEngine:
|
||||
seq_id = next(self.seq_counter)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
|
||||
# Check whether the input specifies prefix
|
||||
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
||||
prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time)
|
||||
arrival_time, prefix)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
@@ -662,6 +672,12 @@ class LLMEngine:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Update prefix state, now all the uncomputed prefixes are computed.
|
||||
for seq_group in scheduled_seq_groups:
|
||||
if (seq_group.prefix is not None and seq_group.prefix.allocated
|
||||
and not seq_group.prefix.computed):
|
||||
seq_group.prefix.computed = True
|
||||
|
||||
if self.log_stats:
|
||||
# Log the system stats.
|
||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||
|
||||
Reference in New Issue
Block a user