Support bge-m3 sparse embeddings and colbert embeddings (#14526)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
Maximilien de Bayser
2026-01-22 12:52:57 -03:00
committed by GitHub
parent 444e2e7e1f
commit ff365eea94
9 changed files with 393 additions and 19 deletions

View File

@@ -6,7 +6,11 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler import (
ClassifierFn,
PoolingParamsUpdate,
ProjectorFn,
)
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
return pooled_data
def pooler_for_token_embed(pooler_config: PoolerConfig):
def pooler_for_token_embed(
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
) -> TokenPooler:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config),
projector=projector
if projector is not None
else _load_st_projector(model_config),
activation=PoolerNormalize(),
)