[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung
2024-11-01 16:13:35 +08:00
committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
21 changed files with 846 additions and 408 deletions

View File

@@ -11,7 +11,7 @@ from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Set
from typing import AsyncIterator, Optional, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request
@@ -51,7 +51,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -248,22 +248,27 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response:
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request)
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request)
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models()
handler = base(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
@@ -314,9 +325,12 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API")
generator = await chat(raw_request).create_chat_completion(
request, raw_request)
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
@@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await completion(raw_request).create_completion(
request, raw_request)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await embedding(raw_request).create_embedding(
request, raw_request)
handler = embedding(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -382,30 +404,26 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@@ -501,7 +519,8 @@ def init_app_state(
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
tool_parser=args.tool_call_parser,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
@@ -510,13 +529,14 @@ def init_app_state(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
) if model_config.task == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
)
chat_template=args.chat_template,
) if model_config.task == "embedding" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,