Add docstrings for LLMServer and related classes and examples (#142)
This commit is contained in:
@@ -8,7 +8,7 @@ from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||
detokenize_incrementally)
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
@@ -19,6 +19,33 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LLMServer:
|
||||
"""An LLM server that receives requests and generates texts.
|
||||
|
||||
This is the main class for the CacheFlow LLM server. It receives requests
|
||||
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||
space allocated for intermediate states (aka KV cache). This class utilizes
|
||||
iteration-level scheduling and efficient memory management to maximize the
|
||||
serving throughput.
|
||||
|
||||
The `LLM` class wraps this class for offline batched inference and the
|
||||
`AsyncLLMServer` class wraps this class for online serving.
|
||||
|
||||
NOTE: The config arguments are derived from the `ServerArgs` class. For the
|
||||
comprehensive list of arguments, see `ServerArgs`.
|
||||
|
||||
Args:
|
||||
model_config: The configuration related to the LLM model.
|
||||
cache_config: The configuration related to the KV cache memory
|
||||
management.
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
distributed_init_method: The initialization method for distributed
|
||||
execution. See `torch.distributed.init_process_group` for details.
|
||||
stage_devices: The list of devices for each stage. Each stage is a list
|
||||
of (rank, node_resource, device) tuples.
|
||||
log_stats: Whether to log statistics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -27,7 +54,7 @@ class LLMServer:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
stage_devices: List[List[Any]],
|
||||
stage_devices: List[List[DeviceID]],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
@@ -83,6 +110,7 @@ class LLMServer:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache."""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
@@ -108,6 +136,7 @@ class LLMServer:
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
|
||||
"""Creates an LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
@@ -126,6 +155,22 @@ class LLMServer:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Add a request to the server's request pool.
|
||||
|
||||
The request is added to the request pool and will be processed by the
|
||||
scheduler as `server.step()` is called. The exact scheduling policy is
|
||||
determined by the scheduler.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters for text generation.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current time.
|
||||
"""
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
if prompt_token_ids is None:
|
||||
@@ -148,15 +193,30 @@ class LLMServer:
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def abort_request(self, request_id: str) -> None:
|
||||
"""Aborts a request with the given ID.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request to abort.
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests."""
|
||||
return self.scheduler.has_unfinished_seqs()
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
|
||||
This function performs one decoding iteration for the server. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
|
||||
# Nothing to do.
|
||||
@@ -188,7 +248,7 @@ class LLMServer:
|
||||
return request_outputs
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Decode the sequence outputs.
|
||||
"""Decodes the sequence outputs."""
|
||||
for seq_group in seq_groups:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_token, new_output_text = detokenize_incrementally(
|
||||
@@ -201,7 +261,7 @@ class LLMServer:
|
||||
seq.output_text = new_output_text
|
||||
|
||||
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
# Stop the sequences.
|
||||
"""Stop the finished sequences."""
|
||||
for seq_group in seq_groups:
|
||||
sampling_params = seq_group.sampling_params
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
@@ -238,6 +298,7 @@ class LLMServer:
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
Reference in New Issue
Block a user