Rename servers to engines (#152)
This commit is contained in:
@@ -6,9 +6,9 @@ from fastapi import BackgroundTasks, FastAPI, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
@@ -30,7 +30,7 @@ async def generate(request: Request) -> Response:
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
results_generator = server.generate(prompt, sampling_params, request_id)
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
@@ -44,7 +44,7 @@ async def generate(request: Request) -> Response:
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
|
||||
if stream:
|
||||
background_tasks = BackgroundTasks()
|
||||
@@ -57,7 +57,7 @@ async def generate(request: Request) -> Response:
|
||||
async for request_output in results_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
|
||||
@@ -75,11 +75,11 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class LLM:
|
||||
|
||||
NOTE: This class is intended to be used for offline inference. For online
|
||||
serving, use the `AsyncLLMEngine` class instead.
|
||||
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
|
||||
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
@@ -45,20 +45,20 @@ class LLM:
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
server_args = ServerArgs(
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_server = LLMEngine.from_server_args(server_args)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_server.tokenizer
|
||||
return self.llm_engine.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -99,7 +99,7 @@ class LLM:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
# Add requests to the server.
|
||||
# Add requests to the engine.
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
else:
|
||||
@@ -111,7 +111,7 @@ class LLM:
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_server(use_tqdm)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
@@ -120,18 +120,18 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
self.llm_engine.add_request(request_id, prompt, sampling_params,
|
||||
prompt_token_ids)
|
||||
|
||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the server.
|
||||
# Run the engine.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
step_outputs = self.llm_server.step()
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished():
|
||||
outputs.append(output)
|
||||
|
||||
@@ -10,29 +10,20 @@ import fastapi
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.server.arg_utils import AsyncServerArgs
|
||||
from cacheflow.server.async_llm_server import AsyncLLMEngine
|
||||
from cacheflow.server.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.engine.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import random_uuid
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
ErrorResponse,
|
||||
LogProbs,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
@@ -102,11 +93,11 @@ async def create_completion(raw_request: Request):
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- echo (since the cacheflow server does not currently support
|
||||
- echo (since the cacheflow engine does not currently support
|
||||
getting the logprobs of prompt tokens)
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported in cacheflow server)
|
||||
- logit_bias (to be supported in cacheflow engine)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
@@ -116,7 +107,7 @@ async def create_completion(raw_request: Request):
|
||||
return error_check_ret
|
||||
|
||||
if request.echo:
|
||||
# We do not support echo since the cacheflow server does not
|
||||
# We do not support echo since the cacheflow engine does not
|
||||
# currently support getting the logprobs of prompt tokens.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"echo is not currently supported")
|
||||
@@ -127,7 +118,7 @@ async def create_completion(raw_request: Request):
|
||||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in cacheflow server.
|
||||
# TODO: support logit_bias in cacheflow engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
@@ -153,7 +144,7 @@ async def create_completion(raw_request: Request):
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = server.generate(prompt, sampling_params,
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
@@ -163,7 +154,7 @@ async def create_completion(raw_request: Request):
|
||||
not request.use_beam_search)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await server.abort(request_id)
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
@@ -303,7 +294,7 @@ if __name__ == "__main__":
|
||||
help="The model name used in the API. If not specified, "
|
||||
"the model name will be the same as the "
|
||||
"huggingface name.")
|
||||
parser = AsyncServerArgs.add_cli_args(parser)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
@@ -318,8 +309,8 @@ if __name__ == "__main__":
|
||||
|
||||
served_model = args.served_model_name or args.model
|
||||
|
||||
server_args = AsyncServerArgs.from_cli_args(args)
|
||||
server = AsyncLLMEngine.from_server_args(server_args)
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
|
||||
Reference in New Issue
Block a user