[Feature][Frontend] add support for Cohere Embed v2 API (#37074)
Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
This commit is contained in:
committed by
khluu
parent
1fe3932c8b
commit
4d22667c32
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for EmbedIOProcessor."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveTruncation:
|
||||
"""Unit tests for EmbedIOProcessor._resolve_cohere_truncation."""
|
||||
|
||||
@staticmethod
|
||||
def _make_request(**kwargs) -> CohereEmbedRequest:
|
||||
defaults = {
|
||||
"model": "test",
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"],
|
||||
}
|
||||
return CohereEmbedRequest(**(defaults | kwargs))
|
||||
|
||||
def test_truncate_end_default(self):
|
||||
req = self._make_request()
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_explicit(self):
|
||||
req = self._make_request(truncate="END")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_with_max_tokens(self):
|
||||
req = self._make_request(truncate="END", max_tokens=128)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 128
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none(self):
|
||||
req = self._make_request(truncate="NONE")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none_with_max_tokens(self):
|
||||
"""truncate=NONE should NOT set truncate_prompt_tokens; the
|
||||
max_tokens limit is enforced separately via _check_max_tokens."""
|
||||
req = self._make_request(truncate="NONE", max_tokens=10)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_start(self):
|
||||
req = self._make_request(truncate="START")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side == "left"
|
||||
|
||||
def test_truncate_start_with_max_tokens(self):
|
||||
req = self._make_request(truncate="START", max_tokens=64)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 64
|
||||
assert side == "left"
|
||||
|
||||
|
||||
class TestApplyStPrompt:
|
||||
"""Unit tests for EmbedIOProcessor._apply_task_instruction."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_no_prompts_configured(self):
|
||||
handler = self._make_handler(None)
|
||||
texts = ["hello", "world"]
|
||||
assert handler._apply_task_instruction(texts, "query") is texts
|
||||
|
||||
def test_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
result = handler._apply_task_instruction(["hello"], "query")
|
||||
assert result == ["search_query: hello"]
|
||||
|
||||
def test_non_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "document") is texts
|
||||
|
||||
def test_multiple_texts(self):
|
||||
handler = self._make_handler(
|
||||
{"query": "Represent this sentence for searching: "}
|
||||
)
|
||||
result = handler._apply_task_instruction(["a", "b", "c"], "query")
|
||||
assert result == [
|
||||
"Represent this sentence for searching: a",
|
||||
"Represent this sentence for searching: b",
|
||||
"Represent this sentence for searching: c",
|
||||
]
|
||||
|
||||
def test_empty_prefix_returns_unchanged(self):
|
||||
handler = self._make_handler({"passage": ""})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "passage") is texts
|
||||
|
||||
|
||||
class TestLoadTaskInstructions:
|
||||
"""Unit tests for EmbedIOProcessor._load_task_instructions."""
|
||||
|
||||
def test_no_attribute(self):
|
||||
class FakeConfig:
|
||||
pass
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_with_task_instructions(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
result = EmbedIOProcessor._load_task_instructions(FakeConfig())
|
||||
assert result == {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
def test_empty_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {}
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_non_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = "not a dict"
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
|
||||
class TestCheckMaxTokens:
|
||||
"""Unit tests for EmbedIOProcessor._check_cohere_max_tokens."""
|
||||
|
||||
@staticmethod
|
||||
def _fake_output(n_tokens: int):
|
||||
class _Out:
|
||||
def __init__(self, n: int):
|
||||
self.prompt_token_ids = list(range(n))
|
||||
|
||||
return _Out(n_tokens)
|
||||
|
||||
def test_none_check_is_noop(self):
|
||||
outs = [self._fake_output(100)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, None)
|
||||
|
||||
def test_within_limit(self):
|
||||
outs = [self._fake_output(5), self._fake_output(3)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exceeds_limit(self):
|
||||
outs = [self._fake_output(3), self._fake_output(10)]
|
||||
with pytest.raises(ValueError, match="exceeds max_tokens=5"):
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exact_limit(self):
|
||||
outs = [self._fake_output(5)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
|
||||
class TestValidateInputType:
|
||||
"""Unit tests for EmbedIOProcessor._validate_input_type."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_none_input_type_always_accepted(self):
|
||||
handler = self._make_handler(None)
|
||||
handler._validate_input_type(None)
|
||||
handler_with = self._make_handler({"query": "q: "})
|
||||
handler_with._validate_input_type(None)
|
||||
|
||||
def test_no_prompts_rejects(self):
|
||||
handler = self._make_handler(None)
|
||||
with pytest.raises(ValueError, match="does not define any input_type"):
|
||||
handler._validate_input_type("anything")
|
||||
|
||||
def test_known_type_accepted(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
handler._validate_input_type("query")
|
||||
handler._validate_input_type("document")
|
||||
|
||||
def test_unknown_type_rejected(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
with pytest.raises(ValueError, match="Unsupported input_type 'other'"):
|
||||
handler._validate_input_type("other")
|
||||
|
||||
def test_error_lists_supported(self):
|
||||
handler = self._make_handler({"a": "", "b": ""})
|
||||
with pytest.raises(ValueError, match="Supported values: a, b"):
|
||||
handler._validate_input_type("z")
|
||||
Reference in New Issue
Block a user