[Feature][Frontend] add support for Cohere Embed v2 API (#37074)
Signed-off-by: walterbm <walter.beller.morales@gmail.com>
(cherry picked from commit 061980c36a)
This commit is contained in:
committed by
khluu
parent
1fe3932c8b
commit
4d22667c32
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
@@ -24,6 +24,14 @@ class PoolingBasicRequestMixin(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [start:pooling-common-extra-params]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
truncation_side: Literal["left", "right"] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Which side to truncate from when truncate_prompt_tokens is active. "
|
||||
"'right' keeps the first N tokens. "
|
||||
"'left' keeps the last N tokens."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
|
||||
@@ -32,6 +32,7 @@ class ClassificationCompletionRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -54,6 +55,7 @@ class ClassificationChatRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
|
||||
@@ -7,12 +7,12 @@ from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingRequest
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
EmbeddingRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -40,3 +40,24 @@ async def create_embedding(
|
||||
raise NotImplementedError("The model does not support Embeddings API")
|
||||
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/embed",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_cohere_embedding(
|
||||
request: CohereEmbedRequest,
|
||||
raw_request: Request,
|
||||
):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Embeddings API")
|
||||
|
||||
return await handler(request, raw_request)
|
||||
|
||||
@@ -1,14 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import torch
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedInput,
|
||||
CohereEmbedRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.inputs.data import ProcessorInputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.renderers import merge_kwargs
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbedIOProcessor(PoolingIOProcessor):
|
||||
@@ -21,16 +44,45 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
self.pooler_config = self.model_config.pooler_config
|
||||
self.enable_chunked_processing = self.pooler_config.enable_chunked_processing
|
||||
|
||||
# Load task instructions from HF config or sentence-transformers config
|
||||
self.task_instructions: dict[str, str] | None = self._load_task_instructions(
|
||||
self.model_config.hf_config
|
||||
) or self._load_st_prompts(self.model_config.model, self.model_config.revision)
|
||||
if self.task_instructions:
|
||||
logger.info(
|
||||
"Loaded prompt prefixes for input_type: %s",
|
||||
list(self.task_instructions.keys()),
|
||||
)
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
self._pre_process_cohere_online(ctx)
|
||||
else:
|
||||
super().pre_process_online(ctx)
|
||||
|
||||
if self.enable_chunked_processing:
|
||||
self._pre_process_chunked(ctx)
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
self._enforce_cohere_max_tokens(ctx)
|
||||
return super().post_process_online(ctx)
|
||||
|
||||
self._post_process_chunked(ctx)
|
||||
self._enforce_cohere_max_tokens(ctx)
|
||||
|
||||
#################################################################
|
||||
# Long Text Embedding with Chunked Processing
|
||||
# PTAL: examples/pooling/embed/openai_embedding_long_text
|
||||
#################################################################
|
||||
|
||||
def pre_process_online(self, ctx: PoolingServeContext):
|
||||
super().pre_process_online(ctx)
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
return None
|
||||
|
||||
def _pre_process_chunked(self, ctx: PoolingServeContext) -> None:
|
||||
if ctx.engine_prompts is None:
|
||||
raise ValueError("Engine prompts not available")
|
||||
|
||||
@@ -61,18 +113,10 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
|
||||
ctx.engine_prompts = chunked_engine_prompts
|
||||
ctx.prompt_request_ids = prompt_request_ids
|
||||
|
||||
return None
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if not self.enable_chunked_processing:
|
||||
return super().post_process_online(ctx)
|
||||
|
||||
def _post_process_chunked(self, ctx: PoolingServeContext) -> None:
|
||||
# Online aggregation for chunked requests to
|
||||
# minimize memory usage
|
||||
# Track aggregation state for each prompt
|
||||
@@ -195,4 +239,245 @@ class EmbedIOProcessor(PoolingIOProcessor):
|
||||
raise ValueError(f"Result not found for prompt {prompt_idx}")
|
||||
|
||||
ctx.final_res_batch = final_res_batch
|
||||
|
||||
return None
|
||||
|
||||
#################################################################
|
||||
# Cohere Request Preprocessing & Postprocessing
|
||||
#################################################################
|
||||
|
||||
@staticmethod
|
||||
def _load_task_instructions(hf_config: Any) -> dict[str, str] | None:
|
||||
"""Extract ``task_instructions`` from the HF model config."""
|
||||
ti = getattr(hf_config, "task_instructions", None)
|
||||
if not isinstance(ti, dict) or not ti:
|
||||
return None
|
||||
return {k: v for k, v in ti.items() if isinstance(v, str)}
|
||||
|
||||
@staticmethod
|
||||
def _load_st_prompts(
|
||||
model: str | Any,
|
||||
revision: str | None,
|
||||
) -> dict[str, str] | None:
|
||||
"""Load ``task_instructions`` from ``config_sentence_transformers.json``."""
|
||||
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
|
||||
|
||||
try:
|
||||
cfg = get_hf_file_to_dict(
|
||||
"config_sentence_transformers.json", str(model), revision
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
return None
|
||||
|
||||
if cfg is None:
|
||||
return None
|
||||
prompts = cfg.get("prompts")
|
||||
if not isinstance(prompts, dict) or not prompts:
|
||||
return None
|
||||
return {k: v for k, v in prompts.items() if isinstance(v, str)}
|
||||
|
||||
@staticmethod
|
||||
def _mixed_input_to_messages(
|
||||
inp: CohereEmbedInput,
|
||||
*,
|
||||
task_prefix: str | None = None,
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Build chat messages from a mixed text+image input.
|
||||
|
||||
When *task_prefix* is given, it is prepended to each text part.
|
||||
"""
|
||||
parts: list[ChatCompletionContentPartParam] = []
|
||||
for item in inp.content:
|
||||
if item.type == "text" and item.text is not None:
|
||||
text = task_prefix + item.text if task_prefix else item.text
|
||||
parts.append(ChatCompletionContentPartTextParam(type="text", text=text))
|
||||
elif item.type == "image_url" and item.image_url is not None:
|
||||
parts.append(
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url=ImageURL(url=item.image_url["url"]),
|
||||
)
|
||||
)
|
||||
return [CustomChatCompletionMessageParam(role="user", content=parts)]
|
||||
|
||||
@staticmethod
|
||||
def _check_cohere_max_tokens(
|
||||
outputs: list[PoolingRequestOutput],
|
||||
max_tokens_check: int | None,
|
||||
) -> None:
|
||||
"""Raise if any output exceeds *max_tokens_check* tokens.
|
||||
|
||||
Used to enforce ``truncate=NONE`` with an explicit ``max_tokens``:
|
||||
the pipeline runs without truncation and we reject afterwards.
|
||||
"""
|
||||
if max_tokens_check is None:
|
||||
return
|
||||
for out in outputs:
|
||||
n = len(out.prompt_token_ids)
|
||||
if n > max_tokens_check:
|
||||
raise ValueError(
|
||||
f"Input of {n} tokens exceeds max_tokens={max_tokens_check} "
|
||||
"with truncate=NONE. Set truncate to END or START to "
|
||||
"allow truncation."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_cohere_truncation(
|
||||
request: CohereEmbedRequest,
|
||||
) -> tuple[int | None, Literal["left", "right"] | None]:
|
||||
"""Return ``(truncate_prompt_tokens, truncation_side)``."""
|
||||
if request.truncate == "NONE":
|
||||
return None, None
|
||||
if request.truncate == "START":
|
||||
tokens = request.max_tokens if request.max_tokens is not None else -1
|
||||
return tokens, "left"
|
||||
if request.max_tokens is not None:
|
||||
return request.max_tokens, None
|
||||
return -1, None
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
if isinstance(request, CohereEmbedRequest):
|
||||
return PoolingParams(
|
||||
task="embed",
|
||||
dimensions=request.output_dimension,
|
||||
)
|
||||
return super().create_pooling_params(request)
|
||||
|
||||
def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None:
|
||||
"""Convert a ``CohereEmbedRequest`` into engine prompts.
|
||||
|
||||
For texts, a single batched completion request path is used.
|
||||
For images and mixed inputs, conversations are batch-rendered
|
||||
through the chat template in one ``render_chat`` call.
|
||||
"""
|
||||
request = ctx.request
|
||||
assert isinstance(request, CohereEmbedRequest)
|
||||
|
||||
if request.texts is None and request.images is None and request.inputs is None:
|
||||
raise ValueError("One of texts, images, or inputs must be provided")
|
||||
|
||||
truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation(
|
||||
request
|
||||
)
|
||||
input_type = request.input_type
|
||||
self._validate_input_type(input_type)
|
||||
|
||||
if request.images is not None:
|
||||
all_messages: list[list[ChatCompletionMessageParam]] = [
|
||||
[
|
||||
CustomChatCompletionMessageParam(
|
||||
role="user",
|
||||
content=[{"type": "image_url", "image_url": {"url": uri}}],
|
||||
)
|
||||
]
|
||||
for uri in request.images
|
||||
]
|
||||
ctx.engine_prompts = self._batch_render_chat(
|
||||
request, all_messages, truncate_prompt_tokens, truncation_side
|
||||
)
|
||||
|
||||
elif request.inputs is not None:
|
||||
task_prefix = self._get_task_instruction_prefix(input_type)
|
||||
all_messages = [
|
||||
self._mixed_input_to_messages(inp, task_prefix=task_prefix)
|
||||
for inp in request.inputs
|
||||
]
|
||||
ctx.engine_prompts = self._batch_render_chat(
|
||||
request, all_messages, truncate_prompt_tokens, truncation_side
|
||||
)
|
||||
|
||||
else:
|
||||
prefixed = self._apply_task_instruction(request.texts or [], input_type)
|
||||
proxy = EmbeddingCompletionRequest(
|
||||
model=request.model,
|
||||
input=prefixed,
|
||||
dimensions=request.output_dimension,
|
||||
encoding_format="float",
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
ctx.engine_prompts = self._preprocess_completion_online(
|
||||
proxy, prompt_input=proxy.input, prompt_embeds=None
|
||||
)
|
||||
|
||||
def _batch_render_chat(
|
||||
self,
|
||||
request: CohereEmbedRequest,
|
||||
all_messages: Sequence[list[ChatCompletionMessageParam]],
|
||||
truncate_prompt_tokens: int | None,
|
||||
truncation_side: Literal["left", "right"] | None,
|
||||
) -> list[ProcessorInputs]:
|
||||
"""Batch-render multiple conversations through the chat template."""
|
||||
if not all_messages:
|
||||
return []
|
||||
|
||||
proxy = EmbeddingChatRequest(
|
||||
model=request.model,
|
||||
messages=list(all_messages[0]),
|
||||
dimensions=request.output_dimension,
|
||||
encoding_format="float",
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=truncation_side,
|
||||
)
|
||||
|
||||
renderer = self.renderer
|
||||
mm_config = self.model_config.multimodal_config
|
||||
|
||||
tok_params = proxy.build_tok_params(self.model_config)
|
||||
chat_params = proxy.build_chat_params(
|
||||
self.chat_template,
|
||||
self.chat_template_content_format,
|
||||
).with_defaults(
|
||||
merge_kwargs(
|
||||
None,
|
||||
dict(
|
||||
tools=None,
|
||||
tokenize=is_mistral_tokenizer(renderer.tokenizer),
|
||||
),
|
||||
),
|
||||
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
|
||||
)
|
||||
|
||||
_, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params)
|
||||
return engine_prompts
|
||||
|
||||
def _validate_input_type(self, input_type: str | None) -> None:
|
||||
"""Raise if *input_type* is not supported by this model."""
|
||||
if input_type is None:
|
||||
return
|
||||
if self.task_instructions is None:
|
||||
raise ValueError(
|
||||
f"Unsupported input_type {input_type!r}. "
|
||||
"This model does not define any input_type task instructions."
|
||||
)
|
||||
if input_type not in self.task_instructions:
|
||||
supported = ", ".join(sorted(self.task_instructions))
|
||||
raise ValueError(
|
||||
f"Unsupported input_type {input_type!r}. Supported values: {supported}"
|
||||
)
|
||||
|
||||
def _apply_task_instruction(
|
||||
self,
|
||||
texts: list[str],
|
||||
input_type: str | None,
|
||||
) -> list[str]:
|
||||
"""Prepend the task-instruction prefix for *input_type*.
|
||||
|
||||
Returns *texts* unchanged when no matching prefix is configured.
|
||||
"""
|
||||
prefix = self._get_task_instruction_prefix(input_type)
|
||||
if not prefix:
|
||||
return texts
|
||||
return [prefix + t for t in texts]
|
||||
|
||||
def _get_task_instruction_prefix(self, input_type: str | None) -> str | None:
|
||||
"""Return the task-instruction prefix for *input_type*, or ``None``."""
|
||||
if not self.task_instructions or input_type is None:
|
||||
return None
|
||||
return self.task_instructions.get(input_type) or None
|
||||
|
||||
def _enforce_cohere_max_tokens(self, ctx: PoolingServeContext) -> None:
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
request = ctx.request
|
||||
if request.truncate == "NONE" and request.max_tokens is not None:
|
||||
self._check_cohere_max_tokens(ctx.final_res_batch, request.max_tokens)
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import TypeAlias
|
||||
"""Embedding API protocol models for OpenAI and Cohere formats.
|
||||
|
||||
from pydantic import Field
|
||||
OpenAI: https://platform.openai.com/docs/api-reference/embeddings
|
||||
Cohere: https://docs.cohere.com/reference/embed
|
||||
"""
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
import struct
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
@@ -17,6 +27,10 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI /v1/embeddings — request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
@@ -50,6 +64,7 @@ class EmbeddingCompletionRequest(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -79,6 +94,7 @@ class EmbeddingChatRequest(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -96,6 +112,11 @@ class EmbeddingChatRequest(
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI /v1/embeddings — response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
@@ -106,7 +127,7 @@ class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
model: str | None = None
|
||||
data: list[EmbeddingResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
@@ -115,3 +136,146 @@ class EmbeddingBytesResponse(OpenAIBaseModel):
|
||||
content: list[bytes]
|
||||
headers: dict[str, str] | None = None
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere /v2/embed — request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CohereEmbeddingType = Literal[
|
||||
"float",
|
||||
"binary",
|
||||
"ubinary",
|
||||
"base64",
|
||||
]
|
||||
CohereTruncate = Literal["NONE", "START", "END"]
|
||||
|
||||
|
||||
class CohereEmbedContent(BaseModel):
|
||||
type: Literal["text", "image_url"]
|
||||
text: str | None = None
|
||||
image_url: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CohereEmbedInput(BaseModel):
|
||||
content: list[CohereEmbedContent]
|
||||
|
||||
|
||||
class CohereEmbedRequest(BaseModel):
|
||||
model: str | None = None
|
||||
input_type: str | None = None
|
||||
texts: list[str] | None = None
|
||||
images: list[str] | None = None
|
||||
inputs: list[CohereEmbedInput] | None = None
|
||||
output_dimension: int | None = None
|
||||
embedding_types: list[CohereEmbeddingType] | None = None
|
||||
truncate: CohereTruncate = "END"
|
||||
max_tokens: int | None = None
|
||||
priority: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere /v2/embed — response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CohereApiVersion(BaseModel):
|
||||
version: str = "2"
|
||||
|
||||
|
||||
class CohereBilledUnits(BaseModel):
|
||||
input_tokens: int | None = None
|
||||
image_tokens: int | None = None
|
||||
|
||||
|
||||
class CohereMeta(BaseModel):
|
||||
api_version: CohereApiVersion = Field(default_factory=CohereApiVersion)
|
||||
billed_units: CohereBilledUnits | None = None
|
||||
|
||||
|
||||
class CohereEmbedByTypeEmbeddings(BaseModel):
|
||||
# The field name ``float`` shadows the builtin type, so the annotation
|
||||
# must use ``builtins.float`` to avoid a self-referential type error.
|
||||
float: list[list[builtins.float]] | None = None
|
||||
binary: list[list[int]] | None = None
|
||||
ubinary: list[list[int]] | None = None
|
||||
base64: list[str] | None = None
|
||||
|
||||
|
||||
class CohereEmbedResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
embeddings: CohereEmbedByTypeEmbeddings
|
||||
texts: list[str] | None = None
|
||||
meta: CohereMeta | None = None
|
||||
response_type: Literal["embeddings_by_type"] = "embeddings_by_type"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cohere embedding type conversion helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UNSIGNED_TO_SIGNED_DIFF = 1 << 7 # 128
|
||||
|
||||
|
||||
def _pack_binary_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
signed: bool,
|
||||
) -> list[list[int]]:
|
||||
"""Bit-pack float embeddings: positive -> 1, negative -> 0.
|
||||
|
||||
Each bit is shifted left by ``7 - idx%8``, and every 8 bits are packed
|
||||
into one byte.
|
||||
"""
|
||||
result: list[list[int]] = []
|
||||
for embedding in float_embeddings:
|
||||
dim = len(embedding)
|
||||
if dim % 8 != 0:
|
||||
raise ValueError(
|
||||
"Embedding dimension must be a multiple of 8 for binary "
|
||||
f"embedding types, but got {dim}."
|
||||
)
|
||||
packed_len = dim // 8
|
||||
packed: list[int] = []
|
||||
byte_val = 0
|
||||
for idx, value in enumerate(embedding):
|
||||
bit = 1 if value >= 0 else 0
|
||||
byte_val += bit << (7 - idx % 8)
|
||||
if (idx + 1) % 8 == 0:
|
||||
if signed:
|
||||
byte_val -= _UNSIGNED_TO_SIGNED_DIFF
|
||||
packed.append(byte_val)
|
||||
byte_val = 0
|
||||
assert len(packed) == packed_len
|
||||
result.append(packed)
|
||||
return result
|
||||
|
||||
|
||||
def _encode_base64_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
) -> list[str]:
|
||||
"""Encode float embeddings as base64 (little-endian float32)."""
|
||||
result: list[str] = []
|
||||
for embedding in float_embeddings:
|
||||
buf = struct.pack(f"<{len(embedding)}f", *embedding)
|
||||
result.append(base64.b64encode(buf).decode("utf-8"))
|
||||
return result
|
||||
|
||||
|
||||
def build_typed_embeddings(
|
||||
float_embeddings: list[list[float]],
|
||||
embedding_types: Sequence[str],
|
||||
) -> CohereEmbedByTypeEmbeddings:
|
||||
"""Convert float embeddings to all requested Cohere embedding types."""
|
||||
result = CohereEmbedByTypeEmbeddings()
|
||||
|
||||
for emb_type in embedding_types:
|
||||
if emb_type == "float":
|
||||
result.float = float_embeddings
|
||||
elif emb_type == "binary":
|
||||
result.binary = _pack_binary_embeddings(float_embeddings, signed=True)
|
||||
elif emb_type == "ubinary":
|
||||
result.ubinary = _pack_binary_embeddings(float_embeddings, signed=False)
|
||||
elif emb_type == "base64":
|
||||
result.base64 = _encode_base64_embeddings(float_embeddings)
|
||||
|
||||
return result
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Literal, TypeAlias, cast
|
||||
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@@ -14,10 +14,15 @@ from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereBilledUnits,
|
||||
CohereEmbedRequest,
|
||||
CohereEmbedResponse,
|
||||
CohereMeta,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
build_typed_embeddings,
|
||||
)
|
||||
from vllm.entrypoints.pooling.typing import PoolingServeContext
|
||||
from vllm.entrypoints.pooling.utils import (
|
||||
@@ -26,24 +31,23 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_float,
|
||||
get_json_response_cls,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
JSONResponseCLS = get_json_response_cls()
|
||||
|
||||
EmbeddingServeContext: TypeAlias = PoolingServeContext[EmbeddingRequest]
|
||||
|
||||
|
||||
class ServingEmbedding(PoolingServing):
|
||||
"""
|
||||
Embedding API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
"""
|
||||
"""Embedding API supporting both OpenAI and Cohere formats."""
|
||||
|
||||
request_id_prefix = "embd"
|
||||
io_processor: EmbedIOProcessor
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
@@ -58,6 +62,14 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: PoolingServeContext,
|
||||
) -> Response:
|
||||
if isinstance(ctx.request, CohereEmbedRequest):
|
||||
return self._build_cohere_response_from_ctx(ctx)
|
||||
return await self._build_openai_response(ctx)
|
||||
|
||||
async def _build_openai_response(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> JSONResponse | StreamingResponse:
|
||||
@@ -66,7 +78,7 @@ class ServingEmbedding(PoolingServing):
|
||||
endianness = ctx.request.endianness
|
||||
|
||||
if encoding_format == "float" or encoding_format == "base64":
|
||||
return self._request_output_to_embed_json_response(
|
||||
return self._openai_json_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
@@ -77,7 +89,7 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
|
||||
if encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return self._request_output_to_to_embed_bytes_response(
|
||||
return self._openai_bytes_response(
|
||||
ctx.final_res_batch,
|
||||
ctx.request_id,
|
||||
ctx.created_time,
|
||||
@@ -89,7 +101,7 @@ class ServingEmbedding(PoolingServing):
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _request_output_to_embed_json_response(
|
||||
def _openai_json_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
@@ -139,7 +151,7 @@ class ServingEmbedding(PoolingServing):
|
||||
)
|
||||
return JSONResponseCLS(content=response.model_dump())
|
||||
|
||||
def _request_output_to_to_embed_bytes_response(
|
||||
def _openai_bytes_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
@@ -177,3 +189,33 @@ class ServingEmbedding(PoolingServing):
|
||||
headers=response.headers,
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_cohere_response_from_ctx(
|
||||
ctx: PoolingServeContext,
|
||||
) -> JSONResponse:
|
||||
request = ctx.request
|
||||
assert isinstance(request, CohereEmbedRequest)
|
||||
|
||||
all_floats = [encode_pooling_output_float(out) for out in ctx.final_res_batch]
|
||||
total_tokens = sum(len(out.prompt_token_ids) for out in ctx.final_res_batch)
|
||||
|
||||
image_tokens = total_tokens if request.images is not None else 0
|
||||
texts_echo = request.texts
|
||||
|
||||
embedding_types = request.embedding_types or ["float"]
|
||||
embeddings_obj = build_typed_embeddings(all_floats, embedding_types)
|
||||
|
||||
input_tokens = total_tokens - image_tokens
|
||||
response = CohereEmbedResponse(
|
||||
id=ctx.request_id,
|
||||
embeddings=embeddings_obj,
|
||||
texts=texts_echo,
|
||||
meta=CohereMeta(
|
||||
billed_units=CohereBilledUnits(
|
||||
input_tokens=input_tokens,
|
||||
image_tokens=image_tokens,
|
||||
),
|
||||
),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
|
||||
@@ -36,6 +36,7 @@ class PoolingCompletionRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -61,6 +62,7 @@ class PoolingChatRequest(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
@@ -88,6 +90,7 @@ class IOProcessorRequest(PoolingBasicRequestMixin, EncodingRequestMixin, Generic
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
max_total_tokens_param="max_model_len",
|
||||
|
||||
@@ -30,6 +30,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
@@ -105,6 +106,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
CohereEmbedRequest,
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
@@ -50,6 +51,7 @@ AnyPoolingRequest: TypeAlias = (
|
||||
| IOProcessorRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| CohereEmbedRequest
|
||||
)
|
||||
|
||||
AnyPoolingResponse: TypeAlias = (
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
@@ -153,6 +153,14 @@ class TokenizeParams:
|
||||
- `-1` maps to `max_input_tokens`.
|
||||
"""
|
||||
|
||||
truncation_side: Literal["left", "right"] | None = None
|
||||
"""
|
||||
Which side to truncate from when ``truncate_prompt_tokens`` is active:
|
||||
- ``"right"`` keeps the first N tokens (truncate from the end).
|
||||
- ``"left"`` keeps the last N tokens (truncate from the start).
|
||||
- ``None`` falls back to the tokenizer default.
|
||||
"""
|
||||
|
||||
do_lower_case: bool = False
|
||||
"""Whether to normalize text to lower case before tokenization."""
|
||||
|
||||
@@ -271,6 +279,7 @@ class TokenizeParams:
|
||||
),
|
||||
pad_prompt_tokens=pad_prompt_tokens,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
truncation_side=self.truncation_side,
|
||||
do_lower_case=do_lower_case,
|
||||
add_special_tokens=add_special_tokens,
|
||||
needs_detokenization=needs_detokenization,
|
||||
@@ -286,6 +295,16 @@ class TokenizeParams:
|
||||
# while still failing `self._token_len_check` as expected by users
|
||||
max_length = self.max_input_tokens + 1
|
||||
|
||||
# Left-side truncation requires the full token sequence so we can
|
||||
# slice from the end in _token_truncation. Disable HF-level
|
||||
# truncation (which would incorrectly truncate from the right for
|
||||
# pooling models) and let _token_truncation handle it.
|
||||
if self.truncation_side == "left":
|
||||
return dict(
|
||||
truncation=False,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
return dict(
|
||||
truncation=max_length is not None,
|
||||
max_length=max_length,
|
||||
@@ -375,7 +394,10 @@ class TokenizeParams:
|
||||
if max_length == 0:
|
||||
return tokens[:0]
|
||||
|
||||
if getattr(tokenizer, "truncation_side", "left") == "left":
|
||||
side = self.truncation_side or (
|
||||
tokenizer.truncation_side if tokenizer is not None else None
|
||||
)
|
||||
if side == "left":
|
||||
return tokens[-max_length:]
|
||||
|
||||
return tokens[:max_length]
|
||||
|
||||
Reference in New Issue
Block a user