Rename servers to engines (#152)

This commit is contained in:
Zhuohan Li
2023-06-17 17:25:21 +08:00
committed by GitHub
parent bab8f3dd0d
commit e5464ee484
15 changed files with 165 additions and 174 deletions

View File

@@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import Response, StreamingResponse
import uvicorn
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id)
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
@@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await server.abort(request_id)
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
@@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
@@ -75,11 +75,11 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncServerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@@ -1,12 +1,12 @@
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter
@@ -21,7 +21,7 @@ class LLM:
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
@@ -45,20 +45,20 @@ class LLM:
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
server_args = ServerArgs(
engine_args = EngineArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.llm_server = LLMEngine.from_server_args(server_args)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer
return self.llm_engine.tokenizer
def generate(
self,
@@ -99,7 +99,7 @@ class LLM:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the server.
# Add requests to the engine.
if prompts is not None:
num_requests = len(prompts)
else:
@@ -111,7 +111,7 @@ class LLM:
else:
token_ids = prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
return self._run_server(use_tqdm)
return self._run_engine(use_tqdm)
def _add_request(
self,
@@ -120,18 +120,18 @@ class LLM:
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests()
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server.
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished():
outputs.append(output)

View File

@@ -10,29 +10,20 @@ import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
from cacheflow.engine.tokenizer_utils import get_tokenizer
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
TIMEOUT_KEEP_ALIVE = 5 # seconds
@@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the cacheflow server does not currently support
- echo (since the cacheflow engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported in cacheflow server)
- logit_bias (to be supported in cacheflow engine)
"""
request = CompletionRequest(**await raw_request.json())
logger.info(f"Received completion request: {request}")
@@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
return error_check_ret
if request.echo:
# We do not support echo since the cacheflow server does not
# We do not support echo since the cacheflow engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
@@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
"suffix is not currently supported")
if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server.
# TODO: support logit_bias in cacheflow engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
@@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = server.generate(prompt, sampling_params,
result_generator = engine.generate(prompt, sampling_params,
request_id)
# Similar to the OpenAI API, when n != best_of, we do not stream the
@@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
not request.use_beam_search)
async def abort_request() -> None:
await server.abort(request_id)
await engine.abort(request_id)
def create_stream_response_json(index: int,
text: str,
@@ -303,7 +294,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = AsyncServerArgs.add_cli_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
@@ -318,8 +309,8 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMEngine.from_server_args(server_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)