Prefix Cache Aware Scheduling [1/n] (#10128)
Signed-off-by: rickyx <rickyx@anyscale.com>
This commit is contained in:
@@ -1,17 +1,20 @@
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Logprob, Sequence, SequenceGroup
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
|
||||
def create_dummy_prompt(
|
||||
request_id: str,
|
||||
prompt_length: int,
|
||||
prompt_length: int = -1,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
best_of: int = 1,
|
||||
@@ -26,6 +29,7 @@ def create_dummy_prompt(
|
||||
# Create dummy prompt sequence with tokens 0...block_size-1
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
@@ -42,6 +46,15 @@ def create_dummy_prompt(
|
||||
return prompt, seq_group
|
||||
|
||||
|
||||
def create_dummy_sequence(request_id: int, token_ids: List[int],
|
||||
block_size: int) -> Sequence:
|
||||
return Sequence(
|
||||
seq_id=request_id,
|
||||
inputs=token_inputs(token_ids),
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_prompt_encoder_decoder(
|
||||
request_id: str,
|
||||
decoder_prompt_length: int,
|
||||
@@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
|
||||
|
||||
def schedule_and_update_computed_tokens(scheduler):
|
||||
metas, out, _ = scheduler.schedule()
|
||||
for s, meta in zip(out.scheduled_seq_groups, metas):
|
||||
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
|
||||
for s in out.scheduled_seq_groups:
|
||||
s.seq_group.update_num_computed_tokens(s.token_chunk_size)
|
||||
return metas, out
|
||||
|
||||
|
||||
def append_new_token_seq(seq: Sequence, token_id: int):
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
|
||||
seq_group.update_num_computed_tokens(token_chunk_size)
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.append_token_id(token_id, {token_id: Logprob(token_id)})
|
||||
|
||||
|
||||
class SchedulerProxy:
|
||||
"""
|
||||
A proxy class to forward calls to the scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, scheduler: Scheduler):
|
||||
self.scheduler_ = scheduler
|
||||
self.call_history: Dict[str, List[Any]] = defaultdict(list)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
result = getattr(self.scheduler_, name)(*args, **kwargs)
|
||||
self.call_history[name].append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def last_schedule_ret(
|
||||
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
|
||||
_, _, ret = self.call_history["schedule"][-1]
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user