Refactor system architecture (#109)
This commit is contained in:
@@ -2,10 +2,10 @@ import enum
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from cacheflow.config import CacheConfig, SchedulerConfig
|
||||
from cacheflow.core.block_manager import BlockSpaceManager
|
||||
from cacheflow.core.policy import PolicyFactory
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
@@ -28,43 +28,53 @@ class PreemptionMode(enum.Enum):
|
||||
RECOMPUTE = enum.auto()
|
||||
|
||||
|
||||
class SchedulerOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
self.blocks_to_swap_in = blocks_to_swap_in
|
||||
self.blocks_to_swap_out = blocks_to_swap_out
|
||||
self.blocks_to_copy = blocks_to_copy
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out
|
||||
and not self.blocks_to_copy)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
controllers: List,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_sequences: int,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
self.controllers = controllers
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_sequences = max_num_sequences
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
)
|
||||
|
||||
# Sequence groups in the WAITING state.
|
||||
self.waiting: List[SequenceGroup] = []
|
||||
# Sequence groups in the RUNNING state.
|
||||
self.running: List[SequenceGroup] = []
|
||||
# Mapping: group_id -> num_steps.
|
||||
self.num_steps: Dict[int, int] = {}
|
||||
# Mapping: group_id -> sampling params.
|
||||
self.sampling_params: Dict[int, SamplingParams] = {}
|
||||
# Mapping: request_id -> num_steps.
|
||||
self.num_steps: Dict[str, int] = {}
|
||||
# Sequence groups in the SWAPPED state.
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
|
||||
@@ -72,18 +82,15 @@ class Scheduler:
|
||||
# List[timestamp, num_tokens]
|
||||
self.num_input_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def add_sequence_groups(
|
||||
self,
|
||||
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
|
||||
) -> None:
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the waiting queue.
|
||||
for seq_group, sampling_params in seq_groups:
|
||||
self.waiting.append(seq_group)
|
||||
self.sampling_params[seq_group.group_id] = sampling_params
|
||||
assert seq_group.request_id not in self.num_steps
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def _schedule(
|
||||
self,
|
||||
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
@@ -136,8 +143,9 @@ class Scheduler:
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
if len(self.running) + num_seqs > self.max_num_sequences:
|
||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
num_curr_seqs = len(self.running)
|
||||
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
|
||||
seq_group = self.swapped.pop(0)
|
||||
@@ -151,7 +159,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Join waiting sequences if possible.
|
||||
prompt_group_ids: List[int] = []
|
||||
prompt_group_ids: List[str] = []
|
||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
||||
# prioritized over the sequence groups in the WAITING state.
|
||||
# This is because we want to bound the amount of CPU memory taken by
|
||||
@@ -172,25 +180,31 @@ class Scheduler:
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||
if (num_batched_tokens + num_prompt_tokens
|
||||
> self.max_num_batched_tokens):
|
||||
> self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
||||
if len(self.running) + num_seqs > self.max_num_sequences:
|
||||
num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
|
||||
num_curr_seqs = len(self.running)
|
||||
if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
prompt_group_ids.append(seq_group.group_id)
|
||||
prompt_group_ids.append(seq_group.request_id)
|
||||
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
if not self.log_stats:
|
||||
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
|
||||
prompt_group_ids)
|
||||
return scheduler_outputs, prompt_group_ids
|
||||
|
||||
# TODO(woosuk): Move the below code to server.
|
||||
now = time.time()
|
||||
if num_batched_tokens > 0:
|
||||
self.num_input_tokens.append((now, num_batched_tokens))
|
||||
@@ -208,13 +222,16 @@ class Scheduler:
|
||||
else:
|
||||
avg_throughput = 0.0
|
||||
|
||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
||||
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
|
||||
gpu_cache_usage = num_used_gpu_blocks / self.num_gpu_blocks
|
||||
if self.num_cpu_blocks > 0:
|
||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||
|
||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
if total_num_cpu_blocks > 0:
|
||||
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
|
||||
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / self.num_cpu_blocks
|
||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
@@ -225,27 +242,18 @@ class Scheduler:
|
||||
f"Pending: {len(self.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
return scheduler_outputs, prompt_group_ids
|
||||
|
||||
return (blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy,
|
||||
prompt_group_ids)
|
||||
|
||||
def step(self) -> List[SequenceGroup]:
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_output = self._schedule()
|
||||
blocks_to_swap_in = scheduler_output[0]
|
||||
blocks_to_swap_out = scheduler_output[1]
|
||||
blocks_to_copy = scheduler_output[2]
|
||||
prompt_group_ids = scheduler_output[3]
|
||||
scheduler_outputs, prompt_group_ids = self._schedule()
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
updated_seq_groups: List[SequenceGroup] = self.running.copy()
|
||||
|
||||
for seq_group in self.running:
|
||||
group_id = seq_group.group_id
|
||||
is_prompt = group_id in prompt_group_ids
|
||||
is_prompt = seq_group.request_id in prompt_group_ids
|
||||
|
||||
seq_data: Dict[int, List[SequenceData]] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
@@ -255,36 +263,24 @@ class Scheduler:
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
group_id=group_id,
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
sampling_params=self.sampling_params[group_id],
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
# Execute the first stage of the pipeline.
|
||||
if seq_group_metadata_list or blocks_to_swap_in or blocks_to_swap_out:
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
self.controllers[0].execute_stage(
|
||||
seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
return updated_seq_groups
|
||||
|
||||
def post_step(
|
||||
def update(
|
||||
self,
|
||||
seq_outputs: Dict[int, SequenceOutputs],
|
||||
) -> None:
|
||||
) -> List[SequenceGroup]:
|
||||
# Update the running sequences and free blocks.
|
||||
for seq_group in self.running:
|
||||
group_id = seq_group.group_id
|
||||
self.num_steps[group_id] += 1
|
||||
stop_token_ids = self.sampling_params[group_id].stop_token_ids
|
||||
request_id = seq_group.request_id
|
||||
self.num_steps[request_id] += 1
|
||||
stop_token_ids = seq_group.sampling_params.stop_token_ids
|
||||
|
||||
# Process beam search results before processing the next tokens.
|
||||
for seq in seq_group.seqs:
|
||||
@@ -316,12 +312,13 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
# Check if the sequence has reached the maximum number of steps.
|
||||
max_num_steps = self.sampling_params[group_id].max_tokens
|
||||
if self.num_steps[group_id] == max_num_steps:
|
||||
max_num_steps = seq_group.sampling_params.max_tokens
|
||||
if self.num_steps[request_id] == max_num_steps:
|
||||
self._free_seq(seq)
|
||||
continue
|
||||
|
||||
# Update the running sequences.
|
||||
updated = self.running.copy()
|
||||
running: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
if seq_group.is_finished():
|
||||
@@ -329,13 +326,14 @@ class Scheduler:
|
||||
else:
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
return updated
|
||||
|
||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
||||
self.block_manager.allocate(seq_group)
|
||||
for seq in seq_group.seqs:
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
if seq_group.group_id not in self.num_steps:
|
||||
self.num_steps[seq_group.group_id] = 0
|
||||
if seq_group.request_id not in self.num_steps:
|
||||
self.num_steps[seq_group.request_id] = 0
|
||||
|
||||
def _append_slot(
|
||||
self,
|
||||
@@ -410,9 +408,7 @@ class Scheduler:
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
group_id = seq_group.group_id
|
||||
del self.num_steps[group_id]
|
||||
del self.sampling_params[group_id]
|
||||
del self.num_steps[seq_group.request_id]
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user