FastAPI-based working frontend (#10)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from cacheflow.master.block_manager import BlockSpaceManager
|
||||
from cacheflow.master.frontend import Frontend
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
@@ -14,14 +13,12 @@ class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frontend: Frontend,
|
||||
controllers: List,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
max_num_batched_tokens: int,
|
||||
) -> None:
|
||||
self.frontend = frontend
|
||||
self.controllers = controllers
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -47,9 +44,12 @@ class Scheduler:
|
||||
# Pending sequence groups (FIFO).
|
||||
self.pending: List[SequenceGroup] = []
|
||||
|
||||
def _fetch_inputs(self) -> None:
|
||||
inputs = self.frontend.get_inputs()
|
||||
for seq_group, sampling_params in inputs:
|
||||
def add_sequence_groups(
|
||||
self,
|
||||
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]],
|
||||
) -> None:
|
||||
# Add sequence groups to the pending queue.
|
||||
for seq_group, sampling_params in sequence_groups:
|
||||
self.pending.append(seq_group)
|
||||
self.sampling_params[seq_group.group_id] = sampling_params
|
||||
|
||||
@@ -104,7 +104,7 @@ class Scheduler:
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def step(self) -> None:
|
||||
def step(self) -> List[SequenceGroup]:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
@@ -158,7 +158,6 @@ class Scheduler:
|
||||
# 3. Join new sequences if possible.
|
||||
# NOTE: Here we implicitly assume FCFS scheduling.
|
||||
# TODO(woosuk): Add a batching policy to control the batch size.
|
||||
self._fetch_inputs()
|
||||
if not self.swapped:
|
||||
for i, seq_group in enumerate(self.pending):
|
||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||
@@ -176,6 +175,8 @@ class Scheduler:
|
||||
|
||||
# 4. Create input data structures.
|
||||
input_seq_groups: List[SequenceGroupInputs] = []
|
||||
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
||||
|
||||
for seq_group in self.running:
|
||||
group_id = seq_group.group_id
|
||||
num_steps = self.num_steps[group_id]
|
||||
@@ -219,6 +220,8 @@ class Scheduler:
|
||||
blocks_to_copy,
|
||||
)
|
||||
|
||||
return updated_seq_groups
|
||||
|
||||
def post_step(
|
||||
self,
|
||||
seq_outputs: Dict[int, SequenceOutputs],
|
||||
@@ -268,13 +271,12 @@ class Scheduler:
|
||||
running: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
if seq_group.is_finished():
|
||||
self._return(seq_group)
|
||||
self._free_seq_group(seq_group)
|
||||
else:
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
|
||||
def _return(self, seq_group: SequenceGroup) -> None:
|
||||
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
group_id = seq_group.group_id
|
||||
del self.num_steps[group_id]
|
||||
del self.sampling_params[group_id]
|
||||
self.frontend.print_response(seq_group)
|
||||
|
||||
Reference in New Issue
Block a user