[Feature]: Support for multiple embedding types in a single inference call (#35829)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
@@ -3,10 +3,10 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import (
|
||||
IOProcessor,
|
||||
@@ -16,14 +16,13 @@ from vllm.renderers import BaseRenderer
|
||||
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
|
||||
|
||||
from .types import (
|
||||
EMBED_TASKS,
|
||||
SparseEmbeddingCompletionRequestMixin,
|
||||
SparseEmbeddingResponse,
|
||||
SparseEmbeddingResponseData,
|
||||
SparseEmbeddingTokenWeight,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BgeM3SparseEmbeddingsProcessor(
|
||||
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
|
||||
@@ -33,6 +32,22 @@ class BgeM3SparseEmbeddingsProcessor(
|
||||
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
|
||||
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
|
||||
self.renderer: BaseRenderer = renderer
|
||||
self.default_pooling_params = {}
|
||||
pooler_config: PoolerConfig = vllm_config.model_config.pooler_config
|
||||
if pooler_config is not None:
|
||||
for param in ["use_activation", "dimensions"]:
|
||||
if getattr(pooler_config, param, None) is None:
|
||||
continue
|
||||
self.default_pooling_params[param] = getattr(pooler_config, param)
|
||||
self.embed_dimensions = vllm_config.model_config.embedding_size
|
||||
self.embed_request_queue: list[EmbedRequestMixin] = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"BgeM3SparseEmbeddingsProcessor("
|
||||
f"embed_dimensions={self.embed_dimensions}, "
|
||||
f"default_pooling_params={self.default_pooling_params})"
|
||||
)
|
||||
|
||||
def merge_pooling_params(
|
||||
self,
|
||||
@@ -41,7 +56,57 @@ class BgeM3SparseEmbeddingsProcessor(
|
||||
if params is None:
|
||||
params = PoolingParams()
|
||||
# refer to PoolingCompletionRequest.to_pooling_params
|
||||
params.task = "token_classify"
|
||||
# set and verify pooling params
|
||||
params.skip_reading_prefix_cache = True
|
||||
|
||||
raw_embed_request = self.embed_request_queue.pop(0)
|
||||
if raw_embed_request.embed_task not in EMBED_TASKS:
|
||||
raise ValueError(
|
||||
f"Unsupported task {raw_embed_request}, "
|
||||
f"Supported tasks are {EMBED_TASKS}"
|
||||
)
|
||||
has_dense_embed = True
|
||||
if raw_embed_request.embed_task == "dense":
|
||||
params.task = "embed"
|
||||
params.skip_reading_prefix_cache = False
|
||||
elif raw_embed_request.embed_task == "sparse":
|
||||
params.task = "token_classify"
|
||||
has_dense_embed = False
|
||||
else:
|
||||
params.task = "embed&token_classify"
|
||||
params.use_activation = raw_embed_request.use_activation
|
||||
if params.use_activation is None:
|
||||
params.use_activation = True
|
||||
if not has_dense_embed:
|
||||
params.dimensions = None
|
||||
return params
|
||||
|
||||
params.dimensions = raw_embed_request.dimensions
|
||||
|
||||
model_config: ModelConfig = self.vllm_config.model_config
|
||||
for param in self.default_pooling_params:
|
||||
if getattr(params, param, None) is None:
|
||||
setattr(params, param, self.default_pooling_params[param])
|
||||
|
||||
if params.dimensions is not None:
|
||||
if not model_config.is_matryoshka:
|
||||
raise ValueError(
|
||||
f'Model "{model_config.served_model_name}" does not '
|
||||
f"support matryoshka representation, "
|
||||
f"changing output dimensions will lead to poor results."
|
||||
)
|
||||
|
||||
mds = model_config.matryoshka_dimensions
|
||||
if mds is not None:
|
||||
if params.dimensions not in mds:
|
||||
raise ValueError(
|
||||
f"Model {model_config.served_model_name!r} "
|
||||
f"only supports {str(mds)} matryoshka dimensions, "
|
||||
f"use other output dimensions will "
|
||||
f"lead to poor results."
|
||||
)
|
||||
elif params.dimensions < 1:
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
return params
|
||||
|
||||
def parse_request(
|
||||
@@ -61,14 +126,16 @@ class BgeM3SparseEmbeddingsProcessor(
|
||||
if request_id is not None:
|
||||
assert request_id not in self.online_requests, "request_id duplicated"
|
||||
self.online_requests[request_id] = prompt
|
||||
self.embed_request_queue.extend(prompt.to_embed_requests_online())
|
||||
else:
|
||||
self.offline_requests.append(prompt)
|
||||
self.embed_request_queue.extend(prompt.to_embed_requests_offline())
|
||||
return prompt.input
|
||||
|
||||
def _get_sparse_embedding_request(self, request_id: str | None = None):
|
||||
if request_id:
|
||||
return self.online_requests.pop(request_id, None)
|
||||
return self.offline_requests.pop()
|
||||
return self.offline_requests.pop(0)
|
||||
|
||||
def _build_sparse_embedding_token_weights(
|
||||
self,
|
||||
@@ -100,26 +167,45 @@ class BgeM3SparseEmbeddingsProcessor(
|
||||
) -> SparseEmbeddingResponse:
|
||||
num_prompt_tokens = 0
|
||||
response_data = []
|
||||
return_tokens = self._get_sparse_embedding_request(request_id).return_tokens
|
||||
raw_request = self._get_sparse_embedding_request(request_id)
|
||||
has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"]
|
||||
has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"]
|
||||
embed_dimensions = 0
|
||||
if has_dense_embed:
|
||||
embed_dimensions = (
|
||||
self.embed_dimensions
|
||||
if raw_request.dimensions is None
|
||||
else raw_request.dimensions
|
||||
)
|
||||
for idx in range(len(model_output)):
|
||||
mo = model_output[idx]
|
||||
sparse_embedding: dict[int, float] = {}
|
||||
sparse_embedding_dict: dict[int, float] = {}
|
||||
num_prompt_tokens += len(mo.prompt_token_ids)
|
||||
if len(mo.prompt_token_ids) != len(mo.outputs.data):
|
||||
# this is the case that add_special_tokens is True,
|
||||
# which means first token and last token are special tokens
|
||||
mo.prompt_token_ids = mo.prompt_token_ids[1:]
|
||||
for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()):
|
||||
sparse_embedding[token_id] = max(
|
||||
weight, sparse_embedding.get(token_id, 0.0)
|
||||
dense_embedding: list[float] | None = None
|
||||
sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None
|
||||
if has_dense_embed:
|
||||
dense_embedding = mo.outputs.data[:embed_dimensions].tolist()
|
||||
if has_sparse_embed:
|
||||
sparse_weights = mo.outputs.data[embed_dimensions:].tolist()
|
||||
if len(mo.prompt_token_ids) != len(sparse_weights):
|
||||
# this is the case that add_special_tokens is True,
|
||||
# which means first token and last token are special tokens
|
||||
mo.prompt_token_ids = mo.prompt_token_ids[1:]
|
||||
for token_id, weight in zip(mo.prompt_token_ids, sparse_weights):
|
||||
sparse_embedding_dict[token_id] = max(
|
||||
weight, sparse_embedding_dict.get(token_id, 0.0)
|
||||
)
|
||||
sparse_embedding = self._build_sparse_embedding_token_weights(
|
||||
sparse_embedding_dict,
|
||||
raw_request.return_tokens,
|
||||
)
|
||||
|
||||
response_data.append(
|
||||
SparseEmbeddingResponseData(
|
||||
index=idx,
|
||||
sparse_embedding=self._build_sparse_embedding_token_weights(
|
||||
sparse_embedding,
|
||||
return_tokens,
|
||||
),
|
||||
object=raw_request.embed_task,
|
||||
sparse_embedding=sparse_embedding,
|
||||
dense_embedding=dense_embedding,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,18 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, get_args
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
CompletionRequestMixin,
|
||||
EmbedRequestMixin,
|
||||
)
|
||||
|
||||
EmbedTask = Literal[
|
||||
"sparse",
|
||||
"dense",
|
||||
"dense&sparse",
|
||||
]
|
||||
|
||||
EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask)
|
||||
|
||||
|
||||
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin):
|
||||
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin):
|
||||
return_tokens: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to return dict shows the mapping of token_id to text."
|
||||
"`None` or False means not return.",
|
||||
)
|
||||
embed_task: EmbedTask = Field(
|
||||
default="dense&sparse",
|
||||
description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', "
|
||||
"default to 'dense&sparse'",
|
||||
)
|
||||
|
||||
def to_embed_requests_offline(self) -> list[EmbedRequestMixin]:
|
||||
if isinstance(self.input, list):
|
||||
return [self] * len(self.input)
|
||||
return [self]
|
||||
|
||||
def to_embed_requests_online(self) -> list[EmbedRequestMixin]:
|
||||
return [self]
|
||||
|
||||
|
||||
class SparseEmbeddingTokenWeight(BaseModel):
|
||||
@@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel):
|
||||
|
||||
class SparseEmbeddingResponseData(BaseModel):
|
||||
index: int
|
||||
object: str = "sparse-embedding"
|
||||
sparse_embedding: list[SparseEmbeddingTokenWeight]
|
||||
object: str = "dense&sparse"
|
||||
sparse_embedding: list[SparseEmbeddingTokenWeight] | None
|
||||
dense_embedding: list[float] | None
|
||||
|
||||
|
||||
class SparseEmbeddingResponse(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user