Rename servers to engines (#152)
This commit is contained in:
@@ -10,29 +10,20 @@ import fastapi
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.engine.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import random_uuid
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
@@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- echo (since the cacheflow server does not currently support
|
||||
- echo (since the cacheflow engine does not currently support
|
||||
getting the logprobs of prompt tokens)
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported in cacheflow server)
|
||||
- logit_bias (to be supported in cacheflow engine)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
@@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
|
||||
return error_check_ret
|
||||
|
||||
if request.echo:
|
||||
# We do not support echo since the cacheflow server does not
|
||||
# We do not support echo since the cacheflow engine does not
|
||||
# currently support getting the logprobs of prompt tokens.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"echo is not currently supported")
|
||||
@@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in cacheflow server.
|
||||
# TODO: support logit_bias in cacheflow engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
@@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = server.generate(prompt, sampling_params,
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
@@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
|
||||
not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
@@ -303,7 +294,7 @@ if __name__ == "__main__":
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
@@ -318,8 +309,8 @@ if __name__ == "__main__":
|
||||
|
||||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
|
||||
Reference in New Issue
Block a user