Files
vllm/vllm/entrypoints/openai/api_server.py

557 lines
21 KiB
Python
Raw Normal View History

2023-05-23 21:39:50 -07:00
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
import argparse
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
2023-05-23 21:39:50 -07:00
import fastapi
from fastapi import BackgroundTasks, Request
2023-05-23 21:39:50 -07:00
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
2023-06-17 17:25:21 +08:00
from fastapi.responses import JSONResponse, StreamingResponse
2023-05-23 21:39:50 -07:00
import uvicorn
2023-06-17 03:07:40 -07:00
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
2023-06-17 17:25:21 +08:00
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
ModelCard, ModelList, ModelPermission, UsageInfo)
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
2023-06-17 03:07:40 -07:00
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
2023-06-17 03:07:40 -07:00
from vllm.utils import random_uuid
2023-05-23 21:39:50 -07:00
TIMEOUT_KEEP_ALIVE = 5 # seconds
2023-05-23 21:39:50 -07:00
logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
async def get_gen_prompt(request) -> str:
conv = get_conv_template(request.model)
conv = Conversation(
name=conv.name,
system=conv.system,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(request, prompt, engine):
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
context_len = engine.engine.model_config.hf_config.max_sequence_length
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
context_len = engine.engine.model_config.hf_config.max_position_embeddings
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length
else:
context_len = 2048
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if token_num + request.max_tokens > context_len:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {context_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return None
2023-05-23 21:39:50 -07:00
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model,
permission=[ModelPermission()])]
return ModelList(data=model_cards)
def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append(
{tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()})
return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
request = ChatCompletionRequest(**await raw_request.json())
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params,
request_id)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
finish_reason: Optional[str] = None) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
2023-05-23 21:39:50 -07:00
@app.post("/v1/completions")
async def create_completion(raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
2023-06-17 03:07:40 -07:00
- echo (since the vLLM engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
2023-06-17 03:07:40 -07:00
- logit_bias (to be supported by vLLM engine)
"""
request = CompletionRequest(**await raw_request.json())
2023-05-23 21:39:50 -07:00
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.echo:
2023-06-17 03:07:40 -07:00
# We do not support echo since the vLLM engine does not
2023-05-23 21:39:50 -07:00
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None:
2023-06-17 03:07:40 -07:00
# TODO: support logit_bias in vLLM engine.
2023-05-23 21:39:50 -07:00
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
if isinstance(request.prompt, list):
assert len(request.prompt) == 1
prompt = request.prompt[0]
else:
prompt = request.prompt
2023-05-23 21:39:50 -07:00
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
2023-06-17 17:25:21 +08:00
result_generator = engine.generate(prompt, sampling_params,
request_id)
2023-05-23 21:39:50 -07:00
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)
async def abort_request() -> None:
2023-06-17 17:25:21 +08:00
await engine.abort(request_id)
2023-05-23 21:39:50 -07:00
def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
2023-05-23 21:39:50 -07:00
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream",
background=background_tasks)
2023-05-23 21:39:50 -07:00
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
2023-05-23 21:39:50 -07:00
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(
2023-06-17 03:07:40 -07:00
description="vLLM OpenAI-Compatible RESTful API server."
2023-05-23 21:39:50 -07:00
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
2023-06-17 17:25:21 +08:00
parser = AsyncEngineArgs.add_cli_args(parser)
2023-05-23 21:39:50 -07:00
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
logger.info(f"args: {args}")
served_model = args.served_model_name or args.model
2023-06-17 17:25:21 +08:00
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
2023-05-23 21:39:50 -07:00
# A separate tokenizer to map token IDs to strings.
2023-06-28 14:19:22 -07:00
tokenizer = get_tokenizer(engine_args.tokenizer, engine_args.tokenizer_mode)
2023-05-23 21:39:50 -07:00
uvicorn.run(app, host=args.host, port=args.port, log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)