FastAPI-based working frontend (#10)

This commit is contained in:
Zhuohan Li
2023-03-29 14:48:56 +08:00
committed by GitHub
parent d359cda5fa
commit 721fa3df15
15 changed files with 536 additions and 146 deletions

View File

@@ -1,7 +1,6 @@
from typing import Dict, List
from typing import Dict, List, Tuple
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
@@ -14,14 +13,12 @@ class Scheduler:
def __init__(
self,
frontend: Frontend,
controllers: List,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
) -> None:
self.frontend = frontend
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
@@ -47,9 +44,12 @@ class Scheduler:
# Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = []
def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None:
# Add sequence groups to the pending queue.
for seq_group, sampling_params in sequence_groups:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
@@ -104,7 +104,7 @@ class Scheduler:
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)
def step(self) -> None:
def step(self) -> List[SequenceGroup]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
@@ -158,7 +158,6 @@ class Scheduler:
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
self._fetch_inputs()
if not self.swapped:
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
@@ -176,6 +175,8 @@ class Scheduler:
# 4. Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
@@ -219,6 +220,8 @@ class Scheduler:
blocks_to_copy,
)
return updated_seq_groups
def post_step(
self,
seq_outputs: Dict[int, SequenceOutputs],
@@ -268,13 +271,12 @@ class Scheduler:
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._return(seq_group)
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
def _return(self, seq_group: SequenceGroup) -> None:
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)