[BugFix][Frontend] apply task instruction as system prompt in cohere v2/embed (#38362)
Signed-off-by: walterbm <walter.beller.morales@gmail.com>
This commit is contained in:
committed by
GitHub
parent
aa4eb0db78
commit
fafca38adc
@@ -57,16 +57,25 @@ def _openai_embed(
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
|
||||
def _cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
va, vb = np.array(a), np.array(b)
|
||||
return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))
|
||||
|
||||
|
||||
def test_single_text_parity(server: RemoteOpenAIServer):
|
||||
"""A single text should produce identical embeddings via both APIs."""
|
||||
"""A single text should produce equivalent embeddings via both APIs."""
|
||||
texts = ["the quick brown fox jumps over the lazy dog"]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5)
|
||||
# Full-suite BF16 runs can introduce tiny numerical drift even when both
|
||||
# endpoints are functionally equivalent, so compare semantic equivalence
|
||||
# instead of exact elementwise equality.
|
||||
cos = _cosine_sim(v2[0], v1[0])
|
||||
assert cos > 0.9999, f"single-text parity failed, cosine={cos}"
|
||||
|
||||
|
||||
def test_batch_parity(server: RemoteOpenAIServer):
|
||||
"""A batch of texts should produce identical embeddings via both APIs,
|
||||
"""A batch of texts should produce equivalent embeddings via both APIs,
|
||||
in the same order."""
|
||||
texts = [
|
||||
"machine learning",
|
||||
@@ -76,8 +85,18 @@ def test_batch_parity(server: RemoteOpenAIServer):
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
assert len(v2) == len(v1) == 3
|
||||
|
||||
similarities = np.array(
|
||||
[[_cosine_sim(v2_emb, v1_emb) for v1_emb in v1] for v2_emb in v2]
|
||||
)
|
||||
for i in range(3):
|
||||
np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}")
|
||||
assert int(np.argmax(similarities[i])) == i, (
|
||||
f"batch parity order mismatch at index {i}: "
|
||||
f"similarities={similarities[i].tolist()}"
|
||||
)
|
||||
assert similarities[i, i] > 0.9999, (
|
||||
f"batch parity failed at index {i}, cosine={similarities[i, i]}"
|
||||
)
|
||||
|
||||
|
||||
def test_token_count_parity(server: RemoteOpenAIServer):
|
||||
|
||||
@@ -6,8 +6,11 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedContent,
|
||||
CohereEmbedInput,
|
||||
CohereEmbedRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
|
||||
|
||||
class TestResolveTruncation:
|
||||
@@ -206,3 +209,116 @@ class TestValidateInputType:
|
||||
handler = self._make_handler({"a": "", "b": ""})
|
||||
with pytest.raises(ValueError, match="Supported values: a, b"):
|
||||
handler._validate_input_type("z")
|
||||
|
||||
|
||||
class TestPreProcessCohereOnline:
|
||||
"""Unit tests for EmbedIOProcessor._pre_process_cohere_online."""
|
||||
|
||||
@staticmethod
|
||||
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
|
||||
return PoolingServeContext(
|
||||
request=CohereEmbedRequest(model="test", **request_kwargs),
|
||||
model_name="test",
|
||||
request_id="embd-test",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _make_handler():
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler._validate_input_type = lambda _input_type: None
|
||||
return handler
|
||||
|
||||
def test_text_only_without_task_prefix_uses_completion_path(self):
|
||||
handler = self._make_handler()
|
||||
ctx = self._make_context(texts=["hello"])
|
||||
calls: list[tuple[str, object]] = []
|
||||
|
||||
def preprocess_completion(request, prompt_input, prompt_embeds):
|
||||
calls.append(("completion", prompt_input))
|
||||
return ["completion"]
|
||||
|
||||
handler._get_task_instruction_prefix = lambda _input_type: None
|
||||
handler._has_chat_template = lambda: False
|
||||
handler._preprocess_completion_online = preprocess_completion
|
||||
handler._batch_render_chat = lambda *_args, **_kwargs: (
|
||||
pytest.fail("text-only request should not require chat rendering")
|
||||
)
|
||||
|
||||
handler._pre_process_cohere_online(ctx)
|
||||
|
||||
assert ctx.engine_inputs == ["completion"]
|
||||
assert calls == [("completion", ["hello"])]
|
||||
|
||||
def test_text_only_falls_back_to_prefixed_completion_without_template(self):
|
||||
handler = self._make_handler()
|
||||
ctx = self._make_context(texts=["hello"], input_type="query")
|
||||
calls: list[tuple[str, object]] = []
|
||||
|
||||
def preprocess_completion(request, prompt_input, prompt_embeds):
|
||||
calls.append(("completion", prompt_input))
|
||||
return ["fallback"]
|
||||
|
||||
handler._get_task_instruction_prefix = lambda _input_type: "query: "
|
||||
handler._has_chat_template = lambda: False
|
||||
handler._batch_render_chat = lambda *_args, **_kwargs: (
|
||||
pytest.fail("chat rendering should be skipped without a template")
|
||||
)
|
||||
handler._preprocess_completion_online = preprocess_completion
|
||||
|
||||
handler._pre_process_cohere_online(ctx)
|
||||
|
||||
assert ctx.engine_inputs == ["fallback"]
|
||||
assert calls == [("completion", ["query: hello"])]
|
||||
|
||||
def test_text_only_with_template_uses_chat_path(self):
|
||||
handler = self._make_handler()
|
||||
ctx = self._make_context(texts=["hello"], input_type="query")
|
||||
calls: list[tuple[str, object]] = []
|
||||
|
||||
def batch_render_chat(
|
||||
request,
|
||||
all_messages,
|
||||
truncate_prompt_tokens,
|
||||
truncation_side,
|
||||
):
|
||||
calls.append(
|
||||
(
|
||||
"chat",
|
||||
{
|
||||
"request": request,
|
||||
"all_messages": all_messages,
|
||||
"truncate_prompt_tokens": truncate_prompt_tokens,
|
||||
"truncation_side": truncation_side,
|
||||
},
|
||||
)
|
||||
)
|
||||
return ["chat"]
|
||||
|
||||
handler._get_task_instruction_prefix = lambda _input_type: "query: "
|
||||
handler._has_chat_template = lambda: True
|
||||
handler._batch_render_chat = batch_render_chat
|
||||
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
|
||||
pytest.fail("completion path should be skipped when a template exists")
|
||||
)
|
||||
|
||||
handler._pre_process_cohere_online(ctx)
|
||||
|
||||
assert ctx.engine_inputs == ["chat"]
|
||||
assert calls == [
|
||||
(
|
||||
"chat",
|
||||
{
|
||||
"request": ctx.request,
|
||||
"all_messages": [
|
||||
handler._mixed_input_to_messages(
|
||||
CohereEmbedInput(
|
||||
content=[CohereEmbedContent(type="text", text="hello")]
|
||||
),
|
||||
task_prefix="query: ",
|
||||
)
|
||||
],
|
||||
"truncate_prompt_tokens": -1,
|
||||
"truncation_side": None,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user