Support bge-m3 sparse embeddings and colbert embeddings (#14526)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
Maximilien de Bayser
2026-01-22 12:52:57 -03:00
committed by GitHub
parent 444e2e7e1f
commit ff365eea94
9 changed files with 393 additions and 19 deletions

View File

@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = {
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"LlavaNextForConditionalGeneration": (

View File

@@ -1,15 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
import torch
from torch import nn
from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
BOSEOSFilter,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler.seqwise import (
pooler_for_embed,
)
from vllm.model_executor.layers.pooler.tokwise import (
AllPool,
pooler_for_token_classify,
pooler_for_token_embed,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT,
BertEmbeddingModel,
@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
def filter_secondary_weights(
all_weights: Iterable[tuple[str, torch.Tensor]],
secondary_weights: list[str],
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
all_weights1, all_weights2 = itertools.tee(all_weights)
def filtered(n):
return any(n.startswith(f) for f in secondary_weights)
return ((n, w) for n, w in all_weights1 if filtered(n)), (
(n, w) for n, w in all_weights2 if not filtered(n)
)
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
This class supports loading an additional sparse_linear.pt file
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
model_config = vllm_config.model_config
self.head_dtype = model_config.head_dtype
self.bos_token_id = model_config.hf_config.bos_token_id
self.eos_token_id = model_config.hf_config.eos_token_id
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
self.secondary_weight_files = [
prefix + "pt" for prefix in self.secondary_weight_prefixes
]
self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
revision=None,
prefix=prefix,
allow_patterns_overrides=[filename],
)
for filename, prefix in zip(
self.secondary_weight_files, self.secondary_weight_prefixes
)
]
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
self.colbert_linear = nn.Linear(
self.hidden_size, self.hidden_size, dtype=self.head_dtype
)
return DispatchPooler(
{
"embed": pooler_for_embed(pooler_config),
"token_embed": BOSEOSFilter(
pooler_for_token_embed(pooler_config, self.colbert_linear),
self.bos_token_id,
# for some reason m3 only filters the bos for colbert vectors
),
"token_classify": BOSEOSFilter(
pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
),
}
)
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
secondary, weights = filter_secondary_weights(
all_weights, self.secondary_weight_prefixes
)
super().load_weights(weights)
params_dict = dict(self.named_parameters())
for name, loaded_weight in secondary:
if any(
name.startswith(prefix) for prefix in self.secondary_weight_prefixes
):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
@default_pooling_type(seq_pooling_type="CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.