2026-03-16 19:55:53 -04:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
"""Unit tests for EmbedIOProcessor."""
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
|
|
|
|
from vllm.entrypoints.pooling.embed.protocol import (
|
2026-03-28 14:30:54 -04:00
|
|
|
CohereEmbedContent,
|
|
|
|
|
CohereEmbedInput,
|
2026-03-16 19:55:53 -04:00
|
|
|
CohereEmbedRequest,
|
|
|
|
|
)
|
2026-03-28 14:30:54 -04:00
|
|
|
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
2026-03-16 19:55:53 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestResolveTruncation:
|
|
|
|
|
"""Unit tests for EmbedIOProcessor._resolve_cohere_truncation."""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _make_request(**kwargs) -> CohereEmbedRequest:
|
|
|
|
|
defaults = {
|
|
|
|
|
"model": "test",
|
|
|
|
|
"input_type": "search_document",
|
|
|
|
|
"texts": ["hello"],
|
|
|
|
|
}
|
|
|
|
|
return CohereEmbedRequest(**(defaults | kwargs))
|
|
|
|
|
|
|
|
|
|
def test_truncate_end_default(self):
|
|
|
|
|
req = self._make_request()
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens == -1
|
|
|
|
|
assert side is None
|
|
|
|
|
|
|
|
|
|
def test_truncate_end_explicit(self):
|
|
|
|
|
req = self._make_request(truncate="END")
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens == -1
|
|
|
|
|
assert side is None
|
|
|
|
|
|
|
|
|
|
def test_truncate_end_with_max_tokens(self):
|
|
|
|
|
req = self._make_request(truncate="END", max_tokens=128)
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens == 128
|
|
|
|
|
assert side is None
|
|
|
|
|
|
|
|
|
|
def test_truncate_none(self):
|
|
|
|
|
req = self._make_request(truncate="NONE")
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens is None
|
|
|
|
|
assert side is None
|
|
|
|
|
|
|
|
|
|
def test_truncate_none_with_max_tokens(self):
|
|
|
|
|
"""truncate=NONE should NOT set truncate_prompt_tokens; the
|
|
|
|
|
max_tokens limit is enforced separately via _check_max_tokens."""
|
|
|
|
|
req = self._make_request(truncate="NONE", max_tokens=10)
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens is None
|
|
|
|
|
assert side is None
|
|
|
|
|
|
|
|
|
|
def test_truncate_start(self):
|
|
|
|
|
req = self._make_request(truncate="START")
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens == -1
|
|
|
|
|
assert side == "left"
|
|
|
|
|
|
|
|
|
|
def test_truncate_start_with_max_tokens(self):
|
|
|
|
|
req = self._make_request(truncate="START", max_tokens=64)
|
|
|
|
|
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
|
|
|
|
assert tokens == 64
|
|
|
|
|
assert side == "left"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestApplyStPrompt:
|
|
|
|
|
"""Unit tests for EmbedIOProcessor._apply_task_instruction."""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _make_handler(task_instructions: dict[str, str] | None):
|
|
|
|
|
handler = object.__new__(EmbedIOProcessor)
|
|
|
|
|
handler.task_instructions = task_instructions
|
|
|
|
|
return handler
|
|
|
|
|
|
|
|
|
|
def test_no_prompts_configured(self):
|
|
|
|
|
handler = self._make_handler(None)
|
|
|
|
|
texts = ["hello", "world"]
|
|
|
|
|
assert handler._apply_task_instruction(texts, "query") is texts
|
|
|
|
|
|
|
|
|
|
def test_matching_input_type(self):
|
|
|
|
|
handler = self._make_handler({"query": "search_query: "})
|
|
|
|
|
result = handler._apply_task_instruction(["hello"], "query")
|
|
|
|
|
assert result == ["search_query: hello"]
|
|
|
|
|
|
|
|
|
|
def test_non_matching_input_type(self):
|
|
|
|
|
handler = self._make_handler({"query": "search_query: "})
|
|
|
|
|
texts = ["hello"]
|
|
|
|
|
assert handler._apply_task_instruction(texts, "document") is texts
|
|
|
|
|
|
|
|
|
|
def test_multiple_texts(self):
|
|
|
|
|
handler = self._make_handler(
|
|
|
|
|
{"query": "Represent this sentence for searching: "}
|
|
|
|
|
)
|
|
|
|
|
result = handler._apply_task_instruction(["a", "b", "c"], "query")
|
|
|
|
|
assert result == [
|
|
|
|
|
"Represent this sentence for searching: a",
|
|
|
|
|
"Represent this sentence for searching: b",
|
|
|
|
|
"Represent this sentence for searching: c",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def test_empty_prefix_returns_unchanged(self):
|
|
|
|
|
handler = self._make_handler({"passage": ""})
|
|
|
|
|
texts = ["hello"]
|
|
|
|
|
assert handler._apply_task_instruction(texts, "passage") is texts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestLoadTaskInstructions:
|
|
|
|
|
"""Unit tests for EmbedIOProcessor._load_task_instructions."""
|
|
|
|
|
|
|
|
|
|
def test_no_attribute(self):
|
|
|
|
|
class FakeConfig:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
|
|
|
|
|
|
|
|
|
def test_with_task_instructions(self):
|
|
|
|
|
class FakeConfig:
|
|
|
|
|
task_instructions = {
|
|
|
|
|
"retrieval.query": "Represent the query: ",
|
|
|
|
|
"retrieval.passage": "",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
result = EmbedIOProcessor._load_task_instructions(FakeConfig())
|
|
|
|
|
assert result == {
|
|
|
|
|
"retrieval.query": "Represent the query: ",
|
|
|
|
|
"retrieval.passage": "",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def test_empty_dict(self):
|
|
|
|
|
class FakeConfig:
|
|
|
|
|
task_instructions = {}
|
|
|
|
|
|
|
|
|
|
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
|
|
|
|
|
|
|
|
|
def test_non_dict(self):
|
|
|
|
|
class FakeConfig:
|
|
|
|
|
task_instructions = "not a dict"
|
|
|
|
|
|
|
|
|
|
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestCheckMaxTokens:
|
|
|
|
|
"""Unit tests for EmbedIOProcessor._check_cohere_max_tokens."""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _fake_output(n_tokens: int):
|
|
|
|
|
class _Out:
|
|
|
|
|
def __init__(self, n: int):
|
|
|
|
|
self.prompt_token_ids = list(range(n))
|
|
|
|
|
|
|
|
|
|
return _Out(n_tokens)
|
|
|
|
|
|
|
|
|
|
def test_none_check_is_noop(self):
|
|
|
|
|
outs = [self._fake_output(100)]
|
|
|
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, None)
|
|
|
|
|
|
|
|
|
|
def test_within_limit(self):
|
|
|
|
|
outs = [self._fake_output(5), self._fake_output(3)]
|
|
|
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
|
|
|
|
|
|
|
|
|
def test_exceeds_limit(self):
|
|
|
|
|
outs = [self._fake_output(3), self._fake_output(10)]
|
|
|
|
|
with pytest.raises(ValueError, match="exceeds max_tokens=5"):
|
|
|
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
|
|
|
|
|
|
|
|
|
def test_exact_limit(self):
|
|
|
|
|
outs = [self._fake_output(5)]
|
|
|
|
|
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestValidateInputType:
|
|
|
|
|
"""Unit tests for EmbedIOProcessor._validate_input_type."""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _make_handler(task_instructions: dict[str, str] | None):
|
|
|
|
|
handler = object.__new__(EmbedIOProcessor)
|
|
|
|
|
handler.task_instructions = task_instructions
|
|
|
|
|
return handler
|
|
|
|
|
|
|
|
|
|
def test_none_input_type_always_accepted(self):
|
|
|
|
|
handler = self._make_handler(None)
|
|
|
|
|
handler._validate_input_type(None)
|
|
|
|
|
handler_with = self._make_handler({"query": "q: "})
|
|
|
|
|
handler_with._validate_input_type(None)
|
|
|
|
|
|
|
|
|
|
def test_no_prompts_rejects(self):
|
|
|
|
|
handler = self._make_handler(None)
|
|
|
|
|
with pytest.raises(ValueError, match="does not define any input_type"):
|
|
|
|
|
handler._validate_input_type("anything")
|
|
|
|
|
|
|
|
|
|
def test_known_type_accepted(self):
|
|
|
|
|
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
|
|
|
|
handler._validate_input_type("query")
|
|
|
|
|
handler._validate_input_type("document")
|
|
|
|
|
|
|
|
|
|
def test_unknown_type_rejected(self):
|
|
|
|
|
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
|
|
|
|
with pytest.raises(ValueError, match="Unsupported input_type 'other'"):
|
|
|
|
|
handler._validate_input_type("other")
|
|
|
|
|
|
|
|
|
|
def test_error_lists_supported(self):
|
|
|
|
|
handler = self._make_handler({"a": "", "b": ""})
|
|
|
|
|
with pytest.raises(ValueError, match="Supported values: a, b"):
|
|
|
|
|
handler._validate_input_type("z")
|
2026-03-28 14:30:54 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPreProcessCohereOnline:
|
|
|
|
|
"""Unit tests for EmbedIOProcessor._pre_process_cohere_online."""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
|
|
|
|
|
return PoolingServeContext(
|
|
|
|
|
request=CohereEmbedRequest(model="test", **request_kwargs),
|
|
|
|
|
model_name="test",
|
|
|
|
|
request_id="embd-test",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _make_handler():
|
|
|
|
|
handler = object.__new__(EmbedIOProcessor)
|
|
|
|
|
handler._validate_input_type = lambda _input_type: None
|
|
|
|
|
return handler
|
|
|
|
|
|
|
|
|
|
def test_text_only_without_task_prefix_uses_completion_path(self):
|
|
|
|
|
handler = self._make_handler()
|
|
|
|
|
ctx = self._make_context(texts=["hello"])
|
|
|
|
|
calls: list[tuple[str, object]] = []
|
|
|
|
|
|
|
|
|
|
def preprocess_completion(request, prompt_input, prompt_embeds):
|
|
|
|
|
calls.append(("completion", prompt_input))
|
|
|
|
|
return ["completion"]
|
|
|
|
|
|
|
|
|
|
handler._get_task_instruction_prefix = lambda _input_type: None
|
|
|
|
|
handler._has_chat_template = lambda: False
|
|
|
|
|
handler._preprocess_completion_online = preprocess_completion
|
|
|
|
|
handler._batch_render_chat = lambda *_args, **_kwargs: (
|
|
|
|
|
pytest.fail("text-only request should not require chat rendering")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
handler._pre_process_cohere_online(ctx)
|
|
|
|
|
|
|
|
|
|
assert ctx.engine_inputs == ["completion"]
|
|
|
|
|
assert calls == [("completion", ["hello"])]
|
|
|
|
|
|
|
|
|
|
def test_text_only_falls_back_to_prefixed_completion_without_template(self):
|
|
|
|
|
handler = self._make_handler()
|
|
|
|
|
ctx = self._make_context(texts=["hello"], input_type="query")
|
|
|
|
|
calls: list[tuple[str, object]] = []
|
|
|
|
|
|
|
|
|
|
def preprocess_completion(request, prompt_input, prompt_embeds):
|
|
|
|
|
calls.append(("completion", prompt_input))
|
|
|
|
|
return ["fallback"]
|
|
|
|
|
|
|
|
|
|
handler._get_task_instruction_prefix = lambda _input_type: "query: "
|
|
|
|
|
handler._has_chat_template = lambda: False
|
|
|
|
|
handler._batch_render_chat = lambda *_args, **_kwargs: (
|
|
|
|
|
pytest.fail("chat rendering should be skipped without a template")
|
|
|
|
|
)
|
|
|
|
|
handler._preprocess_completion_online = preprocess_completion
|
|
|
|
|
|
|
|
|
|
handler._pre_process_cohere_online(ctx)
|
|
|
|
|
|
|
|
|
|
assert ctx.engine_inputs == ["fallback"]
|
|
|
|
|
assert calls == [("completion", ["query: hello"])]
|
|
|
|
|
|
|
|
|
|
def test_text_only_with_template_uses_chat_path(self):
|
|
|
|
|
handler = self._make_handler()
|
|
|
|
|
ctx = self._make_context(texts=["hello"], input_type="query")
|
|
|
|
|
calls: list[tuple[str, object]] = []
|
|
|
|
|
|
|
|
|
|
def batch_render_chat(
|
|
|
|
|
request,
|
|
|
|
|
all_messages,
|
|
|
|
|
truncate_prompt_tokens,
|
|
|
|
|
truncation_side,
|
|
|
|
|
):
|
|
|
|
|
calls.append(
|
|
|
|
|
(
|
|
|
|
|
"chat",
|
|
|
|
|
{
|
|
|
|
|
"request": request,
|
|
|
|
|
"all_messages": all_messages,
|
|
|
|
|
"truncate_prompt_tokens": truncate_prompt_tokens,
|
|
|
|
|
"truncation_side": truncation_side,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return ["chat"]
|
|
|
|
|
|
|
|
|
|
handler._get_task_instruction_prefix = lambda _input_type: "query: "
|
|
|
|
|
handler._has_chat_template = lambda: True
|
|
|
|
|
handler._batch_render_chat = batch_render_chat
|
|
|
|
|
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
|
|
|
|
|
pytest.fail("completion path should be skipped when a template exists")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
handler._pre_process_cohere_online(ctx)
|
|
|
|
|
|
|
|
|
|
assert ctx.engine_inputs == ["chat"]
|
|
|
|
|
assert calls == [
|
|
|
|
|
(
|
|
|
|
|
"chat",
|
|
|
|
|
{
|
|
|
|
|
"request": ctx.request,
|
|
|
|
|
"all_messages": [
|
|
|
|
|
handler._mixed_input_to_messages(
|
|
|
|
|
CohereEmbedInput(
|
|
|
|
|
content=[CohereEmbedContent(type="text", text="hello")]
|
|
|
|
|
),
|
|
|
|
|
task_prefix="query: ",
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
"truncate_prompt_tokens": -1,
|
|
|
|
|
"truncation_side": None,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
]
|