[Feature]: Support for multiple embedding types in a single inference call (#35829)
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
This commit is contained in:
@@ -3,10 +3,10 @@
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||||
|
from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput
|
||||||
from vllm.plugins.io_processors.interface import (
|
from vllm.plugins.io_processors.interface import (
|
||||||
IOProcessor,
|
IOProcessor,
|
||||||
@@ -16,14 +16,13 @@ from vllm.renderers import BaseRenderer
|
|||||||
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
|
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
|
||||||
|
|
||||||
from .types import (
|
from .types import (
|
||||||
|
EMBED_TASKS,
|
||||||
SparseEmbeddingCompletionRequestMixin,
|
SparseEmbeddingCompletionRequestMixin,
|
||||||
SparseEmbeddingResponse,
|
SparseEmbeddingResponse,
|
||||||
SparseEmbeddingResponseData,
|
SparseEmbeddingResponseData,
|
||||||
SparseEmbeddingTokenWeight,
|
SparseEmbeddingTokenWeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BgeM3SparseEmbeddingsProcessor(
|
class BgeM3SparseEmbeddingsProcessor(
|
||||||
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
|
IOProcessor[SparseEmbeddingCompletionRequestMixin, SparseEmbeddingResponse]
|
||||||
@@ -33,6 +32,22 @@ class BgeM3SparseEmbeddingsProcessor(
|
|||||||
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
|
self.offline_requests: list[SparseEmbeddingCompletionRequestMixin] = []
|
||||||
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
|
self.online_requests: dict[str, SparseEmbeddingCompletionRequestMixin] = {}
|
||||||
self.renderer: BaseRenderer = renderer
|
self.renderer: BaseRenderer = renderer
|
||||||
|
self.default_pooling_params = {}
|
||||||
|
pooler_config: PoolerConfig = vllm_config.model_config.pooler_config
|
||||||
|
if pooler_config is not None:
|
||||||
|
for param in ["use_activation", "dimensions"]:
|
||||||
|
if getattr(pooler_config, param, None) is None:
|
||||||
|
continue
|
||||||
|
self.default_pooling_params[param] = getattr(pooler_config, param)
|
||||||
|
self.embed_dimensions = vllm_config.model_config.embedding_size
|
||||||
|
self.embed_request_queue: list[EmbedRequestMixin] = []
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"BgeM3SparseEmbeddingsProcessor("
|
||||||
|
f"embed_dimensions={self.embed_dimensions}, "
|
||||||
|
f"default_pooling_params={self.default_pooling_params})"
|
||||||
|
)
|
||||||
|
|
||||||
def merge_pooling_params(
|
def merge_pooling_params(
|
||||||
self,
|
self,
|
||||||
@@ -41,7 +56,57 @@ class BgeM3SparseEmbeddingsProcessor(
|
|||||||
if params is None:
|
if params is None:
|
||||||
params = PoolingParams()
|
params = PoolingParams()
|
||||||
# refer to PoolingCompletionRequest.to_pooling_params
|
# refer to PoolingCompletionRequest.to_pooling_params
|
||||||
params.task = "token_classify"
|
# set and verify pooling params
|
||||||
|
params.skip_reading_prefix_cache = True
|
||||||
|
|
||||||
|
raw_embed_request = self.embed_request_queue.pop(0)
|
||||||
|
if raw_embed_request.embed_task not in EMBED_TASKS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported task {raw_embed_request}, "
|
||||||
|
f"Supported tasks are {EMBED_TASKS}"
|
||||||
|
)
|
||||||
|
has_dense_embed = True
|
||||||
|
if raw_embed_request.embed_task == "dense":
|
||||||
|
params.task = "embed"
|
||||||
|
params.skip_reading_prefix_cache = False
|
||||||
|
elif raw_embed_request.embed_task == "sparse":
|
||||||
|
params.task = "token_classify"
|
||||||
|
has_dense_embed = False
|
||||||
|
else:
|
||||||
|
params.task = "embed&token_classify"
|
||||||
|
params.use_activation = raw_embed_request.use_activation
|
||||||
|
if params.use_activation is None:
|
||||||
|
params.use_activation = True
|
||||||
|
if not has_dense_embed:
|
||||||
|
params.dimensions = None
|
||||||
|
return params
|
||||||
|
|
||||||
|
params.dimensions = raw_embed_request.dimensions
|
||||||
|
|
||||||
|
model_config: ModelConfig = self.vllm_config.model_config
|
||||||
|
for param in self.default_pooling_params:
|
||||||
|
if getattr(params, param, None) is None:
|
||||||
|
setattr(params, param, self.default_pooling_params[param])
|
||||||
|
|
||||||
|
if params.dimensions is not None:
|
||||||
|
if not model_config.is_matryoshka:
|
||||||
|
raise ValueError(
|
||||||
|
f'Model "{model_config.served_model_name}" does not '
|
||||||
|
f"support matryoshka representation, "
|
||||||
|
f"changing output dimensions will lead to poor results."
|
||||||
|
)
|
||||||
|
|
||||||
|
mds = model_config.matryoshka_dimensions
|
||||||
|
if mds is not None:
|
||||||
|
if params.dimensions not in mds:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {model_config.served_model_name!r} "
|
||||||
|
f"only supports {str(mds)} matryoshka dimensions, "
|
||||||
|
f"use other output dimensions will "
|
||||||
|
f"lead to poor results."
|
||||||
|
)
|
||||||
|
elif params.dimensions < 1:
|
||||||
|
raise ValueError("Dimensions must be greater than 0")
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def parse_request(
|
def parse_request(
|
||||||
@@ -61,14 +126,16 @@ class BgeM3SparseEmbeddingsProcessor(
|
|||||||
if request_id is not None:
|
if request_id is not None:
|
||||||
assert request_id not in self.online_requests, "request_id duplicated"
|
assert request_id not in self.online_requests, "request_id duplicated"
|
||||||
self.online_requests[request_id] = prompt
|
self.online_requests[request_id] = prompt
|
||||||
|
self.embed_request_queue.extend(prompt.to_embed_requests_online())
|
||||||
else:
|
else:
|
||||||
self.offline_requests.append(prompt)
|
self.offline_requests.append(prompt)
|
||||||
|
self.embed_request_queue.extend(prompt.to_embed_requests_offline())
|
||||||
return prompt.input
|
return prompt.input
|
||||||
|
|
||||||
def _get_sparse_embedding_request(self, request_id: str | None = None):
|
def _get_sparse_embedding_request(self, request_id: str | None = None):
|
||||||
if request_id:
|
if request_id:
|
||||||
return self.online_requests.pop(request_id, None)
|
return self.online_requests.pop(request_id, None)
|
||||||
return self.offline_requests.pop()
|
return self.offline_requests.pop(0)
|
||||||
|
|
||||||
def _build_sparse_embedding_token_weights(
|
def _build_sparse_embedding_token_weights(
|
||||||
self,
|
self,
|
||||||
@@ -100,26 +167,45 @@ class BgeM3SparseEmbeddingsProcessor(
|
|||||||
) -> SparseEmbeddingResponse:
|
) -> SparseEmbeddingResponse:
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
response_data = []
|
response_data = []
|
||||||
return_tokens = self._get_sparse_embedding_request(request_id).return_tokens
|
raw_request = self._get_sparse_embedding_request(request_id)
|
||||||
|
has_dense_embed = raw_request.embed_task in ["dense", "dense&sparse"]
|
||||||
|
has_sparse_embed = raw_request.embed_task in ["sparse", "dense&sparse"]
|
||||||
|
embed_dimensions = 0
|
||||||
|
if has_dense_embed:
|
||||||
|
embed_dimensions = (
|
||||||
|
self.embed_dimensions
|
||||||
|
if raw_request.dimensions is None
|
||||||
|
else raw_request.dimensions
|
||||||
|
)
|
||||||
for idx in range(len(model_output)):
|
for idx in range(len(model_output)):
|
||||||
mo = model_output[idx]
|
mo = model_output[idx]
|
||||||
sparse_embedding: dict[int, float] = {}
|
sparse_embedding_dict: dict[int, float] = {}
|
||||||
num_prompt_tokens += len(mo.prompt_token_ids)
|
num_prompt_tokens += len(mo.prompt_token_ids)
|
||||||
if len(mo.prompt_token_ids) != len(mo.outputs.data):
|
dense_embedding: list[float] | None = None
|
||||||
# this is the case that add_special_tokens is True,
|
sparse_embedding: list[SparseEmbeddingTokenWeight] | None = None
|
||||||
# which means first token and last token are special tokens
|
if has_dense_embed:
|
||||||
mo.prompt_token_ids = mo.prompt_token_ids[1:]
|
dense_embedding = mo.outputs.data[:embed_dimensions].tolist()
|
||||||
for token_id, weight in zip(mo.prompt_token_ids, mo.outputs.data.tolist()):
|
if has_sparse_embed:
|
||||||
sparse_embedding[token_id] = max(
|
sparse_weights = mo.outputs.data[embed_dimensions:].tolist()
|
||||||
weight, sparse_embedding.get(token_id, 0.0)
|
if len(mo.prompt_token_ids) != len(sparse_weights):
|
||||||
|
# this is the case that add_special_tokens is True,
|
||||||
|
# which means first token and last token are special tokens
|
||||||
|
mo.prompt_token_ids = mo.prompt_token_ids[1:]
|
||||||
|
for token_id, weight in zip(mo.prompt_token_ids, sparse_weights):
|
||||||
|
sparse_embedding_dict[token_id] = max(
|
||||||
|
weight, sparse_embedding_dict.get(token_id, 0.0)
|
||||||
|
)
|
||||||
|
sparse_embedding = self._build_sparse_embedding_token_weights(
|
||||||
|
sparse_embedding_dict,
|
||||||
|
raw_request.return_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_data.append(
|
response_data.append(
|
||||||
SparseEmbeddingResponseData(
|
SparseEmbeddingResponseData(
|
||||||
index=idx,
|
index=idx,
|
||||||
sparse_embedding=self._build_sparse_embedding_token_weights(
|
object=raw_request.embed_task,
|
||||||
sparse_embedding,
|
sparse_embedding=sparse_embedding,
|
||||||
return_tokens,
|
dense_embedding=dense_embedding,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,44 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Literal, get_args
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||||
from vllm.entrypoints.pooling.base.protocol import CompletionRequestMixin
|
from vllm.entrypoints.pooling.base.protocol import (
|
||||||
|
CompletionRequestMixin,
|
||||||
|
EmbedRequestMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
EmbedTask = Literal[
|
||||||
|
"sparse",
|
||||||
|
"dense",
|
||||||
|
"dense&sparse",
|
||||||
|
]
|
||||||
|
|
||||||
|
EMBED_TASKS: tuple[EmbedTask, ...] = get_args(EmbedTask)
|
||||||
|
|
||||||
|
|
||||||
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin):
|
class SparseEmbeddingCompletionRequestMixin(CompletionRequestMixin, EmbedRequestMixin):
|
||||||
return_tokens: bool | None = Field(
|
return_tokens: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Whether to return dict shows the mapping of token_id to text."
|
description="Whether to return dict shows the mapping of token_id to text."
|
||||||
"`None` or False means not return.",
|
"`None` or False means not return.",
|
||||||
)
|
)
|
||||||
|
embed_task: EmbedTask = Field(
|
||||||
|
default="dense&sparse",
|
||||||
|
description="embed task, can be one of 'sparse', 'dense' , 'dense&sparse', "
|
||||||
|
"default to 'dense&sparse'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_embed_requests_offline(self) -> list[EmbedRequestMixin]:
|
||||||
|
if isinstance(self.input, list):
|
||||||
|
return [self] * len(self.input)
|
||||||
|
return [self]
|
||||||
|
|
||||||
|
def to_embed_requests_online(self) -> list[EmbedRequestMixin]:
|
||||||
|
return [self]
|
||||||
|
|
||||||
|
|
||||||
class SparseEmbeddingTokenWeight(BaseModel):
|
class SparseEmbeddingTokenWeight(BaseModel):
|
||||||
@@ -23,8 +49,9 @@ class SparseEmbeddingTokenWeight(BaseModel):
|
|||||||
|
|
||||||
class SparseEmbeddingResponseData(BaseModel):
|
class SparseEmbeddingResponseData(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
object: str = "sparse-embedding"
|
object: str = "dense&sparse"
|
||||||
sparse_embedding: list[SparseEmbeddingTokenWeight]
|
sparse_embedding: list[SparseEmbeddingTokenWeight] | None
|
||||||
|
dense_embedding: list[float] | None
|
||||||
|
|
||||||
|
|
||||||
class SparseEmbeddingResponse(BaseModel):
|
class SparseEmbeddingResponse(BaseModel):
|
||||||
|
|||||||
@@ -19,6 +19,12 @@ model_config = {
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dense_embedding_sum = [
|
||||||
|
-0.7214539647102356, # "What is the capital of France?"
|
||||||
|
-0.6926871538162231, # "What is the capital of Germany?"
|
||||||
|
-0.7129564881324768, # "What is the capital of Spain?"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _float_close(expected: object, result: object):
|
def _float_close(expected: object, result: object):
|
||||||
assert isinstance(expected, float) and isinstance(result, float), (
|
assert isinstance(expected, float) and isinstance(result, float), (
|
||||||
@@ -33,6 +39,12 @@ def _get_attr_or_val(obj: object | dict, key: str):
|
|||||||
return getattr(obj, key, None)
|
return getattr(obj, key, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dense_embedding(data, index=0):
|
||||||
|
assert _float_close(sum(data), dense_embedding_sum[index]), (
|
||||||
|
"dense-embedding result not match"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _check_sparse_embedding(data, check_tokens=False):
|
def _check_sparse_embedding(data, check_tokens=False):
|
||||||
expected_weights = [
|
expected_weights = [
|
||||||
{"token_id": 32, "weight": 0.0552978515625, "token": "?"},
|
{"token_id": 32, "weight": 0.0552978515625, "token": "?"},
|
||||||
@@ -109,7 +121,7 @@ async def test_bge_m3_sparse_plugin_online(
|
|||||||
assert len(_get_attr_or_val(parsed_response, "data")) > 0
|
assert len(_get_attr_or_val(parsed_response, "data")) > 0
|
||||||
|
|
||||||
data_entry = _get_attr_or_val(parsed_response, "data")[0]
|
data_entry = _get_attr_or_val(parsed_response, "data")[0]
|
||||||
assert _get_attr_or_val(data_entry, "object") == "sparse-embedding"
|
assert _get_attr_or_val(data_entry, "object") == "dense&sparse"
|
||||||
assert _get_attr_or_val(data_entry, "sparse_embedding")
|
assert _get_attr_or_val(data_entry, "sparse_embedding")
|
||||||
|
|
||||||
# Verify sparse embedding format
|
# Verify sparse embedding format
|
||||||
@@ -117,6 +129,11 @@ async def test_bge_m3_sparse_plugin_online(
|
|||||||
assert isinstance(sparse_embedding, list)
|
assert isinstance(sparse_embedding, list)
|
||||||
_check_sparse_embedding(sparse_embedding, return_tokens)
|
_check_sparse_embedding(sparse_embedding, return_tokens)
|
||||||
|
|
||||||
|
# Verify dense embedding format
|
||||||
|
dense_embedding = _get_attr_or_val(data_entry, "dense_embedding")
|
||||||
|
assert isinstance(dense_embedding, list)
|
||||||
|
_check_dense_embedding(dense_embedding)
|
||||||
|
|
||||||
# Verify usage information
|
# Verify usage information
|
||||||
usage = _get_attr_or_val(parsed_response, "usage")
|
usage = _get_attr_or_val(parsed_response, "usage")
|
||||||
assert usage, f"usage not found for {parsed_response}"
|
assert usage, f"usage not found for {parsed_response}"
|
||||||
@@ -164,6 +181,9 @@ def test_bge_m3_sparse_plugin_offline(vllm_runner, return_tokens: bool):
|
|||||||
sparse_embedding = output.sparse_embedding
|
sparse_embedding = output.sparse_embedding
|
||||||
assert isinstance(sparse_embedding, list)
|
assert isinstance(sparse_embedding, list)
|
||||||
_check_sparse_embedding(sparse_embedding, return_tokens)
|
_check_sparse_embedding(sparse_embedding, return_tokens)
|
||||||
|
dense_embedding = output.dense_embedding
|
||||||
|
assert isinstance(dense_embedding, list)
|
||||||
|
_check_dense_embedding(dense_embedding)
|
||||||
|
|
||||||
# Verify usage
|
# Verify usage
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
@@ -206,6 +226,9 @@ def test_bge_m3_sparse_plugin_offline_multiple_inputs(vllm_runner):
|
|||||||
# Each output should have sparse embeddings
|
# Each output should have sparse embeddings
|
||||||
sparse_embedding = output.sparse_embedding
|
sparse_embedding = output.sparse_embedding
|
||||||
assert isinstance(sparse_embedding, list)
|
assert isinstance(sparse_embedding, list)
|
||||||
|
dense_embedding = output.dense_embedding
|
||||||
|
assert isinstance(dense_embedding, list)
|
||||||
|
_check_dense_embedding(dense_embedding, i)
|
||||||
|
|
||||||
# Verify usage
|
# Verify usage
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
|
|||||||
@@ -170,4 +170,42 @@ class BOSEOSFilter(Pooler):
|
|||||||
return pooled_outputs
|
return pooled_outputs
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
|
class BgeM3Pooler(Pooler):
|
||||||
|
def __init__(self, token_classify_pooler: Pooler, embed_pooler: Pooler) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.token_classify_pooler = token_classify_pooler
|
||||||
|
self.embed_pooler = embed_pooler
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata
|
||||||
|
) -> PoolerOutput:
|
||||||
|
embed_outputs = self.embed_pooler(hidden_states, pooling_metadata)
|
||||||
|
token_classify_outputs = self.token_classify_pooler(
|
||||||
|
hidden_states, pooling_metadata
|
||||||
|
)
|
||||||
|
pooler_outputs: list[torch.Tensor] = []
|
||||||
|
for embed_output, token_classify_output in zip(
|
||||||
|
embed_outputs, token_classify_outputs
|
||||||
|
):
|
||||||
|
pooler_outputs.append(
|
||||||
|
torch.cat(
|
||||||
|
[embed_output.view(-1), token_classify_output.view(-1)], dim=-1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return pooler_outputs
|
||||||
|
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"embed&token_classify"}
|
||||||
|
|
||||||
|
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||||
|
return self.embed_pooler.get_pooling_updates(
|
||||||
|
"embed"
|
||||||
|
) | self.token_classify_pooler.get_pooling_updates("token_classify")
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = f"supported_task={self.get_supported_tasks()}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler", "BgeM3Pooler"]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from transformers import RobertaConfig
|
|||||||
|
|
||||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||||
from vllm.model_executor.layers.pooler import (
|
from vllm.model_executor.layers.pooler import (
|
||||||
|
BgeM3Pooler,
|
||||||
BOSEOSFilter,
|
BOSEOSFilter,
|
||||||
DispatchPooler,
|
DispatchPooler,
|
||||||
Pooler,
|
Pooler,
|
||||||
@@ -216,24 +217,29 @@ class BgeM3EmbeddingModel(RobertaEmbeddingModel):
|
|||||||
self.colbert_linear = nn.Linear(
|
self.colbert_linear = nn.Linear(
|
||||||
self.hidden_size, self.hidden_size, dtype=self.head_dtype
|
self.hidden_size, self.hidden_size, dtype=self.head_dtype
|
||||||
)
|
)
|
||||||
|
embed_pooler = pooler_for_embed(pooler_config)
|
||||||
|
token_classify_pooler = BOSEOSFilter(
|
||||||
|
pooler_for_token_classify(
|
||||||
|
pooler_config,
|
||||||
|
pooling=AllPool(),
|
||||||
|
classifier=self.sparse_linear,
|
||||||
|
act_fn=torch.relu,
|
||||||
|
),
|
||||||
|
self.bos_token_id,
|
||||||
|
self.eos_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
return DispatchPooler(
|
return DispatchPooler(
|
||||||
{
|
{
|
||||||
"embed": pooler_for_embed(pooler_config),
|
"embed": embed_pooler,
|
||||||
"token_embed": BOSEOSFilter(
|
"token_embed": BOSEOSFilter(
|
||||||
pooler_for_token_embed(pooler_config, self.colbert_linear),
|
pooler_for_token_embed(pooler_config, self.colbert_linear),
|
||||||
self.bos_token_id,
|
self.bos_token_id,
|
||||||
# for some reason m3 only filters the bos for colbert vectors
|
# for some reason m3 only filters the bos for colbert vectors
|
||||||
),
|
),
|
||||||
"token_classify": BOSEOSFilter(
|
"token_classify": token_classify_pooler,
|
||||||
pooler_for_token_classify(
|
"embed&token_classify": BgeM3Pooler(
|
||||||
pooler_config,
|
token_classify_pooler, embed_pooler
|
||||||
pooling=AllPool(),
|
|
||||||
classifier=self.sparse_linear,
|
|
||||||
act_fn=torch.relu,
|
|
||||||
),
|
|
||||||
self.bos_token_id,
|
|
||||||
self.eos_token_id,
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -96,6 +96,10 @@ class PoolingParams(
|
|||||||
self.skip_reading_prefix_cache = True
|
self.skip_reading_prefix_cache = True
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# skipping verify, let plugins configure and validate pooling params
|
||||||
|
if self.task not in self.valid_parameters:
|
||||||
|
return
|
||||||
|
|
||||||
# NOTE: Task validation needs to done against the model instance,
|
# NOTE: Task validation needs to done against the model instance,
|
||||||
# which is not available in model config. So, it's not included
|
# which is not available in model config. So, it's not included
|
||||||
# in this method
|
# in this method
|
||||||
|
|||||||
@@ -6,7 +6,13 @@ GenerationTask = Literal["generate", "transcription", "realtime"]
|
|||||||
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
|
GENERATION_TASKS: tuple[GenerationTask, ...] = get_args(GenerationTask)
|
||||||
|
|
||||||
PoolingTask = Literal[
|
PoolingTask = Literal[
|
||||||
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
|
"embed",
|
||||||
|
"classify",
|
||||||
|
"score",
|
||||||
|
"token_embed",
|
||||||
|
"token_classify",
|
||||||
|
"plugin",
|
||||||
|
"embed&token_classify",
|
||||||
]
|
]
|
||||||
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
|
POOLING_TASKS: tuple[PoolingTask, ...] = get_args(PoolingTask)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user