Support bge-m3 sparse embeddings and colbert embeddings (#14526)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
committed by
GitHub
parent
444e2e7e1f
commit
ff365eea94
@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = {
|
||||
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"LlavaNextForConditionalGeneration": (
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
BOSEOSFilter,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
pooler_for_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
AllPool,
|
||||
pooler_for_token_classify,
|
||||
pooler_for_token_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bert import (
|
||||
TOKEN_TYPE_SHIFT,
|
||||
BertEmbeddingModel,
|
||||
@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
def filter_secondary_weights(
|
||||
all_weights: Iterable[tuple[str, torch.Tensor]],
|
||||
secondary_weights: list[str],
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def filtered(n):
|
||||
return any(n.startswith(f) for f in secondary_weights)
|
||||
|
||||
return ((n, w) for n, w in all_weights1 if filtered(n)), (
|
||||
(n, w) for n, w in all_weights2 if not filtered(n)
|
||||
)
|
||||
|
||||
|
||||
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
|
||||
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
|
||||
|
||||
This class supports loading an additional sparse_linear.pt file
|
||||
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.head_dtype = model_config.head_dtype
|
||||
self.bos_token_id = model_config.hf_config.bos_token_id
|
||||
self.eos_token_id = model_config.hf_config.eos_token_id
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
|
||||
self.secondary_weight_files = [
|
||||
prefix + "pt" for prefix in self.secondary_weight_prefixes
|
||||
]
|
||||
|
||||
self.secondary_weights = [
|
||||
DefaultModelLoader.Source(
|
||||
model_or_path=vllm_config.model_config.model,
|
||||
revision=None,
|
||||
prefix=prefix,
|
||||
allow_patterns_overrides=[filename],
|
||||
)
|
||||
for filename, prefix in zip(
|
||||
self.secondary_weight_files, self.secondary_weight_prefixes
|
||||
)
|
||||
]
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
|
||||
self.colbert_linear = nn.Linear(
|
||||
self.hidden_size, self.hidden_size, dtype=self.head_dtype
|
||||
)
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"embed": pooler_for_embed(pooler_config),
|
||||
"token_embed": BOSEOSFilter(
|
||||
pooler_for_token_embed(pooler_config, self.colbert_linear),
|
||||
self.bos_token_id,
|
||||
# for some reason m3 only filters the bos for colbert vectors
|
||||
),
|
||||
"token_classify": BOSEOSFilter(
|
||||
pooler_for_token_classify(
|
||||
pooler_config,
|
||||
pooling=AllPool(),
|
||||
classifier=self.sparse_linear,
|
||||
act_fn=torch.relu,
|
||||
),
|
||||
self.bos_token_id,
|
||||
self.eos_token_id,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
secondary, weights = filter_secondary_weights(
|
||||
all_weights, self.secondary_weight_prefixes
|
||||
)
|
||||
|
||||
super().load_weights(weights)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in secondary:
|
||||
if any(
|
||||
name.startswith(prefix) for prefix in self.secondary_weight_prefixes
|
||||
):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
Reference in New Issue
Block a user