Support bge-m3 sparse embeddings and colbert embeddings (#14526)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
Maximilien de Bayser
2026-01-22 12:52:57 -03:00
committed by GitHub
parent 444e2e7e1f
commit ff365eea94
9 changed files with 393 additions and 19 deletions

View File

@@ -305,6 +305,44 @@ Expected output:
An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py)
## Specific models
### BAAI/bge-m3
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
the architecture is declared as `XLMRobertaModel`, which makes `vLLM` load it as a vanilla ROBERTA model without the
extra weights. To load the full model weights, override its architecture like this:
```shell
vllm serve BAAI/bge-m3 --hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}'
```
Then you obtain the sparse embeddings like this:
```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
"model": "BAAI/bge-m3",
"task": "token_classify",
"input": ["What is BGE M3?", "Defination of BM25"]
}'
```
Due to limitations in the the output schema, the output consists of a list of
token scores for each token for each input. This means that you'll have to call
`/tokenize` as well to be able to pair tokens with scores.
Refer to the tests in `tests/models/language/pooling/test_bge_m3.py` to see how
to do that.
You can obtain the colbert embeddings like this:
```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
"model": "BAAI/bge-m3",
"task": "token_embed",
"input": ["What is BGE M3?", "Defination of BM25"]
}'
```
## Deprecated Features
### Encode task

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import openai
import pytest
from tests.conftest import HfRunner
@@ -65,3 +66,16 @@ def correctness_test_embed_models(
hf_model_callback(hf_model)
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
async def run_client_embeddings(
client: openai.AsyncOpenAI,
model_name: str,
queries: list[str],
instruction: str = "",
) -> list[list[float]]:
outputs = await client.embeddings.create(
model=model_name,
input=[instruction + q for q in queries],
)
return [data.embedding for data in outputs.data]

View File

@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import httpx
import openai
import pytest
import pytest_asyncio
import torch
from ....utils import RemoteOpenAIServer
from .embed_utils import run_client_embeddings
MODEL_NAME = "BAAI/bge-m3"
MAX_MODEL_LEN = 512
# Example from https://huggingface.co/BAAI/bge-m3
sentences_1 = ["What is BGE M3?", "Defination of BM25"]
sentences_2 = [
"BGE M3 is an embedding model supporting dense retrieval, "
"lexical matching and multi-vector interaction.",
"BM25 is a bag-of-words retrieval function that ranks a set "
"of documents based on the query terms appearing in each document",
]
similarity_reference = [[0.6265, 0.3477], [0.3499, 0.678]]
lexical_score_reference = [0.19554901123046875, 0.0]
colbert_score_reference = [0.7797, 0.4620]
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
"--hf-overrides",
'{"architectures": ["BgeM3EmbeddingModel"]}',
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI):
embeddings_list_1 = await run_client_embeddings(
client,
MODEL_NAME,
sentences_1,
)
embeddings_list_2 = await run_client_embeddings(
client,
MODEL_NAME,
sentences_2,
)
embeddings_1 = torch.tensor(embeddings_list_1)
embeddings_2 = torch.tensor(embeddings_list_2)
similarity = embeddings_1 @ embeddings_2.T
# reference values from BAAI/bge-m3 documentation
reference = torch.tensor(similarity_reference)
assert torch.allclose(similarity, reference, rtol=0.01)
async def tokenize(client: openai.AsyncOpenAI, sentences: list[str]) -> list[list[int]]:
futures = []
for sentence in sentences:
futures.append(
client.post(
"../tokenize",
body={"model": MODEL_NAME, "prompt": sentence},
cast_to=httpx.Response,
)
)
return [(await future).json()["tokens"] for future in futures]
async def sparse_embeddings(
client: openai.AsyncOpenAI, sentences: list[str]
) -> list[dict[int, float]]:
all_tokens = await tokenize(client, sentences)
result = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences, "task": "token_classify"},
cast_to=httpx.Response,
)
all_embeddings = [data["data"] for data in result.json()["data"]]
ret = []
for sent_tokens, sent_emb in zip(all_tokens, all_embeddings):
token_embs = dict[int, float]()
if sent_tokens[0] == 0:
sent_tokens = sent_tokens[1:]
for token, val in zip(sent_tokens, sent_emb):
token_embs[token] = max(val, token_embs.get(token, 0.0))
ret.append(token_embs)
return ret
# Based on https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L129
def compute_lexical_matching_score(
lw1: dict[int, float], lw2: dict[int, float]
) -> float:
scores = 0.0
for token, weight in lw1.items():
if token in lw2:
scores += weight * lw2[token]
return scores
@pytest.mark.asyncio
async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI):
embeddings_1 = await sparse_embeddings(client, sentences_1)
embeddings_2 = await sparse_embeddings(client, sentences_2)
lexical_scores_1_0_x_2_0 = compute_lexical_matching_score(
embeddings_1[0], embeddings_2[0]
)
assert lexical_scores_1_0_x_2_0 == pytest.approx(
lexical_score_reference[0], rel=0.01
)
lexical_scores_1_0_x_1_1 = compute_lexical_matching_score(
embeddings_1[0], embeddings_1[1]
)
assert lexical_scores_1_0_x_1_1 == pytest.approx(
lexical_score_reference[1], rel=0.01
)
# https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L163
def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor:
token_scores = torch.einsum("in,jn->ij", q_reps, p_reps)
scores, _ = token_scores.max(-1)
scores = torch.sum(scores) / q_reps.size(0)
return scores
@pytest.mark.asyncio
async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI):
result_1 = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
cast_to=httpx.Response,
)
embeddings_1 = [torch.tensor(data["data"]) for data in result_1.json()["data"]]
result_2 = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences_2, "task": "token_embed"},
cast_to=httpx.Response,
)
embeddings_2 = [torch.tensor(data["data"]) for data in result_2.json()["data"]]
colbert_score_1_0_x_2_0 = colbert_score(embeddings_1[0], embeddings_2[0])
assert colbert_score_1_0_x_2_0 == pytest.approx(
colbert_score_reference[0], rel=0.01
)
colbert_score_1_0_x_2_1 = colbert_score(embeddings_1[0], embeddings_2[1])
assert colbert_score_1_0_x_2_1 == pytest.approx(
colbert_score_reference[1], rel=0.01
)

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import openai
import pytest
from scipy.spatial.distance import cosine
@@ -9,6 +8,7 @@ from vllm import LLM, SamplingParams
from vllm.config import ModelConfig
from ....utils import RemoteOpenAIServer
from .embed_utils import run_client_embeddings
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000
@@ -55,18 +55,6 @@ def run_llm_encode(
return [output.outputs.embedding for output in outputs]
async def run_client_embeddings(
client: openai.AsyncOpenAI,
queries: list[str],
instruction: str,
) -> list[list[float]]:
outputs = await client.embeddings.create(
model=MODEL_NAME,
input=[instruction + q for q in queries],
)
return [data.embedding for data in outputs.data]
def gritlm_instruction(instruction):
return (
"<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
@@ -145,11 +133,13 @@ async def test_gritlm_api_server_embedding():
d_rep = await run_client_embeddings(
client_embedding,
MODEL_NAME,
documents,
d_instruction,
)
q_rep = await run_client_embeddings(
client_embedding,
MODEL_NAME,
queries,
q_instruction,
)

View File

@@ -513,6 +513,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = {
# [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),

View File

@@ -125,4 +125,49 @@ class IdentityPooler(Pooler):
return hidden_states
__all__ = ["DispatchPooler", "IdentityPooler"]
class BOSEOSFilter(Pooler):
"""Filters the BOS and EOS token results from outputs."""
def __init__(
self,
pooler: Pooler,
bos_token_id: int = -1, # -1 disables the filtering
eos_token_id: int = -1,
) -> None:
super().__init__()
self.pooler = pooler
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooler.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list)
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
pooled_data = pooled_outputs[i]
assert (
isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len
)
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze()
return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]

View File

@@ -6,7 +6,11 @@ from typing import TypeAlias
import torch
from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler import (
ClassifierFn,
PoolingParamsUpdate,
ProjectorFn,
)
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
return pooled_data
def pooler_for_token_embed(pooler_config: PoolerConfig):
def pooler_for_token_embed(
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
) -> TokenPooler:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead(
head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config),
projector=projector
if projector is not None
else _load_st_projector(model_config),
activation=PoolerNormalize(),
)

