OpenAI Compatible Frontend (#116)
This commit is contained in:
300
cacheflow/entrypoints/openai/openai_frontend.py
Normal file
300
cacheflow/entrypoints/openai/openai_frontend.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
|
||||
import argparse
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMServer
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import random_uuid
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
served_model = None
|
||||
app = fastapi.FastAPI()
|
||||
|
||||
|
||||
def create_error_response(status_code: HTTPStatus,
|
||||
message: str) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
ErrorResponse(message=message, type="invalid_request_error").dict(),
|
||||
status_code=status_code.value
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request, exc):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
|
||||
|
||||
|
||||
async def check_model(request) -> Optional[JSONResponse]:
|
||||
if request.model == served_model:
|
||||
return
|
||||
ret = create_error_response(
|
||||
HTTPStatus.NOT_FOUND,
|
||||
f"The model `{request.model}` does not exist.",
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [ModelCard(id=served_model, root=served_model,
|
||||
permission=[ModelPermission()])]
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
|
||||
def create_logprobs(token_ids: List[int],
|
||||
id_logprobs: List[Dict[int, float]],
|
||||
initial_text_offset: int = 0) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
for token_id, id_logprob in zip(token_ids, id_logprobs):
|
||||
token = tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(id_logprob[token_id])
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
logprobs.top_logprobs.append(
|
||||
{tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in id_logprob.items()})
|
||||
return logprobs
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
||||
error_check_ret = await check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.echo:
|
||||
# We do not support echo since the cacheflow server does not
|
||||
# currently support getting the logprobs of prompt tokens.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"echo is not currently supported")
|
||||
|
||||
if request.suffix is not None:
|
||||
# The language models we currently support do not support suffix.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in cacheflow server.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
prompt = request.prompt
|
||||
created_time = int(time.time())
|
||||
try:
|
||||
sampling_params = SamplingParams(
|
||||
n=request.n,
|
||||
best_of=request.best_of,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
ignore_eos=request.ignore_eos,
|
||||
max_tokens=request.max_tokens,
|
||||
logprobs=request.logprobs,
|
||||
use_beam_search=request.use_beam_search,
|
||||
)
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = server.generate(prompt, sampling_params,
|
||||
request_id=request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
stream = (request.stream and
|
||||
(request.best_of is None or request.n == request.best_of) and
|
||||
not request.use_beam_search)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
logprobs: Optional[LogProbs] = None,
|
||||
finish_reason: Optional[str] = None) -> str:
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
response = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[choice_data],
|
||||
)
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
return response_json
|
||||
|
||||
async def completion_stream_generator() -> AsyncGenerator[str, None]:
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
output.token_ids[previous_num_tokens[i]:],
|
||||
output.logprobs[previous_num_tokens[i]:],
|
||||
len(previous_texts[i]))
|
||||
else:
|
||||
logprobs = None
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
if output.finish_reason is not None:
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
response_json = create_stream_response_json(
|
||||
index=i,
|
||||
text="",
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return StreamingResponse(completion_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
# Non-streaming response
|
||||
final_res: RequestOutput = None
|
||||
async for res in result_generator:
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
choices = []
|
||||
for output in final_res.outputs:
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(output.token_ids, output.logprobs)
|
||||
else:
|
||||
logprobs = None
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=output.index,
|
||||
text=output.text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids)
|
||||
for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="CacheFlow OpenAI-Compatible RESTful API server."
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument(
|
||||
"--allow-credentials", action="store_true", help="allow credentials"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||
)
|
||||
parser.add_argument("--served-model-name", type=str, default=None,
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = ServerArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = ServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMServer.from_server_args(server_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||
126
cacheflow/entrypoints/openai/protocol.py
Normal file
126
cacheflow/entrypoints/openai/protocol.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from cacheflow.utils import random_uuid
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
|
||||
|
||||
class ModelPermission(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
allow_create_engine: bool = False
|
||||
allow_sampling: bool = True
|
||||
allow_logprobs: bool = True
|
||||
allow_search_indices: bool = False
|
||||
allow_view: bool = True
|
||||
allow_fine_tuning: bool = False
|
||||
organization: str = "*"
|
||||
group: Optional[str] = None
|
||||
is_blocking: str = False
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "cacheflow"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: List[ModelPermission] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: Optional[bool] = False
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
best_of: Optional[int] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
# Additional parameters supported by cacheflow
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
Reference in New Issue
Block a user