Add docstrings for LLMServer and related classes and examples (#142)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
@@ -15,7 +15,25 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
|
||||
class AsyncLLMServer:
|
||||
"""An asynchronous wrapper for LLMServer.
|
||||
|
||||
This class is used to wrap the LLMServer class to make it asynchronous. It
|
||||
uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMServer is kicked by the generate method when there
|
||||
are requests in the waiting queue. The generate method yields the outputs
|
||||
from the LLMServer to the caller.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see `LLMServer`.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
*args, *kwargs: Arguments for LLMServer.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||
*args, **kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
@@ -35,6 +53,7 @@ class AsyncLLMServer:
|
||||
self.kicking_request_id: Optional[str] = None
|
||||
|
||||
async def server_step(self, kicking_request_id: Optional[str] = None):
|
||||
"""Kick the server to process the waiting requests."""
|
||||
self.is_server_running = True
|
||||
self.kicking_request_id = kicking_request_id
|
||||
if self.server_use_ray:
|
||||
@@ -54,8 +73,31 @@ class AsyncLLMServer:
|
||||
self.request_outputs[request_id] = request_output
|
||||
self.request_events[request_id].set()
|
||||
|
||||
async def generate(self, prompt: str, sampling_params: SamplingParams,
|
||||
request_id: str) -> RequestOutput:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
) -> RequestOutput:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMServer and streams the outputs
|
||||
from the LLMServer to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMServer for the
|
||||
request.
|
||||
"""
|
||||
# Preprocess the request.
|
||||
arrival_time = time.time()
|
||||
|
||||
@@ -66,20 +108,29 @@ class AsyncLLMServer:
|
||||
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}.")
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the cacheflow server's waiting queue.
|
||||
if self.server_use_ray:
|
||||
await self.server.add_request.remote(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.server.add_request(
|
||||
request_id, prompt, sampling_params, arrival_time=arrival_time)
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The cacheflow server does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the server to process the requests.
|
||||
while True:
|
||||
if request_id not in self.request_events:
|
||||
# The request has been aborted.
|
||||
return
|
||||
|
||||
# Kick the server if the server is not running.
|
||||
if not self.is_server_running:
|
||||
await self.server_step(request_id)
|
||||
@@ -113,6 +164,14 @@ class AsyncLLMServer:
|
||||
break
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Abort a submitted request. If the request is finished or not found,
|
||||
this method will be a no-op.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if request_id not in self.request_events:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
@@ -137,6 +196,7 @@ class AsyncLLMServer:
|
||||
|
||||
@classmethod
|
||||
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
|
||||
"""Creates an async LLM server from the server arguments."""
|
||||
# Create the server configs.
|
||||
server_configs = server_args.create_server_configs()
|
||||
parallel_config = server_configs[2]
|
||||
|
||||
Reference in New Issue
Block a user