Rename servers and change port numbers to reduce confusion (#149)

This commit is contained in:
Zhuohan Li
2023-06-17 00:13:02 +08:00
committed by GitHub
parent 311490a720
commit eedb46bf03
10 changed files with 41 additions and 37 deletions

View File

@@ -8,7 +8,7 @@ import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@@ -18,7 +18,7 @@ app = FastAPI()
@app.post("/generate")
async def generate(request: Request) -> Response:
""" Stream the results of the generation request.
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
@@ -74,12 +74,12 @@ async def generate(request: Request) -> Response:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--port", type=int, default=8000)
parser = AsyncServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
server = AsyncLLMEngine.from_server_args(server_args)
uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View File

@@ -6,7 +6,7 @@ from tqdm import tqdm
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter
@@ -20,7 +20,7 @@ class LLM:
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMServer` class instead.
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
Args:
@@ -52,7 +52,7 @@ class LLM:
seed=seed,
**kwargs,
)
self.llm_server = LLMServer.from_server_args(server_args)
self.llm_server = LLMEngine.from_server_args(server_args)
self.request_counter = Counter()
def get_tokenizer(

View File

@@ -15,7 +15,7 @@ import uvicorn
from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.async_llm_server import AsyncLLMEngine
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
@@ -319,7 +319,7 @@ if __name__ == "__main__":
served_model = args.served_model_name or args.model
server_args = AsyncServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)
server = AsyncLLMEngine.from_server_args(server_args)
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)