Prefix Cache Aware Scheduling [1/n] (#10128)

Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
Ricky Xu
2024-11-22 21:15:55 -08:00
committed by GitHub
parent 7c25fe45a6
commit 4634a89d18
13 changed files with 962 additions and 236 deletions

View File

@@ -1,17 +1,20 @@
import time
from typing import List, Optional
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)
def create_dummy_prompt(
request_id: str,
prompt_length: int,
prompt_length: int = -1,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
best_of: int = 1,
@@ -26,6 +29,7 @@ def create_dummy_prompt(
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
@@ -42,6 +46,15 @@ def create_dummy_prompt(
return prompt, seq_group
def create_dummy_sequence(request_id: int, token_ids: List[int],
block_size: int) -> Sequence:
return Sequence(
seq_id=request_id,
inputs=token_inputs(token_ids),
block_size=block_size,
)
def create_dummy_prompt_encoder_decoder(
request_id: str,
decoder_prompt_length: int,
@@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
def schedule_and_update_computed_tokens(scheduler):
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
for s in out.scheduled_seq_groups:
s.seq_group.update_num_computed_tokens(s.token_chunk_size)
return metas, out
def append_new_token_seq(seq: Sequence, token_id: int):
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
class SchedulerProxy:
"""
A proxy class to forward calls to the scheduler.
"""
def __init__(self, scheduler: Scheduler):
self.scheduler_ = scheduler
self.call_history: Dict[str, List[Any]] = defaultdict(list)
def __getattr__(self, name: str) -> Any:
def wrapper(*args, **kwargs):
result = getattr(self.scheduler_, name)(*args, **kwargs)
self.call_history[name].append((args, kwargs, result))
return result
return wrapper
def last_schedule_ret(
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
_, _, ret = self.call_history["schedule"][-1]
return ret