[mypy] Enable following imports for entrypoints (#7248)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
Cyrus Leung
2024-08-21 14:28:21 +08:00
committed by GitHub
parent 4506641212
commit baaedfdb2d
26 changed files with 480 additions and 320 deletions

View File

@@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
@@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
@@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
logger.info("Started engine process with PID %d",
rpc_server_process.pid)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(rpc_path)
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
async_engine_client = rpc_client # type: ignore
try:
while True:
try:
await async_engine_client.setup()
await rpc_client.setup()
break
except TimeoutError as e:
if not rpc_server_process.is_alive():
@@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
rpc_client.close()
# Wait for server process to join
rpc_server_process.join()
@@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
@@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models():
@@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
@@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
@@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)