[Feature]: Support for multiple embedding types in a single inference call (#35829)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
Augusto Yao
2026-03-17 17:05:42 +08:00
committed by GitHub
parent 132bfd45b6
commit 9c7cab5ebb
7 changed files with 226 additions and 36 deletions

View File

@@ -3,10 +3,10 @@
from collections.abc import Sequence
from vllm.config import VllmConfig
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
@@ -16,14 +16,13 @@ from vllm.renderers import BaseRenderer
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
from .types import (
EMBED_TASKS,
SparseEmbeddingCompletionRequestMixin,
SparseEmbeddingResponse,
SparseEmbeddingResponseData,
SparseEmbeddingTokenWeight,
)
logger = init_logger(__name__)
class BgeM3SparseEmbeddingsProcessor(
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
@@ -33,6 +32,22 @@ class BgeM3SparseEmbeddingsProcessor(
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
self.renderer: BaseRenderer = renderer
self.default_pooling_params = {}
pooler_config: PoolerConfig = vllm_config.model_config.pooler_config
if pooler_config is not None:
for param in ["use_activation", "dimensions"]:
if getattr(pooler_config, param, None) is None:
continue
self.default_pooling_params[param] = getattr(pooler_config, param)
self.embed_dimensions = vllm_config.model_config.embedding_size
self.embed_request_queue: list[EmbedRequestMixin] = []
def __repr__(self) -> str:
return (
f"BgeM3SparseEmbeddingsProcessor("
f"embed_dimensions={self.embed_dimensions}, "
f"default_pooling_params={self.default_pooling_params})"
)
def merge_pooling_params(
self,
@@ -41,7 +56,57 @@ class BgeM3SparseEmbeddingsProcessor(
if params is None:
params = PoolingParams()
# refer to PoolingCompletionRequest.to_pooling_params
params.task = "token_classify"
# set and verify pooling params
params.skip_reading_prefix_cache = True
raw_embed_request = self.embed_request_queue.pop(0)
if raw_embed_request.embed_task not in EMBED_TASKS:
raise ValueError(
f"Unsupported task {raw_embed_request}, "
f"Supported tasks are {EMBED_TASKS}"
)
has_dense_embed = True
if raw_embed_request.embed_task == "dense":
params.task = "embed"
params.skip_reading_prefix_cache = False
elif raw_embed_request.embed_task == "sparse":
params.task = "token_classify"
has_dense_embed = False
else:
params.task = "embed&token_classify"
params.use_activation = raw_embed_request.use_activation
if params.use_activation is None:
params.use_activation = True
if not has_dense_embed:
params.dimensions = None
return params
params.dimensions = raw_embed_request.dimensions
model_config: ModelConfig = self.vllm_config.model_config
for param in self.default_pooling_params:
if getattr(params, param, None) is None:
setattr(params, param, self.default_pooling_params[param])
if params.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f"support matryoshka representation, "
f"changing output dimensions will lead to poor results."
)
mds = model_config.matryoshka_dimensions
if mds is not None:
if params.dimensions not in mds:
raise ValueError(
f"Model {model_config.served_model_name!r} "
f"only supports {str(mds)} matryoshka dimensions, "
f"use other output dimensions will "
f"lead to poor results."
)
elif params.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
return params
def parse_request(
@@ -61,14 +126,16 @@ class BgeM3SparseEmbeddingsProcessor(
if request_id is not None:
assert request_id not in self.online_requests, "request_id duplicated"
self.online_requests[request_id] = prompt
self.embed_request_queue.extend(prompt.to_embed_requests_online())
else:
self.offline_requests.append(prompt)
self.embed_request_queue.extend(prompt.to_embed_requests_offline())
return prompt.input
def _get_sparse_embedding_request(self, request_id: str | None = None):
if request_id:
return self.online_requests.pop(request_id, None)
return self.offline_requests.pop()
return self.offline_requests.pop(0)
def _build_sparse_embedding_token_weights(
self,
@@ -100,26 +167,45 @@ class BgeM3SparseEmbeddingsProcessor(
) -> SparseEmbeddingResponse:
num_prompt_tokens = 0
response_data = []
return_tokens = self._get_sparse_embedding_request(request_id).return_tokens
raw_request = self._get_sparse_embedding_request(request_id)
has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"]
has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"]
embed_dimensions = 0
if has_dense_embed:
embed_dimensions = (
self.embed_dimensions
if raw_request.dimensions is None
else raw_request.dimensions
)
for idx in range(len(model_output)):
mo = model_output[idx]
sparse_embedding: dict[int, float] = {}
sparse_embedding_dict: dict[int, float] = {}
num_prompt_tokens += len(mo.prompt_token_ids)
if len(mo.prompt_token_ids) != len(mo.outputs.data):
# this is the case that add_special_tokens is True,
# which means first token and last token are special tokens
mo.prompt_token_ids = mo.prompt_token_ids[1:]
for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()):
sparse_embedding[token_id] = max(
weight, sparse_embedding.get(token_id, 0.0)
dense_embedding: list[float] | None = None
sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None
if has_dense_embed:
dense_embedding = mo.outputs.data[:embed_dimensions].tolist()
if has_sparse_embed:
sparse_weights = mo.outputs.data[embed_dimensions:].tolist()
if len(mo.prompt_token_ids) != len(sparse_weights):
# this is the case that add_special_tokens is True,
# which means first token and last token are special tokens
mo.prompt_token_ids = mo.prompt_token_ids[1:]
for token_id, weight in zip(mo.prompt_token_ids, sparse_weights):
sparse_embedding_dict[token_id] = max(
weight, sparse_embedding_dict.get(token_id, 0.0)
)
sparse_embedding = self._build_sparse_embedding_token_weights(
sparse_embedding_dict,
raw_request.return_tokens,
)
response_data.append(
SparseEmbeddingResponseData(
index=idx,
sparse_embedding=self._build_sparse_embedding_token_weights(
sparse_embedding,
return_tokens,
),
object=raw_request.embed_task,
sparse_embedding=sparse_embedding,
dense_embedding=dense_embedding,
)
)

View File

@@ -1,18 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args
from pydantic import BaseModel, Field
from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin
from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
EmbedRequestMixin,
)
EmbedTask = Literal[
"sparse",
"dense",
"dense&sparse",
]
EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask)
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin):
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin):
return_tokens: bool | None = Field(
default=None,
description="Whether to return dict shows the mapping of token_id to text."
"`None` or False means not return.",
)
embed_task: EmbedTask = Field(
default="dense&sparse",
description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', "
"default to 'dense&sparse'",
)
def to_embed_requests_offline(self) -> list[EmbedRequestMixin]:
if isinstance(self.input, list):
return [self] * len(self.input)
return [self]
def to_embed_requests_online(self) -> list[EmbedRequestMixin]:
return [self]
class SparseEmbeddingTokenWeight(BaseModel):
@@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel):
class SparseEmbeddingResponseData(BaseModel):
index: int
object: str = "sparse-embedding"
sparse_embedding: list[SparseEmbeddingTokenWeight]
object: str = "dense&sparse"
sparse_embedding: list[SparseEmbeddingTokenWeight] | None
dense_embedding: list[float] | None
class SparseEmbeddingResponse(BaseModel):

View File

@@ -19,6 +19,12 @@ model_config = {
),
}
dense_embedding_sum = [
-0.7214539647102356, # "What is the capital of France?"
-0.6926871538162231, # "What is the capital of Germany?"
-0.7129564881324768, # "What is the capital of Spain?"
]
def _float_close(expected: object, result: object):
assert isinstance(expected, float) and isinstance(result, float), (
@@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str):
return getattr(obj, key, None)
def _check_dense_embedding(data, index=0):
assert _float_close(sum(data), dense_embedding_sum[index]), (
"dense-embedding result not match"
)
def _check_sparse_embedding(data, check_tokens=False):
expected_weights = [
{"token_id": 32, "weight": 0.0552978515625, "token": "?"},
@@ -109,7 +121,7 @@ async def test_bge_m3_sparse_plugin_online(
assert len(_get_attr_or_val(parsed_response, "data")) > 0
data_entry = _get_attr_or_val(parsed_response, "data")[0]
assert _get_attr_or_val(data_entry, "object") == "sparse-embedding"
assert _get_attr_or_val(data_entry, "object") == "dense&sparse"
assert _get_attr_or_val(data_entry, "sparse_embedding")
# Verify sparse embedding format
@@ -117,6 +129,11 @@ async def test_bge_m3_sparse_plugin_online(
assert isinstance(sparse_embedding, list)
_check_sparse_embedding(sparse_embedding, return_tokens)
# Verify dense embedding format
dense_embedding = _get_attr_or_val(data_entry, "dense_embedding")
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding)
# Verify usage information
usage = _get_attr_or_val(parsed_response, "usage")
assert usage, f"usage not found for {parsed_response}"
@@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
sparse_embedding = output.sparse_embedding
assert isinstance(sparse_embedding, list)
_check_sparse_embedding(sparse_embedding, return_tokens)
dense_embedding = output.dense_embedding
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding)
# Verify usage
assert response.usage.prompt_tokens > 0
@@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
# Each output should have sparse embeddings
sparse_embedding = output.sparse_embedding
assert isinstance(sparse_embedding, list)
dense_embedding = output.dense_embedding
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding, i)
# Verify usage
assert response.usage.prompt_tokens > 0

View File

@@ -170,4 +170,42 @@ class BOSEOSFilter(Pooler):
return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
class BgeM3Pooler(Pooler):
def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None:
super().__init__()
self.token_classify_pooler = token_classify_pooler
self.embed_pooler = embed_pooler
def forward(
self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput:
embed_outputs = self.embed_pooler(hidden_states, pooling_metadata)
token_classify_outputs = self.token_classify_pooler(
hidden_states, pooling_metadata
)
pooler_outputs: list[torch.Tensor] = []
for embed_output, token_classify_output in zip(
embed_outputs, token_classify_outputs
):
pooler_outputs.append(
torch.cat(
[embed_output.view(-1), token_classify_output.view(-1)], dim=-1
)
)
return pooler_outputs
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed&token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.embed_pooler.get_pooling_updates(
"embed"
) | self.token_classify_pooler.get_pooling_updates("token_classify")
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"]

View File

@@ -10,6 +10,7 @@ from transformers import RobertaConfig
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
BgeM3Pooler,
BOSEOSFilter,
DispatchPooler,
Pooler,
@@ -216,24 +217,29 @@ class BgeM3EmbeddingModel(RobertaEmbeddingModel):
self.colbert_linear = nn.Linear(
self.hidden_size, self.hidden_size, dtype=self.head_dtype
)
embed_pooler = pooler_for_embed(pooler_config)
token_classify_pooler = BOSEOSFilter(
pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
)
return DispatchPooler(
{
"embed": pooler_for_embed(pooler_config),
"embed": embed_pooler,
"token_embed": BOSEOSFilter(
pooler_for_token_embed(pooler_config, self.colbert_linear),
self.bos_token_id,
# for some reason m3 only filters the bos for colbert vectors
),
"token_classify": BOSEOSFilter(
pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
"token_classify": token_classify_pooler,
"embed&token_classify": BgeM3Pooler(
token_classify_pooler, embed_pooler
),
}
)

View File

@@ -96,6 +96,10 @@ class PoolingParams(
self.skip_reading_prefix_cache = True
return
# skipping verify, let plugins configure and validate pooling params
if self.task not in self.valid_parameters:
return
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method

View File

@@ -6,7 +6,13 @@ GenerationTask = Literal["generate", "transcription", "realtime"]
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
PoolingTask = Literal[
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
"embed",
"classify",
"score",
"token_embed",
"token_classify",
"plugin",
"embed&token_classify",
]
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)