[Feature] vLLM CLI (#5090)

Co-authored-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Ethan Xu
2024-07-14 15:36:43 -07:00
committed by GitHub
parent 73030b7dae
commit dbfe254eda
7 changed files with 223 additions and 36 deletions

View File

@@ -8,7 +8,7 @@ from typing import Optional, Set
import fastapi
import uvicorn
from fastapi import Request
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -35,10 +35,14 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
engine: AsyncLLMEngine
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
@@ -64,35 +68,23 @@ async def lifespan(app: fastapi.FastAPI):
yield
app = fastapi.FastAPI(lifespan=lifespan)
def parse_args():
parser = make_arg_parser()
return parser.parse_args()
router = APIRouter()
# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
router.routes.append(route)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@app.get("/health")
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
return Response(status_code=200)
@app.post("/tokenize")
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
@@ -103,7 +95,7 @@ async def tokenize(request: TokenizeRequest):
return JSONResponse(content=generator.model_dump())
@app.post("/detokenize")
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
@@ -114,19 +106,19 @@ async def detokenize(request: DetokenizeRequest):
return JSONResponse(content=generator.model_dump())
@app.get("/v1/models")
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())
@app.get("/version")
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
@@ -142,7 +134,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
@@ -156,7 +148,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
@app.post("/v1/embeddings")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
request, raw_request)
@@ -167,8 +159,10 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
app.add_middleware(
CORSMiddleware,
@@ -178,6 +172,12 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
@@ -203,6 +203,12 @@ if __name__ == "__main__":
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
return app
def run_server(args, llm_engine=None):
app = build_app(args)
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
@@ -211,10 +217,12 @@ if __name__ == "__main__":
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
global engine, engine_args
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
event_loop: Optional[asyncio.AbstractEventLoop]
try:
@@ -230,6 +238,10 @@ if __name__ == "__main__":
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
@@ -258,3 +270,13 @@ if __name__ == "__main__":
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
run_server(args)