Add script for benchmarking serving throughput (#145)
This commit is contained in:
@@ -32,12 +32,14 @@ class AsyncLLMServer:
|
||||
server_use_ray: Whether to make LLMServer a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
*args, *kwargs: Arguments for LLMServer.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
|
||||
*args, **kwargs) -> None:
|
||||
log_requests: bool = True, *args, **kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.server_use_ray = server_use_ray
|
||||
self.log_requests = log_requests
|
||||
if not self.server_use_ray:
|
||||
server_class = LLMServer
|
||||
elif self.worker_use_ray:
|
||||
@@ -106,10 +108,11 @@ class AsyncLLMServer:
|
||||
request_event = asyncio.Event()
|
||||
self.request_events[request_id] = request_event
|
||||
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
if self.log_requests:
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {prompt!r}, "
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the cacheflow server's waiting queue.
|
||||
if self.server_use_ray:
|
||||
@@ -152,7 +155,8 @@ class AsyncLLMServer:
|
||||
|
||||
# Once finished, release the resources of the sequence group.
|
||||
if request_output.finished():
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
if self.log_requests:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
|
||||
del self.request_outputs[request_id]
|
||||
del self.request_events[request_id]
|
||||
@@ -176,7 +180,8 @@ class AsyncLLMServer:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
if self.log_requests:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
if self.server_use_ray:
|
||||
await self.server.abort_request.remote(request_id)
|
||||
@@ -206,6 +211,7 @@ class AsyncLLMServer:
|
||||
# Create the LLM server.
|
||||
server = cls(server_args.worker_use_ray,
|
||||
server_args.server_use_ray,
|
||||
not server_args.disable_log_requests,
|
||||
*server_configs,
|
||||
distributed_init_method, devices,
|
||||
log_stats=not server_args.disable_log_stats)
|
||||
|
||||
Reference in New Issue
Block a user