Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
209 lines
7.2 KiB
Python
209 lines
7.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Unit tests for EmbedIOProcessor."""
|
|
|
|
import pytest
|
|
|
|
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
|
from vllm.entrypoints.pooling.embed.protocol import (
|
|
CohereEmbedRequest,
|
|
)
|
|
|
|
|
|
class TestResolveTruncation:
|
|
"""Unit tests for EmbedIOProcessor._resolve_cohere_truncation."""
|
|
|
|
@staticmethod
|
|
def _make_request(**kwargs) -> CohereEmbedRequest:
|
|
defaults = {
|
|
"model": "test",
|
|
"input_type": "search_document",
|
|
"texts": ["hello"],
|
|
}
|
|
return CohereEmbedRequest(**(defaults | kwargs))
|
|
|
|
def test_truncate_end_default(self):
|
|
req = self._make_request()
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens == -1
|
|
assert side is None
|
|
|
|
def test_truncate_end_explicit(self):
|
|
req = self._make_request(truncate="END")
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens == -1
|
|
assert side is None
|
|
|
|
def test_truncate_end_with_max_tokens(self):
|
|
req = self._make_request(truncate="END", max_tokens=128)
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens == 128
|
|
assert side is None
|
|
|
|
def test_truncate_none(self):
|
|
req = self._make_request(truncate="NONE")
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens is None
|
|
assert side is None
|
|
|
|
def test_truncate_none_with_max_tokens(self):
|
|
"""truncate=NONE should NOT set truncate_prompt_tokens; the
|
|
max_tokens limit is enforced separately via _check_max_tokens."""
|
|
req = self._make_request(truncate="NONE", max_tokens=10)
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens is None
|
|
assert side is None
|
|
|
|
def test_truncate_start(self):
|
|
req = self._make_request(truncate="START")
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens == -1
|
|
assert side == "left"
|
|
|
|
def test_truncate_start_with_max_tokens(self):
|
|
req = self._make_request(truncate="START", max_tokens=64)
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
assert tokens == 64
|
|
assert side == "left"
|
|
|
|
|
|
class TestApplyStPrompt:
|
|
"""Unit tests for EmbedIOProcessor._apply_task_instruction."""
|
|
|
|
@staticmethod
|
|
def _make_handler(task_instructions: dict[str, str] | None):
|
|
handler = object.__new__(EmbedIOProcessor)
|
|
handler.task_instructions = task_instructions
|
|
return handler
|
|
|
|
def test_no_prompts_configured(self):
|
|
handler = self._make_handler(None)
|
|
texts = ["hello", "world"]
|
|
assert handler._apply_task_instruction(texts, "query") is texts
|
|
|
|
def test_matching_input_type(self):
|
|
handler = self._make_handler({"query": "search_query: "})
|
|
result = handler._apply_task_instruction(["hello"], "query")
|
|
assert result == ["search_query: hello"]
|
|
|
|
def test_non_matching_input_type(self):
|
|
handler = self._make_handler({"query": "search_query: "})
|
|
texts = ["hello"]
|
|
assert handler._apply_task_instruction(texts, "document") is texts
|
|
|
|
def test_multiple_texts(self):
|
|
handler = self._make_handler(
|
|
{"query": "Represent this sentence for searching: "}
|
|
)
|
|
result = handler._apply_task_instruction(["a", "b", "c"], "query")
|
|
assert result == [
|
|
"Represent this sentence for searching: a",
|
|
"Represent this sentence for searching: b",
|
|
"Represent this sentence for searching: c",
|
|
]
|
|
|
|
def test_empty_prefix_returns_unchanged(self):
|
|
handler = self._make_handler({"passage": ""})
|
|
texts = ["hello"]
|
|
assert handler._apply_task_instruction(texts, "passage") is texts
|
|
|
|
|
|
class TestLoadTaskInstructions:
|
|
"""Unit tests for EmbedIOProcessor._load_task_instructions."""
|
|
|
|
def test_no_attribute(self):
|
|
class FakeConfig:
|
|
pass
|
|
|
|
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
|
|
|
def test_with_task_instructions(self):
|
|
class FakeConfig:
|
|
task_instructions = {
|
|
"retrieval.query": "Represent the query: ",
|
|
"retrieval.passage": "",
|
|
}
|
|
|
|
result = EmbedIOProcessor._load_task_instructions(FakeConfig())
|
|
assert result == {
|
|
"retrieval.query": "Represent the query: ",
|
|
"retrieval.passage": "",
|
|
}
|
|
|
|
def test_empty_dict(self):
|
|
class FakeConfig:
|
|
task_instructions = {}
|
|
|
|
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
|
|
|
def test_non_dict(self):
|
|
class FakeConfig:
|
|
task_instructions = "not a dict"
|
|
|
|
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
|
|
|
|
|
class TestCheckMaxTokens:
|
|
"""Unit tests for EmbedIOProcessor._check_cohere_max_tokens."""
|
|
|
|
@staticmethod
|
|
def _fake_output(n_tokens: int):
|
|
class _Out:
|
|
def __init__(self, n: int):
|
|
self.prompt_token_ids = list(range(n))
|
|
|
|
return _Out(n_tokens)
|
|
|
|
def test_none_check_is_noop(self):
|
|
outs = [self._fake_output(100)]
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, None)
|
|
|
|
def test_within_limit(self):
|
|
outs = [self._fake_output(5), self._fake_output(3)]
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
|
|
|
def test_exceeds_limit(self):
|
|
outs = [self._fake_output(3), self._fake_output(10)]
|
|
with pytest.raises(ValueError, match="exceeds max_tokens=5"):
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
|
|
|
def test_exact_limit(self):
|
|
outs = [self._fake_output(5)]
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
|
|
|
|
|
class TestValidateInputType:
|
|
"""Unit tests for EmbedIOProcessor._validate_input_type."""
|
|
|
|
@staticmethod
|
|
def _make_handler(task_instructions: dict[str, str] | None):
|
|
handler = object.__new__(EmbedIOProcessor)
|
|
handler.task_instructions = task_instructions
|
|
return handler
|
|
|
|
def test_none_input_type_always_accepted(self):
|
|
handler = self._make_handler(None)
|
|
handler._validate_input_type(None)
|
|
handler_with = self._make_handler({"query": "q: "})
|
|
handler_with._validate_input_type(None)
|
|
|
|
def test_no_prompts_rejects(self):
|
|
handler = self._make_handler(None)
|
|
with pytest.raises(ValueError, match="does not define any input_type"):
|
|
handler._validate_input_type("anything")
|
|
|
|
def test_known_type_accepted(self):
|
|
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
|
handler._validate_input_type("query")
|
|
handler._validate_input_type("document")
|
|
|
|
def test_unknown_type_rejected(self):
|
|
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
|
with pytest.raises(ValueError, match="Unsupported input_type 'other'"):
|
|
handler._validate_input_type("other")
|
|
|
|
def test_error_lists_supported(self):
|
|
handler = self._make_handler({"a": "", "b": ""})
|
|
with pytest.raises(ValueError, match="Supported values: a, b"):
|
|
handler._validate_input_type("z")
|