[Frontend] Online Pooling API (#11457)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-24 17:54:30 +08:00
committed by GitHub
parent 4f074fbf53
commit 9edca6bf8f
15 changed files with 808 additions and 156 deletions

View File

@@ -45,8 +45,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
@@ -56,6 +59,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
@@ -284,6 +288,10 @@ def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion
def pooling(request: Request) -> Optional[OpenAIServingPooling]:
return request.app.state.openai_serving_pooling
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
@@ -395,10 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
fallback_handler = pooling(raw_request)
if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -408,6 +442,24 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/pooling")
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
@@ -605,7 +657,7 @@ def init_app_state(
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
@@ -613,13 +665,20 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if (model_config.runner_type == "pooling" \
and model_config.is_cross_encoder) else None
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,

View File

@@ -963,6 +963,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
class ScoreRequest(OpenAIBaseModel):
model: str
@@ -1058,6 +1062,21 @@ class EmbeddingResponse(OpenAIBaseModel):
usage: UsageInfo
class PoolingResponseData(OpenAIBaseModel):
index: int
object: str = "pooling"
data: Union[List[List[float]], List[float], str]
class PoolingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"pool-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[PoolingResponseData]
usage: UsageInfo
class ScoreResponseData(OpenAIBaseModel):
index: int
object: str = "score"

View File

@@ -232,7 +232,7 @@ async def main(args):
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.runner_type == "pooling" else None
) if model_config.task == "embed" else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)

View File

@@ -40,36 +40,6 @@ def _get_embedding(
assert_never(encoding_format)
def request_output_to_embedding_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
prompt_token_ids = final_res.prompt_token_ids
embedding = _get_embedding(embedding_res.outputs, encoding_format)
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
class OpenAIServingEmbedding(OpenAIServing):
def __init__(
@@ -114,7 +84,7 @@ class OpenAIServingEmbedding(OpenAIServing):
model_name = request.model
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
created_time = int(time.time())
truncate_prompt_tokens = None
@@ -218,9 +188,13 @@ class OpenAIServingEmbedding(OpenAIServing):
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
response = self.request_output_to_embedding_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
@@ -228,3 +202,40 @@ class OpenAIServingEmbedding(OpenAIServing):
return self.create_error_response(str(e))
return response
def request_output_to_embedding_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> EmbeddingResponse:
items: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
embedding_res = EmbeddingRequestOutput.from_base(final_res)
item = EmbeddingResponseData(
index=idx,
embedding=_get_embedding(embedding_res.outputs,
encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)

View File

@@ -0,0 +1,234 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingChatRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
def _get_data(
output: PoolingOutput,
encoding_format: Literal["float", "base64"],
) -> Union[List[float], str]:
if encoding_format == "float":
return output.data.tolist()
elif encoding_format == "base64":
# Force to use float32 for base64 encoding
# to match the OpenAI python client behavior
pooling_bytes = np.array(output.data, dtype="float32").tobytes()
return base64.b64encode(pooling_bytes).decode("utf-8")
assert_never(encoding_format)
class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = request.model
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = None
if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for pooling models")
if isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_pooling_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_pooling_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
) -> PoolingResponse:
items: List[PoolingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=_get_data(final_res.outputs, encoding_format),
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return PoolingResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)

View File

@@ -20,32 +20,6 @@ from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
def request_output_to_score_response(
final_res_batch: List[PoolingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
score_data = ScoreResponseData(index=idx,
score=classify_res.outputs.score)
data.append(score_data)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
str]) -> List:
if isinstance(text_1, (str, dict)):
@@ -103,7 +77,7 @@ class OpenAIServingScores(OpenAIServing):
model_name = request.model
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
request_prompts = []
@@ -203,8 +177,12 @@ class OpenAIServingScores(OpenAIServing):
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = request_output_to_score_response(
final_res_batch_checked, request_id, created_time, model_name)
response = self.request_output_to_score_response(
final_res_batch_checked,
request_id,
created_time,
model_name,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
@@ -212,3 +190,38 @@ class OpenAIServingScores(OpenAIServing):
return self.create_error_response(str(e))
return response
def request_output_to_score_response(
self,
final_res_batch: List[PoolingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> ScoreResponse:
items: List[ScoreResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
item = ScoreResponseData(
index=idx,
score=classify_res.outputs.score,
)
prompt_token_ids = final_res.prompt_token_ids
items.append(item)
num_prompt_tokens += len(prompt_token_ids)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=items,
usage=usage,
)

View File

@@ -355,7 +355,8 @@ class PoolingRequestOutput(Generic[_O]):
pooled_data = seq_group.pooled_data
assert pooled_data is not None
output = PoolingOutput(pooled_data)
data = pooled_data.to(dtype=torch.float32, device="cpu")
output = PoolingOutput(data)
prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished()