Support bge-m3 sparse embeddings and colbert embeddings (#14526)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
committed by
GitHub
parent
444e2e7e1f
commit
ff365eea94
@@ -125,4 +125,49 @@ class IdentityPooler(Pooler):
|
||||
return hidden_states
|
||||
|
||||
|
||||
__all__ = ["DispatchPooler", "IdentityPooler"]
|
||||
class BOSEOSFilter(Pooler):
|
||||
"""Filters the BOS and EOS token results from outputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler: Pooler,
|
||||
bos_token_id: int = -1, # -1 disables the filtering
|
||||
eos_token_id: int = -1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooler = pooler
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooler.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
|
||||
assert isinstance(pooled_outputs, list)
|
||||
|
||||
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
||||
pooled_data = pooled_outputs[i]
|
||||
assert (
|
||||
isinstance(pooled_data, torch.Tensor)
|
||||
and pooled_data.shape[0] == prompt_len
|
||||
)
|
||||
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
|
||||
if token_ids[0] == self.bos_token_id:
|
||||
pooled_data = pooled_data[1:]
|
||||
if token_ids[-1] == self.eos_token_id:
|
||||
pooled_data = pooled_data[:-1]
|
||||
pooled_outputs[i] = pooled_data.squeeze()
|
||||
|
||||
return pooled_outputs
|
||||
|
||||
|
||||
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
|
||||
|
||||
@@ -6,7 +6,11 @@ from typing import TypeAlias
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierFn,
|
||||
PoolingParamsUpdate,
|
||||
ProjectorFn,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
def pooler_for_token_embed(
|
||||
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
|
||||
) -> TokenPooler:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenEmbeddingPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
projector=_load_st_projector(model_config),
|
||||
projector=projector
|
||||
if projector is not None
|
||||
else _load_st_projector(model_config),
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user