[BugFix][Frontend] apply task instruction as system prompt in cohere v2/embed (#38362)

Signed-off-by: walterbm <walter.beller.morales@gmail.com>
This commit is contained in:
Walter Beller-Morales
2026-03-28 14:30:54 -04:00
committed by GitHub
parent aa4eb0db78
commit fafca38adc
3 changed files with 245 additions and 42 deletions

View File

@@ -6,8 +6,11 @@ import pytest
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
CohereEmbedInput,
CohereEmbedRequest,
)
from vllm.entrypoints.pooling.typing import PoolingServeContext
class TestResolveTruncation:
@@ -206,3 +209,116 @@ class TestValidateInputType:
handler = self._make_handler({"a": "", "b": ""})
with pytest.raises(ValueError, match="Supported values: a, b"):
handler._validate_input_type("z")
class TestPreProcessCohereOnline:
"""Unit tests for EmbedIOProcessor._pre_process_cohere_online."""
@staticmethod
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
return PoolingServeContext(
request=CohereEmbedRequest(model="test", **request_kwargs),
model_name="test",
request_id="embd-test",
)
@staticmethod
def _make_handler():
handler = object.__new__(EmbedIOProcessor)
handler._validate_input_type = lambda _input_type: None
return handler
def test_text_only_without_task_prefix_uses_completion_path(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"])
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["completion"]
handler._get_task_instruction_prefix = lambda _input_type: None
handler._has_chat_template = lambda: False
handler._preprocess_completion_online = preprocess_completion
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("text-only request should not require chat rendering")
)
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["completion"]
assert calls == [("completion", ["hello"])]
def test_text_only_falls_back_to_prefixed_completion_without_template(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def preprocess_completion(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["fallback"]
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: False
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("chat rendering should be skipped without a template")
)
handler._preprocess_completion_online = preprocess_completion
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["fallback"]
assert calls == [("completion", ["query: hello"])]
def test_text_only_with_template_uses_chat_path(self):
handler = self._make_handler()
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []
def batch_render_chat(
request,
all_messages,
truncate_prompt_tokens,
truncation_side,
):
calls.append(
(
"chat",
{
"request": request,
"all_messages": all_messages,
"truncate_prompt_tokens": truncate_prompt_tokens,
"truncation_side": truncation_side,
},
)
)
return ["chat"]
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: True
handler._batch_render_chat = batch_render_chat
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
pytest.fail("completion path should be skipped when a template exists")
)
handler._pre_process_cohere_online(ctx)
assert ctx.engine_inputs == ["chat"]
assert calls == [
(
"chat",
{
"request": ctx.request,
"all_messages": [
handler._mixed_input_to_messages(
CohereEmbedInput(
content=[CohereEmbedContent(type="text", text="hello")]
),
task_prefix="query: ",
)
],
"truncate_prompt_tokens": -1,
"truncation_side": None,
},
)
]