diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 45af2b693..cf44a1bfe 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -72,6 +72,9 @@ In addition, we have the following custom APIs: - Only applicable to [classification models](../models/pooling_models.md). - [Score API](#score-api) (`/score`) - Applicable to [embedding models and cross-encoder models](../models/pooling_models.md). +- [Cohere Embed API](#cohere-embed-api) (`/v2/embed`) + - Compatible with [Cohere's Embed API](https://docs.cohere.com/reference/embed) + - Works with any [embedding model](../models/pooling_models.md), including multimodal models. - [Re-rank API](#re-rank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`) - Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/) - Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank) @@ -429,6 +432,137 @@ these extra parameters are supported instead: --8<-- "vllm/entrypoints/pooling/base/protocol.py:embed-extra-params" ``` +### Cohere Embed API + +Our API is also compatible with [Cohere's Embed v2 API](https://docs.cohere.com/reference/embed) which adds support for some modern embedding feature such as truncation, output dimensions, embedding types, and input types. This endpoint works with any embedding model (including multimodal models). + +#### Cohere Embed API request parameters + +| Parameter | Type | Required | Description | +| --------- | ---- | -------- | ----------- | +| `model` | string | Yes | Model name | +| `input_type` | string | No | Prompt prefix key (model-dependent, see below) | +| `texts` | list[string] | No | Text inputs (use one of `texts`, `images`, or `inputs`) | +| `images` | list[string] | No | Base64 data URI images | +| `inputs` | list[object] | No | Mixed text and image content objects | +| `embedding_types` | list[string] | No | Output types (default: `["float"]`) | +| `output_dimension` | int | No | Truncate embeddings to this dimension (Matryoshka) | +| `truncate` | string | No | `END`, `START`, or `NONE` (default: `END`) | + +#### Text embedding + +```bash +curl -X POST "http://localhost:8000/v2/embed" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Snowflake/snowflake-arctic-embed-m-v1.5", + "input_type": "query", + "texts": ["Hello world", "How are you?"], + "embedding_types": ["float"] + }' +``` + +??? console "Response" + + ```json + { + "id": "embd-...", + "embeddings": { + "float": [ + [0.012, -0.034, ...], + [0.056, 0.078, ...] + ] + }, + "texts": ["Hello world", "How are you?"], + "meta": { + "api_version": {"version": "2"}, + "billed_units": {"input_tokens": 12} + } + } + ``` + +#### Mixed text and image inputs + +For multimodal models, you can embed images by passing base64 data URIs. The `inputs` field accepts a list of objects with mixed text and image content: + +```bash +curl -X POST "http://localhost:8000/v2/embed" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "google/siglip-so400m-patch14-384", + "inputs": [ + { + "content": [ + {"type": "text", "text": "A photo of a cat"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}} + ] + } + ], + "embedding_types": ["float"] + }' +``` + +#### Embedding types + +The `embedding_types` parameter controls the output format. Multiple types can be requested in a single call: + +| Type | Description | +| ---- | ----------- | +| `float` | Raw float32 embeddings (default) | +| `binary` | Bit-packed signed binary | +| `ubinary` | Bit-packed unsigned binary | +| `base64` | Little-endian float32 encoded as base64 | + +```bash +curl -X POST "http://localhost:8000/v2/embed" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Snowflake/snowflake-arctic-embed-m-v1.5", + "input_type": "query", + "texts": ["What is machine learning?"], + "embedding_types": ["float", "binary"] + }' +``` + +??? console "Response" + + ```json + { + "id": "embd-...", + "embeddings": { + "float": [[0.012, -0.034, ...]], + "binary": [[42, -117, ...]] + }, + "texts": ["What is machine learning?"], + "meta": { + "api_version": {"version": "2"}, + "billed_units": {"input_tokens": 8} + } + } + ``` + +#### Truncation + +The `truncate` parameter controls how inputs exceeding the model's maximum sequence length are handled: + +| Value | Behavior | +| ----- | --------- | +| `END` (default) | Keep the first tokens, drop the end | +| `START` | Keep the last tokens, drop the beginning | +| `NONE` | Return an error if the input is too long | + +#### Input type and prompt prefixes + +The `input_type` field selects a prompt prefix to prepend to each text input. The available values +depend on the model: + +- **Models with `task_instructions` in `config.json`**: The keys from the `task_instructions` dict are + the valid `input_type` values and the corresponding value is prepended to each text. +- **Models with `config_sentence_transformers.json` prompts**: The keys from the `prompts` dict are + the valid `input_type` values. For example, `Snowflake/snowflake-arctic-embed-xs` defines `"query"`, + so setting `input_type: "query"` prepends `"Represent this sentence for searching relevant passages: "`. +- **Other models**: `input_type` is not accepted and will raise a validation error if passed. + ### Transcriptions API Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); diff --git a/tests/entrypoints/pooling/embed/test_cohere_online.py b/tests/entrypoints/pooling/embed/test_cohere_online.py new file mode 100644 index 000000000..fc313819f --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_cohere_online.py @@ -0,0 +1,310 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the Cohere /v2/embed API with generic (non-Cohere) models. + +Validates that the Cohere v2 embed endpoint works correctly with standard +embedding models, covering text embedding, embedding type conversions, +response structure, batching, normalisation, and semantic similarity. +""" + +import base64 +import struct + +import numpy as np +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +DTYPE = "bfloat16" + +MODELS: list[tuple[str, list[str]]] = [ + ("intfloat/multilingual-e5-small", []), + ( + "Snowflake/snowflake-arctic-embed-m-v1.5", + [ + "--trust_remote_code", + "--hf_overrides", + '{"matryoshka_dimensions":[256]}', + ], + ), +] + + +@pytest.fixture(scope="module", params=MODELS, ids=lambda m: m[0]) +def model_config(request): + return request.param + + +@pytest.fixture(scope="module") +def model_name(model_config): + return model_config[0] + + +@pytest.fixture(scope="module") +def server(model_config): + name, extra_args = model_config + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", + "--gpu-memory-utilization", + "0.02", + ] + extra_args + with RemoteOpenAIServer(name, args) as remote_server: + yield remote_server + + +def _cohere_embed( + server: RemoteOpenAIServer, + model_name: str, + texts: list[str] | None = None, + images: list[str] | None = None, + input_type: str | None = None, + embedding_types: list[str] | None = None, +) -> dict: + body: dict = {"model": model_name} + if input_type is not None: + body["input_type"] = input_type + if texts is not None: + body["texts"] = texts + if images is not None: + body["images"] = images + if embedding_types is not None: + body["embedding_types"] = embedding_types + resp = requests.post(server.url_for("/v2/embed"), json=body) + resp.raise_for_status() + return resp.json() + + +def _openai_embed( + server: RemoteOpenAIServer, model_name: str, texts: list[str] +) -> dict: + body = {"model": model_name, "input": texts, "encoding_format": "float"} + resp = requests.post(server.url_for("/v1/embeddings"), json=body) + resp.raise_for_status() + return resp.json() + + +def _cosine_sim(a: list[float], b: list[float]) -> float: + va, vb = np.array(a), np.array(b) + return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb))) + + +# ----------------------------------------------------------- +# Text embedding tests +# ----------------------------------------------------------- + + +def test_basic_embed(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, model_name, texts=["hello world"], embedding_types=["float"] + ) + assert "embeddings" in r + assert len(r["embeddings"]["float"]) == 1 + assert len(r["embeddings"]["float"][0]) > 0 + + +def test_unsupported_input_type_rejected(server: RemoteOpenAIServer, model_name: str): + """An input_type not defined in the model's prompt config should be + rejected with a 400 error.""" + body = { + "model": model_name, + "input_type": "nonexistent_type", + "texts": ["hello world"], + "embedding_types": ["float"], + } + resp = requests.post(server.url_for("/v2/embed"), json=body) + assert resp.status_code == 400 + assert "Unsupported input_type" in resp.json()["error"]["message"] + + +def test_omitted_input_type_accepted(server: RemoteOpenAIServer, model_name: str): + """Omitting input_type should always work (no prompt prefix applied).""" + body = { + "model": model_name, + "texts": ["hello world"], + "embedding_types": ["float"], + } + resp = requests.post(server.url_for("/v2/embed"), json=body) + assert resp.status_code == 200 + data = resp.json() + assert len(data["embeddings"]["float"]) == 1 + + +def test_v1_v2_parity(server: RemoteOpenAIServer, model_name: str): + """v1 (OpenAI) and v2 (Cohere) endpoints should produce the same + float embeddings for a generic model.""" + texts = ["hello world"] + v2 = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"]) + v1 = _openai_embed(server, model_name, texts) + cos = _cosine_sim(v2["embeddings"]["float"][0], v1["data"][0]["embedding"]) + assert cos > 0.9999, f"v1/v2 parity failed, cosine={cos}" + + +def test_embedding_types(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, + model_name, + texts=["test"], + embedding_types=["float", "binary", "ubinary"], + ) + dim = len(r["embeddings"]["float"][0]) + assert len(r["embeddings"]["binary"][0]) == dim // 8 + assert len(r["embeddings"]["ubinary"][0]) == dim // 8 + + +def test_response_structure(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed(server, model_name, texts=["test"], embedding_types=["float"]) + assert "id" in r + assert "embeddings" in r + assert "texts" in r + assert r["texts"] == ["test"] + assert "meta" in r + assert r["meta"]["api_version"]["version"] == "2" + assert "billed_units" in r["meta"] + assert r["meta"]["billed_units"]["input_tokens"] > 0 + assert r["meta"]["billed_units"]["image_tokens"] == 0 + + +def test_batch(server: RemoteOpenAIServer, model_name: str): + texts = ["apple", "banana", "cherry"] + r = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"]) + assert len(r["embeddings"]["float"]) == 3 + dim = len(r["embeddings"]["float"][0]) + for emb in r["embeddings"]["float"]: + assert len(emb) == dim + + +def test_l2_normalized(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, model_name, texts=["hello world"], embedding_types=["float"] + ) + emb = np.array(r["embeddings"]["float"][0]) + assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01 + + +def test_semantic_similarity(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, + model_name, + texts=["machine learning", "deep learning", "chocolate cake recipe"], + embedding_types=["float"], + ) + embs = r["embeddings"]["float"] + cos_related = _cosine_sim(embs[0], embs[1]) + cos_unrelated = _cosine_sim(embs[0], embs[2]) + assert cos_related > cos_unrelated + + +def test_missing_input_returns_error(server: RemoteOpenAIServer, model_name: str): + body = {"model": model_name} + resp = requests.post(server.url_for("/v2/embed"), json=body) + assert resp.status_code == 400 + + +def test_base64_embedding_type(server: RemoteOpenAIServer, model_name: str): + r = _cohere_embed( + server, + model_name, + texts=["test encoding"], + embedding_types=["float", "base64"], + ) + float_emb = r["embeddings"]["float"][0] + b64_str = r["embeddings"]["base64"][0] + decoded = struct.unpack(f"<{len(float_emb)}f", base64.b64decode(b64_str)) + np.testing.assert_allclose(float_emb, decoded, rtol=1e-5) + + +# ----------------------------------------------------------- +# Truncation tests +# ----------------------------------------------------------- + + +def _cohere_embed_raw( + server: RemoteOpenAIServer, + body: dict, +) -> requests.Response: + return requests.post(server.url_for("/v2/embed"), json=body) + + +def test_truncate_end_succeeds(server: RemoteOpenAIServer, model_name: str): + """truncate=END should silently truncate long input.""" + long_text = " ".join(["word"] * 2000) + body = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "END", + } + resp = _cohere_embed_raw(server, body) + assert resp.status_code == 200 + data = resp.json() + assert len(data["embeddings"]["float"]) == 1 + + +def test_truncate_start_succeeds(server: RemoteOpenAIServer, model_name: str): + """truncate=START should silently truncate long input from the start.""" + long_text = " ".join(["word"] * 2000) + body = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "START", + } + resp = _cohere_embed_raw(server, body) + assert resp.status_code == 200 + data = resp.json() + assert len(data["embeddings"]["float"]) == 1 + + +def test_truncate_none_rejects_long_input(server: RemoteOpenAIServer, model_name: str): + """truncate=NONE should error when input exceeds model context.""" + long_text = " ".join(["word"] * 2000) + body = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "NONE", + } + resp = _cohere_embed_raw(server, body) + assert resp.status_code == 400 + + +def test_truncate_start_vs_end_differ(server: RemoteOpenAIServer, model_name: str): + """START and END truncation should produce different embeddings + when the input is long enough to actually be truncated. + + We construct input with distinct tokens at the start vs end + so that keeping different halves produces different embeddings. + """ + start_words = " ".join([f"alpha{i}" for i in range(300)]) + end_words = " ".join([f"omega{i}" for i in range(300)]) + long_text = start_words + " " + end_words + + body_end = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "END", + } + body_start = { + "model": model_name, + "texts": [long_text], + "embedding_types": ["float"], + "truncate": "START", + } + r_end = _cohere_embed_raw(server, body_end).json() + r_start = _cohere_embed_raw(server, body_start).json() + + emb_end = r_end["embeddings"]["float"][0] + emb_start = r_start["embeddings"]["float"][0] + cos = _cosine_sim(emb_end, emb_start) + assert cos < 0.99, ( + f"START and END truncation should produce different embeddings " + f"for long input, but cosine similarity was {cos}" + ) diff --git a/tests/entrypoints/pooling/embed/test_cohere_online_vision.py b/tests/entrypoints/pooling/embed/test_cohere_online_vision.py new file mode 100644 index 000000000..ab874e4e2 --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_cohere_online_vision.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the Cohere /v2/embed API with a multimodal model (SigLIP). + +Validates image embedding, batching, normalisation, and embedding type +conversions through the /v2/embed endpoint. +""" + +import base64 +import struct +import zlib + +import numpy as np +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "google/siglip-so400m-patch14-384" +DTYPE = "bfloat16" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "64", + "--gpu-memory-utilization", + "0.3", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def _make_tiny_png(r: int, g: int, b: int, w: int = 2, h: int = 2) -> str: + raw = b"" + for _ in range(h): + raw += b"\x00" + bytes([r, g, b]) * w + compressed = zlib.compress(raw) + + def chunk(ctype: bytes, cdata: bytes) -> bytes: + c = ctype + cdata + return ( + struct.pack(">I", len(cdata)) + + c + + struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF) + ) + + ihdr = struct.pack(">IIBBBBB", w, h, 8, 2, 0, 0, 0) + png = ( + b"\x89PNG\r\n\x1a\n" + + chunk(b"IHDR", ihdr) + + chunk(b"IDAT", compressed) + + chunk(b"IEND", b"") + ) + return "data:image/png;base64," + base64.b64encode(png).decode() + + +def _cohere_embed( + server: RemoteOpenAIServer, + texts: list[str] | None = None, + images: list[str] | None = None, + embedding_types: list[str] | None = None, +) -> dict: + body: dict = {"model": MODEL_NAME} + if texts is not None: + body["texts"] = texts + if images is not None: + body["images"] = images + if embedding_types is not None: + body["embedding_types"] = embedding_types + resp = requests.post(server.url_for("/v2/embed"), json=body) + resp.raise_for_status() + return resp.json() + + +def test_image_embed(server: RemoteOpenAIServer): + img_uri = _make_tiny_png(255, 0, 0) + r = _cohere_embed( + server, + images=[img_uri], + embedding_types=["float"], + ) + assert "embeddings" in r + assert len(r["embeddings"]["float"]) == 1 + assert len(r["embeddings"]["float"][0]) > 0 + assert r["meta"]["billed_units"]["image_tokens"] > 0 + assert r["meta"]["billed_units"]["input_tokens"] == 0 + + +def test_image_batch(server: RemoteOpenAIServer): + red = _make_tiny_png(255, 0, 0) + blue = _make_tiny_png(0, 0, 255) + r = _cohere_embed( + server, + images=[red, blue], + embedding_types=["float"], + ) + assert len(r["embeddings"]["float"]) == 2 + + +def test_image_l2_normalized(server: RemoteOpenAIServer): + img_uri = _make_tiny_png(0, 255, 0) + r = _cohere_embed( + server, + images=[img_uri], + embedding_types=["float"], + ) + emb = np.array(r["embeddings"]["float"][0]) + assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01 + + +def test_image_embedding_types(server: RemoteOpenAIServer): + img_uri = _make_tiny_png(128, 128, 128) + r = _cohere_embed( + server, + images=[img_uri], + embedding_types=["float", "binary", "ubinary"], + ) + dim = len(r["embeddings"]["float"][0]) + assert len(r["embeddings"]["binary"][0]) == dim // 8 + assert len(r["embeddings"]["ubinary"][0]) == dim // 8 + + +def test_text_embed_on_multimodal(server: RemoteOpenAIServer): + """SigLIP also supports text-only embedding via /v2/embed.""" + r = _cohere_embed(server, texts=["hello world"], embedding_types=["float"]) + assert "embeddings" in r + assert len(r["embeddings"]["float"]) == 1 + assert len(r["embeddings"]["float"][0]) > 0 diff --git a/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py b/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py new file mode 100644 index 000000000..d23e1461b --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_cohere_openai_parity.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Parity test between Cohere /v2/embed and OpenAI /v1/embeddings. + +Verifies that both endpoints produce identical float embeddings when +no prompt prefix is applied (input_type omitted for Cohere /v2/embed). +""" + +import numpy as np +import pytest +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-base-en-v1.5" +DTYPE = "bfloat16" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--runner", + "pooling", + "--dtype", + DTYPE, + "--enforce-eager", + "--max-model-len", + "512", + "--gpu-memory-utilization", + "0.02", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +def _cohere_embed( + server: RemoteOpenAIServer, + texts: list[str], +) -> list[list[float]]: + body = { + "model": MODEL_NAME, + "texts": texts, + "embedding_types": ["float"], + } + resp = requests.post(server.url_for("/v2/embed"), json=body) + resp.raise_for_status() + return resp.json()["embeddings"]["float"] + + +def _openai_embed( + server: RemoteOpenAIServer, + texts: list[str], +) -> list[list[float]]: + body = {"model": MODEL_NAME, "input": texts, "encoding_format": "float"} + resp = requests.post(server.url_for("/v1/embeddings"), json=body) + resp.raise_for_status() + return [item["embedding"] for item in resp.json()["data"]] + + +def test_single_text_parity(server: RemoteOpenAIServer): + """A single text should produce identical embeddings via both APIs.""" + texts = ["the quick brown fox jumps over the lazy dog"] + v2 = _cohere_embed(server, texts) + v1 = _openai_embed(server, texts) + np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5) + + +def test_batch_parity(server: RemoteOpenAIServer): + """A batch of texts should produce identical embeddings via both APIs, + in the same order.""" + texts = [ + "machine learning", + "deep learning", + "natural language processing", + ] + v2 = _cohere_embed(server, texts) + v1 = _openai_embed(server, texts) + assert len(v2) == len(v1) == 3 + for i in range(3): + np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}") + + +def test_token_count_parity(server: RemoteOpenAIServer): + """Both APIs should report the same prompt token count.""" + texts = ["hello world"] + v2_resp = requests.post( + server.url_for("/v2/embed"), + json={ + "model": MODEL_NAME, + "texts": texts, + "embedding_types": ["float"], + }, + ) + v1_resp = requests.post( + server.url_for("/v1/embeddings"), + json={"model": MODEL_NAME, "input": texts, "encoding_format": "float"}, + ) + v2_resp.raise_for_status() + v1_resp.raise_for_status() + v2_tokens = v2_resp.json()["meta"]["billed_units"]["input_tokens"] + v1_tokens = v1_resp.json()["usage"]["prompt_tokens"] + assert v2_tokens == v1_tokens diff --git a/tests/entrypoints/pooling/embed/test_io_processor.py b/tests/entrypoints/pooling/embed/test_io_processor.py new file mode 100644 index 000000000..e7db0df1e --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_io_processor.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for EmbedIOProcessor.""" + +import pytest + +from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor +from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedRequest, +) + + +class TestResolveTruncation: + """Unit tests for EmbedIOProcessor._resolve_cohere_truncation.""" + + @staticmethod + def _make_request(**kwargs) -> CohereEmbedRequest: + defaults = { + "model": "test", + "input_type": "search_document", + "texts": ["hello"], + } + return CohereEmbedRequest(**(defaults | kwargs)) + + def test_truncate_end_default(self): + req = self._make_request() + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == -1 + assert side is None + + def test_truncate_end_explicit(self): + req = self._make_request(truncate="END") + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == -1 + assert side is None + + def test_truncate_end_with_max_tokens(self): + req = self._make_request(truncate="END", max_tokens=128) + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == 128 + assert side is None + + def test_truncate_none(self): + req = self._make_request(truncate="NONE") + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens is None + assert side is None + + def test_truncate_none_with_max_tokens(self): + """truncate=NONE should NOT set truncate_prompt_tokens; the + max_tokens limit is enforced separately via _check_max_tokens.""" + req = self._make_request(truncate="NONE", max_tokens=10) + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens is None + assert side is None + + def test_truncate_start(self): + req = self._make_request(truncate="START") + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == -1 + assert side == "left" + + def test_truncate_start_with_max_tokens(self): + req = self._make_request(truncate="START", max_tokens=64) + tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req) + assert tokens == 64 + assert side == "left" + + +class TestApplyStPrompt: + """Unit tests for EmbedIOProcessor._apply_task_instruction.""" + + @staticmethod + def _make_handler(task_instructions: dict[str, str] | None): + handler = object.__new__(EmbedIOProcessor) + handler.task_instructions = task_instructions + return handler + + def test_no_prompts_configured(self): + handler = self._make_handler(None) + texts = ["hello", "world"] + assert handler._apply_task_instruction(texts, "query") is texts + + def test_matching_input_type(self): + handler = self._make_handler({"query": "search_query: "}) + result = handler._apply_task_instruction(["hello"], "query") + assert result == ["search_query: hello"] + + def test_non_matching_input_type(self): + handler = self._make_handler({"query": "search_query: "}) + texts = ["hello"] + assert handler._apply_task_instruction(texts, "document") is texts + + def test_multiple_texts(self): + handler = self._make_handler( + {"query": "Represent this sentence for searching: "} + ) + result = handler._apply_task_instruction(["a", "b", "c"], "query") + assert result == [ + "Represent this sentence for searching: a", + "Represent this sentence for searching: b", + "Represent this sentence for searching: c", + ] + + def test_empty_prefix_returns_unchanged(self): + handler = self._make_handler({"passage": ""}) + texts = ["hello"] + assert handler._apply_task_instruction(texts, "passage") is texts + + +class TestLoadTaskInstructions: + """Unit tests for EmbedIOProcessor._load_task_instructions.""" + + def test_no_attribute(self): + class FakeConfig: + pass + + assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None + + def test_with_task_instructions(self): + class FakeConfig: + task_instructions = { + "retrieval.query": "Represent the query: ", + "retrieval.passage": "", + } + + result = EmbedIOProcessor._load_task_instructions(FakeConfig()) + assert result == { + "retrieval.query": "Represent the query: ", + "retrieval.passage": "", + } + + def test_empty_dict(self): + class FakeConfig: + task_instructions = {} + + assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None + + def test_non_dict(self): + class FakeConfig: + task_instructions = "not a dict" + + assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None + + +class TestCheckMaxTokens: + """Unit tests for EmbedIOProcessor._check_cohere_max_tokens.""" + + @staticmethod + def _fake_output(n_tokens: int): + class _Out: + def __init__(self, n: int): + self.prompt_token_ids = list(range(n)) + + return _Out(n_tokens) + + def test_none_check_is_noop(self): + outs = [self._fake_output(100)] + EmbedIOProcessor._check_cohere_max_tokens(outs, None) + + def test_within_limit(self): + outs = [self._fake_output(5), self._fake_output(3)] + EmbedIOProcessor._check_cohere_max_tokens(outs, 5) + + def test_exceeds_limit(self): + outs = [self._fake_output(3), self._fake_output(10)] + with pytest.raises(ValueError, match="exceeds max_tokens=5"): + EmbedIOProcessor._check_cohere_max_tokens(outs, 5) + + def test_exact_limit(self): + outs = [self._fake_output(5)] + EmbedIOProcessor._check_cohere_max_tokens(outs, 5) + + +class TestValidateInputType: + """Unit tests for EmbedIOProcessor._validate_input_type.""" + + @staticmethod + def _make_handler(task_instructions: dict[str, str] | None): + handler = object.__new__(EmbedIOProcessor) + handler.task_instructions = task_instructions + return handler + + def test_none_input_type_always_accepted(self): + handler = self._make_handler(None) + handler._validate_input_type(None) + handler_with = self._make_handler({"query": "q: "}) + handler_with._validate_input_type(None) + + def test_no_prompts_rejects(self): + handler = self._make_handler(None) + with pytest.raises(ValueError, match="does not define any input_type"): + handler._validate_input_type("anything") + + def test_known_type_accepted(self): + handler = self._make_handler({"query": "q: ", "document": "d: "}) + handler._validate_input_type("query") + handler._validate_input_type("document") + + def test_unknown_type_rejected(self): + handler = self._make_handler({"query": "q: ", "document": "d: "}) + with pytest.raises(ValueError, match="Unsupported input_type 'other'"): + handler._validate_input_type("other") + + def test_error_lists_supported(self): + handler = self._make_handler({"a": "", "b": ""}) + with pytest.raises(ValueError, match="Supported values: a, b"): + handler._validate_input_type("z") diff --git a/tests/entrypoints/pooling/embed/test_protocol.py b/tests/entrypoints/pooling/embed/test_protocol.py new file mode 100644 index 000000000..f2bd5d2cc --- /dev/null +++ b/tests/entrypoints/pooling/embed/test_protocol.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Cohere embed protocol: build_typed_embeddings and its +underlying packing helpers, plus Cohere-specific serving helpers.""" + +import base64 +import struct + +import numpy as np +import pytest + +from vllm.entrypoints.pooling.embed.protocol import ( + build_typed_embeddings, +) + + +@pytest.fixture +def sample_embeddings() -> list[list[float]]: + return [ + [0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8], + [-0.05, 0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75], + ] + + +class TestBuildTypedEmbeddingsFloat: + def test_float_passthrough(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["float"]) + assert result.float == sample_embeddings + assert result.binary is None + + def test_empty_input(self): + result = build_typed_embeddings([], ["float"]) + assert result.float == [] + + +class TestBuildTypedEmbeddingsBinary: + def test_binary_packing(self): + # 8 values: positive->1, negative->0 => bits: 10101010 = 0xAA = 170 + # signed: 170 - 128 = 42 + embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + assert result.binary[0] == [42] + + def test_ubinary_packing(self): + embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]] + result = build_typed_embeddings(embs, ["ubinary"]) + assert result.ubinary is not None + assert result.ubinary[0] == [170] # 0b10101010 + + def test_binary_all_positive(self): + embs = [[0.1] * 8] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + # all bits = 1 => 0xFF = 255, signed: 255 - 128 = 127 + assert result.binary[0] == [127] + + def test_binary_all_negative(self): + embs = [[-0.1] * 8] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + # all bits = 0, signed: 0 - 128 = -128 + assert result.binary[0] == [-128] + + def test_binary_dimension_is_eighth(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["binary"]) + assert result.binary is not None + for orig, packed in zip(sample_embeddings, result.binary): + assert len(packed) == len(orig) // 8 + + def test_zero_treated_as_positive(self): + embs = [[0.0] * 8] + result = build_typed_embeddings(embs, ["binary"]) + assert result.binary is not None + # 0.0 >= 0 is True, so bit=1 for all => 127 (signed) + assert result.binary[0] == [127] + + def test_non_multiple_of_8_raises(self): + embs = [[0.1] * 7] + with pytest.raises(ValueError, match="multiple of 8"): + build_typed_embeddings(embs, ["binary"]) + + def test_ubinary_non_multiple_of_8_raises(self): + embs = [[0.1] * 10] + with pytest.raises(ValueError, match="multiple of 8"): + build_typed_embeddings(embs, ["ubinary"]) + + +class TestBuildTypedEmbeddingsBase64: + def test_base64_roundtrip(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["base64"]) + assert result.base64 is not None + assert len(result.base64) == 2 + + for orig, b64_str in zip(sample_embeddings, result.base64): + decoded = base64.b64decode(b64_str) + n = len(orig) + values = struct.unpack(f"<{n}f", decoded) + np.testing.assert_allclose(orig, values, rtol=1e-5) + + def test_base64_byte_length(self): + embs = [[0.1, 0.2, 0.3]] + result = build_typed_embeddings(embs, ["base64"]) + assert result.base64 is not None + raw = base64.b64decode(result.base64[0]) + assert len(raw) == 3 * 4 # 3 floats * 4 bytes each + + +class TestBuildTypedEmbeddingsMultiple: + def test_all_types_at_once(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings( + sample_embeddings, + ["float", "binary", "ubinary", "base64"], + ) + assert result.float is not None + assert result.binary is not None + assert result.ubinary is not None + assert result.base64 is not None + + def test_subset_types(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["float", "binary"]) + assert result.float is not None + assert result.binary is not None + assert result.ubinary is None + assert result.base64 is None + + def test_unknown_type_ignored(self, sample_embeddings: list[list[float]]): + result = build_typed_embeddings(sample_embeddings, ["float", "unknown_type"]) + assert result.float is not None diff --git a/vllm/entrypoints/pooling/base/protocol.py b/vllm/entrypoints/pooling/base/protocol.py index 50be58374..2f547df8d 100644 --- a/vllm/entrypoints/pooling/base/protocol.py +++ b/vllm/entrypoints/pooling/base/protocol.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Annotated, Any +from typing import Annotated, Any, Literal from pydantic import Field, model_validator @@ -24,6 +24,14 @@ class PoolingBasicRequestMixin(OpenAIBaseModel): # --8<-- [start:pooling-common-extra-params] truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None + truncation_side: Literal["left", "right"] | None = Field( + default=None, + description=( + "Which side to truncate from when truncate_prompt_tokens is active. " + "'right' keeps the first N tokens. " + "'left' keeps the last N tokens." + ), + ) request_id: str = Field( default_factory=random_uuid, description=( diff --git a/vllm/entrypoints/pooling/classify/protocol.py b/vllm/entrypoints/pooling/classify/protocol.py index bfc38ebef..fe8c898e0 100644 --- a/vllm/entrypoints/pooling/classify/protocol.py +++ b/vllm/entrypoints/pooling/classify/protocol.py @@ -32,6 +32,7 @@ class ClassificationCompletionRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -54,6 +55,7 @@ class ClassificationChatRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", diff --git a/vllm/entrypoints/pooling/embed/api_router.py b/vllm/entrypoints/pooling/embed/api_router.py index f88999468..390efc6a1 100644 --- a/vllm/entrypoints/pooling/embed/api_router.py +++ b/vllm/entrypoints/pooling/embed/api_router.py @@ -7,12 +7,12 @@ from fastapi import APIRouter, Depends, Request from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request -from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest -from vllm.entrypoints.pooling.embed.serving import ServingEmbedding -from vllm.entrypoints.utils import ( - load_aware_call, - with_cancellation, +from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedRequest, + EmbeddingRequest, ) +from vllm.entrypoints.pooling.embed.serving import ServingEmbedding +from vllm.entrypoints.utils import load_aware_call, with_cancellation router = APIRouter() @@ -40,3 +40,24 @@ async def create_embedding( raise NotImplementedError("The model does not support Embeddings API") return await handler(request, raw_request) + + +@router.post( + "/v2/embed", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_cohere_embedding( + request: CohereEmbedRequest, + raw_request: Request, +): + handler = embedding(raw_request) + if handler is None: + raise NotImplementedError("The model does not support Embeddings API") + + return await handler(request, raw_request) diff --git a/vllm/entrypoints/pooling/embed/io_processor.py b/vllm/entrypoints/pooling/embed/io_processor.py index 22ece7542..9342013bf 100644 --- a/vllm/entrypoints/pooling/embed/io_processor.py +++ b/vllm/entrypoints/pooling/embed/io_processor.py @@ -1,14 +1,37 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, cast +from collections.abc import Sequence +from typing import Any, Literal, cast import torch +from openai.types.chat import ( + ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam, +) +from openai.types.chat.chat_completion_content_part_image_param import ImageURL +from vllm import PoolingParams +from vllm.entrypoints.chat_utils import ( + ChatCompletionContentPartParam, + ChatCompletionMessageParam, + CustomChatCompletionMessageParam, +) from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedInput, + CohereEmbedRequest, + EmbeddingChatRequest, + EmbeddingCompletionRequest, +) from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.inputs.data import ProcessorInputs, token_inputs +from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.renderers import merge_kwargs from vllm.utils.collection_utils import chunk_list +from vllm.utils.mistral import is_mistral_tokenizer + +logger = init_logger(__name__) class EmbedIOProcessor(PoolingIOProcessor): @@ -21,16 +44,45 @@ class EmbedIOProcessor(PoolingIOProcessor): self.pooler_config = self.model_config.pooler_config self.enable_chunked_processing = self.pooler_config.enable_chunked_processing + # Load task instructions from HF config or sentence-transformers config + self.task_instructions: dict[str, str] | None = self._load_task_instructions( + self.model_config.hf_config + ) or self._load_st_prompts(self.model_config.model, self.model_config.revision) + if self.task_instructions: + logger.info( + "Loaded prompt prefixes for input_type: %s", + list(self.task_instructions.keys()), + ) + + def pre_process_online(self, ctx: PoolingServeContext): + if isinstance(ctx.request, CohereEmbedRequest): + self._pre_process_cohere_online(ctx) + else: + super().pre_process_online(ctx) + + if self.enable_chunked_processing: + self._pre_process_chunked(ctx) + + def post_process_online( + self, + ctx: PoolingServeContext, + ): + if ctx.final_res_batch is None: + raise ValueError("Final response batch not available") + + if not self.enable_chunked_processing: + self._enforce_cohere_max_tokens(ctx) + return super().post_process_online(ctx) + + self._post_process_chunked(ctx) + self._enforce_cohere_max_tokens(ctx) + ################################################################# # Long Text Embedding with Chunked Processing # PTAL: examples/pooling/embed/openai_embedding_long_text + ################################################################# - def pre_process_online(self, ctx: PoolingServeContext): - super().pre_process_online(ctx) - - if not self.enable_chunked_processing: - return None - + def _pre_process_chunked(self, ctx: PoolingServeContext) -> None: if ctx.engine_prompts is None: raise ValueError("Engine prompts not available") @@ -61,18 +113,10 @@ class EmbedIOProcessor(PoolingIOProcessor): ctx.engine_prompts = chunked_engine_prompts ctx.prompt_request_ids = prompt_request_ids + return None - def post_process_online( - self, - ctx: PoolingServeContext, - ): - if ctx.final_res_batch is None: - raise ValueError("Final response batch not available") - - if not self.enable_chunked_processing: - return super().post_process_online(ctx) - + def _post_process_chunked(self, ctx: PoolingServeContext) -> None: # Online aggregation for chunked requests to # minimize memory usage # Track aggregation state for each prompt @@ -195,4 +239,245 @@ class EmbedIOProcessor(PoolingIOProcessor): raise ValueError(f"Result not found for prompt {prompt_idx}") ctx.final_res_batch = final_res_batch + return None + + ################################################################# + # Cohere Request Preprocessing & Postprocessing + ################################################################# + + @staticmethod + def _load_task_instructions(hf_config: Any) -> dict[str, str] | None: + """Extract ``task_instructions`` from the HF model config.""" + ti = getattr(hf_config, "task_instructions", None) + if not isinstance(ti, dict) or not ti: + return None + return {k: v for k, v in ti.items() if isinstance(v, str)} + + @staticmethod + def _load_st_prompts( + model: str | Any, + revision: str | None, + ) -> dict[str, str] | None: + """Load ``task_instructions`` from ``config_sentence_transformers.json``.""" + from vllm.transformers_utils.repo_utils import get_hf_file_to_dict + + try: + cfg = get_hf_file_to_dict( + "config_sentence_transformers.json", str(model), revision + ) + except (ValueError, OSError): + return None + + if cfg is None: + return None + prompts = cfg.get("prompts") + if not isinstance(prompts, dict) or not prompts: + return None + return {k: v for k, v in prompts.items() if isinstance(v, str)} + + @staticmethod + def _mixed_input_to_messages( + inp: CohereEmbedInput, + *, + task_prefix: str | None = None, + ) -> list[ChatCompletionMessageParam]: + """Build chat messages from a mixed text+image input. + + When *task_prefix* is given, it is prepended to each text part. + """ + parts: list[ChatCompletionContentPartParam] = [] + for item in inp.content: + if item.type == "text" and item.text is not None: + text = task_prefix + item.text if task_prefix else item.text + parts.append(ChatCompletionContentPartTextParam(type="text", text=text)) + elif item.type == "image_url" and item.image_url is not None: + parts.append( + ChatCompletionContentPartImageParam( + type="image_url", + image_url=ImageURL(url=item.image_url["url"]), + ) + ) + return [CustomChatCompletionMessageParam(role="user", content=parts)] + + @staticmethod + def _check_cohere_max_tokens( + outputs: list[PoolingRequestOutput], + max_tokens_check: int | None, + ) -> None: + """Raise if any output exceeds *max_tokens_check* tokens. + + Used to enforce ``truncate=NONE`` with an explicit ``max_tokens``: + the pipeline runs without truncation and we reject afterwards. + """ + if max_tokens_check is None: + return + for out in outputs: + n = len(out.prompt_token_ids) + if n > max_tokens_check: + raise ValueError( + f"Input of {n} tokens exceeds max_tokens={max_tokens_check} " + "with truncate=NONE. Set truncate to END or START to " + "allow truncation." + ) + + @staticmethod + def _resolve_cohere_truncation( + request: CohereEmbedRequest, + ) -> tuple[int | None, Literal["left", "right"] | None]: + """Return ``(truncate_prompt_tokens, truncation_side)``.""" + if request.truncate == "NONE": + return None, None + if request.truncate == "START": + tokens = request.max_tokens if request.max_tokens is not None else -1 + return tokens, "left" + if request.max_tokens is not None: + return request.max_tokens, None + return -1, None + + def create_pooling_params(self, request): + if isinstance(request, CohereEmbedRequest): + return PoolingParams( + task="embed", + dimensions=request.output_dimension, + ) + return super().create_pooling_params(request) + + def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None: + """Convert a ``CohereEmbedRequest`` into engine prompts. + + For texts, a single batched completion request path is used. + For images and mixed inputs, conversations are batch-rendered + through the chat template in one ``render_chat`` call. + """ + request = ctx.request + assert isinstance(request, CohereEmbedRequest) + + if request.texts is None and request.images is None and request.inputs is None: + raise ValueError("One of texts, images, or inputs must be provided") + + truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation( + request + ) + input_type = request.input_type + self._validate_input_type(input_type) + + if request.images is not None: + all_messages: list[list[ChatCompletionMessageParam]] = [ + [ + CustomChatCompletionMessageParam( + role="user", + content=[{"type": "image_url", "image_url": {"url": uri}}], + ) + ] + for uri in request.images + ] + ctx.engine_prompts = self._batch_render_chat( + request, all_messages, truncate_prompt_tokens, truncation_side + ) + + elif request.inputs is not None: + task_prefix = self._get_task_instruction_prefix(input_type) + all_messages = [ + self._mixed_input_to_messages(inp, task_prefix=task_prefix) + for inp in request.inputs + ] + ctx.engine_prompts = self._batch_render_chat( + request, all_messages, truncate_prompt_tokens, truncation_side + ) + + else: + prefixed = self._apply_task_instruction(request.texts or [], input_type) + proxy = EmbeddingCompletionRequest( + model=request.model, + input=prefixed, + dimensions=request.output_dimension, + encoding_format="float", + truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=truncation_side, + ) + ctx.engine_prompts = self._preprocess_completion_online( + proxy, prompt_input=proxy.input, prompt_embeds=None + ) + + def _batch_render_chat( + self, + request: CohereEmbedRequest, + all_messages: Sequence[list[ChatCompletionMessageParam]], + truncate_prompt_tokens: int | None, + truncation_side: Literal["left", "right"] | None, + ) -> list[ProcessorInputs]: + """Batch-render multiple conversations through the chat template.""" + if not all_messages: + return [] + + proxy = EmbeddingChatRequest( + model=request.model, + messages=list(all_messages[0]), + dimensions=request.output_dimension, + encoding_format="float", + truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=truncation_side, + ) + + renderer = self.renderer + mm_config = self.model_config.multimodal_config + + tok_params = proxy.build_tok_params(self.model_config) + chat_params = proxy.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ).with_defaults( + merge_kwargs( + None, + dict( + tools=None, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ), + default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), + ) + + _, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params) + return engine_prompts + + def _validate_input_type(self, input_type: str | None) -> None: + """Raise if *input_type* is not supported by this model.""" + if input_type is None: + return + if self.task_instructions is None: + raise ValueError( + f"Unsupported input_type {input_type!r}. " + "This model does not define any input_type task instructions." + ) + if input_type not in self.task_instructions: + supported = ", ".join(sorted(self.task_instructions)) + raise ValueError( + f"Unsupported input_type {input_type!r}. Supported values: {supported}" + ) + + def _apply_task_instruction( + self, + texts: list[str], + input_type: str | None, + ) -> list[str]: + """Prepend the task-instruction prefix for *input_type*. + + Returns *texts* unchanged when no matching prefix is configured. + """ + prefix = self._get_task_instruction_prefix(input_type) + if not prefix: + return texts + return [prefix + t for t in texts] + + def _get_task_instruction_prefix(self, input_type: str | None) -> str | None: + """Return the task-instruction prefix for *input_type*, or ``None``.""" + if not self.task_instructions or input_type is None: + return None + return self.task_instructions.get(input_type) or None + + def _enforce_cohere_max_tokens(self, ctx: PoolingServeContext) -> None: + if isinstance(ctx.request, CohereEmbedRequest): + request = ctx.request + if request.truncate == "NONE" and request.max_tokens is not None: + self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens) diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py index 4b47c6522..b02f91dfa 100644 --- a/vllm/entrypoints/pooling/embed/protocol.py +++ b/vllm/entrypoints/pooling/embed/protocol.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from typing import TypeAlias +"""Embedding API protocol models for OpenAI and Cohere formats. -from pydantic import Field +OpenAI: https://platform.openai.com/docs/api-reference/embeddings +Cohere: https://docs.cohere.com/reference/embed +""" + +import base64 +import builtins +import struct +import time +from collections.abc import Sequence +from typing import Literal, TypeAlias + +from pydantic import BaseModel, Field from vllm import PoolingParams from vllm.config import ModelConfig @@ -17,6 +27,10 @@ from vllm.entrypoints.pooling.base.protocol import ( from vllm.renderers import TokenizeParams from vllm.utils import random_uuid +# --------------------------------------------------------------------------- +# OpenAI /v1/embeddings — request models +# --------------------------------------------------------------------------- + def _get_max_total_output_tokens( model_config: ModelConfig, @@ -50,6 +64,7 @@ class EmbeddingCompletionRequest( max_total_tokens=max_total_tokens, max_output_tokens=max_output_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -79,6 +94,7 @@ class EmbeddingChatRequest( max_total_tokens=max_total_tokens, max_output_tokens=max_output_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -96,6 +112,11 @@ class EmbeddingChatRequest( EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest +# --------------------------------------------------------------------------- +# OpenAI /v1/embeddings — response models +# --------------------------------------------------------------------------- + + class EmbeddingResponseData(OpenAIBaseModel): index: int object: str = "embedding" @@ -106,7 +127,7 @@ class EmbeddingResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) - model: str + model: str | None = None data: list[EmbeddingResponseData] usage: UsageInfo @@ -115,3 +136,146 @@ class EmbeddingBytesResponse(OpenAIBaseModel): content: list[bytes] headers: dict[str, str] | None = None media_type: str = "application/octet-stream" + + +# --------------------------------------------------------------------------- +# Cohere /v2/embed — request models +# --------------------------------------------------------------------------- + +CohereEmbeddingType = Literal[ + "float", + "binary", + "ubinary", + "base64", +] +CohereTruncate = Literal["NONE", "START", "END"] + + +class CohereEmbedContent(BaseModel): + type: Literal["text", "image_url"] + text: str | None = None + image_url: dict[str, str] | None = None + + +class CohereEmbedInput(BaseModel): + content: list[CohereEmbedContent] + + +class CohereEmbedRequest(BaseModel): + model: str | None = None + input_type: str | None = None + texts: list[str] | None = None + images: list[str] | None = None + inputs: list[CohereEmbedInput] | None = None + output_dimension: int | None = None + embedding_types: list[CohereEmbeddingType] | None = None + truncate: CohereTruncate = "END" + max_tokens: int | None = None + priority: int = 0 + + +# --------------------------------------------------------------------------- +# Cohere /v2/embed — response models +# --------------------------------------------------------------------------- + + +class CohereApiVersion(BaseModel): + version: str = "2" + + +class CohereBilledUnits(BaseModel): + input_tokens: int | None = None + image_tokens: int | None = None + + +class CohereMeta(BaseModel): + api_version: CohereApiVersion = Field(default_factory=CohereApiVersion) + billed_units: CohereBilledUnits | None = None + + +class CohereEmbedByTypeEmbeddings(BaseModel): + # The field name ``float`` shadows the builtin type, so the annotation + # must use ``builtins.float`` to avoid a self-referential type error. + float: list[list[builtins.float]] | None = None + binary: list[list[int]] | None = None + ubinary: list[list[int]] | None = None + base64: list[str] | None = None + + +class CohereEmbedResponse(BaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + embeddings: CohereEmbedByTypeEmbeddings + texts: list[str] | None = None + meta: CohereMeta | None = None + response_type: Literal["embeddings_by_type"] = "embeddings_by_type" + + +# --------------------------------------------------------------------------- +# Cohere embedding type conversion helpers +# --------------------------------------------------------------------------- + +_UNSIGNED_TO_SIGNED_DIFF = 1 << 7 # 128 + + +def _pack_binary_embeddings( + float_embeddings: list[list[float]], + signed: bool, +) -> list[list[int]]: + """Bit-pack float embeddings: positive -> 1, negative -> 0. + + Each bit is shifted left by ``7 - idx%8``, and every 8 bits are packed + into one byte. + """ + result: list[list[int]] = [] + for embedding in float_embeddings: + dim = len(embedding) + if dim % 8 != 0: + raise ValueError( + "Embedding dimension must be a multiple of 8 for binary " + f"embedding types, but got {dim}." + ) + packed_len = dim // 8 + packed: list[int] = [] + byte_val = 0 + for idx, value in enumerate(embedding): + bit = 1 if value >= 0 else 0 + byte_val += bit << (7 - idx % 8) + if (idx + 1) % 8 == 0: + if signed: + byte_val -= _UNSIGNED_TO_SIGNED_DIFF + packed.append(byte_val) + byte_val = 0 + assert len(packed) == packed_len + result.append(packed) + return result + + +def _encode_base64_embeddings( + float_embeddings: list[list[float]], +) -> list[str]: + """Encode float embeddings as base64 (little-endian float32).""" + result: list[str] = [] + for embedding in float_embeddings: + buf = struct.pack(f"<{len(embedding)}f", *embedding) + result.append(base64.b64encode(buf).decode("utf-8")) + return result + + +def build_typed_embeddings( + float_embeddings: list[list[float]], + embedding_types: Sequence[str], +) -> CohereEmbedByTypeEmbeddings: + """Convert float embeddings to all requested Cohere embedding types.""" + result = CohereEmbedByTypeEmbeddings() + + for emb_type in embedding_types: + if emb_type == "float": + result.float = float_embeddings + elif emb_type == "binary": + result.binary = _pack_binary_embeddings(float_embeddings, signed=True) + elif emb_type == "ubinary": + result.ubinary = _pack_binary_embeddings(float_embeddings, signed=False) + elif emb_type == "base64": + result.base64 = _encode_base64_embeddings(float_embeddings) + + return result diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index c4ecf2683..f0c331645 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -5,7 +5,7 @@ from collections.abc import Callable from functools import partial from typing import Literal, TypeAlias, cast -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse from typing_extensions import assert_never from vllm.config import ModelConfig @@ -14,10 +14,15 @@ from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.protocol import ( + CohereBilledUnits, + CohereEmbedRequest, + CohereEmbedResponse, + CohereMeta, EmbeddingBytesResponse, EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, + build_typed_embeddings, ) from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.entrypoints.pooling.utils import ( @@ -26,24 +31,23 @@ from vllm.entrypoints.pooling.utils import ( encode_pooling_output_float, get_json_response_cls, ) +from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.renderers import BaseRenderer from vllm.utils.serial_utils import EmbedDType, Endianness +logger = init_logger(__name__) + JSONResponseCLS = get_json_response_cls() EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest] class ServingEmbedding(PoolingServing): - """ - Embedding API similar to OpenAI's API. - - See https://platform.openai.com/docs/api-reference/embeddings/create - for the API specification. This API mimics the OpenAI Embedding API. - """ + """Embedding API supporting both OpenAI and Cohere formats.""" request_id_prefix = "embd" + io_processor: EmbedIOProcessor def init_io_processor( self, @@ -58,6 +62,14 @@ class ServingEmbedding(PoolingServing): ) async def _build_response( + self, + ctx: PoolingServeContext, + ) -> Response: + if isinstance(ctx.request, CohereEmbedRequest): + return self._build_cohere_response_from_ctx(ctx) + return await self._build_openai_response(ctx) + + async def _build_openai_response( self, ctx: EmbeddingServeContext, ) -> JSONResponse | StreamingResponse: @@ -66,7 +78,7 @@ class ServingEmbedding(PoolingServing): endianness = ctx.request.endianness if encoding_format == "float" or encoding_format == "base64": - return self._request_output_to_embed_json_response( + return self._openai_json_response( ctx.final_res_batch, ctx.request_id, ctx.created_time, @@ -77,7 +89,7 @@ class ServingEmbedding(PoolingServing): ) if encoding_format == "bytes" or encoding_format == "bytes_only": - return self._request_output_to_to_embed_bytes_response( + return self._openai_bytes_response( ctx.final_res_batch, ctx.request_id, ctx.created_time, @@ -89,7 +101,7 @@ class ServingEmbedding(PoolingServing): assert_never(encoding_format) - def _request_output_to_embed_json_response( + def _openai_json_response( self, final_res_batch: list[PoolingRequestOutput], request_id: str, @@ -139,7 +151,7 @@ class ServingEmbedding(PoolingServing): ) return JSONResponseCLS(content=response.model_dump()) - def _request_output_to_to_embed_bytes_response( + def _openai_bytes_response( self, final_res_batch: list[PoolingRequestOutput], request_id: str, @@ -177,3 +189,33 @@ class ServingEmbedding(PoolingServing): headers=response.headers, media_type=response.media_type, ) + + @staticmethod + def _build_cohere_response_from_ctx( + ctx: PoolingServeContext, + ) -> JSONResponse: + request = ctx.request + assert isinstance(request, CohereEmbedRequest) + + all_floats = [encode_pooling_output_float(out) for out in ctx.final_res_batch] + total_tokens = sum(len(out.prompt_token_ids) for out in ctx.final_res_batch) + + image_tokens = total_tokens if request.images is not None else 0 + texts_echo = request.texts + + embedding_types = request.embedding_types or ["float"] + embeddings_obj = build_typed_embeddings(all_floats, embedding_types) + + input_tokens = total_tokens - image_tokens + response = CohereEmbedResponse( + id=ctx.request_id, + embeddings=embeddings_obj, + texts=texts_echo, + meta=CohereMeta( + billed_units=CohereBilledUnits( + input_tokens=input_tokens, + image_tokens=image_tokens, + ), + ), + ) + return JSONResponse(content=response.model_dump(exclude_none=True)) diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index b99f98959..098690db2 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -36,6 +36,7 @@ class PoolingCompletionRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -61,6 +62,7 @@ class PoolingChatRequest( max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=self.add_special_tokens, max_total_tokens_param="max_model_len", @@ -88,6 +90,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), add_special_tokens=not model_config.is_encoder_decoder, max_total_tokens_param="max_model_len", diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py index 643eeed36..2aea1bd7b 100644 --- a/vllm/entrypoints/pooling/score/protocol.py +++ b/vllm/entrypoints/pooling/score/protocol.py @@ -30,6 +30,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), max_total_tokens_param="max_model_len", ) @@ -105,6 +106,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): max_total_tokens=model_config.max_model_len, max_output_tokens=0, truncate_prompt_tokens=self.truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=encoder_config.get("do_lower_case", False), max_total_tokens_param="max_model_len", ) diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py index 74ed9b50c..f9f361824 100644 --- a/vllm/entrypoints/pooling/typing.py +++ b/vllm/entrypoints/pooling/typing.py @@ -15,6 +15,7 @@ from vllm.entrypoints.pooling.classify.protocol import ( ClassificationResponse, ) from vllm.entrypoints.pooling.embed.protocol import ( + CohereEmbedRequest, EmbeddingBytesResponse, EmbeddingChatRequest, EmbeddingCompletionRequest, @@ -50,6 +51,7 @@ AnyPoolingRequest: TypeAlias = ( | IOProcessorRequest | RerankRequest | ScoreRequest + | CohereEmbedRequest ) AnyPoolingResponse: TypeAlias = ( diff --git a/vllm/renderers/params.py b/vllm/renderers/params.py index 54da0f3b5..a2c95690c 100644 --- a/vllm/renderers/params.py +++ b/vllm/renderers/params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, Literal, TypeVar from vllm.exceptions import VLLMValidationError from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt @@ -153,6 +153,14 @@ class TokenizeParams: - `-1` maps to `max_input_tokens`. """ + truncation_side: Literal["left", "right"] | None = None + """ + Which side to truncate from when ``truncate_prompt_tokens`` is active: + - ``"right"`` keeps the first N tokens (truncate from the end). + - ``"left"`` keeps the last N tokens (truncate from the start). + - ``None`` falls back to the tokenizer default. + """ + do_lower_case: bool = False """Whether to normalize text to lower case before tokenization.""" @@ -271,6 +279,7 @@ class TokenizeParams: ), pad_prompt_tokens=pad_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=self.truncation_side, do_lower_case=do_lower_case, add_special_tokens=add_special_tokens, needs_detokenization=needs_detokenization, @@ -286,6 +295,16 @@ class TokenizeParams: # while still failing `self._token_len_check` as expected by users max_length = self.max_input_tokens + 1 + # Left-side truncation requires the full token sequence so we can + # slice from the end in _token_truncation. Disable HF-level + # truncation (which would incorrectly truncate from the right for + # pooling models) and let _token_truncation handle it. + if self.truncation_side == "left": + return dict( + truncation=False, + add_special_tokens=self.add_special_tokens, + ) + return dict( truncation=max_length is not None, max_length=max_length, @@ -375,7 +394,10 @@ class TokenizeParams: if max_length == 0: return tokens[:0] - if getattr(tokenizer, "truncation_side", "left") == "left": + side = self.truncation_side or ( + tokenizer.truncation_side if tokenizer is not None else None + ) + if side == "left": return tokens[-max_length:] return tokens[:max_length]