Fix various issues of async servers (#135)
This commit is contained in:
@@ -7,13 +7,14 @@ import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.logger import init_logger
|
||||
@@ -33,6 +34,7 @@ from cacheflow.entrypoints.openai.protocol import (
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
@@ -93,7 +95,8 @@ def create_logprobs(token_ids: List[int],
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
async def create_completion(raw_request: Request):
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
@@ -139,7 +142,7 @@ async def create_completion(request: CompletionRequest):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = server.generate(prompt, sampling_params,
|
||||
request_id=request_id)
|
||||
request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
@@ -147,6 +150,9 @@ async def create_completion(request: CompletionRequest):
|
||||
(request.best_of is None or request.n == request.best_of) and
|
||||
not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
@@ -203,12 +209,21 @@ async def create_completion(request: CompletionRequest):
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
# Abort the request if the client disconnects.
|
||||
background_tasks.add_task(abort_request)
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
media_type="text/event-stream",
|
||||
background=background_tasks)
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await server.abort(request_id)
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
@@ -276,7 +291,7 @@ if __name__ == "__main__":
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
@@ -291,10 +306,11 @@ if __name__ == "__main__":
|
||||
|
||||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMServer.from_server_args(server_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
||||
Reference in New Issue
Block a user