diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index c1355fe49..0555eac41 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -307,6 +307,62 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed ## Specific models +### ColBERT Late Interaction Models + +[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders. + +vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance: + +```shell +vllm serve answerdotai/answerai-colbert-small-v1 +``` + +Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`). + +ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`). + +If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with: + +```shell +vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}' +``` + +Then you can use the rerank endpoint: + +```shell +curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{ + "model": "answerdotai/answerai-colbert-small-v1", + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of artificial intelligence.", + "Python is a programming language.", + "Deep learning uses neural networks." + ] +}' +``` + +Or the score endpoint: + +```shell +curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{ + "model": "answerdotai/answerai-colbert-small-v1", + "text_1": "What is machine learning?", + "text_2": ["Machine learning is a subset of AI.", "The weather is sunny."] +}' +``` + +You can also get the raw token embeddings using the pooling endpoint with `token_embed` task: + +```shell +curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{ + "model": "answerdotai/answerai-colbert-small-v1", + "input": "What is machine learning?", + "task": "token_embed" +}' +``` + +An example can be found here: [examples/pooling/score/colbert_rerank_online.py](../../examples/pooling/score/colbert_rerank_online.py) + ### BAAI/bge-m3 The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json` diff --git a/examples/pooling/score/colbert_rerank_online.py b/examples/pooling/score/colbert_rerank_online.py new file mode 100644 index 000000000..b9223e791 --- /dev/null +++ b/examples/pooling/score/colbert_rerank_online.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Example of using ColBERT late interaction model for reranking. + +ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings +and MaxSim scoring for document reranking, providing better accuracy than +single-vector models while being more efficient than cross-encoders. + +Start the server with: + vllm serve answerdotai/answerai-colbert-small-v1 + +Then run this script: + python colbert_rerank_online.py +""" + +import json + +import requests + +url = "http://127.0.0.1:8000/rerank" + +headers = {"accept": "application/json", "Content-Type": "application/json"} + +data = { + "model": "answerdotai/answerai-colbert-small-v1", + "query": "What is machine learning?", + "documents": [ + "Machine learning is a subset of artificial intelligence.", + "Python is a programming language.", + "Deep learning uses neural networks for complex tasks.", + "The weather today is sunny.", + ], +} + + +def main(): + response = requests.post(url, headers=headers, json=data) + + if response.status_code == 200: + print("ColBERT Rerank Request successful!") + result = response.json() + print(json.dumps(result, indent=2)) + + # Show ranked results + print("\nRanked documents (most relevant first):") + for item in result["results"]: + doc_idx = item["index"] + score = item["relevance_score"] + print(f" Score {score:.4f}: {data['documents'][doc_idx]}") + else: + print(f"Request failed with status code: {response.status_code}") + print(response.text) + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/pooling/score/test_online_colbert.py b/tests/entrypoints/pooling/score/test_online_colbert.py new file mode 100644 index 000000000..a7b404d0f --- /dev/null +++ b/tests/entrypoints/pooling/score/test_online_colbert.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Online API tests for ColBERT late interaction scoring.""" + +import pytest +import requests + +from tests.utils import RemoteOpenAIServer +from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse + +# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model +MODEL_NAME = "answerdotai/answerai-colbert-small-v1" +COLBERT_DIM = 96 # This model uses 96-dimensional output +DTYPE = "half" +MAX_MODEL_LEN = 512 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", + str(MAX_MODEL_LEN), + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str): + """Test ColBERT rerank endpoint.""" + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + ] + + rerank_response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + }, + ) + rerank_response.raise_for_status() + rerank = RerankResponse.model_validate(rerank_response.json()) + + assert rerank.id is not None + assert rerank.results is not None + assert len(rerank.results) == 2 + + # The relevant document (Paris) should have higher score + paris_result = next(r for r in rerank.results if r.index == 1) + brazil_result = next(r for r in rerank.results if r.index == 0) + + assert paris_result.relevance_score > brazil_result.relevance_score + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str): + """Test ColBERT rerank with top_n parameter.""" + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris.", + "Machine learning is a field of AI.", + ] + + rerank_response = requests.post( + server.url_for("rerank"), + json={ + "model": model_name, + "query": query, + "documents": documents, + "top_n": 2, + }, + ) + rerank_response.raise_for_status() + rerank = RerankResponse.model_validate(rerank_response.json()) + + assert len(rerank.results) == 2 + # Top result should be about Paris (index 1) + assert rerank.results[0].index == 1 + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_colbert_score(server: RemoteOpenAIServer, model_name: str): + """Test ColBERT score endpoint.""" + text_1 = "What is the capital of France?" + text_2 = ["The capital of France is Paris.", "Python is a language."] + + score_response = requests.post( + server.url_for("score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }, + ) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + + # The relevant document should have higher score + assert score.data[0].score > score.data[1].score + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str): + """Test ColBERT token_embed task via pooling endpoint.""" + text = "What is the capital of France?" + + pooling_response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": text, + "task": "token_embed", + }, + ) + pooling_response.raise_for_status() + pooling = pooling_response.json() + + assert "data" in pooling + assert len(pooling["data"]) == 1 + + # Token embeddings should be 2D + embeddings = pooling["data"][0]["data"] + assert isinstance(embeddings, list) + assert len(embeddings) > 0 # Should have tokens + assert len(embeddings[0]) == COLBERT_DIM + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str): + """Test that ColBERT model does not support 'embed' task.""" + text = "What is the capital of France?" + + pooling_response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "input": text, + "task": "embed", + }, + ) + + # Should return error + assert pooling_response.status_code == 400 + assert "Task embed is not supported" in pooling_response.text diff --git a/tests/models/language/pooling/test_colbert.py b/tests/models/language/pooling/test_colbert.py new file mode 100644 index 000000000..fa77b8c26 --- /dev/null +++ b/tests/models/language/pooling/test_colbert.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for ColBERT late interaction scoring.""" + +import pytest +import torch + +from vllm.entrypoints.pooling.score.utils import compute_maxsim_score + +# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model +# suitable for testing (based on BERT-base) +COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1" +COLBERT_DIM = 96 # This model uses 96-dimensional output + +TEXTS_1 = [ + "What is the capital of France?", + "What is the capital of Germany?", +] + +TEXTS_2 = [ + "The capital of France is Paris.", + "The capital of Germany is Berlin.", +] + +DTYPE = "half" + + +@pytest.fixture(scope="module") +def colbert_model_name(): + return COLBERT_MODEL + + +def test_colbert_token_embed(vllm_runner, colbert_model_name): + """Test that ColBERT model produces token embeddings.""" + with vllm_runner( + colbert_model_name, + runner="pooling", + dtype=DTYPE, + max_model_len=512, + enforce_eager=True, + ) as vllm_model: + # Get token embeddings for a single text + outputs = vllm_model.token_embed([TEXTS_1[0]]) + + assert len(outputs) == 1 + # Token embeddings should be 2D: [num_tokens, colbert_dim] + emb = torch.tensor(outputs[0]) + assert emb.dim() == 2 + assert emb.shape[1] == COLBERT_DIM + # Should have at least a few tokens + assert emb.shape[0] > 1 + + +def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name): + """Test ColBERT late interaction scoring with 1:1 query-document pair.""" + with vllm_runner( + colbert_model_name, + runner="pooling", + dtype=DTYPE, + max_model_len=512, + enforce_eager=True, + ) as vllm_model: + # Get token embeddings + q_outputs = vllm_model.token_embed([TEXTS_1[0]]) + d_outputs = vllm_model.token_embed([TEXTS_2[0]]) + + q_emb = torch.tensor(q_outputs[0]) + d_emb = torch.tensor(d_outputs[0]) + + # Compute MaxSim manually + manual_score = compute_maxsim_score(q_emb, d_emb).item() + + # Use the score API (which should internally use _late_interaction_score) + vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0]) + + assert len(vllm_scores) == 1 + assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01) + + +def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name): + """Test ColBERT late interaction scoring with 1:N query-documents.""" + with vllm_runner( + colbert_model_name, + runner="pooling", + dtype=DTYPE, + max_model_len=512, + enforce_eager=True, + ) as vllm_model: + # Get token embeddings + q_outputs = vllm_model.token_embed([TEXTS_1[0]]) + d_outputs = vllm_model.token_embed(TEXTS_2) + + q_emb = torch.tensor(q_outputs[0]) + + # Compute MaxSim manually for each document + manual_scores = [] + for d_out in d_outputs: + d_emb = torch.tensor(d_out) + manual_scores.append(compute_maxsim_score(q_emb, d_emb).item()) + + # Use the score API + vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2) + + assert len(vllm_scores) == 2 + for i in range(2): + assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01) + + +def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name): + """Test ColBERT late interaction scoring with N:N query-documents.""" + with vllm_runner( + colbert_model_name, + runner="pooling", + dtype=DTYPE, + max_model_len=512, + enforce_eager=True, + ) as vllm_model: + # Get token embeddings + q_outputs = vllm_model.token_embed(TEXTS_1) + d_outputs = vllm_model.token_embed(TEXTS_2) + + # Compute MaxSim manually for each pair + manual_scores = [] + for q_out, d_out in zip(q_outputs, d_outputs): + q_emb = torch.tensor(q_out) + d_emb = torch.tensor(d_out) + manual_scores.append(compute_maxsim_score(q_emb, d_emb).item()) + + # Use the score API + vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2) + + assert len(vllm_scores) == 2 + for i in range(2): + assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01) + + +def test_colbert_relevance_ordering(vllm_runner, colbert_model_name): + """Test that ColBERT scores relevant documents higher than irrelevant ones.""" + query = "What is machine learning?" + documents = [ + "Machine learning is a subset of artificial intelligence.", + "Python is a programming language.", + "Deep learning uses neural networks.", + ] + + with vllm_runner( + colbert_model_name, + runner="pooling", + dtype=DTYPE, + max_model_len=512, + enforce_eager=True, + ) as vllm_model: + scores = vllm_model.score(query, documents) + + assert len(scores) == 3 + # ML-related documents should score higher than unrelated Python doc + # Document 0 (ML definition) should be most relevant + # Document 2 (Deep learning) should also be relevant + # Document 1 (Python) should be least relevant + assert scores[0] > scores[1], "ML doc should score higher than Python doc" + assert scores[2] > scores[1], "DL doc should score higher than Python doc" + + +def test_colbert_embed_not_supported(vllm_runner, colbert_model_name): + """Test that ColBERT model does not support 'embed' task.""" + with ( + vllm_runner( + colbert_model_name, + runner="pooling", + dtype=DTYPE, + max_model_len=512, + enforce_eager=True, + ) as vllm_model, + pytest.raises(ValueError, match="Embedding API is not supported"), + ): + vllm_model.embed([TEXTS_1[0]]) + + +def test_colbert_hf_comparison(vllm_runner, colbert_model_name): + """Test that vLLM ColBERT produces same embeddings as HuggingFace.""" + import torch.nn.functional as F + from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + from transformers import AutoTokenizer, BertModel + + test_texts = [TEXTS_1[0], TEXTS_2[0]] + + # Get vLLM embeddings first (to avoid GPU memory contention) + # Use fp32 to match HuggingFace default precision for fair comparison + with vllm_runner( + colbert_model_name, + runner="pooling", + dtype="float32", + max_model_len=512, + enforce_eager=True, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed(test_texts) + + # Get HuggingFace reference embeddings on CPU + # Load the base BERT model and manually apply the ColBERT linear projection + hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name) + hf_bert = BertModel.from_pretrained(colbert_model_name) + hf_bert.eval() + + # Load the ColBERT linear weights from safetensors + weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors") + weights = load_file(weights_path) + linear_weight = weights["linear.weight"] # [96, 384] + + hf_embeddings = [] + for text in test_texts: + inputs = hf_tokenizer(text, return_tensors="pt") + with torch.no_grad(): + outputs = hf_bert(**inputs) + # Get last hidden state: [1, seq_len, 384] + hidden_states = outputs.last_hidden_state + # Apply ColBERT linear projection: [1, seq_len, 96] + token_emb = F.linear(hidden_states, linear_weight) + # L2 normalize + token_emb = F.normalize(token_emb, p=2, dim=-1) + hf_embeddings.append(token_emb.squeeze(0).float()) + + # Compare embeddings + for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)): + vllm_emb = torch.tensor(vllm_out).float() + + # Print first few components for debugging + print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===") + print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}") + print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}") + print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}") + print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}") + print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}") + + # Should have same shape + assert hf_emb.shape == vllm_emb.shape, ( + f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}" + ) + + # Should have same values (with tolerance for fp16) + torch.testing.assert_close( + vllm_emb, + hf_emb, + rtol=1e-2, + atol=1e-2, + msg=f"Embedding mismatch for text {i}", + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index cbd07cbc1..ffa4f52f1 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -520,6 +520,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = { # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), + "HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"), "BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), diff --git a/vllm/config/model.py b/vllm/config/model.py index 2686df4c2..86b484181 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1411,6 +1411,11 @@ class ModelConfig: self._model_info.supports_cross_encoding or self.convert_type == "classify" ) + @property + def is_late_interaction(self) -> bool: + """Check if model uses late interaction (ColBERT-style) scoring.""" + return self._model_info.supports_late_interaction + @property def is_pp_supported(self) -> bool: return self._model_info.supports_pp diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 24545de19..435ccbee6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import ( ScoreMultiModalParam, _cosine_similarity, compress_token_type_ids, + compute_maxsim_score, get_score_prompt, validate_score_input, ) @@ -1368,6 +1369,87 @@ class LLM: items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in items] + def _late_interaction_score( + self, + data_1: list[ScoreData], + data_2: list[ScoreData], + *, + use_tqdm: bool | Callable[..., tqdm], + pooling_params: PoolingParams | None, + lora_request: list[LoRARequest] | LoRARequest | None, + tokenization_kwargs: dict[str, Any], + ) -> list[ScoringRequestOutput]: + """ + Late interaction scoring (ColBERT MaxSim). + + Encodes queries and documents into per-token embeddings, then computes + MaxSim: sum over query tokens of max similarity to any document token. + """ + from vllm.outputs import PoolingOutput + + tokenizer = self.get_tokenizer() + + # Extract text from ScoreData + text_1: list[str] = [] + for text in data_1: + if not isinstance(text, str): + raise NotImplementedError( + "Late interaction scores currently do not support multimodal input." + ) + text_1.append(text) + + text_2: list[str] = [] + for text in data_2: + if not isinstance(text, str): + raise NotImplementedError( + "Late interaction scores currently do not support multimodal input." + ) + text_2.append(text) + + encoded_output: list[PoolingRequestOutput] = self.encode( + text_1 + text_2, + use_tqdm=use_tqdm, + lora_request=lora_request, + pooling_params=pooling_params, + pooling_task="token_embed", + tokenization_kwargs=tokenization_kwargs, + ) + + encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)] + encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :] + + if len(encoded_output_1) == 1: + encoded_output_1 = encoded_output_1 * len(encoded_output_2) + + # Compute MaxSim scores + scores: list[PoolingRequestOutput] = [] + padding: list[int] = [] + if (pad_token_id := tokenizer.pad_token_id) is not None: + padding = [pad_token_id] + + for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2): + # emb_1.outputs.data: [query_len, dim] + # emb_2.outputs.data: [doc_len, dim] + q_emb = emb_1.outputs.data + d_emb = emb_2.outputs.data + + maxsim_score = compute_maxsim_score(q_emb, d_emb) + + tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id=f"{emb_1.request_id}_{emb_2.request_id}", + outputs=PoolingOutput(data=maxsim_score), + prompt_token_ids=tokens, + num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens, + finished=True, + ) + ) + + items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) + return [ScoringRequestOutput.from_base(item) for item in items] + def _cross_encoding_score( self, data_1: list[ScoreData], @@ -1497,7 +1579,11 @@ class LLM: ) supported_tasks = self.supported_tasks - if all(t not in supported_tasks for t in ("embed", "classify")): + # Late interaction models (e.g., ColBERT) use token_embed for scoring + is_late_interaction = model_config.is_late_interaction + if not is_late_interaction and all( + t not in supported_tasks for t in ("embed", "classify") + ): raise ValueError( "Score API is not supported by this model. " "Try converting the model using " @@ -1538,6 +1624,15 @@ class LLM: tokenization_kwargs=encode_kwargs, score_template=chat_template, ) + elif is_late_interaction: + return self._late_interaction_score( + score_data_1, + score_data_2, + use_tqdm=use_tqdm, + pooling_params=pooling_params, + lora_request=lora_request, + tokenization_kwargs=encode_kwargs, + ) else: return self._embedding_score( score_data_1, diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 737f1efe8..4321e19f9 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -37,7 +37,11 @@ def register_pooling_api_routers( app.include_router(embed_router) - if "score" in supported_tasks or "embed" in supported_tasks: + # Score/rerank endpoints are available for: + # - "score" task (cross-encoder models) + # - "embed" task (bi-encoder models) + # - "token_embed" task (late interaction models like ColBERT) + if any(t in supported_tasks for t in ("score", "embed", "token_embed")): from vllm.entrypoints.pooling.score.api_router import router as score_router app.include_router(score_router) @@ -101,6 +105,10 @@ def init_pooling_state( if "classify" in supported_tasks else None ) + # ServingScores handles score/rerank for: + # - "score" task (cross-encoder models) + # - "embed" task (bi-encoder models) + # - "token_embed" task (late interaction models like ColBERT) state.openai_serving_scores = ( ServingScores( engine_client, @@ -109,6 +117,6 @@ def init_pooling_state( score_template=resolved_chat_template, log_error_stack=args.log_error_stack, ) - if ("embed" in supported_tasks or "score" in supported_tasks) + if any(t in supported_tasks for t in ("embed", "score", "token_embed")) else None ) diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index c32f5470d..9ef3b9aff 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import ( ScoreInputs, _cosine_similarity, compress_token_type_ids, + compute_maxsim_score, get_score_prompt, validate_score_input, ) @@ -68,9 +69,12 @@ class ServingScores(OpenAIServing): self.is_cross_encoder = self.model_config.is_cross_encoder self.is_multimodal_model = self.model_config.is_multimodal_model self.architecture = self.model_config.architecture + self.is_late_interaction = self.model_config.is_late_interaction if self.is_cross_encoder: self._score_func = self._cross_encoding_score + elif self.is_late_interaction: + self._score_func = self._late_interaction_score else: self._score_func = self._embedding_score @@ -172,6 +176,142 @@ class ServingScores(OpenAIServing): return final_res_batch + async def _late_interaction_score( + self, + data_1: list[ScoreData], + data_2: list[ScoreData], + request: RerankRequest | ScoreRequest, + request_id: str, + lora_request: LoRARequest | None = None, + trace_headers: Mapping[str, str] | None = None, + ) -> list[PoolingRequestOutput] | ErrorResponse: + """ + Late interaction scoring (ColBERT MaxSim). + + Encodes queries and documents into per-token embeddings, then computes + MaxSim: sum over query tokens of max similarity to any document token. + """ + input_texts: list[str] = [] + for text in data_1 + data_2: + if not isinstance(text, str): + raise NotImplementedError( + "Late interaction scores currently do not support multimodal input." + ) + input_texts.append(text) + + model_config = self.model_config + tokenizer = self.renderer.get_tokenizer() + + encode_async = make_async( + tokenizer.encode, + executor=self._tokenizer_executor, + ) + + tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs() + tokenized_prompts = await asyncio.gather( + *(encode_async(t, **tokenization_kwargs) for t in input_texts) + ) + + engine_prompts: list[TokensPrompt] = [] + for tok_result, input_text in zip(tokenized_prompts, input_texts): + text_token_prompt = self._validate_input(request, tok_result, input_text) + + engine_prompts.append( + TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"]) + ) + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + # Use token_embed task for late interaction models + from vllm import PoolingParams + + pooling_params = PoolingParams( + task="token_embed", + truncate_prompt_tokens=request.truncate_prompt_tokens, + use_activation=request.use_activation, + ) + + try: + pooling_params.verify("token_embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs( + request_id_item, + input_texts[i], + params=pooling_params, + lora_request=lora_request, + ) + + generators.append( + self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + ) + + result_generator = merge_async_iterators(*generators) + + # Collect token embeddings + embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts) + + async for i, res in result_generator: + embeddings[i] = res + + # Split into query and document embeddings + emb_data_1: list[PoolingRequestOutput] = [] + emb_data_2: list[PoolingRequestOutput] = [] + + for i in range(0, len(data_1)): + assert (emb := embeddings[i]) is not None + emb_data_1.append(emb) + + for i in range(len(data_1), len(embeddings)): + assert (emb := embeddings[i]) is not None + emb_data_2.append(emb) + + # Expand queries if 1:N scoring + if len(emb_data_1) == 1: + emb_data_1 = emb_data_1 * len(emb_data_2) + + # Compute MaxSim scores + from vllm.outputs import PoolingOutput + + scores: list[PoolingRequestOutput] = [] + padding: list[int] = [] + if (pad_token_id := tokenizer.pad_token_id) is not None: + padding = [pad_token_id] + + for emb_1, emb_2 in zip(emb_data_1, emb_data_2): + # emb_1.outputs.data: [query_len, dim] + # emb_2.outputs.data: [doc_len, dim] + q_emb = emb_1.outputs.data + d_emb = emb_2.outputs.data + + maxsim_score = compute_maxsim_score(q_emb, d_emb) + + tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids + + scores.append( + PoolingRequestOutput( + request_id=f"{emb_1.request_id}_{emb_2.request_id}", + outputs=PoolingOutput(data=maxsim_score), + prompt_token_ids=tokens, + num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens, + finished=True, + ) + ) + + return scores + async def _cross_encoding_score( self, data_1: list[ScoreData], diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index bf3bfe8a8..7d00f42f5 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from typing import Any, TypeAlias, cast +import torch from torch.nn import CosineSimilarity from typing_extensions import Required, TypedDict @@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = ( ) +def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor: + """ + Compute ColBERT MaxSim score. + + Args: + q_emb: Query token embeddings [query_len, dim] + d_emb: Document token embeddings [doc_len, dim] + + Returns: + MaxSim score (sum over query tokens of max similarity to any doc token) + """ + # [query_len, doc_len] + token_scores = torch.matmul(q_emb, d_emb.T) + # Max over document tokens, sum over query tokens + return token_scores.amax(dim=-1).sum() + + class ScoreMultiModalParam(TypedDict, total=False): """ A specialized parameter type for scoring multimodal content diff --git a/vllm/model_executor/models/colbert.py b/vllm/model_executor/models/colbert.py new file mode 100644 index 000000000..dbb160556 --- /dev/null +++ b/vllm/model_executor/models/colbert.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +ColBERT late interaction model for retrieval and reranking. + +ColBERT uses per-token embeddings and late interaction (MaxSim) scoring +instead of single-vector representations or cross-encoder concatenation. + +Reference: https://arxiv.org/abs/2004.12832 +""" + +from collections.abc import Iterable +from typing import ClassVar, Literal + +import torch +from torch import nn + +from vllm.config import PoolerConfig, VllmConfig +from vllm.model_executor.layers.pooler import Pooler +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed + +from .bert import BertEmbeddingModel, BertModel +from .interfaces_base import default_pooling_type + + +@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") +class ColBERTModel(BertEmbeddingModel): + """ColBERT late interaction model for retrieval/reranking. + + This model extends BertEmbeddingModel with a ColBERT-style linear + projection layer for per-token embeddings. It supports only: + - "token_embed" task: Per-token embeddings for late interaction + + ColBERT is fundamentally a per-token embedding model - the linear + projection is trained for per-token representations, not for CLS + pooling. Use a dedicated dense embedding model if you need single- + vector representations. + + The ColBERT scoring (MaxSim) is computed externally, either client-side + or via the late interaction scoring path in ServingScores. + + Attributes: + colbert_linear: Linear projection from hidden_size to colbert_dim + supports_late_interaction: Flag indicating this model uses late + interaction scoring + """ + + # Mark this model as supporting late interaction scoring + supports_late_interaction: ClassVar[Literal[True]] = True + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Get config before calling super().__init__ + config = vllm_config.model_config.hf_config + self.hidden_size = config.hidden_size + self.head_dtype = vllm_config.model_config.head_dtype + + # ColBERT dimension - check various config field names used by different + # ColBERT implementations. If not found in config, will be inferred + # from loaded weights in load_weights() + self.colbert_dim: int | None = ( + getattr(config, "colbert_dim", None) + or getattr(config, "dim", None) + or getattr(config, "projection_dim", None) + ) + + # Initialize parent (this will call _build_pooler) + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: + return BertModel(vllm_config=vllm_config, prefix=prefix) + + def _build_colbert_linear(self) -> nn.Linear: + """Build the ColBERT linear projection layer.""" + if self.colbert_dim is None: + raise ValueError("colbert_dim must be set before building the linear layer") + return nn.Linear( + self.hidden_size, + self.colbert_dim, + bias=False, + dtype=self.head_dtype, + ) + + def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: + # ColBERT linear projection: hidden_size -> colbert_dim + # Original ColBERT uses bias=False + # If colbert_dim is not set from config, it will be inferred during + # load_weights and the linear layer will be created there + if self.colbert_dim is not None: + self.colbert_linear = self._build_colbert_linear() + else: + # Placeholder - will be created when weights are loaded + self.colbert_linear = None + + # ColBERT only supports token_embed - it's fundamentally a per-token + # embedding model. + return pooler_for_token_embed( + pooler_config, + projector=self.colbert_linear, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def _strip(name: str) -> str: + for p in ("model.", "bert."): + if name.startswith(p): + name = name[len(p) :] + return name + + weights_list = list(weights) + model_side: list[tuple[str, torch.Tensor]] = [] + colbert_side: list[tuple[str, torch.Tensor]] = [] + + for name, weight in weights_list: + stripped = _strip(name) + # Handle different checkpoint naming conventions for ColBERT linear + if stripped in ("linear.weight", "colbert_linear.weight"): + colbert_side.append(("colbert_linear.weight", weight)) + elif stripped.startswith("linear.") or stripped.startswith( + "colbert_linear." + ): + new_name = stripped.replace("linear.", "colbert_linear.") + colbert_side.append((new_name, weight)) + else: + model_side.append((stripped, weight)) + + # Load base BERT weights using BertModel.load_weights which handles QKV fusion + loaded: set[str] = set() + loaded_model = self.model.load_weights(model_side) + loaded.update({"model." + n for n in loaded_model}) + + # Load ColBERT linear weights + if colbert_side: + for name, weight in colbert_side: + if name == "colbert_linear.weight": + # Infer colbert_dim from weights if not set in config + if self.colbert_dim is None: + # Weight shape is [colbert_dim, hidden_size] + self.colbert_dim = weight.shape[0] + # Create the linear layer now that we know the dimension + self.colbert_linear = self._build_colbert_linear() + # Move to the same device as the model's existing parameters + device = next(self.model.parameters()).device + self.colbert_linear.to(device) + # Update the pooler's projector to use the new linear layer + self.pooler.head.projector = self.colbert_linear + + # Load weights directly into the pooler's projector + weight = weight.to(self.pooler.head.projector.weight.device) + self.pooler.head.projector.weight.data.copy_(weight) + loaded.add("pooler.head.projector.weight") + break + + return loaded diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c97a9faf6..2c3ca1a50 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -981,6 +981,40 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) +@runtime_checkable +class SupportsLateInteraction(Protocol): + """The interface required for all models that support late interaction. + + Late interaction models (like ColBERT) encode queries and documents + separately into per-token embeddings, then compute similarity via + MaxSim (max over document tokens, sum over query tokens). + """ + + supports_late_interaction: ClassVar[Literal[True]] = True + + +@overload +def supports_late_interaction( + model: type[object], +) -> TypeIs[type[SupportsLateInteraction]]: ... + + +@overload +def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ... + + +def _supports_late_interaction( + model: type[object] | object, +) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]: + return getattr(model, "supports_late_interaction", False) + + +def supports_late_interaction( + model: type[object] | object, +) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]: + return is_pooling_model(model) and _supports_late_interaction(model) + + class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5eeb32ed9..830a615ce 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -49,6 +49,7 @@ from .interfaces import ( is_hybrid, requires_raw_input_tokens, supports_cross_encoding, + supports_late_interaction, supports_mamba_prefix_caching, supports_multimodal, supports_multimodal_encoder_tp_data, @@ -205,6 +206,7 @@ _EMBEDDING_MODELS = { # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), + "HF_ColBERT": ("colbert", "ColBERTModel"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma3TextModel": ("gemma3", "Gemma3Model"), @@ -593,6 +595,7 @@ class _ModelInfo: default_seq_pooling_type: SequencePoolingType default_tok_pooling_type: TokenPoolingType supports_cross_encoding: bool + supports_late_interaction: bool supports_multimodal: bool supports_multimodal_raw_input_only: bool requires_raw_input_tokens: bool @@ -616,6 +619,7 @@ class _ModelInfo: default_tok_pooling_type=get_default_tok_pooling_type(model), attn_type=get_attn_type(model), supports_cross_encoding=supports_cross_encoding(model), + supports_late_interaction=supports_late_interaction(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( model