[Frontend][1/N] Improve all pooling task | Support FP16 Embedding Base64 (Still uses fp32 by default). (#26414)

Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Maximilien de Bayser <maxdebayser@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
wang.yuqi
2025-10-14 03:06:43 +08:00
committed by GitHub
parent 89342ce4c0
commit d2a7938582
8 changed files with 312 additions and 30 deletions

View File

@@ -17,6 +17,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
@@ -29,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.utils import encoding_pooling_output
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
@@ -90,6 +92,12 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
if request.embed_dtype not in EMBED_DTYPE_TO_TORCH_DTYPE:
return self.create_error_response(
f"embed_dtype={request.embed_dtype!r} is not supported. "
f"Supported types: {EMBED_DTYPE_TO_TORCH_DTYPE.keys()}"
)
model_name = self.models.model_name()
request_id = f"pool-{self._base_request_id(raw_request)}"
@@ -235,6 +243,7 @@ class OpenAIServingPooling(OpenAIServing):
created_time,
model_name,
request.encoding_format,
request.embed_dtype,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
@@ -251,6 +260,7 @@ class OpenAIServingPooling(OpenAIServing):
created_time: int,
model_name: str,
encoding_format: Literal["float", "base64"],
embed_dtype: str,
) -> PoolingResponse:
items: list[PoolingResponseData] = []
num_prompt_tokens = 0
@@ -258,7 +268,7 @@ class OpenAIServingPooling(OpenAIServing):
for idx, final_res in enumerate(final_res_batch):
item = PoolingResponseData(
index=idx,
data=_get_data(final_res.outputs, encoding_format),
data=encoding_pooling_output(final_res, encoding_format, embed_dtype),
)
prompt_token_ids = final_res.prompt_token_ids