Rename servers to engines (#152)

This commit is contained in:
Zhuohan Li
2023-06-17 17:25:21 +08:00
committed by GitHub
parent bab8f3dd0d
commit e5464ee484
15 changed files with 165 additions and 174 deletions

View File

@@ -10,29 +10,20 @@ import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.engine.tokenizer_utils import get_tokenizer
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds
@@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
- echo (since the cacheflow engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
- logit_bias (to be supported in cacheflow engine)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
@@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
return error_check_ret
if request.echo:
# We do not support echo since the cacheflow server does not
# We do not support echo since the cacheflow engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
@@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server.
# TODO: support logit_bias in cacheflow engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
@@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
result_generator = engine.generate(prompt, sampling_params,
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
@@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
@@ -303,7 +294,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncServerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
@@ -318,8 +309,8 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)