[Feature][Frontend] add support for Cohere Embed v2 API (#37074)
Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
This commit is contained in:
committed by
khluu
parent
1fe3932c8b
commit
4d22667c32
@@ -72,6 +72,9 @@ In addition, we have the following custom APIs:
|
||||
- Only applicable to [classification models](../models/pooling_models.md).
|
||||
- [Score API](#score-api) (`/score`)
|
||||
- Applicable to [embedding models and cross-encoder models](../models/pooling_models.md).
|
||||
- [Cohere Embed API](#cohere-embed-api) (`/v2/embed`)
|
||||
- Compatible with [Cohere's Embed API](https://docs.cohere.com/reference/embed)
|
||||
- Works with any [embedding model](../models/pooling_models.md), including multimodal models.
|
||||
- [Re-rank API](#re-rank-api) (`/rerank`, `/v1/rerank`, `/v2/rerank`)
|
||||
- Implements [Jina AI's v1 re-rank API](https://jina.ai/reranker/)
|
||||
- Also compatible with [Cohere's v1 & v2 re-rank APIs](https://docs.cohere.com/v2/reference/rerank)
|
||||
@@ -429,6 +432,137 @@ these extra parameters are supported instead:
|
||||
--8<-- "vllm/entrypoints/pooling/base/protocol.py:embed-extra-params"
|
||||
```
|
||||
|
||||
### Cohere Embed API
|
||||
|
||||
Our API is also compatible with [Cohere's Embed v2 API](https://docs.cohere.com/reference/embed) which adds support for some modern embedding feature such as truncation, output dimensions, embedding types, and input types. This endpoint works with any embedding model (including multimodal models).
|
||||
|
||||
#### Cohere Embed API request parameters
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `model` | string | Yes | Model name |
|
||||
| `input_type` | string | No | Prompt prefix key (model-dependent, see below) |
|
||||
| `texts` | list[string] | No | Text inputs (use one of `texts`, `images`, or `inputs`) |
|
||||
| `images` | list[string] | No | Base64 data URI images |
|
||||
| `inputs` | list[object] | No | Mixed text and image content objects |
|
||||
| `embedding_types` | list[string] | No | Output types (default: `["float"]`) |
|
||||
| `output_dimension` | int | No | Truncate embeddings to this dimension (Matryoshka) |
|
||||
| `truncate` | string | No | `END`, `START`, or `NONE` (default: `END`) |
|
||||
|
||||
#### Text embedding
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v2/embed" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
"input_type": "query",
|
||||
"texts": ["Hello world", "How are you?"],
|
||||
"embedding_types": ["float"]
|
||||
}'
|
||||
```
|
||||
|
||||
??? console "Response"
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "embd-...",
|
||||
"embeddings": {
|
||||
"float": [
|
||||
[0.012, -0.034, ...],
|
||||
[0.056, 0.078, ...]
|
||||
]
|
||||
},
|
||||
"texts": ["Hello world", "How are you?"],
|
||||
"meta": {
|
||||
"api_version": {"version": "2"},
|
||||
"billed_units": {"input_tokens": 12}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Mixed text and image inputs
|
||||
|
||||
For multimodal models, you can embed images by passing base64 data URIs. The `inputs` field accepts a list of objects with mixed text and image content:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v2/embed" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "google/siglip-so400m-patch14-384",
|
||||
"inputs": [
|
||||
{
|
||||
"content": [
|
||||
{"type": "text", "text": "A photo of a cat"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}
|
||||
]
|
||||
}
|
||||
],
|
||||
"embedding_types": ["float"]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Embedding types
|
||||
|
||||
The `embedding_types` parameter controls the output format. Multiple types can be requested in a single call:
|
||||
|
||||
| Type | Description |
|
||||
| ---- | ----------- |
|
||||
| `float` | Raw float32 embeddings (default) |
|
||||
| `binary` | Bit-packed signed binary |
|
||||
| `ubinary` | Bit-packed unsigned binary |
|
||||
| `base64` | Little-endian float32 encoded as base64 |
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v2/embed" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
"input_type": "query",
|
||||
"texts": ["What is machine learning?"],
|
||||
"embedding_types": ["float", "binary"]
|
||||
}'
|
||||
```
|
||||
|
||||
??? console "Response"
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "embd-...",
|
||||
"embeddings": {
|
||||
"float": [[0.012, -0.034, ...]],
|
||||
"binary": [[42, -117, ...]]
|
||||
},
|
||||
"texts": ["What is machine learning?"],
|
||||
"meta": {
|
||||
"api_version": {"version": "2"},
|
||||
"billed_units": {"input_tokens": 8}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Truncation
|
||||
|
||||
The `truncate` parameter controls how inputs exceeding the model's maximum sequence length are handled:
|
||||
|
||||
| Value | Behavior |
|
||||
| ----- | --------- |
|
||||
| `END` (default) | Keep the first tokens, drop the end |
|
||||
| `START` | Keep the last tokens, drop the beginning |
|
||||
| `NONE` | Return an error if the input is too long |
|
||||
|
||||
#### Input type and prompt prefixes
|
||||
|
||||
The `input_type` field selects a prompt prefix to prepend to each text input. The available values
|
||||
depend on the model:
|
||||
|
||||
- **Models with `task_instructions` in `config.json`**: The keys from the `task_instructions` dict are
|
||||
the valid `input_type` values and the corresponding value is prepended to each text.
|
||||
- **Models with `config_sentence_transformers.json` prompts**: The keys from the `prompts` dict are
|
||||
the valid `input_type` values. For example, `Snowflake/snowflake-arctic-embed-xs` defines `"query"`,
|
||||
so setting `input_type: "query"` prepends `"Represent this sentence for searching relevant passages: "`.
|
||||
- **Other models**: `input_type` is not accepted and will raise a validation error if passed.
|
||||
|
||||
### Transcriptions API
|
||||
|
||||
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
|
||||
|
||||
310
tests/entrypoints/pooling/embed/test_cohere_online.py
Normal file
310
tests/entrypoints/pooling/embed/test_cohere_online.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Cohere /v2/embed API with generic (non-Cohere) models.
|
||||
|
||||
Validates that the Cohere v2 embed endpoint works correctly with standard
|
||||
embedding models, covering text embedding, embedding type conversions,
|
||||
response structure, batching, normalisation, and semantic similarity.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
MODELS: list[tuple[str, list[str]]] = [
|
||||
("intfloat/multilingual-e5-small", []),
|
||||
(
|
||||
"Snowflake/snowflake-arctic-embed-m-v1.5",
|
||||
[
|
||||
"--trust_remote_code",
|
||||
"--hf_overrides",
|
||||
'{"matryoshka_dimensions":[256]}',
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=MODELS, ids=lambda m: m[0])
|
||||
def model_config(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_name(model_config):
|
||||
return model_config[0]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(model_config):
|
||||
name, extra_args = model_config
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--gpu-memory-utilization",
|
||||
"0.02",
|
||||
] + extra_args
|
||||
with RemoteOpenAIServer(name, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
texts: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
input_type: str | None = None,
|
||||
embedding_types: list[str] | None = None,
|
||||
) -> dict:
|
||||
body: dict = {"model": model_name}
|
||||
if input_type is not None:
|
||||
body["input_type"] = input_type
|
||||
if texts is not None:
|
||||
body["texts"] = texts
|
||||
if images is not None:
|
||||
body["images"] = images
|
||||
if embedding_types is not None:
|
||||
body["embedding_types"] = embedding_types
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _openai_embed(
|
||||
server: RemoteOpenAIServer, model_name: str, texts: list[str]
|
||||
) -> dict:
|
||||
body = {"model": model_name, "input": texts, "encoding_format": "float"}
|
||||
resp = requests.post(server.url_for("/v1/embeddings"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _cosine_sim(a: list[float], b: list[float]) -> float:
|
||||
va, vb = np.array(a), np.array(b)
|
||||
return float(np.dot(va, vb) / (np.linalg.norm(va) * np.linalg.norm(vb)))
|
||||
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Text embedding tests
|
||||
# -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_basic_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server, model_name, texts=["hello world"], embedding_types=["float"]
|
||||
)
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
|
||||
|
||||
def test_unsupported_input_type_rejected(server: RemoteOpenAIServer, model_name: str):
|
||||
"""An input_type not defined in the model's prompt config should be
|
||||
rejected with a 400 error."""
|
||||
body = {
|
||||
"model": model_name,
|
||||
"input_type": "nonexistent_type",
|
||||
"texts": ["hello world"],
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 400
|
||||
assert "Unsupported input_type" in resp.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_omitted_input_type_accepted(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Omitting input_type should always work (no prompt prefix applied)."""
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": ["hello world"],
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_v1_v2_parity(server: RemoteOpenAIServer, model_name: str):
|
||||
"""v1 (OpenAI) and v2 (Cohere) endpoints should produce the same
|
||||
float embeddings for a generic model."""
|
||||
texts = ["hello world"]
|
||||
v2 = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"])
|
||||
v1 = _openai_embed(server, model_name, texts)
|
||||
cos = _cosine_sim(v2["embeddings"]["float"][0], v1["data"][0]["embedding"])
|
||||
assert cos > 0.9999, f"v1/v2 parity failed, cosine={cos}"
|
||||
|
||||
|
||||
def test_embedding_types(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["test"],
|
||||
embedding_types=["float", "binary", "ubinary"],
|
||||
)
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
assert len(r["embeddings"]["binary"][0]) == dim // 8
|
||||
assert len(r["embeddings"]["ubinary"][0]) == dim // 8
|
||||
|
||||
|
||||
def test_response_structure(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(server, model_name, texts=["test"], embedding_types=["float"])
|
||||
assert "id" in r
|
||||
assert "embeddings" in r
|
||||
assert "texts" in r
|
||||
assert r["texts"] == ["test"]
|
||||
assert "meta" in r
|
||||
assert r["meta"]["api_version"]["version"] == "2"
|
||||
assert "billed_units" in r["meta"]
|
||||
assert r["meta"]["billed_units"]["input_tokens"] > 0
|
||||
assert r["meta"]["billed_units"]["image_tokens"] == 0
|
||||
|
||||
|
||||
def test_batch(server: RemoteOpenAIServer, model_name: str):
|
||||
texts = ["apple", "banana", "cherry"]
|
||||
r = _cohere_embed(server, model_name, texts=texts, embedding_types=["float"])
|
||||
assert len(r["embeddings"]["float"]) == 3
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
for emb in r["embeddings"]["float"]:
|
||||
assert len(emb) == dim
|
||||
|
||||
|
||||
def test_l2_normalized(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server, model_name, texts=["hello world"], embedding_types=["float"]
|
||||
)
|
||||
emb = np.array(r["embeddings"]["float"][0])
|
||||
assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01
|
||||
|
||||
|
||||
def test_semantic_similarity(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["machine learning", "deep learning", "chocolate cake recipe"],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
embs = r["embeddings"]["float"]
|
||||
cos_related = _cosine_sim(embs[0], embs[1])
|
||||
cos_unrelated = _cosine_sim(embs[0], embs[2])
|
||||
assert cos_related > cos_unrelated
|
||||
|
||||
|
||||
def test_missing_input_returns_error(server: RemoteOpenAIServer, model_name: str):
|
||||
body = {"model": model_name}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_base64_embedding_type(server: RemoteOpenAIServer, model_name: str):
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
model_name,
|
||||
texts=["test encoding"],
|
||||
embedding_types=["float", "base64"],
|
||||
)
|
||||
float_emb = r["embeddings"]["float"][0]
|
||||
b64_str = r["embeddings"]["base64"][0]
|
||||
decoded = struct.unpack(f"<{len(float_emb)}f", base64.b64decode(b64_str))
|
||||
np.testing.assert_allclose(float_emb, decoded, rtol=1e-5)
|
||||
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Truncation tests
|
||||
# -----------------------------------------------------------
|
||||
|
||||
|
||||
def _cohere_embed_raw(
|
||||
server: RemoteOpenAIServer,
|
||||
body: dict,
|
||||
) -> requests.Response:
|
||||
return requests.post(server.url_for("/v2/embed"), json=body)
|
||||
|
||||
|
||||
def test_truncate_end_succeeds(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=END should silently truncate long input."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "END",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_truncate_start_succeeds(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=START should silently truncate long input from the start."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "START",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["embeddings"]["float"]) == 1
|
||||
|
||||
|
||||
def test_truncate_none_rejects_long_input(server: RemoteOpenAIServer, model_name: str):
|
||||
"""truncate=NONE should error when input exceeds model context."""
|
||||
long_text = " ".join(["word"] * 2000)
|
||||
body = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "NONE",
|
||||
}
|
||||
resp = _cohere_embed_raw(server, body)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_truncate_start_vs_end_differ(server: RemoteOpenAIServer, model_name: str):
|
||||
"""START and END truncation should produce different embeddings
|
||||
when the input is long enough to actually be truncated.
|
||||
|
||||
We construct input with distinct tokens at the start vs end
|
||||
so that keeping different halves produces different embeddings.
|
||||
"""
|
||||
start_words = " ".join([f"alpha{i}" for i in range(300)])
|
||||
end_words = " ".join([f"omega{i}" for i in range(300)])
|
||||
long_text = start_words + " " + end_words
|
||||
|
||||
body_end = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "END",
|
||||
}
|
||||
body_start = {
|
||||
"model": model_name,
|
||||
"texts": [long_text],
|
||||
"embedding_types": ["float"],
|
||||
"truncate": "START",
|
||||
}
|
||||
r_end = _cohere_embed_raw(server, body_end).json()
|
||||
r_start = _cohere_embed_raw(server, body_start).json()
|
||||
|
||||
emb_end = r_end["embeddings"]["float"][0]
|
||||
emb_start = r_start["embeddings"]["float"][0]
|
||||
cos = _cosine_sim(emb_end, emb_start)
|
||||
assert cos < 0.99, (
|
||||
f"START and END truncation should produce different embeddings "
|
||||
f"for long input, but cosine similarity was {cos}"
|
||||
)
|
||||
135
tests/entrypoints/pooling/embed/test_cohere_online_vision.py
Normal file
135
tests/entrypoints/pooling/embed/test_cohere_online_vision.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the Cohere /v2/embed API with a multimodal model (SigLIP).
|
||||
|
||||
Validates image embedding, batching, normalisation, and embedding type
|
||||
conversions through the /v2/embed endpoint.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
import zlib
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "google/siglip-so400m-patch14-384"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"64",
|
||||
"--gpu-memory-utilization",
|
||||
"0.3",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _make_tiny_png(r: int, g: int, b: int, w: int = 2, h: int = 2) -> str:
|
||||
raw = b""
|
||||
for _ in range(h):
|
||||
raw += b"\x00" + bytes([r, g, b]) * w
|
||||
compressed = zlib.compress(raw)
|
||||
|
||||
def chunk(ctype: bytes, cdata: bytes) -> bytes:
|
||||
c = ctype + cdata
|
||||
return (
|
||||
struct.pack(">I", len(cdata))
|
||||
+ c
|
||||
+ struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF)
|
||||
)
|
||||
|
||||
ihdr = struct.pack(">IIBBBBB", w, h, 8, 2, 0, 0, 0)
|
||||
png = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
+ chunk(b"IHDR", ihdr)
|
||||
+ chunk(b"IDAT", compressed)
|
||||
+ chunk(b"IEND", b"")
|
||||
)
|
||||
return "data:image/png;base64," + base64.b64encode(png).decode()
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str] | None = None,
|
||||
images: list[str] | None = None,
|
||||
embedding_types: list[str] | None = None,
|
||||
) -> dict:
|
||||
body: dict = {"model": MODEL_NAME}
|
||||
if texts is not None:
|
||||
body["texts"] = texts
|
||||
if images is not None:
|
||||
body["images"] = images
|
||||
if embedding_types is not None:
|
||||
body["embedding_types"] = embedding_types
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def test_image_embed(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(255, 0, 0)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
assert r["meta"]["billed_units"]["image_tokens"] > 0
|
||||
assert r["meta"]["billed_units"]["input_tokens"] == 0
|
||||
|
||||
|
||||
def test_image_batch(server: RemoteOpenAIServer):
|
||||
red = _make_tiny_png(255, 0, 0)
|
||||
blue = _make_tiny_png(0, 0, 255)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[red, blue],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
assert len(r["embeddings"]["float"]) == 2
|
||||
|
||||
|
||||
def test_image_l2_normalized(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(0, 255, 0)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float"],
|
||||
)
|
||||
emb = np.array(r["embeddings"]["float"][0])
|
||||
assert abs(float(np.linalg.norm(emb)) - 1.0) < 0.01
|
||||
|
||||
|
||||
def test_image_embedding_types(server: RemoteOpenAIServer):
|
||||
img_uri = _make_tiny_png(128, 128, 128)
|
||||
r = _cohere_embed(
|
||||
server,
|
||||
images=[img_uri],
|
||||
embedding_types=["float", "binary", "ubinary"],
|
||||
)
|
||||
dim = len(r["embeddings"]["float"][0])
|
||||
assert len(r["embeddings"]["binary"][0]) == dim // 8
|
||||
assert len(r["embeddings"]["ubinary"][0]) == dim // 8
|
||||
|
||||
|
||||
def test_text_embed_on_multimodal(server: RemoteOpenAIServer):
|
||||
"""SigLIP also supports text-only embedding via /v2/embed."""
|
||||
r = _cohere_embed(server, texts=["hello world"], embedding_types=["float"])
|
||||
assert "embeddings" in r
|
||||
assert len(r["embeddings"]["float"]) == 1
|
||||
assert len(r["embeddings"]["float"][0]) > 0
|
||||
102
tests/entrypoints/pooling/embed/test_cohere_openai_parity.py
Normal file
102
tests/entrypoints/pooling/embed/test_cohere_openai_parity.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Parity test between Cohere /v2/embed and OpenAI /v1/embeddings.
|
||||
|
||||
Verifies that both endpoints produce identical float embeddings when
|
||||
no prompt prefix is applied (input_type omitted for Cohere /v2/embed).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512",
|
||||
"--gpu-memory-utilization",
|
||||
"0.02",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def _cohere_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
body = {
|
||||
"model": MODEL_NAME,
|
||||
"texts": texts,
|
||||
"embedding_types": ["float"],
|
||||
}
|
||||
resp = requests.post(server.url_for("/v2/embed"), json=body)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["embeddings"]["float"]
|
||||
|
||||
|
||||
def _openai_embed(
|
||||
server: RemoteOpenAIServer,
|
||||
texts: list[str],
|
||||
) -> list[list[float]]:
|
||||
body = {"model": MODEL_NAME, "input": texts, "encoding_format": "float"}
|
||||
resp = requests.post(server.url_for("/v1/embeddings"), json=body)
|
||||
resp.raise_for_status()
|
||||
return [item["embedding"] for item in resp.json()["data"]]
|
||||
|
||||
|
||||
def test_single_text_parity(server: RemoteOpenAIServer):
|
||||
"""A single text should produce identical embeddings via both APIs."""
|
||||
texts = ["the quick brown fox jumps over the lazy dog"]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
np.testing.assert_allclose(v2[0], v1[0], rtol=1e-5)
|
||||
|
||||
|
||||
def test_batch_parity(server: RemoteOpenAIServer):
|
||||
"""A batch of texts should produce identical embeddings via both APIs,
|
||||
in the same order."""
|
||||
texts = [
|
||||
"machine learning",
|
||||
"deep learning",
|
||||
"natural language processing",
|
||||
]
|
||||
v2 = _cohere_embed(server, texts)
|
||||
v1 = _openai_embed(server, texts)
|
||||
assert len(v2) == len(v1) == 3
|
||||
for i in range(3):
|
||||
np.testing.assert_allclose(v2[i], v1[i], rtol=1e-5, err_msg=f"index {i}")
|
||||
|
||||
|
||||
def test_token_count_parity(server: RemoteOpenAIServer):
|
||||
"""Both APIs should report the same prompt token count."""
|
||||
texts = ["hello world"]
|
||||
v2_resp = requests.post(
|
||||
server.url_for("/v2/embed"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"texts": texts,
|
||||
"embedding_types": ["float"],
|
||||
},
|
||||
)
|
||||
v1_resp = requests.post(
|
||||
server.url_for("/v1/embeddings"),
|
||||
json={"model": MODEL_NAME, "input": texts, "encoding_format": "float"},
|
||||
)
|
||||
v2_resp.raise_for_status()
|
||||
v1_resp.raise_for_status()
|
||||
v2_tokens = v2_resp.json()["meta"]["billed_units"]["input_tokens"]
|
||||
v1_tokens = v1_resp.json()["usage"]["prompt_tokens"]
|
||||
assert v2_tokens == v1_tokens
|
||||
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
208
tests/entrypoints/pooling/embed/test_io_processor.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for EmbedIOProcessor."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveTruncation:
|
||||
"""Unit tests for EmbedIOProcessor._resolve_cohere_truncation."""
|
||||
|
||||
@staticmethod
|
||||
def _make_request(**kwargs) -> CohereEmbedRequest:
|
||||
defaults = {
|
||||
"model": "test",
|
||||
"input_type": "search_document",
|
||||
"texts": ["hello"],
|
||||
}
|
||||
return CohereEmbedRequest(**(defaults | kwargs))
|
||||
|
||||
def test_truncate_end_default(self):
|
||||
req = self._make_request()
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_explicit(self):
|
||||
req = self._make_request(truncate="END")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side is None
|
||||
|
||||
def test_truncate_end_with_max_tokens(self):
|
||||
req = self._make_request(truncate="END", max_tokens=128)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 128
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none(self):
|
||||
req = self._make_request(truncate="NONE")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_none_with_max_tokens(self):
|
||||
"""truncate=NONE should NOT set truncate_prompt_tokens; the
|
||||
max_tokens limit is enforced separately via _check_max_tokens."""
|
||||
req = self._make_request(truncate="NONE", max_tokens=10)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens is None
|
||||
assert side is None
|
||||
|
||||
def test_truncate_start(self):
|
||||
req = self._make_request(truncate="START")
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == -1
|
||||
assert side == "left"
|
||||
|
||||
def test_truncate_start_with_max_tokens(self):
|
||||
req = self._make_request(truncate="START", max_tokens=64)
|
||||
tokens, side = EmbedIOProcessor._resolve_cohere_truncation(req)
|
||||
assert tokens == 64
|
||||
assert side == "left"
|
||||
|
||||
|
||||
class TestApplyStPrompt:
|
||||
"""Unit tests for EmbedIOProcessor._apply_task_instruction."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_no_prompts_configured(self):
|
||||
handler = self._make_handler(None)
|
||||
texts = ["hello", "world"]
|
||||
assert handler._apply_task_instruction(texts, "query") is texts
|
||||
|
||||
def test_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
result = handler._apply_task_instruction(["hello"], "query")
|
||||
assert result == ["search_query: hello"]
|
||||
|
||||
def test_non_matching_input_type(self):
|
||||
handler = self._make_handler({"query": "search_query: "})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "document") is texts
|
||||
|
||||
def test_multiple_texts(self):
|
||||
handler = self._make_handler(
|
||||
{"query": "Represent this sentence for searching: "}
|
||||
)
|
||||
result = handler._apply_task_instruction(["a", "b", "c"], "query")
|
||||
assert result == [
|
||||
"Represent this sentence for searching: a",
|
||||
"Represent this sentence for searching: b",
|
||||
"Represent this sentence for searching: c",
|
||||
]
|
||||
|
||||
def test_empty_prefix_returns_unchanged(self):
|
||||
handler = self._make_handler({"passage": ""})
|
||||
texts = ["hello"]
|
||||
assert handler._apply_task_instruction(texts, "passage") is texts
|
||||
|
||||
|
||||
class TestLoadTaskInstructions:
|
||||
"""Unit tests for EmbedIOProcessor._load_task_instructions."""
|
||||
|
||||
def test_no_attribute(self):
|
||||
class FakeConfig:
|
||||
pass
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_with_task_instructions(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
result = EmbedIOProcessor._load_task_instructions(FakeConfig())
|
||||
assert result == {
|
||||
"retrieval.query": "Represent the query: ",
|
||||
"retrieval.passage": "",
|
||||
}
|
||||
|
||||
def test_empty_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = {}
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
def test_non_dict(self):
|
||||
class FakeConfig:
|
||||
task_instructions = "not a dict"
|
||||
|
||||
assert EmbedIOProcessor._load_task_instructions(FakeConfig()) is None
|
||||
|
||||
|
||||
class TestCheckMaxTokens:
|
||||
"""Unit tests for EmbedIOProcessor._check_cohere_max_tokens."""
|
||||
|
||||
@staticmethod
|
||||
def _fake_output(n_tokens: int):
|
||||
class _Out:
|
||||
def __init__(self, n: int):
|
||||
self.prompt_token_ids = list(range(n))
|
||||
|
||||
return _Out(n_tokens)
|
||||
|
||||
def test_none_check_is_noop(self):
|
||||
outs = [self._fake_output(100)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, None)
|
||||
|
||||
def test_within_limit(self):
|
||||
outs = [self._fake_output(5), self._fake_output(3)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exceeds_limit(self):
|
||||
outs = [self._fake_output(3), self._fake_output(10)]
|
||||
with pytest.raises(ValueError, match="exceeds max_tokens=5"):
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
def test_exact_limit(self):
|
||||
outs = [self._fake_output(5)]
|
||||
EmbedIOProcessor._check_cohere_max_tokens(outs, 5)
|
||||
|
||||
|
||||
class TestValidateInputType:
|
||||
"""Unit tests for EmbedIOProcessor._validate_input_type."""
|
||||
|
||||
@staticmethod
|
||||
def _make_handler(task_instructions: dict[str, str] | None):
|
||||
handler = object.__new__(EmbedIOProcessor)
|
||||
handler.task_instructions = task_instructions
|
||||
return handler
|
||||
|
||||
def test_none_input_type_always_accepted(self):
|
||||
handler = self._make_handler(None)
|
||||
handler._validate_input_type(None)
|
||||
handler_with = self._make_handler({"query": "q: "})
|
||||
handler_with._validate_input_type(None)
|
||||
|
||||
def test_no_prompts_rejects(self):
|
||||
handler = self._make_handler(None)
|
||||
with pytest.raises(ValueError, match="does not define any input_type"):
|
||||
handler._validate_input_type("anything")
|
||||
|
||||
def test_known_type_accepted(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
handler._validate_input_type("query")
|
||||
handler._validate_input_type("document")
|
||||
|
||||
def test_unknown_type_rejected(self):
|
||||
handler = self._make_handler({"query": "q: ", "document": "d: "})
|
||||
with pytest.raises(ValueError, match="Unsupported input_type 'other'"):
|
||||
handler._validate_input_type("other")
|
||||
|
||||
def test_error_lists_supported(self):
|
||||
handler = self._make_handler({"a": "", "b": ""})
|
||||
with pytest.raises(ValueError, match="Supported values: a, b"):
|
||||
handler._validate_input_type("z")
|
||||
129
tests/entrypoints/pooling/embed/test_protocol.py
Normal file
129
tests/entrypoints/pooling/embed/test_protocol.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for Cohere embed protocol: build_typed_embeddings and its
|
||||
underlying packing helpers, plus Cohere-specific serving helpers."""
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
build_typed_embeddings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings() -> list[list[float]]:
|
||||
return [
|
||||
[0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8],
|
||||
[-0.05, 0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75],
|
||||
]
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsFloat:
|
||||
def test_float_passthrough(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float"])
|
||||
assert result.float == sample_embeddings
|
||||
assert result.binary is None
|
||||
|
||||
def test_empty_input(self):
|
||||
result = build_typed_embeddings([], ["float"])
|
||||
assert result.float == []
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsBinary:
|
||||
def test_binary_packing(self):
|
||||
# 8 values: positive->1, negative->0 => bits: 10101010 = 0xAA = 170
|
||||
# signed: 170 - 128 = 42
|
||||
embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
assert result.binary[0] == [42]
|
||||
|
||||
def test_ubinary_packing(self):
|
||||
embs = [[1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0]]
|
||||
result = build_typed_embeddings(embs, ["ubinary"])
|
||||
assert result.ubinary is not None
|
||||
assert result.ubinary[0] == [170] # 0b10101010
|
||||
|
||||
def test_binary_all_positive(self):
|
||||
embs = [[0.1] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# all bits = 1 => 0xFF = 255, signed: 255 - 128 = 127
|
||||
assert result.binary[0] == [127]
|
||||
|
||||
def test_binary_all_negative(self):
|
||||
embs = [[-0.1] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# all bits = 0, signed: 0 - 128 = -128
|
||||
assert result.binary[0] == [-128]
|
||||
|
||||
def test_binary_dimension_is_eighth(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["binary"])
|
||||
assert result.binary is not None
|
||||
for orig, packed in zip(sample_embeddings, result.binary):
|
||||
assert len(packed) == len(orig) // 8
|
||||
|
||||
def test_zero_treated_as_positive(self):
|
||||
embs = [[0.0] * 8]
|
||||
result = build_typed_embeddings(embs, ["binary"])
|
||||
assert result.binary is not None
|
||||
# 0.0 >= 0 is True, so bit=1 for all => 127 (signed)
|
||||
assert result.binary[0] == [127]
|
||||
|
||||
def test_non_multiple_of_8_raises(self):
|
||||
embs = [[0.1] * 7]
|
||||
with pytest.raises(ValueError, match="multiple of 8"):
|
||||
build_typed_embeddings(embs, ["binary"])
|
||||
|
||||
def test_ubinary_non_multiple_of_8_raises(self):
|
||||
embs = [[0.1] * 10]
|
||||
with pytest.raises(ValueError, match="multiple of 8"):
|
||||
build_typed_embeddings(embs, ["ubinary"])
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsBase64:
|
||||
def test_base64_roundtrip(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["base64"])
|
||||
assert result.base64 is not None
|
||||
assert len(result.base64) == 2
|
||||
|
||||
for orig, b64_str in zip(sample_embeddings, result.base64):
|
||||
decoded = base64.b64decode(b64_str)
|
||||
n = len(orig)
|
||||
values = struct.unpack(f"<{n}f", decoded)
|
||||
np.testing.assert_allclose(orig, values, rtol=1e-5)
|
||||
|
||||
def test_base64_byte_length(self):
|
||||
embs = [[0.1, 0.2, 0.3]]
|
||||
result = build_typed_embeddings(embs, ["base64"])
|
||||
assert result.base64 is not None
|
||||
raw = base64.b64decode(result.base64[0])
|
||||
assert len(raw) == 3 * 4 # 3 floats * 4 bytes each
|
||||
|
||||
|
||||
class TestBuildTypedEmbeddingsMultiple:
|
||||
def test_all_types_at_once(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(
|
||||
sample_embeddings,
|
||||
["float", "binary", "ubinary", "base64"],
|
||||
)
|
||||
assert result.float is not None
|
||||
assert result.binary is not None
|
||||
assert result.ubinary is not None
|
||||
assert result.base64 is not None
|
||||
|
||||
def test_subset_types(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float", "binary"])
|
||||
assert result.float is not None
|
||||
assert result.binary is not None
|
||||
assert result.ubinary is None
|
||||
assert result.base64 is None
|
||||
|
||||
def test_unknown_type_ignored(self, sample_embeddings: list[list[float]]):
|
||||
result = build_typed_embeddings(sample_embeddings, ["float", "unknown_type"])
|
||||
assert result.float is not None
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
@@ -24,6 +24,14 @@ class PoolingBasicRequestMixin(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [start:pooling-common-extra-params]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
truncation_side: Literal["left", "right"] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Which side to truncate from when truncate_prompt_tokens is active. "
|
||||
"'right' keeps the first N tokens. "
|
||||
"'left' keeps the last N tokens."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
|
||||
@@ -32,6 +32,7 @@ class ClassificationCompletionRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -54,6 +55,7 @@ class ClassificationChatRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
|
||||
@@ -7,12 +7,12 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
EmbeddingRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -40,3 +40,24 @@ async def create_embedding(
|
||||
raise NotImplementedError("The model does not support Embeddings API")
|
||||
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/embed",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_cohere_embedding(
|
||||
request: CohereEmbedRequest,
|
||||
raw_request: Request,
|
||||
):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Embeddings API")
|
||||
|
||||
return await handler(request, raw_request)
|
||||
|
||||
@@ -1,14 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import torch
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedInput,
|
||||
CohereEmbedRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.inputs.data import ProcessorInputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.renderers import merge_kwargs
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbedIOProcessor(PoolingIOProcessor):
|
||||
@@ -21,16 +44,45 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
self.pooler_config = self.model_config.pooler_config
|
||||
self.enable_chunked_processing = self.pooler_config.enable_chunked_processing
|
||||
|
||||
# Load task instructions from HF config or sentence-transformers config
|
||||
self.task_instructions: dict[str, str] | None = self._load_task_instructions(
|
||||
self.model_config.hf_config
|
||||
) or self._load_st_prompts(self.model_config.model, self.model_config.revision)
|
||||
if self.task_instructions:
|
||||
logger.info(
|
||||
"Loaded prompt prefixes for input_type: %s",
|
||||
list(self.task_instructions.keys()),
|
||||
)
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
self._pre_process_cohere_online(ctx)
|
||||
else:
|
||||
super().pre_process_online(ctx)
|
||||
|
||||
if self.enable_chunked_processing:
|
||||
self._pre_process_chunked(ctx)
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
self._enforce_cohere_max_tokens(ctx)
|
||||
return super().post_process_online(ctx)
|
||||
|
||||
self._post_process_chunked(ctx)
|
||||
self._enforce_cohere_max_tokens(ctx)
|
||||
|
||||
#################################################################
|
||||
# Long Text Embedding with Chunked Processing
|
||||
# PTAL: examples/pooling/embed/openai_embedding_long_text
|
||||
#################################################################
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
super().pre_process_online(ctx)
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
return None
|
||||
|
||||
def _pre_process_chunked(self, ctx: PoolingServeContext) -> None:
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
@@ -61,18 +113,10 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
|
||||
ctx.engine_prompts = chunked_engine_prompts
|
||||
ctx.prompt_request_ids = prompt_request_ids
|
||||
|
||||
return None
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
return super().post_process_online(ctx)
|
||||
|
||||
def _post_process_chunked(self, ctx: PoolingServeContext) -> None:
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
@@ -195,4 +239,245 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
raise ValueError(f"Result not found for prompt {prompt_idx}")
|
||||
|
||||
ctx.final_res_batch = final_res_batch
|
||||
|
||||
return None
|
||||
|
||||
#################################################################
|
||||
# Cohere Request Preprocessing & Postprocessing
|
||||
#################################################################
|
||||
|
||||
@staticmethod
|
||||
def _load_task_instructions(hf_config: Any) -> dict[str, str] | None:
|
||||
"""Extract ``task_instructions`` from the HF model config."""
|
||||
ti = getattr(hf_config, "task_instructions", None)
|
||||
if not isinstance(ti, dict) or not ti:
|
||||
return None
|
||||
return {k: v for k, v in ti.items() if isinstance(v, str)}
|
||||
|
||||
@staticmethod
|
||||
def _load_st_prompts(
|
||||
model: str | Any,
|
||||
revision: str | None,
|
||||
) -> dict[str, str] | None:
|
||||
"""Load ``task_instructions`` from ``config_sentence_transformers.json``."""
|
||||
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
|
||||
|
||||
try:
|
||||
cfg = get_hf_file_to_dict(
|
||||
"config_sentence_transformers.json", str(model), revision
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
return None
|
||||
|
||||
if cfg is None:
|
||||
return None
|
||||
prompts = cfg.get("prompts")
|
||||
if not isinstance(prompts, dict) or not prompts:
|
||||
return None
|
||||
return {k: v for k, v in prompts.items() if isinstance(v, str)}
|
||||
|
||||
@staticmethod
|
||||
def _mixed_input_to_messages(
|
||||
inp: CohereEmbedInput,
|
||||
*,
|
||||
task_prefix: str | None = None,
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Build chat messages from a mixed text+image input.
|
||||
|
||||
When *task_prefix* is given, it is prepended to each text part.
|
||||
"""
|
||||
parts: list[ChatCompletionContentPartParam] = []
|
||||
for item in inp.content:
|
||||
if item.type == "text" and item.text is not None:
|
||||
text = task_prefix + item.text if task_prefix else item.text
|
||||
parts.append(ChatCompletionContentPartTextParam(type="text", text=text))
|
||||
elif item.type == "image_url" and item.image_url is not None:
|
||||
parts.append(
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=ImageURL(url=item.image_url["url"]),
|
||||
)
|
||||
)
|
||||
return [CustomChatCompletionMessageParam(role="user", content=parts)]
|
||||
|
||||
@staticmethod
|
||||
def _check_cohere_max_tokens(
|
||||
outputs: list[PoolingRequestOutput],
|
||||
max_tokens_check: int | None,
|
||||
) -> None:
|
||||
"""Raise if any output exceeds *max_tokens_check* tokens.
|
||||
|
||||
Used to enforce ``truncate=NONE`` with an explicit ``max_tokens``:
|
||||
the pipeline runs without truncation and we reject afterwards.
|
||||
"""
|
||||
if max_tokens_check is None:
|
||||
return
|
||||
for out in outputs:
|
||||
n = len(out.prompt_token_ids)
|
||||
if n > max_tokens_check:
|
||||
raise ValueError(
|
||||
f"Input of {n} tokens exceeds max_tokens={max_tokens_check} "
|
||||
"with truncate=NONE. Set truncate to END or START to "
|
||||
"allow truncation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_cohere_truncation(
|
||||
request: CohereEmbedRequest,
|
||||
) -> tuple[int | None, Literal["left", "right"] | None]:
|
||||
"""Return ``(truncate_prompt_tokens, truncation_side)``."""
|
||||
if request.truncate == "NONE":
|
||||
return None, None
|
||||
if request.truncate == "START":
|
||||
tokens = request.max_tokens if request.max_tokens is not None else -1
|
||||
return tokens, "left"
|
||||
if request.max_tokens is not None:
|
||||
return request.max_tokens, None
|
||||
return -1, None
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
if isinstance(request, CohereEmbedRequest):
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=request.output_dimension,
|
||||
)
|
||||
return super().create_pooling_params(request)
|
||||
|
||||
def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
|
||||
"""Convert a ``CohereEmbedRequest`` into engine prompts.
|
||||
|
||||
For texts, a single batched completion request path is used.
|
||||
For images and mixed inputs, conversations are batch-rendered
|
||||
through the chat template in one ``render_chat`` call.
|
||||
"""
|
||||
request = ctx.request
|
||||
assert isinstance(request, CohereEmbedRequest)
|
||||
|
||||
if request.texts is None and request.images is None and request.inputs is None:
|
||||
raise ValueError("One of texts, images, or inputs must be provided")
|
||||
|
||||
truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation(
|
||||
request
|
||||
)
|
||||
input_type = request.input_type
|
||||
self._validate_input_type(input_type)
|
||||
|
||||
if request.images is not None:
|
||||
all_messages: list[list[ChatCompletionMessageParam]] = [
|
||||
[
|
||||
CustomChatCompletionMessageParam(
|
||||
role="user",
|
||||
content=[{"type": "image_url", "image_url": {"url": uri}}],
|
||||
)
|
||||
]
|
||||
for uri in request.images
|
||||
]
|
||||
ctx.engine_prompts = self._batch_render_chat(
|
||||
request, all_messages, truncate_prompt_tokens, truncation_side
|
||||
)
|
||||
|
||||
elif request.inputs is not None:
|
||||
task_prefix = self._get_task_instruction_prefix(input_type)
|
||||
all_messages = [
|
||||
self._mixed_input_to_messages(inp, task_prefix=task_prefix)
|
||||
for inp in request.inputs
|
||||
]
|
||||
ctx.engine_prompts = self._batch_render_chat(
|
||||
request, all_messages, truncate_prompt_tokens, truncation_side
|
||||
)
|
||||
|
||||
else:
|
||||
prefixed = self._apply_task_instruction(request.texts or [], input_type)
|
||||
proxy = EmbeddingCompletionRequest(
|
||||
model=request.model,
|
||||
input=prefixed,
|
||||
dimensions=request.output_dimension,
|
||||
encoding_format="float",
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
ctx.engine_prompts = self._preprocess_completion_online(
|
||||
proxy, prompt_input=proxy.input, prompt_embeds=None
|
||||
)
|
||||
|
||||
def _batch_render_chat(
|
||||
self,
|
||||
request: CohereEmbedRequest,
|
||||
all_messages: Sequence[list[ChatCompletionMessageParam]],
|
||||
truncate_prompt_tokens: int | None,
|
||||
truncation_side: Literal["left", "right"] | None,
|
||||
) -> list[ProcessorInputs]:
|
||||
"""Batch-render multiple conversations through the chat template."""
|
||||
if not all_messages:
|
||||
return []
|
||||
|
||||
proxy = EmbeddingChatRequest(
|
||||
model=request.model,
|
||||
messages=list(all_messages[0]),
|
||||
dimensions=request.output_dimension,
|
||||
encoding_format="float",
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
|
||||
renderer = self.renderer
|
||||
mm_config = self.model_config.multimodal_config
|
||||
|
||||
tok_params = proxy.build_tok_params(self.model_config)
|
||||
chat_params = proxy.build_chat_params(
|
||||
self.chat_template,
|
||||
self.chat_template_content_format,
|
||||
).with_defaults(
|
||||
merge_kwargs(
|
||||
None,
|
||||
dict(
|
||||
tools=None,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
),
|
||||
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
|
||||
)
|
||||
|
||||
_, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params)
|
||||
return engine_prompts
|
||||
|
||||
def _validate_input_type(self, input_type: str | None) -> None:
|
||||
"""Raise if *input_type* is not supported by this model."""
|
||||
if input_type is None:
|
||||
return
|
||||
if self.task_instructions is None:
|
||||
raise ValueError(
|
||||
f"Unsupported input_type {input_type!r}. "
|
||||
"This model does not define any input_type task instructions."
|
||||
)
|
||||
if input_type not in self.task_instructions:
|
||||
supported = ", ".join(sorted(self.task_instructions))
|
||||
raise ValueError(
|
||||
f"Unsupported input_type {input_type!r}. Supported values: {supported}"
|
||||
)
|
||||
|
||||
def _apply_task_instruction(
|
||||
self,
|
||||
texts: list[str],
|
||||
input_type: str | None,
|
||||
) -> list[str]:
|
||||
"""Prepend the task-instruction prefix for *input_type*.
|
||||
|
||||
Returns *texts* unchanged when no matching prefix is configured.
|
||||
"""
|
||||
prefix = self._get_task_instruction_prefix(input_type)
|
||||
if not prefix:
|
||||
return texts
|
||||
return [prefix + t for t in texts]
|
||||
|
||||
def _get_task_instruction_prefix(self, input_type: str | None) -> str | None:
|
||||
"""Return the task-instruction prefix for *input_type*, or ``None``."""
|
||||
if not self.task_instructions or input_type is None:
|
||||
return None
|
||||
return self.task_instructions.get(input_type) or None
|
||||
|
||||
def _enforce_cohere_max_tokens(self, ctx: PoolingServeContext) -> None:
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
request = ctx.request
|
||||
if request.truncate == "NONE" and request.max_tokens is not None:
|
||||
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import TypeAlias
|
||||
"""Embedding API protocol models for OpenAI and Cohere formats.
|
||||
|
||||
from pydantic import Field
|
||||
OpenAI: https://platform.openai.com/docs/api-reference/embeddings
|
||||
Cohere: https://docs.cohere.com/reference/embed
|
||||
"""
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
import struct
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
@@ -17,6 +27,10 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI /v1/embeddings — request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
@@ -50,6 +64,7 @@ class EmbeddingCompletionRequest(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -79,6 +94,7 @@ class EmbeddingChatRequest(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -96,6 +112,11 @@ class EmbeddingChatRequest(
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI /v1/embeddings — response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
@@ -106,7 +127,7 @@ class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
model: str | None = None
|
||||
data: list[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
@@ -115,3 +136,146 @@ class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||
content: list[bytes]
|
||||
headers: dict[str, str] | None = None
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere /v2/embed — request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CohereEmbeddingType = Literal[
|
||||
"float",
|
||||
"binary",
|
||||
"ubinary",
|
||||
"base64",
|
||||
]
|
||||
CohereTruncate = Literal["NONE", "START", "END"]
|
||||
|
||||
|
||||
class CohereEmbedContent(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
text: str | None = None
|
||||
image_url: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CohereEmbedInput(BaseModel):
|
||||
content: list[CohereEmbedContent]
|
||||
|
||||
|
||||
class CohereEmbedRequest(BaseModel):
|
||||
model: str | None = None
|
||||
input_type: str | None = None
|
||||
texts: list[str] | None = None
|
||||
images: list[str] | None = None
|
||||
inputs: list[CohereEmbedInput] | None = None
|
||||
output_dimension: int | None = None
|
||||
embedding_types: list[CohereEmbeddingType] | None = None
|
||||
truncate: CohereTruncate = "END"
|
||||
max_tokens: int | None = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere /v2/embed — response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CohereApiVersion(BaseModel):
|
||||
version: str = "2"
|
||||
|
||||
|
||||
class CohereBilledUnits(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
image_tokens: int | None = None
|
||||
|
||||
|
||||
class CohereMeta(BaseModel):
|
||||
api_version: CohereApiVersion = Field(default_factory=CohereApiVersion)
|
||||
billed_units: CohereBilledUnits | None = None
|
||||
|
||||
|
||||
class CohereEmbedByTypeEmbeddings(BaseModel):
|
||||
# The field name ``float`` shadows the builtin type, so the annotation
|
||||
# must use ``builtins.float`` to avoid a self-referential type error.
|
||||
float: list[list[builtins.float]] | None = None
|
||||
binary: list[list[int]] | None = None
|
||||
ubinary: list[list[int]] | None = None
|
||||
base64: list[str] | None = None
|
||||
|
||||
|
||||
class CohereEmbedResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
embeddings: CohereEmbedByTypeEmbeddings
|
||||
texts: list[str] | None = None
|
||||
meta: CohereMeta | None = None
|
||||
response_type: Literal["embeddings_by_type"] = "embeddings_by_type"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere embedding type conversion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UNSIGNED_TO_SIGNED_DIFF = 1 << 7 # 128
|
||||
|
||||
|
||||
def _pack_binary_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
signed: bool,
|
||||
) -> list[list[int]]:
|
||||
"""Bit-pack float embeddings: positive -> 1, negative -> 0.
|
||||
|
||||
Each bit is shifted left by ``7 - idx%8``, and every 8 bits are packed
|
||||
into one byte.
|
||||
"""
|
||||
result: list[list[int]] = []
|
||||
for embedding in float_embeddings:
|
||||
dim = len(embedding)
|
||||
if dim % 8 != 0:
|
||||
raise ValueError(
|
||||
"Embedding dimension must be a multiple of 8 for binary "
|
||||
f"embedding types, but got {dim}."
|
||||
)
|
||||
packed_len = dim // 8
|
||||
packed: list[int] = []
|
||||
byte_val = 0
|
||||
for idx, value in enumerate(embedding):
|
||||
bit = 1 if value >= 0 else 0
|
||||
byte_val += bit << (7 - idx % 8)
|
||||
if (idx + 1) % 8 == 0:
|
||||
if signed:
|
||||
byte_val -= _UNSIGNED_TO_SIGNED_DIFF
|
||||
packed.append(byte_val)
|
||||
byte_val = 0
|
||||
assert len(packed) == packed_len
|
||||
result.append(packed)
|
||||
return result
|
||||
|
||||
|
||||
def _encode_base64_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
) -> list[str]:
|
||||
"""Encode float embeddings as base64 (little-endian float32)."""
|
||||
result: list[str] = []
|
||||
for embedding in float_embeddings:
|
||||
buf = struct.pack(f"<{len(embedding)}f", *embedding)
|
||||
result.append(base64.b64encode(buf).decode("utf-8"))
|
||||
return result
|
||||
|
||||
|
||||
def build_typed_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
embedding_types: Sequence[str],
|
||||
) -> CohereEmbedByTypeEmbeddings:
|
||||
"""Convert float embeddings to all requested Cohere embedding types."""
|
||||
result = CohereEmbedByTypeEmbeddings()
|
||||
|
||||
for emb_type in embedding_types:
|
||||
if emb_type == "float":
|
||||
result.float = float_embeddings
|
||||
elif emb_type == "binary":
|
||||
result.binary = _pack_binary_embeddings(float_embeddings, signed=True)
|
||||
elif emb_type == "ubinary":
|
||||
result.ubinary = _pack_binary_embeddings(float_embeddings, signed=False)
|
||||
elif emb_type == "base64":
|
||||
result.base64 = _encode_base64_embeddings(float_embeddings)
|
||||
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Literal, TypeAlias, cast
|
||||
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@@ -14,10 +14,15 @@ from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereBilledUnits,
|
||||
CohereEmbedRequest,
|
||||
CohereEmbedResponse,
|
||||
CohereMeta,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
build_typed_embeddings,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.entrypoints.pooling.utils import (
|
||||
@@ -26,24 +31,23 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_float,
|
||||
get_json_response_cls,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
JSONResponseCLS = get_json_response_cls()
|
||||
|
||||
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
|
||||
|
||||
|
||||
class ServingEmbedding(PoolingServing):
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
"""Embedding API supporting both OpenAI and Cohere formats."""
|
||||
|
||||
request_id_prefix = "embd"
|
||||
io_processor: EmbedIOProcessor
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
@@ -58,6 +62,14 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> Response:
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
return self._build_cohere_response_from_ctx(ctx)
|
||||
return await self._build_openai_response(ctx)
|
||||
|
||||
async def _build_openai_response(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> JSONResponse | StreamingResponse:
|
||||
@@ -66,7 +78,7 @@ class ServingEmbedding(PoolingServing):
|
||||
endianness = ctx.request.endianness
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self._request_output_to_embed_json_response(
|
||||
return self._openai_json_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
@@ -77,7 +89,7 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self._request_output_to_to_embed_bytes_response(
|
||||
return self._openai_bytes_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
@@ -89,7 +101,7 @@ class ServingEmbedding(PoolingServing):
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _request_output_to_embed_json_response(
|
||||
def _openai_json_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
@@ -139,7 +151,7 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
return JSONResponseCLS(content=response.model_dump())
|
||||
|
||||
def _request_output_to_to_embed_bytes_response(
|
||||
def _openai_bytes_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
@@ -177,3 +189,33 @@ class ServingEmbedding(PoolingServing):
|
||||
headers=response.headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_cohere_response_from_ctx(
|
||||
ctx: PoolingServeContext,
|
||||
) -> JSONResponse:
|
||||
request = ctx.request
|
||||
assert isinstance(request, CohereEmbedRequest)
|
||||
|
||||
all_floats = [encode_pooling_output_float(out) for out in ctx.final_res_batch]
|
||||
total_tokens = sum(len(out.prompt_token_ids) for out in ctx.final_res_batch)
|
||||
|
||||
image_tokens = total_tokens if request.images is not None else 0
|
||||
texts_echo = request.texts
|
||||
|
||||
embedding_types = request.embedding_types or ["float"]
|
||||
embeddings_obj = build_typed_embeddings(all_floats, embedding_types)
|
||||
|
||||
input_tokens = total_tokens - image_tokens
|
||||
response = CohereEmbedResponse(
|
||||
id=ctx.request_id,
|
||||
embeddings=embeddings_obj,
|
||||
texts=texts_echo,
|
||||
meta=CohereMeta(
|
||||
billed_units=CohereBilledUnits(
|
||||
input_tokens=input_tokens,
|
||||
image_tokens=image_tokens,
|
||||
),
|
||||
),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
|
||||
@@ -36,6 +36,7 @@ class PoolingCompletionRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -61,6 +62,7 @@ class PoolingChatRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -88,6 +90,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
max_total_tokens_param="max_model_len",
|
||||
|
||||
@@ -30,6 +30,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
@@ -105,6 +106,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
@@ -50,6 +51,7 @@ AnyPoolingRequest: TypeAlias = (
|
||||
| IOProcessorRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| CohereEmbedRequest
|
||||
)
|
||||
|
||||
AnyPoolingResponse: TypeAlias = (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
@@ -153,6 +153,14 @@ class TokenizeParams:
|
||||
- `-1` maps to `max_input_tokens`.
|
||||
"""
|
||||
|
||||
truncation_side: Literal["left", "right"] | None = None
|
||||
"""
|
||||
Which side to truncate from when ``truncate_prompt_tokens`` is active:
|
||||
- ``"right"`` keeps the first N tokens (truncate from the end).
|
||||
- ``"left"`` keeps the last N tokens (truncate from the start).
|
||||
- ``None`` falls back to the tokenizer default.
|
||||
"""
|
||||
|
||||
do_lower_case: bool = False
|
||||
"""Whether to normalize text to lower case before tokenization."""
|
||||
|
||||
@@ -271,6 +279,7 @@ class TokenizeParams:
|
||||
),
|
||||
pad_prompt_tokens=pad_prompt_tokens,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=do_lower_case,
|
||||
add_special_tokens=add_special_tokens,
|
||||
needs_detokenization=needs_detokenization,
|
||||
@@ -286,6 +295,16 @@ class TokenizeParams:
|
||||
# while still failing `self._token_len_check` as expected by users
|
||||
max_length = self.max_input_tokens + 1
|
||||
|
||||
# Left-side truncation requires the full token sequence so we can
|
||||
# slice from the end in _token_truncation. Disable HF-level
|
||||
# truncation (which would incorrectly truncate from the right for
|
||||
# pooling models) and let _token_truncation handle it.
|
||||
if self.truncation_side == "left":
|
||||
return dict(
|
||||
truncation=False,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
return dict(
|
||||
truncation=max_length is not None,
|
||||
max_length=max_length,
|
||||
@@ -375,7 +394,10 @@ class TokenizeParams:
|
||||
if max_length == 0:
|
||||
return tokens[:0]
|
||||
|
||||
if getattr(tokenizer, "truncation_side", "left") == "left":
|
||||
side = self.truncation_side or (
|
||||
tokenizer.truncation_side if tokenizer is not None else None
|
||||
)
|
||||
if side == "left":
|
||||
return tokens[-max_length:]
|
||||
|
||||
return tokens[:max_length]
|
||||
|
||||
Reference in New Issue
Block a user