Support bge-m3 sparse embeddings and colbert embeddings (#14526)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
committed by
GitHub
parent
444e2e7e1f
commit
ff365eea94
@@ -6,7 +6,11 @@ from typing import TypeAlias
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierFn,
|
||||
PoolingParamsUpdate,
|
||||
ProjectorFn,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
def pooler_for_token_embed(
|
||||
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
|
||||
) -> TokenPooler:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenEmbeddingPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
projector=_load_st_projector(model_config),
|
||||
projector=projector
|
||||
if projector is not None
|
||||
else _load_st_projector(model_config),
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user