[Frontend][1/N] Improve all pooling task | Support FP16 Embedding Base64 (Still uses fp32 by default). (#26414)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Maximilien de Bayser <maxdebayser@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
@@ -29,6 +30,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.utils import encoding_pooling_output
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
@@ -90,6 +92,12 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.embed_dtype not in EMBED_DTYPE_TO_TORCH_DTYPE:
|
||||
return self.create_error_response(
|
||||
f"embed_dtype={request.embed_dtype!r} is not supported. "
|
||||
f"Supported types: {EMBED_DTYPE_TO_TORCH_DTYPE.keys()}"
|
||||
)
|
||||
|
||||
model_name = self.models.model_name()
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
@@ -235,6 +243,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
created_time,
|
||||
model_name,
|
||||
request.encoding_format,
|
||||
request.embed_dtype,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
@@ -251,6 +260,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
embed_dtype: str,
|
||||
) -> PoolingResponse:
|
||||
items: list[PoolingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
@@ -258,7 +268,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
item = PoolingResponseData(
|
||||
index=idx,
|
||||
data=_get_data(final_res.outputs, encoding_format),
|
||||
data=encoding_pooling_output(final_res, encoding_format, embed_dtype),
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user