Support bge-m3 sparse embeddings and colbert embeddings (#14526)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
This commit is contained in:
committed by
GitHub
parent
444e2e7e1f
commit
ff365eea94
@@ -305,6 +305,44 @@ Expected output:
|
||||
|
||||
An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py)
|
||||
|
||||
## Specific models
|
||||
|
||||
### BAAI/bge-m3
|
||||
|
||||
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
|
||||
the architecture is declared as `XLMRobertaModel`, which makes `vLLM` load it as a vanilla ROBERTA model without the
|
||||
extra weights. To load the full model weights, override its architecture like this:
|
||||
|
||||
```shell
|
||||
vllm serve BAAI/bge-m3 --hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}'
|
||||
```
|
||||
|
||||
Then you obtain the sparse embeddings like this:
|
||||
|
||||
```shell
|
||||
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
|
||||
"model": "BAAI/bge-m3",
|
||||
"task": "token_classify",
|
||||
"input": ["What is BGE M3?", "Defination of BM25"]
|
||||
}'
|
||||
```
|
||||
|
||||
Due to limitations in the the output schema, the output consists of a list of
|
||||
token scores for each token for each input. This means that you'll have to call
|
||||
`/tokenize` as well to be able to pair tokens with scores.
|
||||
Refer to the tests in `tests/models/language/pooling/test_bge_m3.py` to see how
|
||||
to do that.
|
||||
|
||||
You can obtain the colbert embeddings like this:
|
||||
|
||||
```shell
|
||||
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
|
||||
"model": "BAAI/bge-m3",
|
||||
"task": "token_embed",
|
||||
"input": ["What is BGE M3?", "Defination of BM25"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Deprecated Features
|
||||
|
||||
### Encode task
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
@@ -65,3 +66,16 @@ def correctness_test_embed_models(
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
|
||||
|
||||
|
||||
async def run_client_embeddings(
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
queries: list[str],
|
||||
instruction: str = "",
|
||||
) -> list[list[float]]:
|
||||
outputs = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=[instruction + q for q in queries],
|
||||
)
|
||||
return [data.embedding for data in outputs.data]
|
||||
|
||||
170
tests/models/language/pooling/test_bge_m3.py
Normal file
170
tests/models/language/pooling/test_bge_m3.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from .embed_utils import run_client_embeddings
|
||||
|
||||
MODEL_NAME = "BAAI/bge-m3"
|
||||
MAX_MODEL_LEN = 512
|
||||
|
||||
|
||||
# Example from https://huggingface.co/BAAI/bge-m3
|
||||
sentences_1 = ["What is BGE M3?", "Defination of BM25"]
|
||||
sentences_2 = [
|
||||
"BGE M3 is an embedding model supporting dense retrieval, "
|
||||
"lexical matching and multi-vector interaction.",
|
||||
"BM25 is a bag-of-words retrieval function that ranks a set "
|
||||
"of documents based on the query terms appearing in each document",
|
||||
]
|
||||
|
||||
similarity_reference = [[0.6265, 0.3477], [0.3499, 0.678]]
|
||||
lexical_score_reference = [0.19554901123046875, 0.0]
|
||||
colbert_score_reference = [0.7797, 0.4620]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
str(MAX_MODEL_LEN),
|
||||
"--hf-overrides",
|
||||
'{"architectures": ["BgeM3EmbeddingModel"]}',
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI):
|
||||
embeddings_list_1 = await run_client_embeddings(
|
||||
client,
|
||||
MODEL_NAME,
|
||||
sentences_1,
|
||||
)
|
||||
embeddings_list_2 = await run_client_embeddings(
|
||||
client,
|
||||
MODEL_NAME,
|
||||
sentences_2,
|
||||
)
|
||||
|
||||
embeddings_1 = torch.tensor(embeddings_list_1)
|
||||
embeddings_2 = torch.tensor(embeddings_list_2)
|
||||
similarity = embeddings_1 @ embeddings_2.T
|
||||
|
||||
# reference values from BAAI/bge-m3 documentation
|
||||
reference = torch.tensor(similarity_reference)
|
||||
|
||||
assert torch.allclose(similarity, reference, rtol=0.01)
|
||||
|
||||
|
||||
async def tokenize(client: openai.AsyncOpenAI, sentences: list[str]) -> list[list[int]]:
|
||||
futures = []
|
||||
for sentence in sentences:
|
||||
futures.append(
|
||||
client.post(
|
||||
"../tokenize",
|
||||
body={"model": MODEL_NAME, "prompt": sentence},
|
||||
cast_to=httpx.Response,
|
||||
)
|
||||
)
|
||||
return [(await future).json()["tokens"] for future in futures]
|
||||
|
||||
|
||||
async def sparse_embeddings(
|
||||
client: openai.AsyncOpenAI, sentences: list[str]
|
||||
) -> list[dict[int, float]]:
|
||||
all_tokens = await tokenize(client, sentences)
|
||||
result = await client.post(
|
||||
"../pooling",
|
||||
body={"model": MODEL_NAME, "input": sentences, "task": "token_classify"},
|
||||
cast_to=httpx.Response,
|
||||
)
|
||||
all_embeddings = [data["data"] for data in result.json()["data"]]
|
||||
|
||||
ret = []
|
||||
|
||||
for sent_tokens, sent_emb in zip(all_tokens, all_embeddings):
|
||||
token_embs = dict[int, float]()
|
||||
if sent_tokens[0] == 0:
|
||||
sent_tokens = sent_tokens[1:]
|
||||
for token, val in zip(sent_tokens, sent_emb):
|
||||
token_embs[token] = max(val, token_embs.get(token, 0.0))
|
||||
ret.append(token_embs)
|
||||
return ret
|
||||
|
||||
|
||||
# Based on https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L129
|
||||
def compute_lexical_matching_score(
|
||||
lw1: dict[int, float], lw2: dict[int, float]
|
||||
) -> float:
|
||||
scores = 0.0
|
||||
for token, weight in lw1.items():
|
||||
if token in lw2:
|
||||
scores += weight * lw2[token]
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI):
|
||||
embeddings_1 = await sparse_embeddings(client, sentences_1)
|
||||
embeddings_2 = await sparse_embeddings(client, sentences_2)
|
||||
|
||||
lexical_scores_1_0_x_2_0 = compute_lexical_matching_score(
|
||||
embeddings_1[0], embeddings_2[0]
|
||||
)
|
||||
assert lexical_scores_1_0_x_2_0 == pytest.approx(
|
||||
lexical_score_reference[0], rel=0.01
|
||||
)
|
||||
|
||||
lexical_scores_1_0_x_1_1 = compute_lexical_matching_score(
|
||||
embeddings_1[0], embeddings_1[1]
|
||||
)
|
||||
assert lexical_scores_1_0_x_1_1 == pytest.approx(
|
||||
lexical_score_reference[1], rel=0.01
|
||||
)
|
||||
|
||||
|
||||
# https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L163
|
||||
def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor:
|
||||
token_scores = torch.einsum("in,jn->ij", q_reps, p_reps)
|
||||
scores, _ = token_scores.max(-1)
|
||||
scores = torch.sum(scores) / q_reps.size(0)
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI):
|
||||
result_1 = await client.post(
|
||||
"../pooling",
|
||||
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
|
||||
cast_to=httpx.Response,
|
||||
)
|
||||
embeddings_1 = [torch.tensor(data["data"]) for data in result_1.json()["data"]]
|
||||
|
||||
result_2 = await client.post(
|
||||
"../pooling",
|
||||
body={"model": MODEL_NAME, "input": sentences_2, "task": "token_embed"},
|
||||
cast_to=httpx.Response,
|
||||
)
|
||||
embeddings_2 = [torch.tensor(data["data"]) for data in result_2.json()["data"]]
|
||||
|
||||
colbert_score_1_0_x_2_0 = colbert_score(embeddings_1[0], embeddings_2[0])
|
||||
assert colbert_score_1_0_x_2_0 == pytest.approx(
|
||||
colbert_score_reference[0], rel=0.01
|
||||
)
|
||||
colbert_score_1_0_x_2_1 = colbert_score(embeddings_1[0], embeddings_2[1])
|
||||
assert colbert_score_1_0_x_2_1 == pytest.approx(
|
||||
colbert_score_reference[1], rel=0.01
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
from scipy.spatial.distance import cosine
|
||||
|
||||
@@ -9,6 +8,7 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from .embed_utils import run_client_embeddings
|
||||
|
||||
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
||||
MAX_MODEL_LEN = 4000
|
||||
@@ -55,18 +55,6 @@ def run_llm_encode(
|
||||
return [output.outputs.embedding for output in outputs]
|
||||
|
||||
|
||||
async def run_client_embeddings(
|
||||
client: openai.AsyncOpenAI,
|
||||
queries: list[str],
|
||||
instruction: str,
|
||||
) -> list[list[float]]:
|
||||
outputs = await client.embeddings.create(
|
||||
model=MODEL_NAME,
|
||||
input=[instruction + q for q in queries],
|
||||
)
|
||||
return [data.embedding for data in outputs.data]
|
||||
|
||||
|
||||
def gritlm_instruction(instruction):
|
||||
return (
|
||||
"<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
|
||||
@@ -145,11 +133,13 @@ async def test_gritlm_api_server_embedding():
|
||||
|
||||
d_rep = await run_client_embeddings(
|
||||
client_embedding,
|
||||
MODEL_NAME,
|
||||
documents,
|
||||
d_instruction,
|
||||
)
|
||||
q_rep = await run_client_embeddings(
|
||||
client_embedding,
|
||||
MODEL_NAME,
|
||||
queries,
|
||||
q_instruction,
|
||||
)
|
||||
|
||||
@@ -513,6 +513,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
|
||||
@@ -125,4 +125,49 @@ class IdentityPooler(Pooler):
|
||||
return hidden_states
|
||||
|
||||
|
||||
__all__ = ["DispatchPooler", "IdentityPooler"]
|
||||
class BOSEOSFilter(Pooler):
|
||||
"""Filters the BOS and EOS token results from outputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooler: Pooler,
|
||||
bos_token_id: int = -1, # -1 disables the filtering
|
||||
eos_token_id: int = -1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooler = pooler
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooler.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor | list[torch.Tensor],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
|
||||
assert isinstance(pooled_outputs, list)
|
||||
|
||||
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
||||
pooled_data = pooled_outputs[i]
|
||||
assert (
|
||||
isinstance(pooled_data, torch.Tensor)
|
||||
and pooled_data.shape[0] == prompt_len
|
||||
)
|
||||
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
|
||||
if token_ids[0] == self.bos_token_id:
|
||||
pooled_data = pooled_data[1:]
|
||||
if token_ids[-1] == self.eos_token_id:
|
||||
pooled_data = pooled_data[:-1]
|
||||
pooled_outputs[i] = pooled_data.squeeze()
|
||||
|
||||
return pooled_outputs
|
||||
|
||||
|
||||
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
|
||||
|
||||
@@ -6,7 +6,11 @@ from typing import TypeAlias
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierFn,
|
||||
PoolingParamsUpdate,
|
||||
ProjectorFn,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
def pooler_for_token_embed(
|
||||
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
|
||||
) -> TokenPooler:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenEmbeddingPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
projector=_load_st_projector(model_config),
|
||||
projector=projector
|
||||
if projector is not None
|
||||
else _load_st_projector(model_config),
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
|
||||
@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = {
|
||||
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
# [Multimodal]
|
||||
"CLIPModel": ("clip", "CLIPEmbeddingModel"),
|
||||
"LlavaNextForConditionalGeneration": (
|
||||
|
||||
@@ -1,15 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.config import ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
BOSEOSFilter,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
pooler_for_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
AllPool,
|
||||
pooler_for_token_classify,
|
||||
pooler_for_token_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bert import (
|
||||
TOKEN_TYPE_SHIFT,
|
||||
BertEmbeddingModel,
|
||||
@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
def filter_secondary_weights(
|
||||
all_weights: Iterable[tuple[str, torch.Tensor]],
|
||||
secondary_weights: list[str],
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
|
||||
all_weights1, all_weights2 = itertools.tee(all_weights)
|
||||
|
||||
def filtered(n):
|
||||
return any(n.startswith(f) for f in secondary_weights)
|
||||
|
||||
return ((n, w) for n, w in all_weights1 if filtered(n)), (
|
||||
(n, w) for n, w in all_weights2 if not filtered(n)
|
||||
)
|
||||
|
||||
|
||||
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
|
||||
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
|
||||
|
||||
This class supports loading an additional sparse_linear.pt file
|
||||
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.head_dtype = model_config.head_dtype
|
||||
self.bos_token_id = model_config.hf_config.bos_token_id
|
||||
self.eos_token_id = model_config.hf_config.eos_token_id
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
|
||||
self.secondary_weight_files = [
|
||||
prefix + "pt" for prefix in self.secondary_weight_prefixes
|
||||
]
|
||||
|
||||
self.secondary_weights = [
|
||||
DefaultModelLoader.Source(
|
||||
model_or_path=vllm_config.model_config.model,
|
||||
revision=None,
|
||||
prefix=prefix,
|
||||
allow_patterns_overrides=[filename],
|
||||
)
|
||||
for filename, prefix in zip(
|
||||
self.secondary_weight_files, self.secondary_weight_prefixes
|
||||
)
|
||||
]
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
|
||||
self.colbert_linear = nn.Linear(
|
||||
self.hidden_size, self.hidden_size, dtype=self.head_dtype
|
||||
)
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"embed": pooler_for_embed(pooler_config),
|
||||
"token_embed": BOSEOSFilter(
|
||||
pooler_for_token_embed(pooler_config, self.colbert_linear),
|
||||
self.bos_token_id,
|
||||
# for some reason m3 only filters the bos for colbert vectors
|
||||
),
|
||||
"token_classify": BOSEOSFilter(
|
||||
pooler_for_token_classify(
|
||||
pooler_config,
|
||||
pooling=AllPool(),
|
||||
classifier=self.sparse_linear,
|
||||
act_fn=torch.relu,
|
||||
),
|
||||
self.bos_token_id,
|
||||
self.eos_token_id,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
secondary, weights = filter_secondary_weights(
|
||||
all_weights, self.secondary_weight_prefixes
|
||||
)
|
||||
|
||||
super().load_weights(weights)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in secondary:
|
||||
if any(
|
||||
name.startswith(prefix) for prefix in self.secondary_weight_prefixes
|
||||
):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
"""A model that uses Roberta to provide embedding functionalities.
|
||||
|
||||
Reference in New Issue
Block a user