Extend ColBERT support to non-standard BERT backbones (#34170)
Signed-off-by: Ilya Boytsov <ilya.boytsov@aleph-alpha.com>
This commit is contained in:
@@ -311,20 +311,31 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed
|
||||
|
||||
[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders.
|
||||
|
||||
vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance:
|
||||
vLLM supports ColBERT models with multiple encoder backbones:
|
||||
|
||||
| Architecture | Backbone | Example HF Models |
|
||||
|---|---|---|
|
||||
| `HF_ColBERT` | BERT | `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0` |
|
||||
| `ColBERTModernBertModel` | ModernBERT | `lightonai/GTE-ModernColBERT-v1` |
|
||||
| `ColBERTJinaRobertaModel` | Jina XLM-RoBERTa | `jinaai/jina-colbert-v2` |
|
||||
|
||||
**BERT-based ColBERT** models work out of the box:
|
||||
|
||||
```shell
|
||||
vllm serve answerdotai/answerai-colbert-small-v1
|
||||
```
|
||||
|
||||
Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`).
|
||||
|
||||
ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`).
|
||||
|
||||
If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with:
|
||||
For **non-BERT backbones**, use `--hf-overrides` to set the correct architecture:
|
||||
|
||||
```shell
|
||||
vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}'
|
||||
# ModernBERT backbone
|
||||
vllm serve lightonai/GTE-ModernColBERT-v1 \
|
||||
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
|
||||
|
||||
# Jina XLM-RoBERTa backbone
|
||||
vllm serve jinaai/jina-colbert-v2 \
|
||||
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
|
||||
--trust-remote-code
|
||||
```
|
||||
|
||||
Then you can use the rerank endpoint:
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example of using ColBERT late interaction model for reranking.
|
||||
Example of using ColBERT late interaction models for reranking and scoring.
|
||||
|
||||
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
|
||||
and MaxSim scoring for document reranking, providing better accuracy than
|
||||
single-vector models while being more efficient than cross-encoders.
|
||||
|
||||
Start the server with:
|
||||
vLLM supports ColBERT with multiple encoder backbones. Start the server
|
||||
with one of the following:
|
||||
|
||||
# BERT backbone (works out of the box)
|
||||
vllm serve answerdotai/answerai-colbert-small-v1
|
||||
|
||||
# ModernBERT backbone
|
||||
vllm serve lightonai/GTE-ModernColBERT-v1 \
|
||||
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
|
||||
|
||||
# Jina XLM-RoBERTa backbone
|
||||
vllm serve jinaai/jina-colbert-v2 \
|
||||
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
|
||||
--trust-remote-code
|
||||
|
||||
Then run this script:
|
||||
python colbert_rerank_online.py
|
||||
"""
|
||||
@@ -18,39 +30,62 @@ import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
# Change this to match the model you started the server with
|
||||
MODEL = "answerdotai/answerai-colbert-small-v1"
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"query": "What is machine learning?",
|
||||
"documents": [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a programming language.",
|
||||
"Deep learning uses neural networks for complex tasks.",
|
||||
"The weather today is sunny.",
|
||||
],
|
||||
}
|
||||
documents = [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a programming language.",
|
||||
"Deep learning uses neural networks for complex tasks.",
|
||||
"The weather today is sunny.",
|
||||
]
|
||||
|
||||
|
||||
def rerank_example():
|
||||
"""Use the /rerank endpoint to rank documents by query relevance."""
|
||||
print("=== Rerank Example ===")
|
||||
|
||||
data = {
|
||||
"model": MODEL,
|
||||
"query": "What is machine learning?",
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
|
||||
result = response.json()
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
print("\nRanked documents (most relevant first):")
|
||||
for item in result["results"]:
|
||||
doc_idx = item["index"]
|
||||
score = item["relevance_score"]
|
||||
print(f" Score {score:.4f}: {documents[doc_idx]}")
|
||||
|
||||
|
||||
def score_example():
|
||||
"""Use the /score endpoint for pairwise query-document scoring."""
|
||||
print("\n=== Score Example ===")
|
||||
|
||||
data = {
|
||||
"model": MODEL,
|
||||
"text_1": "What is machine learning?",
|
||||
"text_2": [
|
||||
"Machine learning is a subset of AI.",
|
||||
"The weather is sunny.",
|
||||
],
|
||||
}
|
||||
|
||||
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
|
||||
result = response.json()
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
def main():
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("ColBERT Rerank Request successful!")
|
||||
result = response.json()
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
# Show ranked results
|
||||
print("\nRanked documents (most relevant first):")
|
||||
for item in result["results"]:
|
||||
doc_idx = item["index"]
|
||||
score = item["relevance_score"]
|
||||
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
rerank_example()
|
||||
score_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -8,10 +8,8 @@ import requests
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
|
||||
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
|
||||
MODEL_NAME = "answerdotai/answerai-colbert-small-v1"
|
||||
COLBERT_DIM = 96 # This model uses 96-dimensional output
|
||||
DTYPE = "half"
|
||||
COLBERT_DIM = 96
|
||||
MAX_MODEL_LEN = 512
|
||||
|
||||
|
||||
@@ -26,129 +24,119 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT rerank endpoint."""
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
class TestColBERTOnline:
|
||||
def test_rerank(self, server: RemoteOpenAIServer):
|
||||
"""Test ColBERT rerank endpoint."""
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
# The relevant document (Paris) should have higher score
|
||||
paris_result = next(r for r in rerank.results if r.index == 1)
|
||||
brazil_result = next(r for r in rerank.results if r.index == 0)
|
||||
paris_result = next(r for r in rerank.results if r.index == 1)
|
||||
brazil_result = next(r for r in rerank.results if r.index == 0)
|
||||
|
||||
assert paris_result.relevance_score > brazil_result.relevance_score
|
||||
assert paris_result.relevance_score > brazil_result.relevance_score
|
||||
|
||||
def test_rerank_top_n(self, server: RemoteOpenAIServer):
|
||||
"""Test ColBERT rerank with top_n parameter."""
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
"Machine learning is a field of AI.",
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT rerank with top_n parameter."""
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
"Machine learning is a field of AI.",
|
||||
]
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": 2,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": 2,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
assert len(rerank.results) == 2
|
||||
assert rerank.results[0].index == 1
|
||||
|
||||
assert len(rerank.results) == 2
|
||||
# Top result should be about Paris (index 1)
|
||||
assert rerank.results[0].index == 1
|
||||
def test_score(self, server: RemoteOpenAIServer):
|
||||
"""Test ColBERT score endpoint."""
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = ["The capital of France is Paris.", "Python is a language."]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_score(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT score endpoint."""
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = ["The capital of France is Paris.", "Python is a language."]
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
assert score.data[0].score > score.data[1].score
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
def test_token_embed(self, server: RemoteOpenAIServer):
|
||||
"""Test ColBERT token_embed task via pooling endpoint."""
|
||||
text = "What is the capital of France?"
|
||||
|
||||
# The relevant document should have higher score
|
||||
assert score.data[0].score > score.data[1].score
|
||||
pooling_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"input": text,
|
||||
"task": "token_embed",
|
||||
},
|
||||
)
|
||||
pooling_response.raise_for_status()
|
||||
pooling = pooling_response.json()
|
||||
|
||||
assert "data" in pooling
|
||||
assert len(pooling["data"]) == 1
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT token_embed task via pooling endpoint."""
|
||||
text = "What is the capital of France?"
|
||||
embeddings = pooling["data"][0]["data"]
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) > 0
|
||||
assert len(embeddings[0]) == COLBERT_DIM
|
||||
|
||||
pooling_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": text,
|
||||
"task": "token_embed",
|
||||
},
|
||||
)
|
||||
pooling_response.raise_for_status()
|
||||
pooling = pooling_response.json()
|
||||
def test_embed_not_supported(self, server: RemoteOpenAIServer):
|
||||
"""Test that ColBERT model does not support 'embed' task."""
|
||||
task = "embed"
|
||||
text = "What is the capital of France?"
|
||||
|
||||
assert "data" in pooling
|
||||
assert len(pooling["data"]) == 1
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"input": text,
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
|
||||
# Token embeddings should be 2D
|
||||
embeddings = pooling["data"][0]["data"]
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) > 0 # Should have tokens
|
||||
assert len(embeddings[0]) == COLBERT_DIM
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test that ColBERT model does not support 'embed' task."""
|
||||
task = "embed"
|
||||
text = "What is the capital of France?"
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": text,
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
|
||||
assert response.json()["error"]["type"] == "BadRequestError"
|
||||
assert response.json()["error"]["message"].startswith(
|
||||
f"Unsupported task: {task!r}"
|
||||
)
|
||||
|
||||
@@ -1,16 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ColBERT late interaction scoring."""
|
||||
"""Tests for ColBERT late interaction scoring.
|
||||
|
||||
Tests are parametrized across multiple ColBERT backbones to ensure the
|
||||
generic ColBERT support works with different encoder architectures.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
|
||||
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
|
||||
# suitable for testing (based on BERT-base)
|
||||
COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1"
|
||||
COLBERT_DIM = 96 # This model uses 96-dimensional output
|
||||
# -----------------------------------------------------------------------
|
||||
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
|
||||
# -----------------------------------------------------------------------
|
||||
COLBERT_MODELS = {
|
||||
"bert": {
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"colbert_dim": 96,
|
||||
"max_model_len": 512,
|
||||
"extra_kwargs": {},
|
||||
},
|
||||
"modernbert": {
|
||||
"model": "lightonai/GTE-ModernColBERT-v1",
|
||||
"colbert_dim": 128,
|
||||
"max_model_len": 299,
|
||||
"extra_kwargs": {
|
||||
"hf_overrides": {
|
||||
"architectures": ["ColBERTModernBertModel"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"jina": {
|
||||
"model": "jinaai/jina-colbert-v2",
|
||||
"colbert_dim": 128,
|
||||
"max_model_len": 8192,
|
||||
"extra_kwargs": {
|
||||
"hf_overrides": {
|
||||
"architectures": ["ColBERTJinaRobertaModel"],
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
TEXTS_1 = [
|
||||
"What is the capital of France?",
|
||||
@@ -25,80 +56,121 @@ TEXTS_2 = [
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
|
||||
def colbert_spec(request):
|
||||
"""Return the model spec dict for the current parametrization."""
|
||||
return COLBERT_MODELS[request.param]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def colbert_model_name():
|
||||
return COLBERT_MODEL
|
||||
def colbert_model_name(colbert_spec):
|
||||
return colbert_spec["model"]
|
||||
|
||||
|
||||
def test_colbert_token_embed(vllm_runner, colbert_model_name):
|
||||
@pytest.fixture(scope="module")
|
||||
def colbert_dim(colbert_spec):
|
||||
return colbert_spec["colbert_dim"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def colbert_max_model_len(colbert_spec):
|
||||
return colbert_spec["max_model_len"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def colbert_extra_kwargs(colbert_spec):
|
||||
return colbert_spec["extra_kwargs"]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tests
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_colbert_token_embed(
|
||||
vllm_runner,
|
||||
colbert_model_name,
|
||||
colbert_dim,
|
||||
colbert_max_model_len,
|
||||
colbert_extra_kwargs,
|
||||
):
|
||||
"""Test that ColBERT model produces token embeddings."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
max_model_len=colbert_max_model_len,
|
||||
enforce_eager=True,
|
||||
**colbert_extra_kwargs,
|
||||
) as vllm_model:
|
||||
# Get token embeddings for a single text
|
||||
outputs = vllm_model.token_embed([TEXTS_1[0]])
|
||||
|
||||
assert len(outputs) == 1
|
||||
# Token embeddings should be 2D: [num_tokens, colbert_dim]
|
||||
emb = torch.tensor(outputs[0])
|
||||
assert emb.dim() == 2
|
||||
assert emb.shape[1] == COLBERT_DIM
|
||||
# Should have at least a few tokens
|
||||
assert emb.shape[1] == colbert_dim
|
||||
assert emb.shape[0] > 1
|
||||
|
||||
|
||||
def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name):
|
||||
def test_colbert_late_interaction_1_to_1(
|
||||
vllm_runner,
|
||||
colbert_model_name,
|
||||
colbert_max_model_len,
|
||||
colbert_extra_kwargs,
|
||||
):
|
||||
"""Test ColBERT late interaction scoring with 1:1 query-document pair."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
max_model_len=colbert_max_model_len,
|
||||
enforce_eager=True,
|
||||
**colbert_extra_kwargs,
|
||||
) as vllm_model:
|
||||
# Get token embeddings
|
||||
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
|
||||
d_outputs = vllm_model.token_embed([TEXTS_2[0]])
|
||||
|
||||
q_emb = torch.tensor(q_outputs[0])
|
||||
d_emb = torch.tensor(d_outputs[0])
|
||||
|
||||
# Compute MaxSim manually
|
||||
manual_score = compute_maxsim_score(q_emb, d_emb).item()
|
||||
|
||||
# Use the score API (which should internally use _late_interaction_score)
|
||||
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])
|
||||
|
||||
assert len(vllm_scores) == 1
|
||||
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
|
||||
|
||||
|
||||
def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
|
||||
def test_colbert_late_interaction_1_to_N(
|
||||
vllm_runner,
|
||||
colbert_model_name,
|
||||
colbert_max_model_len,
|
||||
colbert_extra_kwargs,
|
||||
):
|
||||
"""Test ColBERT late interaction scoring with 1:N query-documents."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
max_model_len=colbert_max_model_len,
|
||||
enforce_eager=True,
|
||||
**colbert_extra_kwargs,
|
||||
) as vllm_model:
|
||||
# Get token embeddings
|
||||
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
|
||||
d_outputs = vllm_model.token_embed(TEXTS_2)
|
||||
|
||||
q_emb = torch.tensor(q_outputs[0])
|
||||
|
||||
# Compute MaxSim manually for each document
|
||||
manual_scores = []
|
||||
for d_out in d_outputs:
|
||||
d_emb = torch.tensor(d_out)
|
||||
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
|
||||
|
||||
# Use the score API
|
||||
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
||||
|
||||
assert len(vllm_scores) == 2
|
||||
@@ -106,27 +178,30 @@ def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
|
||||
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
|
||||
|
||||
|
||||
def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
|
||||
def test_colbert_late_interaction_N_to_N(
|
||||
vllm_runner,
|
||||
colbert_model_name,
|
||||
colbert_max_model_len,
|
||||
colbert_extra_kwargs,
|
||||
):
|
||||
"""Test ColBERT late interaction scoring with N:N query-documents."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
max_model_len=colbert_max_model_len,
|
||||
enforce_eager=True,
|
||||
**colbert_extra_kwargs,
|
||||
) as vllm_model:
|
||||
# Get token embeddings
|
||||
q_outputs = vllm_model.token_embed(TEXTS_1)
|
||||
d_outputs = vllm_model.token_embed(TEXTS_2)
|
||||
|
||||
# Compute MaxSim manually for each pair
|
||||
manual_scores = []
|
||||
for q_out, d_out in zip(q_outputs, d_outputs):
|
||||
q_emb = torch.tensor(q_out)
|
||||
d_emb = torch.tensor(d_out)
|
||||
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
|
||||
|
||||
# Use the score API
|
||||
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
|
||||
|
||||
assert len(vllm_scores) == 2
|
||||
@@ -134,8 +209,13 @@ def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
|
||||
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
|
||||
|
||||
|
||||
def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
|
||||
"""Test that ColBERT scores relevant documents higher than irrelevant ones."""
|
||||
def test_colbert_relevance_ordering(
|
||||
vllm_runner,
|
||||
colbert_model_name,
|
||||
colbert_max_model_len,
|
||||
colbert_extra_kwargs,
|
||||
):
|
||||
"""Test that ColBERT scores relevant documents higher than irrelevant."""
|
||||
query = "What is machine learning?"
|
||||
documents = [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
@@ -147,48 +227,73 @@ def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
max_model_len=colbert_max_model_len,
|
||||
enforce_eager=True,
|
||||
**colbert_extra_kwargs,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.score(query, documents)
|
||||
|
||||
assert len(scores) == 3
|
||||
# ML-related documents should score higher than unrelated Python doc
|
||||
# Document 0 (ML definition) should be most relevant
|
||||
# Document 2 (Deep learning) should also be relevant
|
||||
# Document 1 (Python) should be least relevant
|
||||
assert scores[0] > scores[1], "ML doc should score higher than Python doc"
|
||||
assert scores[2] > scores[1], "DL doc should score higher than Python doc"
|
||||
|
||||
|
||||
def test_colbert_embed_not_supported(vllm_runner, colbert_model_name):
|
||||
def test_colbert_embed_not_supported(
|
||||
vllm_runner,
|
||||
colbert_model_name,
|
||||
colbert_max_model_len,
|
||||
colbert_extra_kwargs,
|
||||
):
|
||||
"""Test that ColBERT model does not support 'embed' task."""
|
||||
with (
|
||||
vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
max_model_len=colbert_max_model_len,
|
||||
enforce_eager=True,
|
||||
**colbert_extra_kwargs,
|
||||
) as vllm_model,
|
||||
pytest.raises(ValueError, match="Embedding API is not supported"),
|
||||
):
|
||||
vllm_model.embed([TEXTS_1[0]])
|
||||
|
||||
|
||||
def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
|
||||
"""Test that vLLM ColBERT produces same embeddings as HuggingFace."""
|
||||
# -----------------------------------------------------------------------
|
||||
# Per-model HuggingFace comparison tests
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
def _assert_embeddings_close(vllm_outputs, hf_embeddings):
|
||||
"""Assert that vLLM and HuggingFace embeddings match."""
|
||||
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
|
||||
vllm_emb = torch.tensor(vllm_out).float()
|
||||
|
||||
assert hf_emb.shape == vllm_emb.shape, (
|
||||
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
vllm_emb,
|
||||
hf_emb,
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
msg=f"Embedding mismatch for text {i}",
|
||||
)
|
||||
|
||||
|
||||
def test_colbert_hf_comparison_bert(vllm_runner):
|
||||
"""Test that vLLM ColBERT produces same embeddings as HuggingFace (BERT)."""
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, BertModel
|
||||
|
||||
model_name = COLBERT_MODELS["bert"]["model"]
|
||||
test_texts = [TEXTS_1[0], TEXTS_2[0]]
|
||||
|
||||
# Get vLLM embeddings first (to avoid GPU memory contention)
|
||||
# Use fp32 to match HuggingFace default precision for fair comparison
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype="float32",
|
||||
max_model_len=512,
|
||||
@@ -196,14 +301,11 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(test_texts)
|
||||
|
||||
# Get HuggingFace reference embeddings on CPU
|
||||
# Load the base BERT model and manually apply the ColBERT linear projection
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name)
|
||||
hf_bert = BertModel.from_pretrained(colbert_model_name)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
hf_bert = BertModel.from_pretrained(model_name)
|
||||
hf_bert.eval()
|
||||
|
||||
# Load the ColBERT linear weights from safetensors
|
||||
weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors")
|
||||
weights_path = hf_hub_download(model_name, filename="model.safetensors")
|
||||
weights = load_file(weights_path)
|
||||
linear_weight = weights["linear.weight"] # [96, 384]
|
||||
|
||||
@@ -212,36 +314,103 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
|
||||
inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = hf_bert(**inputs)
|
||||
# Get last hidden state: [1, seq_len, 384]
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Apply ColBERT linear projection: [1, seq_len, 96]
|
||||
token_emb = F.linear(hidden_states, linear_weight)
|
||||
# L2 normalize
|
||||
token_emb = F.normalize(token_emb, p=2, dim=-1)
|
||||
hf_embeddings.append(token_emb.squeeze(0).float())
|
||||
|
||||
# Compare embeddings
|
||||
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
|
||||
vllm_emb = torch.tensor(vllm_out).float()
|
||||
_assert_embeddings_close(vllm_outputs, hf_embeddings)
|
||||
|
||||
# Print first few components for debugging
|
||||
print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===")
|
||||
print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}")
|
||||
print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}")
|
||||
print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}")
|
||||
print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}")
|
||||
print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}")
|
||||
|
||||
# Should have same shape
|
||||
assert hf_emb.shape == vllm_emb.shape, (
|
||||
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
|
||||
)
|
||||
def test_colbert_hf_comparison_modernbert(vllm_runner):
|
||||
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
|
||||
(ModernBERT)."""
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# Should have same values (with tolerance for fp16)
|
||||
torch.testing.assert_close(
|
||||
vllm_emb,
|
||||
hf_emb,
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
msg=f"Embedding mismatch for text {i}",
|
||||
)
|
||||
spec = COLBERT_MODELS["modernbert"]
|
||||
model_name = spec["model"]
|
||||
test_texts = [TEXTS_1[0], TEXTS_2[0]]
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype="float32",
|
||||
max_model_len=spec["max_model_len"],
|
||||
enforce_eager=True,
|
||||
**spec["extra_kwargs"],
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(test_texts)
|
||||
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
hf_model = AutoModel.from_pretrained(model_name)
|
||||
hf_model.eval()
|
||||
|
||||
# Load projection from sentence-transformers 1_Dense layer
|
||||
dense_path = hf_hub_download(model_name, filename="1_Dense/model.safetensors")
|
||||
dense_weights = load_file(dense_path)
|
||||
linear_weight = dense_weights["linear.weight"] # [128, 768]
|
||||
|
||||
hf_embeddings = []
|
||||
for text in test_texts:
|
||||
inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
token_emb = F.linear(hidden_states, linear_weight)
|
||||
token_emb = F.normalize(token_emb, p=2, dim=-1)
|
||||
hf_embeddings.append(token_emb.squeeze(0).float())
|
||||
|
||||
_assert_embeddings_close(vllm_outputs, hf_embeddings)
|
||||
|
||||
|
||||
def test_colbert_hf_comparison_jina(vllm_runner):
|
||||
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
|
||||
(Jina XLM-RoBERTa)."""
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
spec = COLBERT_MODELS["jina"]
|
||||
model_name = spec["model"]
|
||||
test_texts = [TEXTS_1[0], TEXTS_2[0]]
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype="float32",
|
||||
max_model_len=spec["max_model_len"],
|
||||
enforce_eager=True,
|
||||
**spec["extra_kwargs"],
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(test_texts)
|
||||
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
hf_model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
hf_model.eval()
|
||||
|
||||
# Load projection from main checkpoint
|
||||
weights_path = hf_hub_download(model_name, filename="model.safetensors")
|
||||
weights = load_file(weights_path)
|
||||
linear_weight = weights["linear.weight"] # [128, 1024]
|
||||
|
||||
hf_embeddings = []
|
||||
for text in test_texts:
|
||||
inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = hf_model(**inputs)
|
||||
hidden_states = outputs.last_hidden_state
|
||||
token_emb = F.linear(hidden_states.float(), linear_weight.float())
|
||||
token_emb = F.normalize(token_emb, p=2, dim=-1)
|
||||
hf_embeddings.append(token_emb.squeeze(0).float())
|
||||
|
||||
_assert_embeddings_close(vllm_outputs, hf_embeddings)
|
||||
|
||||
@@ -529,6 +529,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
|
||||
"ColBERTModernBertModel": _HfExamplesInfo(
|
||||
"lightonai/GTE-ModernColBERT-v1",
|
||||
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
|
||||
),
|
||||
"ColBERTJinaRobertaModel": _HfExamplesInfo(
|
||||
"jinaai/jina-colbert-v2",
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
|
||||
),
|
||||
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
|
||||
|
||||
@@ -6,6 +6,14 @@ ColBERT late interaction model for retrieval and reranking.
|
||||
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
|
||||
instead of single-vector representations or cross-encoder concatenation.
|
||||
|
||||
This module provides:
|
||||
|
||||
- :class:`ColBERTMixin` — mixin that adds ColBERT late-interaction support
|
||||
to any embedding model.
|
||||
- :class:`ColBERTModel` — ColBERT with BERT backbone (original architecture).
|
||||
- :class:`ColBERTModernBertModel` — ColBERT with ModernBERT backbone.
|
||||
- :class:`ColBERTJinaRobertaModel` — ColBERT with Jina XLM-RoBERTa backbone.
|
||||
|
||||
Reference: https://arxiv.org/abs/2004.12832
|
||||
"""
|
||||
|
||||
@@ -23,51 +31,60 @@ from .bert import BertEmbeddingModel, BertModel
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColBERTModel(BertEmbeddingModel):
|
||||
"""ColBERT late interaction model for retrieval/reranking.
|
||||
class ColBERTMixin:
|
||||
"""Mixin that adds ColBERT late interaction support to any embedding model.
|
||||
|
||||
This model extends BertEmbeddingModel with a ColBERT-style linear
|
||||
projection layer for per-token embeddings. It supports only:
|
||||
- "token_embed" task: Per-token embeddings for late interaction
|
||||
ColBERT (Contextualized Late Interaction over BERT) uses per-token
|
||||
embeddings with a linear projection layer. This mixin provides:
|
||||
|
||||
ColBERT is fundamentally a per-token embedding model - the linear
|
||||
projection is trained for per-token representations, not for CLS
|
||||
pooling. Use a dedicated dense embedding model if you need single-
|
||||
vector representations.
|
||||
- ``supports_late_interaction`` class-var
|
||||
- ColBERT linear projection initialisation / lazy creation
|
||||
- Weight loading helpers for the projection layer
|
||||
- A builder for the token-embedding pooler
|
||||
|
||||
The ColBERT scoring (MaxSim) is computed externally, either client-side
|
||||
or via the late interaction scoring path in ServingScores.
|
||||
**Integration:**
|
||||
|
||||
Attributes:
|
||||
colbert_linear: Linear projection from hidden_size to colbert_dim
|
||||
supports_late_interaction: Flag indicating this model uses late
|
||||
interaction scoring
|
||||
1. Inherit from both ``ColBERTMixin`` and ``nn.Module``.
|
||||
2. In ``__init__``: call ``super().__init__()``, then
|
||||
:meth:`_init_colbert_components`, then create ``self.model``
|
||||
(the backbone) and ``self.pooler`` via :meth:`_build_colbert_pooler`.
|
||||
3. In ``load_weights``: use :meth:`_load_colbert_weights` to separate
|
||||
the ColBERT projection weight, then delegate the rest to the backbone.
|
||||
"""
|
||||
|
||||
# Mark this model as supporting late interaction scoring
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Get config before calling super().__init__
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
# Set during _init_colbert_components
|
||||
colbert_dim: int | None
|
||||
colbert_linear: nn.Linear | None
|
||||
hidden_size: int
|
||||
head_dtype: torch.dtype
|
||||
|
||||
# ColBERT dimension - check various config field names used by different
|
||||
# ColBERT implementations. If not found in config, will be inferred
|
||||
# from loaded weights in load_weights()
|
||||
self.colbert_dim: int | None = (
|
||||
getattr(config, "colbert_dim", None)
|
||||
or getattr(config, "dim", None)
|
||||
or getattr(config, "projection_dim", None)
|
||||
)
|
||||
# ------------------------------------------------------------------ init
|
||||
|
||||
# Initialize parent (this will call _build_pooler)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
def _init_colbert_components(
|
||||
self,
|
||||
hidden_size: int,
|
||||
colbert_dim: int | None,
|
||||
head_dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Initialise ColBERT projection layer.
|
||||
|
||||
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config, prefix=prefix)
|
||||
Args:
|
||||
hidden_size: Hidden dimension of the encoder backbone.
|
||||
colbert_dim: Output dimension for ColBERT embeddings. If
|
||||
``None``, will be inferred from weights during loading (or
|
||||
auto-loaded from sentence-transformers Dense layers).
|
||||
head_dtype: Data type for the projection layer.
|
||||
"""
|
||||
self.hidden_size = hidden_size
|
||||
self.colbert_dim = colbert_dim
|
||||
self.head_dtype = head_dtype
|
||||
|
||||
if colbert_dim is not None:
|
||||
self.colbert_linear = self._build_colbert_linear()
|
||||
else:
|
||||
self.colbert_linear = None
|
||||
|
||||
def _build_colbert_linear(self) -> nn.Linear:
|
||||
"""Build the ColBERT linear projection layer."""
|
||||
@@ -80,24 +97,127 @@ class ColBERTModel(BertEmbeddingModel):
|
||||
dtype=self.head_dtype,
|
||||
)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
# ColBERT linear projection: hidden_size -> colbert_dim
|
||||
# Original ColBERT uses bias=False
|
||||
# If colbert_dim is not set from config, it will be inferred during
|
||||
# load_weights and the linear layer will be created there
|
||||
if self.colbert_dim is not None:
|
||||
self.colbert_linear = self._build_colbert_linear()
|
||||
else:
|
||||
# Placeholder - will be created when weights are loaded
|
||||
self.colbert_linear = None
|
||||
# ---------------------------------------------------------------- pooler
|
||||
|
||||
# ColBERT only supports token_embed - it's fundamentally a per-token
|
||||
# embedding model.
|
||||
def _build_colbert_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
"""Build pooler for ColBERT token embeddings.
|
||||
|
||||
When ``colbert_linear`` is set, it is used as the projector.
|
||||
Otherwise ``pooler_for_token_embed`` falls back to auto-loading
|
||||
sentence-transformers Dense layers (``1_Dense/`` etc.).
|
||||
"""
|
||||
return pooler_for_token_embed(
|
||||
pooler_config,
|
||||
projector=self.colbert_linear,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------- config helper
|
||||
|
||||
@classmethod
|
||||
def get_colbert_dim_from_config(cls, hf_config) -> int | None:
|
||||
"""Extract ColBERT dimension from a HuggingFace config.
|
||||
|
||||
Checks ``colbert_dim``, ``dim`` and ``projection_dim`` in that order.
|
||||
"""
|
||||
return (
|
||||
getattr(hf_config, "colbert_dim", None)
|
||||
or getattr(hf_config, "dim", None)
|
||||
or getattr(hf_config, "projection_dim", None)
|
||||
)
|
||||
|
||||
# -------------------------------------------------------- weight loading
|
||||
|
||||
def _load_colbert_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
colbert_weight_names: tuple[str, ...] = (
|
||||
"linear.weight",
|
||||
"colbert_linear.weight",
|
||||
),
|
||||
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
|
||||
"""Separate and load ColBERT projection weights.
|
||||
|
||||
Scans *weights* for entries whose name ends with one of
|
||||
*colbert_weight_names*. The matching weight is loaded into
|
||||
``self.colbert_linear`` (creating it first if ``colbert_dim`` was
|
||||
not known at init time).
|
||||
|
||||
Args:
|
||||
weights: Iterable of ``(name, tensor)`` weight pairs.
|
||||
colbert_weight_names: Suffixes that identify the ColBERT linear
|
||||
weight.
|
||||
|
||||
Returns:
|
||||
``(remaining_weights, loaded_names)`` — the weights that were
|
||||
**not** consumed and the set of names that were loaded.
|
||||
"""
|
||||
weights_list = list(weights)
|
||||
other_weights: list[tuple[str, torch.Tensor]] = []
|
||||
colbert_weight: tuple[str, torch.Tensor] | None = None
|
||||
|
||||
for name, weight in weights_list:
|
||||
if any(name.endswith(cw) for cw in colbert_weight_names):
|
||||
colbert_weight = (name, weight)
|
||||
else:
|
||||
other_weights.append((name, weight))
|
||||
|
||||
loaded: set[str] = set()
|
||||
if colbert_weight is not None:
|
||||
_name, weight = colbert_weight
|
||||
if weight.dim() == 2:
|
||||
# Infer colbert_dim from weight shape if not set
|
||||
if self.colbert_dim is None:
|
||||
self.colbert_dim = weight.shape[0]
|
||||
self.colbert_linear = self._build_colbert_linear()
|
||||
# Update the pooler's projector
|
||||
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
|
||||
self.pooler.head.projector = self.colbert_linear
|
||||
|
||||
assert self.colbert_linear is not None
|
||||
# Move to same device as model
|
||||
if hasattr(self, "model"):
|
||||
device = next(self.model.parameters()).device
|
||||
self.colbert_linear.to(device)
|
||||
|
||||
weight = weight.to(self.colbert_linear.weight.device)
|
||||
self.colbert_linear.weight.data.copy_(weight)
|
||||
loaded.add("pooler.head.projector.weight")
|
||||
|
||||
return other_weights, loaded
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Concrete model: ColBERT + BERT backbone (original architecture)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColBERTModel(ColBERTMixin, BertEmbeddingModel):
|
||||
"""ColBERT late interaction model with BERT backbone.
|
||||
|
||||
Supports the ``token_embed`` task (per-token embeddings for late
|
||||
interaction). MaxSim scoring is computed externally.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Must run before super().__init__ because _build_pooler reads these.
|
||||
colbert_dim = self.get_colbert_dim_from_config(config)
|
||||
self._init_colbert_components(
|
||||
hidden_size=config.hidden_size,
|
||||
colbert_dim=colbert_dim,
|
||||
head_dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return self._build_colbert_pooler(pooler_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
def _strip(name: str) -> str:
|
||||
for p in ("model.", "bert."):
|
||||
@@ -111,7 +231,7 @@ class ColBERTModel(BertEmbeddingModel):
|
||||
|
||||
for name, weight in weights_list:
|
||||
stripped = _strip(name)
|
||||
# Handle different checkpoint naming conventions for ColBERT linear
|
||||
# Handle different checkpoint naming conventions
|
||||
if stripped in ("linear.weight", "colbert_linear.weight"):
|
||||
colbert_side.append(("colbert_linear.weight", weight))
|
||||
elif stripped.startswith("linear.") or stripped.startswith(
|
||||
@@ -122,31 +242,178 @@ class ColBERTModel(BertEmbeddingModel):
|
||||
else:
|
||||
model_side.append((stripped, weight))
|
||||
|
||||
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
|
||||
loaded: set[str] = set()
|
||||
loaded_model = self.model.load_weights(model_side)
|
||||
loaded.update({"model." + n for n in loaded_model})
|
||||
|
||||
# Load ColBERT linear weights
|
||||
if colbert_side:
|
||||
for name, weight in colbert_side:
|
||||
if name == "colbert_linear.weight":
|
||||
# Infer colbert_dim from weights if not set in config
|
||||
if self.colbert_dim is None:
|
||||
# Weight shape is [colbert_dim, hidden_size]
|
||||
self.colbert_dim = weight.shape[0]
|
||||
# Create the linear layer now that we know the dimension
|
||||
self.colbert_linear = self._build_colbert_linear()
|
||||
# Move to the same device as the model's existing parameters
|
||||
device = next(self.model.parameters()).device
|
||||
self.colbert_linear.to(device)
|
||||
# Update the pooler's projector to use the new linear layer
|
||||
self.pooler.head.projector = self.colbert_linear
|
||||
|
||||
# Load weights directly into the pooler's projector
|
||||
weight = weight.to(self.pooler.head.projector.weight.device)
|
||||
self.pooler.head.projector.weight.data.copy_(weight)
|
||||
loaded.add("pooler.head.projector.weight")
|
||||
break
|
||||
_, colbert_loaded = self._load_colbert_weights(colbert_side)
|
||||
loaded.update(colbert_loaded)
|
||||
|
||||
return loaded
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Concrete model: ColBERT + ModernBERT backbone
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
from .modernbert import ModernBertModel # noqa: E402
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColBERTModernBertModel(ColBERTMixin, nn.Module):
|
||||
"""ColBERT late interaction model with ModernBERT backbone.
|
||||
|
||||
For ``lightonai/GTE-ModernColBERT-v1`` and similar models.
|
||||
The projection is auto-loaded from sentence-transformers ``1_Dense/``
|
||||
when not present in the main checkpoint.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
colbert_dim = self.get_colbert_dim_from_config(config)
|
||||
self._init_colbert_components(
|
||||
hidden_size=config.hidden_size,
|
||||
colbert_dim=colbert_dim,
|
||||
head_dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
self.model = ModernBertModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = self._build_colbert_pooler(pooler_config)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors=None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
other_weights, colbert_loaded = self._load_colbert_weights(weights)
|
||||
|
||||
# Strip "model." prefix added by the embedding adapter
|
||||
model_weights = [
|
||||
(n[len("model.") :] if n.startswith("model.") else n, w)
|
||||
for n, w in other_weights
|
||||
]
|
||||
|
||||
loaded_model = self.model.load_weights(model_weights)
|
||||
loaded = {"model." + n for n in loaded_model} | colbert_loaded
|
||||
|
||||
# When the ST projector was auto-loaded during init
|
||||
# (not from the main checkpoint), mark its params as loaded
|
||||
# so the weight validator doesn't complain.
|
||||
if hasattr(self.pooler, "head"):
|
||||
head = self.pooler.head
|
||||
projector = getattr(head, "projector", None)
|
||||
if projector is not None and isinstance(projector, nn.Module):
|
||||
for name, _ in projector.named_parameters():
|
||||
loaded.add(f"pooler.head.projector.{name}")
|
||||
|
||||
return loaded
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Concrete model: ColBERT + Jina XLM-RoBERTa backbone
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
from .bert_with_rope import JinaRobertaModel # noqa: E402
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColBERTJinaRobertaModel(ColBERTMixin, nn.Module):
|
||||
"""ColBERT late interaction model with Jina XLM-RoBERTa backbone.
|
||||
|
||||
For ``jinaai/jina-colbert-v2`` and similar models.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
colbert_dim = self.get_colbert_dim_from_config(config)
|
||||
self._init_colbert_components(
|
||||
hidden_size=config.hidden_size,
|
||||
colbert_dim=colbert_dim,
|
||||
head_dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
self.model = JinaRobertaModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler = self._build_colbert_pooler(pooler_config)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors=None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
model_side: list[tuple[str, torch.Tensor]] = []
|
||||
colbert_side: list[tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, weight in weights_list:
|
||||
stripped = name
|
||||
# Strip "model." prefix added by the embedding adapter
|
||||
if stripped.startswith("model."):
|
||||
stripped = stripped[len("model.") :]
|
||||
# Strip "roberta." prefix from checkpoint
|
||||
if stripped.startswith("roberta."):
|
||||
stripped = stripped[len("roberta.") :]
|
||||
|
||||
if stripped in ("linear.weight", "colbert_linear.weight"):
|
||||
colbert_side.append(("colbert_linear.weight", weight))
|
||||
elif stripped.startswith("pooler."):
|
||||
# Skip HF pooler weights (not used in ColBERT)
|
||||
continue
|
||||
else:
|
||||
model_side.append((stripped, weight))
|
||||
|
||||
loaded: set[str] = set()
|
||||
loaded_model = self.model.load_weights(model_side)
|
||||
loaded.update({"model." + n for n in loaded_model})
|
||||
|
||||
if colbert_side:
|
||||
_, colbert_loaded = self._load_colbert_weights(colbert_side)
|
||||
loaded.update(colbert_loaded)
|
||||
|
||||
return loaded
|
||||
|
||||
@@ -629,6 +629,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
||||
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
|
||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
|
||||
@@ -208,6 +208,8 @@ _EMBEDDING_MODELS = {
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
||||
"HF_ColBERT": ("colbert", "ColBERTModel"),
|
||||
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
|
||||
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||
|
||||
@@ -1068,9 +1068,11 @@ def try_get_dense_modules(
|
||||
if isinstance(modules, dict):
|
||||
modules = modules.get("modules", [])
|
||||
|
||||
dense_modules = [
|
||||
m for m in modules if m.get("type") == "sentence_transformers.models.Dense"
|
||||
]
|
||||
_DENSE_MODULE_TYPES = {
|
||||
"sentence_transformers.models.Dense",
|
||||
"pylate.models.Dense.Dense",
|
||||
}
|
||||
dense_modules = [m for m in modules if m.get("type") in _DENSE_MODULE_TYPES]
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user