Support bge-m3 sparse embeddings and colbert embeddings (#14526)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
committed by
GitHub
parent
444e2e7e1f
commit
ff365eea94
@@ -125,4 +125,49 @@ class IdentityPooler(Pooler):
|
||||
return hidden_states
|
||||
|
||||
|
||||
__all__ = ["DispatchPooler", "IdentityPooler"]
|
||||
class BOSEOSFilter(Pooler):
|
||||
"""Filters the BOS and EOS token results from outputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler: Pooler,
|
||||
bos_token_id: int = -1, # -1 disables the filtering
|
||||
eos_token_id: int = -1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooler = pooler
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooler.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
|
||||
assert isinstance(pooled_outputs, list)
|
||||
|
||||
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
||||
pooled_data = pooled_outputs[i]
|
||||
assert (
|
||||
isinstance(pooled_data, torch.Tensor)
|
||||
and pooled_data.shape[0] == prompt_len
|
||||
)
|
||||
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
|
||||
if token_ids[0] == self.bos_token_id:
|
||||
pooled_data = pooled_data[1:]
|
||||
if token_ids[-1] == self.eos_token_id:
|
||||
pooled_data = pooled_data[:-1]
|
||||
pooled_outputs[i] = pooled_data.squeeze()
|
||||
|
||||
return pooled_outputs
|
||||
|
||||
|
||||
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
|
||||
|
||||
@@ -6,7 +6,11 @@ from typing import TypeAlias
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierFn,
|
||||
PoolingParamsUpdate,
|
||||
ProjectorFn,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
def pooler_for_token_embed(
|
||||
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
|
||||
) -> TokenPooler:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenEmbeddingPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
projector=_load_st_projector(model_config),
|
||||
projector=projector
|
||||
if projector is not None
|
||||
else _load_st_projector(model_config),
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
|
||||
@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = {
|
||||
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"LlavaNextForConditionalGeneration": (
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
BOSEOSFilter,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
pooler_for_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
AllPool,
|
||||
pooler_for_token_classify,
|
||||
pooler_for_token_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bert import (
|
||||
TOKEN_TYPE_SHIFT,
|
||||
BertEmbeddingModel,
|
||||
@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
def filter_secondary_weights(
|
||||
all_weights: Iterable[tuple[str, torch.Tensor]],
|
||||
secondary_weights: list[str],
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def filtered(n):
|
||||
return any(n.startswith(f) for f in secondary_weights)
|
||||
|
||||
return ((n, w) for n, w in all_weights1 if filtered(n)), (
|
||||
(n, w) for n, w in all_weights2 if not filtered(n)
|
||||
)
|
||||
|
||||
|
||||
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
|
||||
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
|
||||
|
||||
This class supports loading an additional sparse_linear.pt file
|
||||
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.head_dtype = model_config.head_dtype
|
||||
self.bos_token_id = model_config.hf_config.bos_token_id
|
||||
self.eos_token_id = model_config.hf_config.eos_token_id
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
|
||||
self.secondary_weight_files = [
|
||||
prefix + "pt" for prefix in self.secondary_weight_prefixes
|
||||
]
|
||||
|
||||
self.secondary_weights = [
|
||||
DefaultModelLoader.Source(
|
||||
model_or_path=vllm_config.model_config.model,
|
||||
revision=None,
|
||||
prefix=prefix,
|
||||
allow_patterns_overrides=[filename],
|
||||
)
|
||||
for filename, prefix in zip(
|
||||
self.secondary_weight_files, self.secondary_weight_prefixes
|
||||
)
|
||||
]
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
|
||||
self.colbert_linear = nn.Linear(
|
||||
self.hidden_size, self.hidden_size, dtype=self.head_dtype
|
||||
)
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"embed": pooler_for_embed(pooler_config),
|
||||
"token_embed": BOSEOSFilter(
|
||||
pooler_for_token_embed(pooler_config, self.colbert_linear),
|
||||
self.bos_token_id,
|
||||
# for some reason m3 only filters the bos for colbert vectors
|
||||
),
|
||||
"token_classify": BOSEOSFilter(
|
||||
pooler_for_token_classify(
|
||||
pooler_config,
|
||||
pooling=AllPool(),
|
||||
classifier=self.sparse_linear,
|
||||
act_fn=torch.relu,
|
||||
),
|
||||
self.bos_token_id,
|
||||
self.eos_token_id,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
secondary, weights = filter_secondary_weights(
|
||||
all_weights, self.secondary_weight_prefixes
|
||||
)
|
||||
|
||||
super().load_weights(weights)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in secondary:
|
||||
if any(
|
||||
name.startswith(prefix) for prefix in self.secondary_weight_prefixes
|
||||
):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
Reference in New Issue
Block a user