[Frontend][3/N] Improve all pooling task | Support binary embedding response (#27066)

Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
wang.yuqi
2025-10-22 18:38:57 +08:00
committed by GitHub
parent a4c29e6e82
commit 1f633b8632
12 changed files with 691 additions and 230 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import json
import numpy as np
import openai
@@ -15,11 +16,17 @@ from tests.models.language.pooling.embed_utils import run_embedding_correctness_
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.openai.protocol import (
EMBED_DTYPE_TO_TORCH_DTYPE,
EmbeddingResponse,
PoolingResponse,
)
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils.serial_utils import (
EMBED_DTYPE_TO_TORCH_DTYPE,
ENDIANNESS,
MetadataItem,
binary2tensor,
decode_pooling_output,
)
MODEL_NAME = "intfloat/multilingual-e5-small"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
@@ -250,8 +257,8 @@ async def test_batch_base64_embedding(
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype(
hf_model, server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
async def test_base64_embed_dtype_and_endianness(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
@@ -262,44 +269,86 @@ async def test_base64_embed_dtype(
)
float_data = [d.embedding for d in responses_float.data]
for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
},
)
base64_data = []
for data in responses_base64.json()["data"]:
base64_data.append(
torch.frombuffer(base64.b64decode(data["embedding"]), dtype=torch_dtype)
.to(torch.float32)
.tolist()
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for endianness in ENDIANNESS:
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
"endianness": endianness,
},
)
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float_data",
name_1="base64_data",
tol=1e-2,
)
base64_data = []
for data in responses_base64.json()["data"]:
binary = base64.b64decode(data["embedding"])
tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
base64_data.append(tensor.to(torch.float32).tolist())
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=base64_data,
name_0="float_data",
name_1="base64_data",
tol=1e-2,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype_not_supported(
hf_model, server: RemoteOpenAIServer, model_name: str
async def test_bytes_embed_dtype_and_endianness(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
bad_embed_dtype = "bad_embed_dtype"
responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float"
)
float_data = [d.embedding for d in responses_float.data]
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for endianness in ENDIANNESS:
responses_bytes = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "bytes",
"embed_dtype": embed_dtype,
"endianness": endianness,
},
)
metadata = json.loads(responses_bytes.headers["metadata"])
body = responses_bytes.content
items = [MetadataItem(**x) for x in metadata["data"]]
bytes_data = decode_pooling_output(items=items, body=body)
bytes_data = [x.to(torch.float32).tolist() for x in bytes_data]
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=bytes_data,
name_0="float_data",
name_1="bytes_data",
tol=1e-2,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
async def test_params_not_supported(
server: RemoteOpenAIServer, model_name: str, param_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
]
responses_base64 = requests.post(
server.url_for("/v1/embeddings"),
@@ -307,14 +356,13 @@ async def test_base64_embed_dtype_not_supported(
"model": model_name,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": bad_embed_dtype,
param_name: f"bad_{param_name}",
},
)
assert responses_base64.status_code == 400
assert responses_base64.json()["error"]["message"].startswith(
f"embed_dtype={bad_embed_dtype!r} is not supported."
)
assert "literal_error" in responses_base64.json()["error"]["message"]
assert f"bad_{param_name}" in responses_base64.json()["error"]["message"]
@pytest.mark.asyncio