Support chat template and echo for chat API (#1756)

This commit is contained in:
Adam Brusselback
2023-11-30 19:43:13 -05:00
committed by GitHub
parent 05a38612b0
commit 66785cc05c
7 changed files with 440 additions and 181 deletions

View File

@@ -3,6 +3,7 @@
import argparse
import asyncio
import codecs
import json
import time
from http import HTTPStatus
@@ -14,7 +15,6 @@ from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response
from packaging import version
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -31,20 +31,55 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid
try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
engine = None
response_role = None
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
def create_error_response(status_code: HTTPStatus,
@@ -54,6 +89,25 @@ def create_error_response(status_code: HTTPStatus,
status_code=status_code.value)
def load_chat_template(args, tokenizer):
if args.chat_template is not None:
try:
with open(args.chat_template, "r") as f:
chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
chat_template = codecs.decode(args.chat_template, "unicode_escape")
tokenizer.chat_template = chat_template
logger.info(
f"Using supplied chat template:\n{tokenizer.chat_template}")
elif tokenizer.chat_template is not None:
logger.info(f"Using default chat template:\n{tokenizer.chat_template}")
else:
logger.warning("No chat template provided. Chat API will not work.")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
@@ -69,53 +123,6 @@ async def check_model(request) -> Optional[JSONResponse]:
return ret
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
raise ModuleNotFoundError(
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
@@ -207,7 +214,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
@@ -217,7 +223,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
try:
prompt = tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
logger.error(f"Error in applying chat template from request: {str(e)}")
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
@@ -225,6 +239,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk"
try:
spaces_between_special_tokens = request.spaces_between_special_tokens
sampling_params = SamplingParams(
@@ -249,128 +264,162 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
usage: Optional[UsageInfo] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
if usage is not None:
response.usage = usage
# exclude unset to leave details out of each sse
response_json = response.json(exclude_unset=True, ensure_ascii=False)
return response_json
def get_role() -> str:
if request.add_generation_prompt:
return response_role
else:
return request.messages[-1]["role"]
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
# Send first response for each request.n (index) with the role
role = get_role()
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
index=i, delta=DeltaMessage(role=role), finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
completion_tokens = len(output.token_ids)
previous_num_tokens[i] = completion_tokens
response_json = create_stream_response_json(
index=i,
text=delta_text,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
if finish_reason_sent[i]:
continue
if output.finish_reason is None:
# Send token-by-token response for each request.n
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
completion_tokens = len(output.token_ids)
previous_num_tokens[i] = completion_tokens
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
usage=final_usage,
)
yield f"data: {response_json}\n\n"
choice_data = ChatCompletionResponseStreamChoice(
index=i, delta=[], finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.json(exclude_unset=True,
exclude_none=True,
ensure_ascii=False)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def completion_full_generator():
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
role = get_role()
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
else:
return await completion_full_generator()
@app.post("/v1/completions")
@@ -642,34 +691,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
args = parse_args()
app.add_middleware(
CORSMiddleware,
@@ -686,6 +708,8 @@ if __name__ == "__main__":
else:
served_model = args.model
response_role = args.response_role
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
@@ -696,6 +720,7 @@ if __name__ == "__main__":
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
load_chat_template(args, tokenizer)
uvicorn.run(app,
host=args.host,

View File

@@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel):
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
add_generation_prompt: Optional[bool] = True
echo: Optional[bool] = False
class CompletionRequest(BaseModel):