[Feature][Frontend] add support for Cohere Embed v2 API (#37074)
Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
This commit is contained in:
committed by
khluu
parent
1fe3932c8b
commit
4d22667c32
310
tests/entrypoints/pooling/embed/test_cohere_online.py
Normal file
310
tests/entrypoints/pooling/embed/test_cohere_online.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Cohere /v2/embed API with generic (non-Cohere) models.
|
||||
|
||||
Validates that the Cohere v2 embed endpoint works correctly with standard
|
||||
embedding models, covering text embedding, embedding type conversions,
|
||||
response structure, batching, normalisation, and semantic similarity.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
MODELS: list[tuple[str, list[str]]] = [
|
||||
("intfloat/multilingual-e5-small", []),
|
||||
(
|
||||
"Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
[
|
||||
"--trust_remote_code",
|
||||
"--hf_overrides",
|
||||
'{"matryoshka_dimensions":[256]}',
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=MODELS, ids=lambda m: m[0])
|
||||
def model_config(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_name(model_config):
|
||||
return model_config[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_config):
|
||||
name, extra_args = model_config
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--gpu-memory-utilization",
|
||||
"0.02",
|
||||
] + extra_args
|
||||
with RemoteOpenAIServer(name, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
texts: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
input_type: str | None = None,
|
||||
embedding_types: list[str] | None = None,
|
||||
) -> dict:
|
||||
body: dict = {"model": model_name}
|
||||
if input_type is not None:
|
||||
body["input_type"] = input_type
|
||||
if texts is not None:
|
||||
body["texts"] = texts
|
||||
if images is not None:
|
||||
body["images"] = images
|
||||
if embedding_types is not None:
|
||||
body["embedding_types"] = embedding_types
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _openai_embed(
|
||||
server: RemoteOpenAIServer, model_name: str, texts: list[str]
|
||||
) -> dict:
|
||||
body = {"model": model_name, "input": texts, "encoding_format": "float"}
|
||||
resp = requests.post(server.url_for("/v1/embeddings"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
va, vb = np.array(a), np.array(b)
|
||||
return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Text embedding tests
|
||||
# -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_basic_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server, model_name, texts=["hello world"], embedding_types=["float"]
|
||||
)
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
|
||||
|
||||
def test_unsupported_input_type_rejected(server: RemoteOpenAIServer, model_name: str):
|
||||
"""An input_type not defined in the model's prompt config should be
|
||||
rejected with a 400 error."""
|
||||
body = {
|
||||
"model": model_name,
|
||||
"input_type": "nonexistent_type",
|
||||
"texts": ["hello world"],
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 400
|
||||
assert "Unsupported input_type" in resp.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_omitted_input_type_accepted(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Omitting input_type should always work (no prompt prefix applied)."""
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": ["hello world"],
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_v1_v2_parity(server: RemoteOpenAIServer, model_name: str):
|
||||
"""v1 (OpenAI) and v2 (Cohere) endpoints should produce the same
|
||||
float embeddings for a generic model."""
|
||||
texts = ["hello world"]
|
||||
v2 = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"])
|
||||
v1 = _openai_embed(server, model_name, texts)
|
||||
cos = _cosine_sim(v2["embeddings"]["float"][0], v1["data"][0]["embedding"])
|
||||
assert cos > 0.9999, f"v1/v2 parity failed, cosine={cos}"
|
||||
|
||||
|
||||
def test_embedding_types(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["test"],
|
||||
embedding_types=["float", "binary", "ubinary"],
|
||||
)
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
assert len(r["embeddings"]["binary"][0]) == dim // 8
|
||||
assert len(r["embeddings"]["ubinary"][0]) == dim // 8
|
||||
|
||||
|
||||
def test_response_structure(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(server, model_name, texts=["test"], embedding_types=["float"])
|
||||
assert "id" in r
|
||||
assert "embeddings" in r
|
||||
assert "texts" in r
|
||||
assert r["texts"] == ["test"]
|
||||
assert "meta" in r
|
||||
assert r["meta"]["api_version"]["version"] == "2"
|
||||
assert "billed_units" in r["meta"]
|
||||
assert r["meta"]["billed_units"]["input_tokens"] > 0
|
||||
assert r["meta"]["billed_units"]["image_tokens"] == 0
|
||||
|
||||
|
||||
def test_batch(server: RemoteOpenAIServer, model_name: str):
|
||||
texts = ["apple", "banana", "cherry"]
|
||||
r = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"])
|
||||
assert len(r["embeddings"]["float"]) == 3
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
for emb in r["embeddings"]["float"]:
|
||||
assert len(emb) == dim
|
||||
|
||||
|
||||
def test_l2_normalized(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server, model_name, texts=["hello world"], embedding_types=["float"]
|
||||
)
|
||||
emb = np.array(r["embeddings"]["float"][0])
|
||||
assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01
|
||||
|
||||
|
||||
def test_semantic_similarity(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["machine learning", "deep learning", "chocolate cake recipe"],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
embs = r["embeddings"]["float"]
|
||||
cos_related = _cosine_sim(embs[0], embs[1])
|
||||
cos_unrelated = _cosine_sim(embs[0], embs[2])
|
||||
assert cos_related > cos_unrelated
|
||||
|
||||
|
||||
def test_missing_input_returns_error(server: RemoteOpenAIServer, model_name: str):
|
||||
body = {"model": model_name}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_base64_embedding_type(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["test encoding"],
|
||||
embedding_types=["float", "base64"],
|
||||
)
|
||||
float_emb = r["embeddings"]["float"][0]
|
||||
b64_str = r["embeddings"]["base64"][0]
|
||||
decoded = struct.unpack(f"<{len(float_emb)}f", base64.b64decode(b64_str))
|
||||
np.testing.assert_allclose(float_emb, decoded, rtol=1e-5)
|
||||
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Truncation tests
|
||||
# -----------------------------------------------------------
|
||||
|
||||
|
||||
def _cohere_embed_raw(
|
||||
server: RemoteOpenAIServer,
|
||||
body: dict,
|
||||
) -> requests.Response:
|
||||
return requests.post(server.url_for("/v2/embed"), json=body)
|
||||
|
||||
|
||||
def test_truncate_end_succeeds(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=END should silently truncate long input."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "END",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_truncate_start_succeeds(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=START should silently truncate long input from the start."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "START",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_truncate_none_rejects_long_input(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=NONE should error when input exceeds model context."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "NONE",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_truncate_start_vs_end_differ(server: RemoteOpenAIServer, model_name: str):
|
||||
"""START and END truncation should produce different embeddings
|
||||
when the input is long enough to actually be truncated.
|
||||
|
||||
We construct input with distinct tokens at the start vs end
|
||||
so that keeping different halves produces different embeddings.
|
||||
"""
|
||||
start_words = " ".join([f"alpha{i}" for i in range(300)])
|
||||
end_words = " ".join([f"omega{i}" for i in range(300)])
|
||||
long_text = start_words + " " + end_words
|
||||
|
||||
body_end = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "END",
|
||||
}
|
||||
body_start = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "START",
|
||||
}
|
||||
r_end = _cohere_embed_raw(server, body_end).json()
|
||||
r_start = _cohere_embed_raw(server, body_start).json()
|
||||
|
||||
emb_end = r_end["embeddings"]["float"][0]
|
||||
emb_start = r_start["embeddings"]["float"][0]
|
||||
cos = _cosine_sim(emb_end, emb_start)
|
||||
assert cos < 0.99, (
|
||||
f"START and END truncation should produce different embeddings "
|
||||
f"for long input, but cosine similarity was {cos}"
|
||||
)
|
||||
135
tests/entrypoints/pooling/embed/test_cohere_online_vision.py
Normal file
135
tests/entrypoints/pooling/embed/test_cohere_online_vision.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Cohere /v2/embed API with a multimodal model (SigLIP).
|
||||
|
||||
Validates image embedding, batching, normalisation, and embedding type
|
||||
conversions through the /v2/embed endpoint.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "google/siglip-so400m-patch14-384"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"64",
|
||||
"--gpu-memory-utilization",
|
||||
"0.3",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _make_tiny_png(r: int, g: int, b: int, w: int = 2, h: int = 2) -> str:
|
||||
raw = b""
|
||||
for _ in range(h):
|
||||
raw += b"\x00" + bytes([r, g, b]) * w
|
||||
compressed = zlib.compress(raw)
|
||||
|
||||
def chunk(ctype: bytes, cdata: bytes) -> bytes:
|
||||
c = ctype + cdata
|
||||
return (
|
||||
struct.pack(">I", len(cdata))
|
||||
+ c
|
||||
+ struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF)
|
||||
)
|
||||
|
||||
ihdr = struct.pack(">IIBBBBB", w, h, 8, 2, 0, 0, 0)
|
||||
png = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
+ chunk(b"IHDR", ihdr)
|
||||
+ chunk(b"IDAT", compressed)
|
||||
+ chunk(b"IEND", b"")
|
||||
)
|
||||
return "data:image/png;base64," + base64.b64encode(png).decode()
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
embedding_types: list[str] | None = None,
|
||||
) -> dict:
|
||||
body: dict = {"model": MODEL_NAME}
|
||||
if texts is not None:
|
||||
body["texts"] = texts
|
||||
if images is not None:
|
||||
body["images"] = images
|
||||
if embedding_types is not None:
|
||||
body["embedding_types"] = embedding_types
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def test_image_embed(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(255, 0, 0)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
assert r["meta"]["billed_units"]["image_tokens"] > 0
|
||||
assert r["meta"]["billed_units"]["input_tokens"] == 0
|
||||
|
||||
|
||||
def test_image_batch(server: RemoteOpenAIServer):
|
||||
red = _make_tiny_png(255, 0, 0)
|
||||
blue = _make_tiny_png(0, 0, 255)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[red, blue],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
assert len(r["embeddings"]["float"]) == 2
|
||||
|
||||
|
||||
def test_image_l2_normalized(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(0, 255, 0)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
emb = np.array(r["embeddings"]["float"][0])
|
||||
assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01
|
||||
|
||||
|
||||
def test_image_embedding_types(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(128, 128, 128)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float", "binary", "ubinary"],
|
||||
)
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
assert len(r["embeddings"]["binary"][0]) == dim // 8
|
||||
assert len(r["embeddings"]["ubinary"][0]) == dim // 8
|
||||
|
||||
|
||||
def test_text_embed_on_multimodal(server: RemoteOpenAIServer):
|
||||
"""SigLIP also supports text-only embedding via /v2/embed."""
|
||||
r = _cohere_embed(server, texts=["hello world"], embedding_types=["float"])
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
102
tests/entrypoints/pooling/embed/test_cohere_openai_parity.py
Normal file
102
tests/entrypoints/pooling/embed/test_cohere_openai_parity.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Parity test between Cohere /v2/embed and OpenAI /v1/embeddings.
|
||||
|
||||
Verifies that both endpoints produce identical float embeddings when
|
||||
no prompt prefix is applied (input_type omitted for Cohere /v2/embed).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--gpu-memory-utilization",
|
||||
"0.02",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
body = {
|
||||
"model": MODEL_NAME,
|
||||
"texts": texts,
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["embeddings"]["float"]
|
||||
|
||||
|
||||
def _openai_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
body = {"model": MODEL_NAME, "input": texts, "encoding_format": "float"}
|
||||
resp = requests.post(server.url_for("/v1/embeddings"), json=body)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
|
||||
def test_single_text_parity(server: RemoteOpenAIServer):
|
||||
"""A single text should produce identical embeddings via both APIs."""
|
||||
texts = ["the quick brown fox jumps over the lazy dog"]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5)
|
||||
|
||||
|
||||
def test_batch_parity(server: RemoteOpenAIServer):
|
||||
"""A batch of texts should produce identical embeddings via both APIs,
|
||||
in the same order."""
|
||||
texts = [
|
||||
"machine learning",
|
||||
"deep learning",
|
||||
"natural language processing",
|
||||
]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
assert len(v2) == len(v1) == 3
|
||||
for i in range(3):
|
||||
np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}")
|
||||
|
||||
|
||||
def test_token_count_parity(server: RemoteOpenAIServer):
|
||||
"""Both APIs should report the same prompt token count."""
|
||||
texts = ["hello world"]
|
||||
v2_resp = requests.post(
|
||||
server.url_for("/v2/embed"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"texts": texts,
|
||||
"embedding_types": ["float"],
|
||||
},
|
||||
)
|
||||
v1_resp = requests.post(
|
||||
server.url_for("/v1/embeddings"),
|
||||
json={"model": MODEL_NAME, "input": texts, "encoding_format": "float"},
|
||||
)
|
||||
v2_resp.raise_for_status()
|
||||
v1_resp.raise_for_status()
|
||||
v2_tokens = v2_resp.json()["meta"]["billed_units"]["input_tokens"]
|
||||
v1_tokens = v1_resp.json()["usage"]["prompt_tokens"]
|
||||
assert v2_tokens == v1_tokens
|
||||
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for EmbedIOProcessor."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveTruncation:
|
||||
"""Unit tests for EmbedIOProcessor._resolve_cohere_truncation."""
|
||||
|
||||
@staticmethod
|
||||
def _make_request(**kwargs) -> CohereEmbedRequest:
|
||||
defaults = {
|
||||
"model": "test",
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"],
|
||||
}
|
||||
return CohereEmbedRequest(**(defaults | kwargs))
|
||||
|
||||
def test_truncate_end_default(self):
|
||||
req = self._make_request()
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_explicit(self):
|
||||
req = self._make_request(truncate="END")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_with_max_tokens(self):
|
||||
req = self._make_request(truncate="END", max_tokens=128)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 128
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none(self):
|
||||
req = self._make_request(truncate="NONE")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none_with_max_tokens(self):
|
||||
"""truncate=NONE should NOT set truncate_prompt_tokens; the
|
||||
max_tokens limit is enforced separately via _check_max_tokens."""
|
||||
req = self._make_request(truncate="NONE", max_tokens=10)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_start(self):
|
||||
req = self._make_request(truncate="START")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side == "left"
|
||||
|
||||
def test_truncate_start_with_max_tokens(self):
|
||||
req = self._make_request(truncate="START", max_tokens=64)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 64
|
||||
assert side == "left"
|
||||
|
||||
|
||||
class TestApplyStPrompt:
|
||||
"""Unit tests for EmbedIOProcessor._apply_task_instruction."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_no_prompts_configured(self):
|
||||
handler = self._make_handler(None)
|
||||
texts = ["hello", "world"]
|
||||
assert handler._apply_task_instruction(texts, "query") is texts
|
||||
|
||||
def test_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
result = handler._apply_task_instruction(["hello"], "query")
|
||||
assert result == ["search_query: hello"]
|
||||
|
||||
def test_non_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "document") is texts
|
||||
|
||||
def test_multiple_texts(self):
|
||||
handler = self._make_handler(
|
||||
{"query": "Represent this sentence for searching: "}
|
||||
)
|
||||
result = handler._apply_task_instruction(["a", "b", "c"], "query")
|
||||
assert result == [
|
||||
"Represent this sentence for searching: a",
|
||||
"Represent this sentence for searching: b",
|
||||
"Represent this sentence for searching: c",
|
||||
]
|
||||
|
||||
def test_empty_prefix_returns_unchanged(self):
|
||||
handler = self._make_handler({"passage": ""})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "passage") is texts
|
||||
|
||||
|
||||
class TestLoadTaskInstructions:
|
||||
"""Unit tests for EmbedIOProcessor._load_task_instructions."""
|
||||
|
||||
def test_no_attribute(self):
|
||||
class FakeConfig:
|
||||
pass
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_with_task_instructions(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
result = EmbedIOProcessor._load_task_instructions(FakeConfig())
|
||||
assert result == {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
def test_empty_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {}
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_non_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = "not a dict"
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
|
||||
class TestCheckMaxTokens:
|
||||
"""Unit tests for EmbedIOProcessor._check_cohere_max_tokens."""
|
||||
|
||||
@staticmethod
|
||||
def _fake_output(n_tokens: int):
|
||||
class _Out:
|
||||
def __init__(self, n: int):
|
||||
self.prompt_token_ids = list(range(n))
|
||||
|
||||
return _Out(n_tokens)
|
||||
|
||||
def test_none_check_is_noop(self):
|
||||
outs = [self._fake_output(100)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, None)
|
||||
|
||||
def test_within_limit(self):
|
||||
outs = [self._fake_output(5), self._fake_output(3)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exceeds_limit(self):
|
||||
outs = [self._fake_output(3), self._fake_output(10)]
|
||||
with pytest.raises(ValueError, match="exceeds max_tokens=5"):
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exact_limit(self):
|
||||
outs = [self._fake_output(5)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
|
||||
class TestValidateInputType:
|
||||
"""Unit tests for EmbedIOProcessor._validate_input_type."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_none_input_type_always_accepted(self):
|
||||
handler = self._make_handler(None)
|
||||
handler._validate_input_type(None)
|
||||
handler_with = self._make_handler({"query": "q: "})
|
||||
handler_with._validate_input_type(None)
|
||||
|
||||
def test_no_prompts_rejects(self):
|
||||
handler = self._make_handler(None)
|
||||
with pytest.raises(ValueError, match="does not define any input_type"):
|
||||
handler._validate_input_type("anything")
|
||||
|
||||
def test_known_type_accepted(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
handler._validate_input_type("query")
|
||||
handler._validate_input_type("document")
|
||||
|
||||
def test_unknown_type_rejected(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
with pytest.raises(ValueError, match="Unsupported input_type 'other'"):
|
||||
handler._validate_input_type("other")
|
||||
|
||||
def test_error_lists_supported(self):
|
||||
handler = self._make_handler({"a": "", "b": ""})
|
||||
with pytest.raises(ValueError, match="Supported values: a, b"):
|
||||
handler._validate_input_type("z")
|
||||
129
tests/entrypoints/pooling/embed/test_protocol.py
Normal file
129
tests/entrypoints/pooling/embed/test_protocol.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for Cohere embed protocol: build_typed_embeddings and its
|
||||
underlying packing helpers, plus Cohere-specific serving helpers."""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
build_typed_embeddings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings() -> list[list[float]]:
|
||||
return [
|
||||
[0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8],
|
||||
[-0.05, 0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75],
|
||||
]
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsFloat:
|
||||
def test_float_passthrough(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float"])
|
||||
assert result.float == sample_embeddings
|
||||
assert result.binary is None
|
||||
|
||||
def test_empty_input(self):
|
||||
result = build_typed_embeddings([], ["float"])
|
||||
assert result.float == []
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsBinary:
|
||||
def test_binary_packing(self):
|
||||
# 8 values: positive->1, negative->0 => bits: 10101010 = 0xAA = 170
|
||||
# signed: 170 - 128 = 42
|
||||
embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
assert result.binary[0] == [42]
|
||||
|
||||
def test_ubinary_packing(self):
|
||||
embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
|
||||
result = build_typed_embeddings(embs, ["ubinary"])
|
||||
assert result.ubinary is not None
|
||||
assert result.ubinary[0] == [170] # 0b10101010
|
||||
|
||||
def test_binary_all_positive(self):
|
||||
embs = [[0.1] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# all bits = 1 => 0xFF = 255, signed: 255 - 128 = 127
|
||||
assert result.binary[0] == [127]
|
||||
|
||||
def test_binary_all_negative(self):
|
||||
embs = [[-0.1] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# all bits = 0, signed: 0 - 128 = -128
|
||||
assert result.binary[0] == [-128]
|
||||
|
||||
def test_binary_dimension_is_eighth(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["binary"])
|
||||
assert result.binary is not None
|
||||
for orig, packed in zip(sample_embeddings, result.binary):
|
||||
assert len(packed) == len(orig) // 8
|
||||
|
||||
def test_zero_treated_as_positive(self):
|
||||
embs = [[0.0] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# 0.0 >= 0 is True, so bit=1 for all => 127 (signed)
|
||||
assert result.binary[0] == [127]
|
||||
|
||||
def test_non_multiple_of_8_raises(self):
|
||||
embs = [[0.1] * 7]
|
||||
with pytest.raises(ValueError, match="multiple of 8"):
|
||||
build_typed_embeddings(embs, ["binary"])
|
||||
|
||||
def test_ubinary_non_multiple_of_8_raises(self):
|
||||
embs = [[0.1] * 10]
|
||||
with pytest.raises(ValueError, match="multiple of 8"):
|
||||
build_typed_embeddings(embs, ["ubinary"])
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsBase64:
|
||||
def test_base64_roundtrip(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["base64"])
|
||||
assert result.base64 is not None
|
||||
assert len(result.base64) == 2
|
||||
|
||||
for orig, b64_str in zip(sample_embeddings, result.base64):
|
||||
decoded = base64.b64decode(b64_str)
|
||||
n = len(orig)
|
||||
values = struct.unpack(f"<{n}f", decoded)
|
||||
np.testing.assert_allclose(orig, values, rtol=1e-5)
|
||||
|
||||
def test_base64_byte_length(self):
|
||||
embs = [[0.1, 0.2, 0.3]]
|
||||
result = build_typed_embeddings(embs, ["base64"])
|
||||
assert result.base64 is not None
|
||||
raw = base64.b64decode(result.base64[0])
|
||||
assert len(raw) == 3 * 4 # 3 floats * 4 bytes each
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsMultiple:
|
||||
def test_all_types_at_once(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(
|
||||
sample_embeddings,
|
||||
["float", "binary", "ubinary", "base64"],
|
||||
)
|
||||
assert result.float is not None
|
||||
assert result.binary is not None
|
||||
assert result.ubinary is not None
|
||||
assert result.base64 is not None
|
||||
|
||||
def test_subset_types(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float", "binary"])
|
||||
assert result.float is not None
|
||||
assert result.binary is not None
|
||||
assert result.ubinary is None
|
||||
assert result.base64 is None
|
||||
|
||||
def test_unknown_type_ignored(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float", "unknown_type"])
|
||||
assert result.float is not None
|
||||
Reference in New Issue
Block a user