[Feature][Frontend] add support for Cohere Embed v2 API (#37074)

Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
This commit is contained in:
Walter Beller-Morales
2026-03-16 19:55:53 -04:00
committed by khluu
parent 1fe3932c8b
commit 4d22667c32
16 changed files with 1609 additions and 40 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Annotated, Any
from typing import Annotated, Any, Literal
from pydantic import Field, model_validator
@@ -24,6 +24,14 @@ class PoolingBasicRequestMixin(OpenAIBaseModel):
# --8<-- [start:pooling-common-extra-params]
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
truncation_side: Literal["left", "right"] | None = Field(
default=None,
description=(
"Which side to truncate from when truncate_prompt_tokens is active. "
"'right' keeps the first N tokens. "
"'left' keeps the last N tokens."
),
)
request_id: str = Field(
default_factory=random_uuid,
description=(

View File

@@ -32,6 +32,7 @@ class ClassificationCompletionRequest(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
@@ -54,6 +55,7 @@ class ClassificationChatRequest(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",

View File

@@ -7,12 +7,12 @@ from fastapi import APIRouter, Depends, Request
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedRequest,
EmbeddingRequest,
)
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
from vllm.entrypoints.utils import load_aware_call, with_cancellation
router = APIRouter()
@@ -40,3 +40,24 @@ async def create_embedding(
raise NotImplementedError("The model does not support Embeddings API")
return await handler(request, raw_request)
@router.post(
"/v2/embed",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_cohere_embedding(
request: CohereEmbedRequest,
raw_request: Request,
):
handler = embedding(raw_request)
if handler is None:
raise NotImplementedError("The model does not support Embeddings API")
return await handler(request, raw_request)

View File

@@ -1,14 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast
from collections.abc import Sequence
from typing import Any, Literal, cast
import torch
from openai.types.chat import (
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
)
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
CustomChatCompletionMessageParam,
)
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedInput,
CohereEmbedRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
)
from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.renderers import merge_kwargs
from vllm.utils.collection_utils import chunk_list
from vllm.utils.mistral import is_mistral_tokenizer
logger = init_logger(__name__)
class EmbedIOProcessor(PoolingIOProcessor):
@@ -21,16 +44,45 @@ class EmbedIOProcessor(PoolingIOProcessor):
self.pooler_config = self.model_config.pooler_config
self.enable_chunked_processing = self.pooler_config.enable_chunked_processing
# Load task instructions from HF config or sentence-transformers config
self.task_instructions: dict[str, str] | None = self._load_task_instructions(
self.model_config.hf_config
) or self._load_st_prompts(self.model_config.model, self.model_config.revision)
if self.task_instructions:
logger.info(
"Loaded prompt prefixes for input_type: %s",
list(self.task_instructions.keys()),
)
def pre_process_online(self, ctx: PoolingServeContext):
if isinstance(ctx.request, CohereEmbedRequest):
self._pre_process_cohere_online(ctx)
else:
super().pre_process_online(ctx)
if self.enable_chunked_processing:
self._pre_process_chunked(ctx)
def post_process_online(
self,
ctx: PoolingServeContext,
):
if ctx.final_res_batch is None:
raise ValueError("Final response batch not available")
if not self.enable_chunked_processing:
self._enforce_cohere_max_tokens(ctx)
return super().post_process_online(ctx)
self._post_process_chunked(ctx)
self._enforce_cohere_max_tokens(ctx)
#################################################################
# Long Text Embedding with Chunked Processing
# PTAL: examples/pooling/embed/openai_embedding_long_text
#################################################################
def pre_process_online(self, ctx: PoolingServeContext):
super().pre_process_online(ctx)
if not self.enable_chunked_processing:
return None
def _pre_process_chunked(self, ctx: PoolingServeContext) -> None:
if ctx.engine_prompts is None:
raise ValueError("Engine prompts not available")
@@ -61,18 +113,10 @@ class EmbedIOProcessor(PoolingIOProcessor):
ctx.engine_prompts = chunked_engine_prompts
ctx.prompt_request_ids = prompt_request_ids
return None
def post_process_online(
self,
ctx: PoolingServeContext,
):
if ctx.final_res_batch is None:
raise ValueError("Final response batch not available")
if not self.enable_chunked_processing:
return super().post_process_online(ctx)
def _post_process_chunked(self, ctx: PoolingServeContext) -> None:
# Online aggregation for chunked requests to
# minimize memory usage
# Track aggregation state for each prompt
@@ -195,4 +239,245 @@ class EmbedIOProcessor(PoolingIOProcessor):
raise ValueError(f"Result not found for prompt {prompt_idx}")
ctx.final_res_batch = final_res_batch
return None
#################################################################
# Cohere Request Preprocessing & Postprocessing
#################################################################
@staticmethod
def _load_task_instructions(hf_config: Any) -> dict[str, str] | None:
"""Extract ``task_instructions`` from the HF model config."""
ti = getattr(hf_config, "task_instructions", None)
if not isinstance(ti, dict) or not ti:
return None
return {k: v for k, v in ti.items() if isinstance(v, str)}
@staticmethod
def _load_st_prompts(
model: str | Any,
revision: str | None,
) -> dict[str, str] | None:
"""Load ``task_instructions`` from ``config_sentence_transformers.json``."""
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
try:
cfg = get_hf_file_to_dict(
"config_sentence_transformers.json", str(model), revision
)
except (ValueError, OSError):
return None
if cfg is None:
return None
prompts = cfg.get("prompts")
if not isinstance(prompts, dict) or not prompts:
return None
return {k: v for k, v in prompts.items() if isinstance(v, str)}
@staticmethod
def _mixed_input_to_messages(
inp: CohereEmbedInput,
*,
task_prefix: str | None = None,
) -> list[ChatCompletionMessageParam]:
"""Build chat messages from a mixed text+image input.
When *task_prefix* is given, it is prepended to each text part.
"""
parts: list[ChatCompletionContentPartParam] = []
for item in inp.content:
if item.type == "text" and item.text is not None:
text = task_prefix + item.text if task_prefix else item.text
parts.append(ChatCompletionContentPartTextParam(type="text", text=text))
elif item.type == "image_url" and item.image_url is not None:
parts.append(
ChatCompletionContentPartImageParam(
type="image_url",
image_url=ImageURL(url=item.image_url["url"]),
)
)
return [CustomChatCompletionMessageParam(role="user", content=parts)]
@staticmethod
def _check_cohere_max_tokens(
outputs: list[PoolingRequestOutput],
max_tokens_check: int | None,
) -> None:
"""Raise if any output exceeds *max_tokens_check* tokens.
Used to enforce ``truncate=NONE`` with an explicit ``max_tokens``:
the pipeline runs without truncation and we reject afterwards.
"""
if max_tokens_check is None:
return
for out in outputs:
n = len(out.prompt_token_ids)
if n > max_tokens_check:
raise ValueError(
f"Input of {n} tokens exceeds max_tokens={max_tokens_check} "
"with truncate=NONE. Set truncate to END or START to "
"allow truncation."
)
@staticmethod
def _resolve_cohere_truncation(
request: CohereEmbedRequest,
) -> tuple[int | None, Literal["left", "right"] | None]:
"""Return ``(truncate_prompt_tokens, truncation_side)``."""
if request.truncate == "NONE":
return None, None
if request.truncate == "START":
tokens = request.max_tokens if request.max_tokens is not None else -1
return tokens, "left"
if request.max_tokens is not None:
return request.max_tokens, None
return -1, None
def create_pooling_params(self, request):
if isinstance(request, CohereEmbedRequest):
return PoolingParams(
task="embed",
dimensions=request.output_dimension,
)
return super().create_pooling_params(request)
def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
"""Convert a ``CohereEmbedRequest`` into engine prompts.
For texts, a single batched completion request path is used.
For images and mixed inputs, conversations are batch-rendered
through the chat template in one ``render_chat`` call.
"""
request = ctx.request
assert isinstance(request, CohereEmbedRequest)
if request.texts is None and request.images is None and request.inputs is None:
raise ValueError("One of texts, images, or inputs must be provided")
truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation(
request
)
input_type = request.input_type
self._validate_input_type(input_type)
if request.images is not None:
all_messages: list[list[ChatCompletionMessageParam]] = [
[
CustomChatCompletionMessageParam(
role="user",
content=[{"type": "image_url", "image_url": {"url": uri}}],
)
]
for uri in request.images
]
ctx.engine_prompts = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side
)
elif request.inputs is not None:
task_prefix = self._get_task_instruction_prefix(input_type)
all_messages = [
self._mixed_input_to_messages(inp, task_prefix=task_prefix)
for inp in request.inputs
]
ctx.engine_prompts = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side
)
else:
prefixed = self._apply_task_instruction(request.texts or [], input_type)
proxy = EmbeddingCompletionRequest(
model=request.model,
input=prefixed,
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
ctx.engine_prompts = self._preprocess_completion_online(
proxy, prompt_input=proxy.input, prompt_embeds=None
)
def _batch_render_chat(
self,
request: CohereEmbedRequest,
all_messages: Sequence[list[ChatCompletionMessageParam]],
truncate_prompt_tokens: int | None,
truncation_side: Literal["left", "right"] | None,
) -> list[ProcessorInputs]:
"""Batch-render multiple conversations through the chat template."""
if not all_messages:
return []
proxy = EmbeddingChatRequest(
model=request.model,
messages=list(all_messages[0]),
dimensions=request.output_dimension,
encoding_format="float",
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
renderer = self.renderer
mm_config = self.model_config.multimodal_config
tok_params = proxy.build_tok_params(self.model_config)
chat_params = proxy.build_chat_params(
self.chat_template,
self.chat_template_content_format,
).with_defaults(
merge_kwargs(
None,
dict(
tools=None,
tokenize=is_mistral_tokenizer(renderer.tokenizer),
),
),
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)
_, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params)
return engine_prompts
def _validate_input_type(self, input_type: str | None) -> None:
"""Raise if *input_type* is not supported by this model."""
if input_type is None:
return
if self.task_instructions is None:
raise ValueError(
f"Unsupported input_type {input_type!r}. "
"This model does not define any input_type task instructions."
)
if input_type not in self.task_instructions:
supported = ", ".join(sorted(self.task_instructions))
raise ValueError(
f"Unsupported input_type {input_type!r}. Supported values: {supported}"
)
def _apply_task_instruction(
self,
texts: list[str],
input_type: str | None,
) -> list[str]:
"""Prepend the task-instruction prefix for *input_type*.
Returns *texts* unchanged when no matching prefix is configured.
"""
prefix = self._get_task_instruction_prefix(input_type)
if not prefix:
return texts
return [prefix + t for t in texts]
def _get_task_instruction_prefix(self, input_type: str | None) -> str | None:
"""Return the task-instruction prefix for *input_type*, or ``None``."""
if not self.task_instructions or input_type is None:
return None
return self.task_instructions.get(input_type) or None
def _enforce_cohere_max_tokens(self, ctx: PoolingServeContext) -> None:
if isinstance(ctx.request, CohereEmbedRequest):
request = ctx.request
if request.truncate == "NONE" and request.max_tokens is not None:
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)

View File

@@ -1,9 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from typing import TypeAlias
"""Embedding API protocol models for OpenAI and Cohere formats.
from pydantic import Field
OpenAI: https://platform.openai.com/docs/api-reference/embeddings
Cohere: https://docs.cohere.com/reference/embed
"""
import base64
import builtins
import struct
import time
from collections.abc import Sequence
from typing import Literal, TypeAlias
from pydantic import BaseModel, Field
from vllm import PoolingParams
from vllm.config import ModelConfig
@@ -17,6 +27,10 @@ from vllm.entrypoints.pooling.base.protocol import (
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
# ---------------------------------------------------------------------------
# OpenAI /v1/embeddings — request models
# ---------------------------------------------------------------------------
def _get_max_total_output_tokens(
model_config: ModelConfig,
@@ -50,6 +64,7 @@ class EmbeddingCompletionRequest(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
@@ -79,6 +94,7 @@ class EmbeddingChatRequest(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
@@ -96,6 +112,11 @@ class EmbeddingChatRequest(
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
# ---------------------------------------------------------------------------
# OpenAI /v1/embeddings — response models
# ---------------------------------------------------------------------------
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
@@ -106,7 +127,7 @@ class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
model: str | None = None
data: list[EmbeddingResponseData]
usage: UsageInfo
@@ -115,3 +136,146 @@ class EmbeddingBytesResponse(OpenAIBaseModel):
content: list[bytes]
headers: dict[str, str] | None = None
media_type: str = "application/octet-stream"
# ---------------------------------------------------------------------------
# Cohere /v2/embed — request models
# ---------------------------------------------------------------------------
CohereEmbeddingType = Literal[
"float",
"binary",
"ubinary",
"base64",
]
CohereTruncate = Literal["NONE", "START", "END"]
class CohereEmbedContent(BaseModel):
type: Literal["text", "image_url"]
text: str | None = None
image_url: dict[str, str] | None = None
class CohereEmbedInput(BaseModel):
content: list[CohereEmbedContent]
class CohereEmbedRequest(BaseModel):
model: str | None = None
input_type: str | None = None
texts: list[str] | None = None
images: list[str] | None = None
inputs: list[CohereEmbedInput] | None = None
output_dimension: int | None = None
embedding_types: list[CohereEmbeddingType] | None = None
truncate: CohereTruncate = "END"
max_tokens: int | None = None
priority: int = 0
# ---------------------------------------------------------------------------
# Cohere /v2/embed — response models
# ---------------------------------------------------------------------------
class CohereApiVersion(BaseModel):
version: str = "2"
class CohereBilledUnits(BaseModel):
input_tokens: int | None = None
image_tokens: int | None = None
class CohereMeta(BaseModel):
api_version: CohereApiVersion = Field(default_factory=CohereApiVersion)
billed_units: CohereBilledUnits | None = None
class CohereEmbedByTypeEmbeddings(BaseModel):
# The field name ``float`` shadows the builtin type, so the annotation
# must use ``builtins.float`` to avoid a self-referential type error.
float: list[list[builtins.float]] | None = None
binary: list[list[int]] | None = None
ubinary: list[list[int]] | None = None
base64: list[str] | None = None
class CohereEmbedResponse(BaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
embeddings: CohereEmbedByTypeEmbeddings
texts: list[str] | None = None
meta: CohereMeta | None = None
response_type: Literal["embeddings_by_type"] = "embeddings_by_type"
# ---------------------------------------------------------------------------
# Cohere embedding type conversion helpers
# ---------------------------------------------------------------------------
_UNSIGNED_TO_SIGNED_DIFF = 1 << 7 # 128
def _pack_binary_embeddings(
float_embeddings: list[list[float]],
signed: bool,
) -> list[list[int]]:
"""Bit-pack float embeddings: positive -> 1, negative -> 0.
Each bit is shifted left by ``7 - idx%8``, and every 8 bits are packed
into one byte.
"""
result: list[list[int]] = []
for embedding in float_embeddings:
dim = len(embedding)
if dim % 8 != 0:
raise ValueError(
"Embedding dimension must be a multiple of 8 for binary "
f"embedding types, but got {dim}."
)
packed_len = dim // 8
packed: list[int] = []
byte_val = 0
for idx, value in enumerate(embedding):
bit = 1 if value >= 0 else 0
byte_val += bit << (7 - idx % 8)
if (idx + 1) % 8 == 0:
if signed:
byte_val -= _UNSIGNED_TO_SIGNED_DIFF
packed.append(byte_val)
byte_val = 0
assert len(packed) == packed_len
result.append(packed)
return result
def _encode_base64_embeddings(
float_embeddings: list[list[float]],
) -> list[str]:
"""Encode float embeddings as base64 (little-endian float32)."""
result: list[str] = []
for embedding in float_embeddings:
buf = struct.pack(f"<{len(embedding)}f", *embedding)
result.append(base64.b64encode(buf).decode("utf-8"))
return result
def build_typed_embeddings(
float_embeddings: list[list[float]],
embedding_types: Sequence[str],
) -> CohereEmbedByTypeEmbeddings:
"""Convert float embeddings to all requested Cohere embedding types."""
result = CohereEmbedByTypeEmbeddings()
for emb_type in embedding_types:
if emb_type == "float":
result.float = float_embeddings
elif emb_type == "binary":
result.binary = _pack_binary_embeddings(float_embeddings, signed=True)
elif emb_type == "ubinary":
result.ubinary = _pack_binary_embeddings(float_embeddings, signed=False)
elif emb_type == "base64":
result.base64 = _encode_base64_embeddings(float_embeddings)
return result

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable
from functools import partial
from typing import Literal, TypeAlias, cast
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never
from vllm.config import ModelConfig
@@ -14,10 +14,15 @@ from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereBilledUnits,
CohereEmbedRequest,
CohereEmbedResponse,
CohereMeta,
EmbeddingBytesResponse,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
build_typed_embeddings,
)
from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.entrypoints.pooling.utils import (
@@ -26,24 +31,23 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_float,
get_json_response_cls,
)
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers import BaseRenderer
from vllm.utils.serial_utils import EmbedDType, Endianness
logger = init_logger(__name__)
JSONResponseCLS = get_json_response_cls()
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
class ServingEmbedding(PoolingServing):
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
"""Embedding API supporting both OpenAI and Cohere formats."""
request_id_prefix = "embd"
io_processor: EmbedIOProcessor
def init_io_processor(
self,
@@ -58,6 +62,14 @@ class ServingEmbedding(PoolingServing):
)
async def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
if isinstance(ctx.request, CohereEmbedRequest):
return self._build_cohere_response_from_ctx(ctx)
return await self._build_openai_response(ctx)
async def _build_openai_response(
self,
ctx: EmbeddingServeContext,
) -> JSONResponse | StreamingResponse:
@@ -66,7 +78,7 @@ class ServingEmbedding(PoolingServing):
endianness = ctx.request.endianness
if encoding_format == "float" or encoding_format == "base64":
return self._request_output_to_embed_json_response(
return self._openai_json_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
@@ -77,7 +89,7 @@ class ServingEmbedding(PoolingServing):
)
if encoding_format == "bytes" or encoding_format == "bytes_only":
return self._request_output_to_to_embed_bytes_response(
return self._openai_bytes_response(
ctx.final_res_batch,
ctx.request_id,
ctx.created_time,
@@ -89,7 +101,7 @@ class ServingEmbedding(PoolingServing):
assert_never(encoding_format)
def _request_output_to_embed_json_response(
def _openai_json_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
@@ -139,7 +151,7 @@ class ServingEmbedding(PoolingServing):
)
return JSONResponseCLS(content=response.model_dump())
def _request_output_to_to_embed_bytes_response(
def _openai_bytes_response(
self,
final_res_batch: list[PoolingRequestOutput],
request_id: str,
@@ -177,3 +189,33 @@ class ServingEmbedding(PoolingServing):
headers=response.headers,
media_type=response.media_type,
)
@staticmethod
def _build_cohere_response_from_ctx(
ctx: PoolingServeContext,
) -> JSONResponse:
request = ctx.request
assert isinstance(request, CohereEmbedRequest)
all_floats = [encode_pooling_output_float(out) for out in ctx.final_res_batch]
total_tokens = sum(len(out.prompt_token_ids) for out in ctx.final_res_batch)
image_tokens = total_tokens if request.images is not None else 0
texts_echo = request.texts
embedding_types = request.embedding_types or ["float"]
embeddings_obj = build_typed_embeddings(all_floats, embedding_types)
input_tokens = total_tokens - image_tokens
response = CohereEmbedResponse(
id=ctx.request_id,
embeddings=embeddings_obj,
texts=texts_echo,
meta=CohereMeta(
billed_units=CohereBilledUnits(
input_tokens=input_tokens,
image_tokens=image_tokens,
),
),
)
return JSONResponse(content=response.model_dump(exclude_none=True))

View File

@@ -36,6 +36,7 @@ class PoolingCompletionRequest(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
@@ -61,6 +62,7 @@ class PoolingChatRequest(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
@@ -88,6 +90,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=not model_config.is_encoder_decoder,
max_total_tokens_param="max_model_len",

View File

@@ -30,6 +30,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
@@ -105,6 +106,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)

View File

@@ -15,6 +15,7 @@ from vllm.entrypoints.pooling.classify.protocol import (
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedRequest,
EmbeddingBytesResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
@@ -50,6 +51,7 @@ AnyPoolingRequest: TypeAlias = (
| IOProcessorRequest
| RerankRequest
| ScoreRequest
| CohereEmbedRequest
)
AnyPoolingResponse: TypeAlias = (

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, Literal, TypeVar
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
@@ -153,6 +153,14 @@ class TokenizeParams:
- `-1` maps to `max_input_tokens`.
"""
truncation_side: Literal["left", "right"] | None = None
"""
Which side to truncate from when ``truncate_prompt_tokens`` is active:
- ``"right"`` keeps the first N tokens (truncate from the end).
- ``"left"`` keeps the last N tokens (truncate from the start).
- ``None`` falls back to the tokenizer default.
"""
do_lower_case: bool = False
"""Whether to normalize text to lower case before tokenization."""
@@ -271,6 +279,7 @@ class TokenizeParams:
),
pad_prompt_tokens=pad_prompt_tokens,
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=self.truncation_side,
do_lower_case=do_lower_case,
add_special_tokens=add_special_tokens,
needs_detokenization=needs_detokenization,
@@ -286,6 +295,16 @@ class TokenizeParams:
# while still failing `self._token_len_check` as expected by users
max_length = self.max_input_tokens + 1
# Left-side truncation requires the full token sequence so we can
# slice from the end in _token_truncation. Disable HF-level
# truncation (which would incorrectly truncate from the right for
# pooling models) and let _token_truncation handle it.
if self.truncation_side == "left":
return dict(
truncation=False,
add_special_tokens=self.add_special_tokens,
)
return dict(
truncation=max_length is not None,
max_length=max_length,
@@ -375,7 +394,10 @@ class TokenizeParams:
if max_length == 0:
return tokens[:0]
if getattr(tokenizer, "truncation_side", "left") == "left":
side = self.truncation_side or (
tokenizer.truncation_side if tokenizer is not None else None
)
if side == "left":
return tokens[-max_length:]
return tokens[:max_length]