Non-streaming simple fastapi server (#144)

This commit is contained in:
Zhuohan Li
2023-06-11 01:43:07 +08:00
committed by GitHub
parent 4298374265
commit 5020e1e80c
3 changed files with 61 additions and 20 deletions

View File

@@ -233,7 +233,7 @@ async def create_completion(raw_request: Request):
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res

View File

@@ -3,7 +3,7 @@ import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse
from fastapi.responses import Response, StreamingResponse
import uvicorn
from cacheflow.sampling_params import SamplingParams
@@ -17,19 +17,22 @@ app = FastAPI()
@app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse:
async def generate(request: Request) -> Response:
""" Stream the results of the generation request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
@@ -37,19 +40,35 @@ async def generate_stream(request: Request) -> StreamingResponse:
prompt + output.text
for output in request_output.outputs
]
ret = {
"text": text_outputs,
"error": 0,
}
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await server.abort(request_id)
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
if __name__ == "__main__":