Support bge-m3 sparse embeddings and colbert embeddings (#14526)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
Maximilien de Bayser
2026-01-22 12:52:57 -03:00
committed by GitHub
parent 444e2e7e1f
commit ff365eea94
9 changed files with 393 additions and 19 deletions

View File

@@ -125,4 +125,49 @@ class IdentityPooler(Pooler):
return hidden_states
__all__ = ["DispatchPooler", "IdentityPooler"]
class BOSEOSFilter(Pooler):
"""Filters the BOS and EOS token results from outputs."""
def __init__(
self,
pooler: Pooler,
bos_token_id: int = -1, # -1 disables the filtering
eos_token_id: int = -1,
) -> None:
super().__init__()
self.pooler = pooler
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooler.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list)
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
pooled_data = pooled_outputs[i]
assert (
isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len
)
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze()
return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]

View File

@@ -6,7 +6,11 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler import (
ClassifierFn,
PoolingParamsUpdate,
ProjectorFn,
)
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
return pooled_data
def pooler_for_token_embed(pooler_config: PoolerConfig):
def pooler_for_token_embed(
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
) -> TokenPooler:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config),
projector=projector
if projector is not None
else _load_st_projector(model_config),
activation=PoolerNormalize(),
)