[Frontend] Online Pooling API (#11457)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-24 17:54:30 +08:00
committed by GitHub
parent 4f074fbf53
commit 9edca6bf8f
15 changed files with 808 additions and 156 deletions

View File

@@ -45,8 +45,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
@@ -56,6 +59,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
@@ -284,6 +288,10 @@ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion
def pooling(request: Request) -> Optional[OpenAIServingPooling]:
return request.app.state.openai_serving_pooling
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
@@ -395,10 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
fallback_handler = pooling(raw_request)
if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -408,6 +442,24 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/pooling")
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
@@ -605,7 +657,7 @@ def init_app_state(
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
@@ -613,13 +665,20 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if (model_config.runner_type == "pooling" \
and model_config.is_cross_encoder) else None
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,