View File

@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = {
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
"LlavaNextForConditionalGeneration": (

View File

@@ -1,15 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
import torch
from torch import nn
from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
BOSEOSFilter,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler.seqwise import (
pooler_for_embed,
)
from vllm.model_executor.layers.pooler.tokwise import (
AllPool,
pooler_for_token_classify,
pooler_for_token_embed,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT,
BertEmbeddingModel,
@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper)
def filter_secondary_weights(
all_weights: Iterable[tuple[str, torch.Tensor]],
secondary_weights: list[str],
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
all_weights1, all_weights2 = itertools.tee(all_weights)
def filtered(n):
return any(n.startswith(f) for f in secondary_weights)
return ((n, w) for n, w in all_weights1 if filtered(n)), (
(n, w) for n, w in all_weights2 if not filtered(n)
)
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
This class supports loading an additional sparse_linear.pt file
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
model_config = vllm_config.model_config
self.head_dtype = model_config.head_dtype
self.bos_token_id = model_config.hf_config.bos_token_id
self.eos_token_id = model_config.hf_config.eos_token_id
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
self.secondary_weight_files = [
prefix + "pt" for prefix in self.secondary_weight_prefixes
]
self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
revision=None,
prefix=prefix,
allow_patterns_overrides=[filename],
)
for filename, prefix in zip(
self.secondary_weight_files, self.secondary_weight_prefixes
)
]
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
self.colbert_linear = nn.Linear(
self.hidden_size, self.hidden_size, dtype=self.head_dtype
)
return DispatchPooler(
{
"embed": pooler_for_embed(pooler_config),
"token_embed": BOSEOSFilter(
pooler_for_token_embed(pooler_config, self.colbert_linear),
self.bos_token_id,
# for some reason m3 only filters the bos for colbert vectors
),
"token_classify": BOSEOSFilter(
pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
),
}
)
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
secondary, weights = filter_secondary_weights(
all_weights, self.secondary_weight_prefixes
)
super().load_weights(weights)
params_dict = dict(self.named_parameters())
for name, loaded_weight in secondary:
if any(
name.startswith(prefix) for prefix in self.secondary_weight_prefixes
):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
@default_pooling_type(seq_pooling_type="CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities.