Add docstrings for LLMServer and related classes and examples (#142)

This commit is contained in:
Zhuohan Li
2023-06-07 18:25:20 +08:00
committed by GitHub
parent e38074b1e6
commit 4298374265
10 changed files with 212 additions and 18 deletions

View File

@@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@@ -19,6 +19,33 @@ logger = init_logger(__name__)
class LLMServer:
"""An LLM server that receives requests and generates texts.
This is the main class for the CacheFlow LLM server. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMServer` class wraps this class for online serving.
NOTE: The config arguments are derived from the `ServerArgs` class. For the
comprehensive list of arguments, see `ServerArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list
of (rank, node_resource, device) tuples.
log_stats: Whether to log statistics.
"""
def __init__(
self,
@@ -27,7 +54,7 @@ class LLMServer:
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
stage_devices: List[List[DeviceID]],
log_stats: bool,
) -> None:
logger.info(
@@ -83,6 +110,7 @@ class LLMServer:
self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
@@ -108,6 +136,7 @@ class LLMServer:
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
"""Creates an LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
@@ -126,6 +155,22 @@ class LLMServer:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
"""Add a request to the server's request pool.
The request is added to the request pool and will be processed by the
scheduler as `server.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
@@ -148,15 +193,30 @@ class LLMServer:
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
"""Aborts a request with the given ID.
Args:
request_id: The ID of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration for the server. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do.
@@ -188,7 +248,7 @@ class LLMServer:
return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Decode the sequence outputs.
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
@@ -201,7 +261,7 @@ class LLMServer:
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
# Stop the sequences.
"""Stop the finished sequences."""
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@@ -238,6 +298,7 @@ class LLMServer:
*args,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)