Fix various issues of async servers (#135)

This commit is contained in:
Zhuohan Li
2023-06-05 23:44:50 +08:00
committed by GitHub
parent 8274ca23ac
commit 1a956e136b
11 changed files with 289 additions and 121 deletions

View File

@@ -7,13 +7,14 @@ import time
from typing import AsyncGenerator, Dict, List, Optional
import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
async def create_completion(raw_request: Request):
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
request_id=request_id)
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
@@ -276,7 +291,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = ServerArgs.add_cli_args(parser)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
@@ -291,10 +306,11 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = ServerArgs.from_cli_args(args)
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)