diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 9c5069ba7..60db341b8 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -305,6 +305,44 @@ Expected output: An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py) +## Specific models + +### BAAI/bge-m3 + +The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json` +the architecture is declared as `XLMRobertaModel`, which makes `vLLM` load it as a vanilla ROBERTA model without the +extra weights. To load the full model weights, override its architecture like this: + +```shell +vllm serve BAAI/bge-m3 --hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}' +``` + +Then you obtain the sparse embeddings like this: + +```shell +curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{ + "model": "BAAI/bge-m3", + "task": "token_classify", + "input": ["What is BGE M3?", "Defination of BM25"] +}' +``` + +Due to limitations in the the output schema, the output consists of a list of +token scores for each token for each input. This means that you'll have to call +`/tokenize` as well to be able to pair tokens with scores. +Refer to the tests in `tests/models/language/pooling/test_bge_m3.py` to see how +to do that. + +You can obtain the colbert embeddings like this: + +```shell +curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{ + "model": "BAAI/bge-m3", + "task": "token_embed", + "input": ["What is BGE M3?", "Defination of BM25"] +}' +``` + ## Deprecated Features ### Encode task diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py index 4ac40656b..3b818aef9 100644 --- a/tests/models/language/pooling/embed_utils.py +++ b/tests/models/language/pooling/embed_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence +import openai import pytest from tests.conftest import HfRunner @@ -65,3 +66,16 @@ def correctness_test_embed_models( hf_model_callback(hf_model) run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) + + +async def run_client_embeddings( + client: openai.AsyncOpenAI, + model_name: str, + queries: list[str], + instruction: str = "", +) -> list[list[float]]: + outputs = await client.embeddings.create( + model=model_name, + input=[instruction + q for q in queries], + ) + return [data.embedding for data in outputs.data] diff --git a/tests/models/language/pooling/test_bge_m3.py b/tests/models/language/pooling/test_bge_m3.py new file mode 100644 index 000000000..5ad1fee03 --- /dev/null +++ b/tests/models/language/pooling/test_bge_m3.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import httpx +import openai +import pytest +import pytest_asyncio +import torch + +from ....utils import RemoteOpenAIServer +from .embed_utils import run_client_embeddings + +MODEL_NAME = "BAAI/bge-m3" +MAX_MODEL_LEN = 512 + + +# Example from https://huggingface.co/BAAI/bge-m3 +sentences_1 = ["What is BGE M3?", "Defination of BM25"] +sentences_2 = [ + "BGE M3 is an embedding model supporting dense retrieval, " + "lexical matching and multi-vector interaction.", + "BM25 is a bag-of-words retrieval function that ranks a set " + "of documents based on the query terms appearing in each document", +] + +similarity_reference = [[0.6265, 0.3477], [0.3499, 0.678]] +lexical_score_reference = [0.19554901123046875, 0.0] +colbert_score_reference = [0.7797, 0.4620] + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + str(MAX_MODEL_LEN), + "--hf-overrides", + '{"architectures": ["BgeM3EmbeddingModel"]}', + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI): + embeddings_list_1 = await run_client_embeddings( + client, + MODEL_NAME, + sentences_1, + ) + embeddings_list_2 = await run_client_embeddings( + client, + MODEL_NAME, + sentences_2, + ) + + embeddings_1 = torch.tensor(embeddings_list_1) + embeddings_2 = torch.tensor(embeddings_list_2) + similarity = embeddings_1 @ embeddings_2.T + + # reference values from BAAI/bge-m3 documentation + reference = torch.tensor(similarity_reference) + + assert torch.allclose(similarity, reference, rtol=0.01) + + +async def tokenize(client: openai.AsyncOpenAI, sentences: list[str]) -> list[list[int]]: + futures = [] + for sentence in sentences: + futures.append( + client.post( + "../tokenize", + body={"model": MODEL_NAME, "prompt": sentence}, + cast_to=httpx.Response, + ) + ) + return [(await future).json()["tokens"] for future in futures] + + +async def sparse_embeddings( + client: openai.AsyncOpenAI, sentences: list[str] +) -> list[dict[int, float]]: + all_tokens = await tokenize(client, sentences) + result = await client.post( + "../pooling", + body={"model": MODEL_NAME, "input": sentences, "task": "token_classify"}, + cast_to=httpx.Response, + ) + all_embeddings = [data["data"] for data in result.json()["data"]] + + ret = [] + + for sent_tokens, sent_emb in zip(all_tokens, all_embeddings): + token_embs = dict[int, float]() + if sent_tokens[0] == 0: + sent_tokens = sent_tokens[1:] + for token, val in zip(sent_tokens, sent_emb): + token_embs[token] = max(val, token_embs.get(token, 0.0)) + ret.append(token_embs) + return ret + + +# Based on https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L129 +def compute_lexical_matching_score( + lw1: dict[int, float], lw2: dict[int, float] +) -> float: + scores = 0.0 + for token, weight in lw1.items(): + if token in lw2: + scores += weight * lw2[token] + return scores + + +@pytest.mark.asyncio +async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI): + embeddings_1 = await sparse_embeddings(client, sentences_1) + embeddings_2 = await sparse_embeddings(client, sentences_2) + + lexical_scores_1_0_x_2_0 = compute_lexical_matching_score( + embeddings_1[0], embeddings_2[0] + ) + assert lexical_scores_1_0_x_2_0 == pytest.approx( + lexical_score_reference[0], rel=0.01 + ) + + lexical_scores_1_0_x_1_1 = compute_lexical_matching_score( + embeddings_1[0], embeddings_1[1] + ) + assert lexical_scores_1_0_x_1_1 == pytest.approx( + lexical_score_reference[1], rel=0.01 + ) + + +# https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L163 +def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor: + token_scores = torch.einsum("in,jn->ij", q_reps, p_reps) + scores, _ = token_scores.max(-1) + scores = torch.sum(scores) / q_reps.size(0) + return scores + + +@pytest.mark.asyncio +async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI): + result_1 = await client.post( + "../pooling", + body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"}, + cast_to=httpx.Response, + ) + embeddings_1 = [torch.tensor(data["data"]) for data in result_1.json()["data"]] + + result_2 = await client.post( + "../pooling", + body={"model": MODEL_NAME, "input": sentences_2, "task": "token_embed"}, + cast_to=httpx.Response, + ) + embeddings_2 = [torch.tensor(data["data"]) for data in result_2.json()["data"]] + + colbert_score_1_0_x_2_0 = colbert_score(embeddings_1[0], embeddings_2[0]) + assert colbert_score_1_0_x_2_0 == pytest.approx( + colbert_score_reference[0], rel=0.01 + ) + colbert_score_1_0_x_2_1 = colbert_score(embeddings_1[0], embeddings_2[1]) + assert colbert_score_1_0_x_2_1 == pytest.approx( + colbert_score_reference[1], rel=0.01 + ) diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index 0adc9b5cf..5ff5073e8 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np -import openai import pytest from scipy.spatial.distance import cosine @@ -9,6 +8,7 @@ from vllm import LLM, SamplingParams from vllm.config import ModelConfig from ....utils import RemoteOpenAIServer +from .embed_utils import run_client_embeddings MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 @@ -55,18 +55,6 @@ def run_llm_encode( return [output.outputs.embedding for output in outputs] -async def run_client_embeddings( - client: openai.AsyncOpenAI, - queries: list[str], - instruction: str, -) -> list[list[float]]: - outputs = await client.embeddings.create( - model=MODEL_NAME, - input=[instruction + q for q in queries], - ) - return [data.embedding for data in outputs.data] - - def gritlm_instruction(instruction): return ( "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" @@ -145,11 +133,13 @@ async def test_gritlm_api_server_embedding(): d_rep = await run_client_embeddings( client_embedding, + MODEL_NAME, documents, d_instruction, ) q_rep = await run_client_embeddings( client_embedding, + MODEL_NAME, queries, q_instruction, ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 2af3c1506..5c6db71b1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -513,6 +513,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), + "BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), diff --git a/vllm/model_executor/layers/pooler/special.py b/vllm/model_executor/layers/pooler/special.py index 425f61a98..707e7c907 100644 --- a/vllm/model_executor/layers/pooler/special.py +++ b/vllm/model_executor/layers/pooler/special.py @@ -125,4 +125,49 @@ class IdentityPooler(Pooler): return hidden_states -__all__ = ["DispatchPooler", "IdentityPooler"] +class BOSEOSFilter(Pooler): + """Filters the BOS and EOS token results from outputs.""" + + def __init__( + self, + pooler: Pooler, + bos_token_id: int = -1, # -1 disables the filtering + eos_token_id: int = -1, + ) -> None: + super().__init__() + + self.pooler = pooler + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooler.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor | list[torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_outputs = self.pooler(hidden_states, pooling_metadata) + assert isinstance(pooled_outputs, list) + + for i, prompt_len in enumerate(pooling_metadata.prompt_lens): + pooled_data = pooled_outputs[i] + assert ( + isinstance(pooled_data, torch.Tensor) + and pooled_data.shape[0] == prompt_len + ) + token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len] + if token_ids[0] == self.bos_token_id: + pooled_data = pooled_data[1:] + if token_ids[-1] == self.eos_token_id: + pooled_data = pooled_data[:-1] + pooled_outputs[i] = pooled_data.squeeze() + + return pooled_outputs + + +__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"] diff --git a/vllm/model_executor/layers/pooler/tokwise/poolers.py b/vllm/model_executor/layers/pooler/tokwise/poolers.py index 20790eff6..996f20d98 100644 --- a/vllm/model_executor/layers/pooler/tokwise/poolers.py +++ b/vllm/model_executor/layers/pooler/tokwise/poolers.py @@ -6,7 +6,11 @@ from typing import TypeAlias import torch from vllm.config import PoolerConfig, get_current_vllm_config -from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate +from vllm.model_executor.layers.pooler import ( + ClassifierFn, + PoolingParamsUpdate, + ProjectorFn, +) from vllm.model_executor.layers.pooler.abstract import Pooler from vllm.model_executor.layers.pooler.activations import ( PoolerActivation, @@ -89,14 +93,18 @@ class TokenPooler(Pooler): return pooled_data -def pooler_for_token_embed(pooler_config: PoolerConfig): +def pooler_for_token_embed( + pooler_config: PoolerConfig, projector: ProjectorFn | None = None +) -> TokenPooler: pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) vllm_config = get_current_vllm_config() model_config = vllm_config.model_config head = TokenEmbeddingPoolerHead( head_dtype=model_config.head_dtype, - projector=_load_st_projector(model_config), + projector=projector + if projector is not None + else _load_st_projector(model_config), activation=PoolerNormalize(), ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5b1881dba..25b6e4025 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -234,6 +234,7 @@ _EMBEDDING_MODELS = { "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), + "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), # [Multimodal] "CLIPModel": ("clip", "CLIPEmbeddingModel"), "LlavaNextForConditionalGeneration": ( diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 7bf9a6882..5faa64654 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,15 +1,30 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools from collections.abc import Iterable import torch from torch import nn from transformers import RobertaConfig -from vllm.config import ModelConfig, VllmConfig -from vllm.model_executor.layers.pooler import DispatchPooler +from vllm.config import ModelConfig, PoolerConfig, VllmConfig +from vllm.model_executor.layers.pooler import ( + BOSEOSFilter, + DispatchPooler, + Pooler, +) +from vllm.model_executor.layers.pooler.seqwise import ( + pooler_for_embed, +) +from vllm.model_executor.layers.pooler.tokwise import ( + AllPool, + pooler_for_token_classify, + pooler_for_token_embed, +) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import ( TOKEN_TYPE_SHIFT, BertEmbeddingModel, @@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel): return loader.load_weights(weights_list, mapper=mapper) +def filter_secondary_weights( + all_weights: Iterable[tuple[str, torch.Tensor]], + secondary_weights: list[str], +) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]: + all_weights1, all_weights2 = itertools.tee(all_weights) + + def filtered(n): + return any(n.startswith(f) for f in secondary_weights) + + return ((n, w) for n, w in all_weights1 if filtered(n)), ( + (n, w) for n, w in all_weights2 if not filtered(n) + ) + + +class BgeM3EmbeddingModel(RobertaEmbeddingModel): + """A model that extends RobertaEmbeddingModel with sparse embeddings. + + This class supports loading an additional sparse_linear.pt file + to create sparse embeddings as described in https://arxiv.org/abs/2402.03216 + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + self.hidden_size = vllm_config.model_config.hf_config.hidden_size + + model_config = vllm_config.model_config + self.head_dtype = model_config.head_dtype + self.bos_token_id = model_config.hf_config.bos_token_id + self.eos_token_id = model_config.hf_config.eos_token_id + + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."] + self.secondary_weight_files = [ + prefix + "pt" for prefix in self.secondary_weight_prefixes + ] + + self.secondary_weights = [ + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + revision=None, + prefix=prefix, + allow_patterns_overrides=[filename], + ) + for filename, prefix in zip( + self.secondary_weight_files, self.secondary_weight_prefixes + ) + ] + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype) + self.colbert_linear = nn.Linear( + self.hidden_size, self.hidden_size, dtype=self.head_dtype + ) + + return DispatchPooler( + { + "embed": pooler_for_embed(pooler_config), + "token_embed": BOSEOSFilter( + pooler_for_token_embed(pooler_config, self.colbert_linear), + self.bos_token_id, + # for some reason m3 only filters the bos for colbert vectors + ), + "token_classify": BOSEOSFilter( + pooler_for_token_classify( + pooler_config, + pooling=AllPool(), + classifier=self.sparse_linear, + act_fn=torch.relu, + ), + self.bos_token_id, + self.eos_token_id, + ), + } + ) + + def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]): + secondary, weights = filter_secondary_weights( + all_weights, self.secondary_weight_prefixes + ) + + super().load_weights(weights) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in secondary: + if any( + name.startswith(prefix) for prefix in self.secondary_weight_prefixes + ): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + @default_pooling_type(seq_pooling_type="CLS") class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): """A model that uses Roberta to provide embedding functionalities.