[Feature]: Support for multiple embedding types in a single inference call (#35829)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
Augusto Yao
2026-03-17 17:05:42 +08:00
committed by GitHub
parent 132bfd45b6
commit 9c7cab5ebb
7 changed files with 226 additions and 36 deletions

View File

@@ -3,10 +3,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from vllm.config import VllmConfig from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import ( from vllm.plugins.io_processors.interface import (
IOProcessor, IOProcessor,
@@ -16,14 +16,13 @@ from vllm.renderers import BaseRenderer
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
from .types import ( from .types import (
EMBED_TASKS,
SparseEmbeddingCompletionRequestMixin, SparseEmbeddingCompletionRequestMixin,
SparseEmbeddingResponse, SparseEmbeddingResponse,
SparseEmbeddingResponseData, SparseEmbeddingResponseData,
SparseEmbeddingTokenWeight, SparseEmbeddingTokenWeight,
) )
logger = init_logger(__name__)
class BgeM3SparseEmbeddingsProcessor( class BgeM3SparseEmbeddingsProcessor(
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse] IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
@@ -33,6 +32,22 @@ class BgeM3SparseEmbeddingsProcessor(
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = [] self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {} self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
self.renderer: BaseRenderer = renderer self.renderer: BaseRenderer = renderer
self.default_pooling_params = {}
pooler_config: PoolerConfig = vllm_config.model_config.pooler_config
if pooler_config is not None:
for param in ["use_activation", "dimensions"]:
if getattr(pooler_config, param, None) is None:
continue
self.default_pooling_params[param] = getattr(pooler_config, param)
self.embed_dimensions = vllm_config.model_config.embedding_size
self.embed_request_queue: list[EmbedRequestMixin] = []
def __repr__(self) -> str:
return (
f"BgeM3SparseEmbeddingsProcessor("
f"embed_dimensions={self.embed_dimensions}, "
f"default_pooling_params={self.default_pooling_params})"
)
def merge_pooling_params( def merge_pooling_params(
self, self,
@@ -41,7 +56,57 @@ class BgeM3SparseEmbeddingsProcessor(
if params is None: if params is None:
params = PoolingParams() params = PoolingParams()
# refer to PoolingCompletionRequest.to_pooling_params # refer to PoolingCompletionRequest.to_pooling_params
params.task = "token_classify" # set and verify pooling params
params.skip_reading_prefix_cache = True
raw_embed_request = self.embed_request_queue.pop(0)
if raw_embed_request.embed_task not in EMBED_TASKS:
raise ValueError(
f"Unsupported task {raw_embed_request}, "
f"Supported tasks are {EMBED_TASKS}"
)
has_dense_embed = True
if raw_embed_request.embed_task == "dense":
params.task = "embed"
params.skip_reading_prefix_cache = False
elif raw_embed_request.embed_task == "sparse":
params.task = "token_classify"
has_dense_embed = False
else:
params.task = "embed&token_classify"
params.use_activation = raw_embed_request.use_activation
if params.use_activation is None:
params.use_activation = True
if not has_dense_embed:
params.dimensions = None
return params
params.dimensions = raw_embed_request.dimensions
model_config: ModelConfig = self.vllm_config.model_config
for param in self.default_pooling_params:
if getattr(params, param, None) is None:
setattr(params, param, self.default_pooling_params[param])
if params.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f"support matryoshka representation, "
f"changing output dimensions will lead to poor results."
)
mds = model_config.matryoshka_dimensions
if mds is not None:
if params.dimensions not in mds:
raise ValueError(
f"Model {model_config.served_model_name!r} "
f"only supports {str(mds)} matryoshka dimensions, "
f"use other output dimensions will "
f"lead to poor results."
)
elif params.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
return params return params
def parse_request( def parse_request(
@@ -61,14 +126,16 @@ class BgeM3SparseEmbeddingsProcessor(
if request_id is not None: if request_id is not None:
assert request_id not in self.online_requests, "request_id duplicated" assert request_id not in self.online_requests, "request_id duplicated"
self.online_requests[request_id] = prompt self.online_requests[request_id] = prompt
self.embed_request_queue.extend(prompt.to_embed_requests_online())
else: else:
self.offline_requests.append(prompt) self.offline_requests.append(prompt)
self.embed_request_queue.extend(prompt.to_embed_requests_offline())
return prompt.input return prompt.input
def _get_sparse_embedding_request(self, request_id: str | None = None): def _get_sparse_embedding_request(self, request_id: str | None = None):
if request_id: if request_id:
return self.online_requests.pop(request_id, None) return self.online_requests.pop(request_id, None)
return self.offline_requests.pop() return self.offline_requests.pop(0)
def _build_sparse_embedding_token_weights( def _build_sparse_embedding_token_weights(
self, self,
@@ -100,26 +167,45 @@ class BgeM3SparseEmbeddingsProcessor(
) -> SparseEmbeddingResponse: ) -> SparseEmbeddingResponse:
num_prompt_tokens = 0 num_prompt_tokens = 0
response_data = [] response_data = []
return_tokens = self._get_sparse_embedding_request(request_id).return_tokens raw_request = self._get_sparse_embedding_request(request_id)
has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"]
has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"]
embed_dimensions = 0
if has_dense_embed:
embed_dimensions = (
self.embed_dimensions
if raw_request.dimensions is None
else raw_request.dimensions
)
for idx in range(len(model_output)): for idx in range(len(model_output)):
mo = model_output[idx] mo = model_output[idx]
sparse_embedding: dict[int, float] = {} sparse_embedding_dict: dict[int, float] = {}
num_prompt_tokens += len(mo.prompt_token_ids) num_prompt_tokens += len(mo.prompt_token_ids)
if len(mo.prompt_token_ids) != len(mo.outputs.data): dense_embedding: list[float] | None = None
# this is the case that add_special_tokens is True, sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None
# which means first token and last token are special tokens if has_dense_embed:
mo.prompt_token_ids = mo.prompt_token_ids[1:] dense_embedding = mo.outputs.data[:embed_dimensions].tolist()
for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()): if has_sparse_embed:
sparse_embedding[token_id] = max( sparse_weights = mo.outputs.data[embed_dimensions:].tolist()
weight, sparse_embedding.get(token_id, 0.0) if len(mo.prompt_token_ids) != len(sparse_weights):
# this is the case that add_special_tokens is True,
# which means first token and last token are special tokens
mo.prompt_token_ids = mo.prompt_token_ids[1:]
for token_id, weight in zip(mo.prompt_token_ids, sparse_weights):
sparse_embedding_dict[token_id] = max(
weight, sparse_embedding_dict.get(token_id, 0.0)
)
sparse_embedding = self._build_sparse_embedding_token_weights(
sparse_embedding_dict,
raw_request.return_tokens,
) )
response_data.append( response_data.append(
SparseEmbeddingResponseData( SparseEmbeddingResponseData(
index=idx, index=idx,
sparse_embedding=self._build_sparse_embedding_token_weights( object=raw_request.embed_task,
sparse_embedding, sparse_embedding=sparse_embedding,
return_tokens, dense_embedding=dense_embedding,
),
) )
) )

View File

@@ -1,18 +1,44 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, get_args
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
EmbedRequestMixin,
)
EmbedTask = Literal[
"sparse",
"dense",
"dense&sparse",
]
EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask)
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin): class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin):
return_tokens: bool | None = Field( return_tokens: bool | None = Field(
default=None, default=None,
description="Whether to return dict shows the mapping of token_id to text." description="Whether to return dict shows the mapping of token_id to text."
"`None` or False means not return.", "`None` or False means not return.",
) )
embed_task: EmbedTask = Field(
default="dense&sparse",
description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', "
"default to 'dense&sparse'",
)
def to_embed_requests_offline(self) -> list[EmbedRequestMixin]:
if isinstance(self.input, list):
return [self] * len(self.input)
return [self]
def to_embed_requests_online(self) -> list[EmbedRequestMixin]:
return [self]
class SparseEmbeddingTokenWeight(BaseModel): class SparseEmbeddingTokenWeight(BaseModel):
@@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel):
class SparseEmbeddingResponseData(BaseModel): class SparseEmbeddingResponseData(BaseModel):
index: int index: int
object: str = "sparse-embedding" object: str = "dense&sparse"
sparse_embedding: list[SparseEmbeddingTokenWeight] sparse_embedding: list[SparseEmbeddingTokenWeight] | None
dense_embedding: list[float] | None
class SparseEmbeddingResponse(BaseModel): class SparseEmbeddingResponse(BaseModel):

View File

@@ -19,6 +19,12 @@ model_config = {
), ),
} }
dense_embedding_sum = [
-0.7214539647102356, # "What is the capital of France?"
-0.6926871538162231, # "What is the capital of Germany?"
-0.7129564881324768, # "What is the capital of Spain?"
]
def _float_close(expected: object, result: object): def _float_close(expected: object, result: object):
assert isinstance(expected, float) and isinstance(result, float), ( assert isinstance(expected, float) and isinstance(result, float), (
@@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str):
return getattr(obj, key, None) return getattr(obj, key, None)
def _check_dense_embedding(data, index=0):
assert _float_close(sum(data), dense_embedding_sum[index]), (
"dense-embedding result not match"
)
def _check_sparse_embedding(data, check_tokens=False): def _check_sparse_embedding(data, check_tokens=False):
expected_weights = [ expected_weights = [
{"token_id": 32, "weight": 0.0552978515625, "token": "?"}, {"token_id": 32, "weight": 0.0552978515625, "token": "?"},
@@ -109,7 +121,7 @@ async def test_bge_m3_sparse_plugin_online(
assert len(_get_attr_or_val(parsed_response, "data")) > 0 assert len(_get_attr_or_val(parsed_response, "data")) > 0
data_entry = _get_attr_or_val(parsed_response, "data")[0] data_entry = _get_attr_or_val(parsed_response, "data")[0]
assert _get_attr_or_val(data_entry, "object") == "sparse-embedding" assert _get_attr_or_val(data_entry, "object") == "dense&sparse"
assert _get_attr_or_val(data_entry, "sparse_embedding") assert _get_attr_or_val(data_entry, "sparse_embedding")
# Verify sparse embedding format # Verify sparse embedding format
@@ -117,6 +129,11 @@ async def test_bge_m3_sparse_plugin_online(
assert isinstance(sparse_embedding, list) assert isinstance(sparse_embedding, list)
_check_sparse_embedding(sparse_embedding, return_tokens) _check_sparse_embedding(sparse_embedding, return_tokens)
# Verify dense embedding format
dense_embedding = _get_attr_or_val(data_entry, "dense_embedding")
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding)
# Verify usage information # Verify usage information
usage = _get_attr_or_val(parsed_response, "usage") usage = _get_attr_or_val(parsed_response, "usage")
assert usage, f"usage not found for {parsed_response}" assert usage, f"usage not found for {parsed_response}"
@@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
sparse_embedding = output.sparse_embedding sparse_embedding = output.sparse_embedding
assert isinstance(sparse_embedding, list) assert isinstance(sparse_embedding, list)
_check_sparse_embedding(sparse_embedding, return_tokens) _check_sparse_embedding(sparse_embedding, return_tokens)
dense_embedding = output.dense_embedding
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding)
# Verify usage # Verify usage
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0
@@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
# Each output should have sparse embeddings # Each output should have sparse embeddings
sparse_embedding = output.sparse_embedding sparse_embedding = output.sparse_embedding
assert isinstance(sparse_embedding, list) assert isinstance(sparse_embedding, list)
dense_embedding = output.dense_embedding
assert isinstance(dense_embedding, list)
_check_dense_embedding(dense_embedding, i)
# Verify usage # Verify usage
assert response.usage.prompt_tokens > 0 assert response.usage.prompt_tokens > 0

View File

@@ -170,4 +170,42 @@ class BOSEOSFilter(Pooler):
return pooled_outputs return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"] class BgeM3Pooler(Pooler):
def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None:
super().__init__()
self.token_classify_pooler = token_classify_pooler
self.embed_pooler = embed_pooler
def forward(
self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata
) -> PoolerOutput:
embed_outputs = self.embed_pooler(hidden_states, pooling_metadata)
token_classify_outputs = self.token_classify_pooler(
hidden_states, pooling_metadata
)
pooler_outputs: list[torch.Tensor] = []
for embed_output, token_classify_output in zip(
embed_outputs, token_classify_outputs
):
pooler_outputs.append(
torch.cat(
[embed_output.view(-1), token_classify_output.view(-1)], dim=-1
)
)
return pooler_outputs
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed&token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.embed_pooler.get_pooling_updates(
"embed"
) | self.token_classify_pooler.get_pooling_updates("token_classify")
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"]

View File

@@ -10,6 +10,7 @@ from transformers import RobertaConfig
from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import ( from vllm.model_executor.layers.pooler import (
BgeM3Pooler,
BOSEOSFilter, BOSEOSFilter,
DispatchPooler, DispatchPooler,
Pooler, Pooler,
@@ -216,24 +217,29 @@ class BgeM3EmbeddingModel(RobertaEmbeddingModel):
self.colbert_linear = nn.Linear( self.colbert_linear = nn.Linear(
self.hidden_size, self.hidden_size, dtype=self.head_dtype self.hidden_size, self.hidden_size, dtype=self.head_dtype
) )
embed_pooler = pooler_for_embed(pooler_config)
token_classify_pooler = BOSEOSFilter(
pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
)
return DispatchPooler( return DispatchPooler(
{ {
"embed": pooler_for_embed(pooler_config), "embed": embed_pooler,
"token_embed": BOSEOSFilter( "token_embed": BOSEOSFilter(
pooler_for_token_embed(pooler_config, self.colbert_linear), pooler_for_token_embed(pooler_config, self.colbert_linear),
self.bos_token_id, self.bos_token_id,
# for some reason m3 only filters the bos for colbert vectors # for some reason m3 only filters the bos for colbert vectors
), ),
"token_classify": BOSEOSFilter( "token_classify": token_classify_pooler,
pooler_for_token_classify( "embed&token_classify": BgeM3Pooler(
pooler_config, token_classify_pooler, embed_pooler
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
), ),
} }
) )

View File

@@ -96,6 +96,10 @@ class PoolingParams(
self.skip_reading_prefix_cache = True self.skip_reading_prefix_cache = True
return return
# skipping verify, let plugins configure and validate pooling params
if self.task not in self.valid_parameters:
return
# NOTE: Task validation needs to done against the model instance, # NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included # which is not available in model config. So, it's not included
# in this method # in this method

View File

@@ -6,7 +6,13 @@ GenerationTask = Literal["generate", "transcription", "realtime"]
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask) GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
PoolingTask = Literal[ PoolingTask = Literal[
"embed", "classify", "score", "token_embed", "token_classify", "plugin" "embed",
"classify",
"score",
"token_embed",
"token_classify",
"plugin",
"embed&token_classify",
] ]
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask) POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)