Add script for benchmarking serving throughput (#145)

This commit is contained in:
Woosuk Kwon
2023-06-14 19:55:38 -07:00
committed by GitHub
parent da5ddcd544
commit 311490a720
10 changed files with 421 additions and 415 deletions

View File

@@ -32,12 +32,14 @@ class AsyncLLMServer:
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMServer.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
*args, **kwargs) -> None:
log_requests: bool = True, *args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray
self.log_requests = log_requests
if not self.server_use_ray:
server_class = LLMServer
elif self.worker_use_ray:
@@ -106,10 +108,11 @@ class AsyncLLMServer:
request_event = asyncio.Event()
self.request_events[request_id] = request_event
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
if self.log_requests:
logger.info(f"Received request {request_id}: "
f"prompt: {prompt!r}, "
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
# Add the request into the cacheflow server's waiting queue.
if self.server_use_ray:
@@ -152,7 +155,8 @@ class AsyncLLMServer:
# Once finished, release the resources of the sequence group.
if request_output.finished():
logger.info(f"Finished request {request_id}.")
if self.log_requests:
logger.info(f"Finished request {request_id}.")
del self.request_outputs[request_id]
del self.request_events[request_id]
@@ -176,7 +180,8 @@ class AsyncLLMServer:
# The request has already finished or been aborted.
return
logger.info(f"Aborted request {request_id}.")
if self.log_requests:
logger.info(f"Aborted request {request_id}.")
if self.server_use_ray:
await self.server.abort_request.remote(request_id)
@@ -206,6 +211,7 @@ class AsyncLLMServer:
# Create the LLM server.
server = cls(server_args.worker_use_ray,
server_args.server_use_ray,
not server_args.disable_log_requests,
*server_configs,
distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)