feat: Add ColBERT late interaction model support (#33686)
Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com> Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -307,6 +307,62 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed
|
||||
|
||||
## Specific models
|
||||
|
||||
### ColBERT Late Interaction Models
|
||||
|
||||
[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders.
|
||||
|
||||
vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance:
|
||||
|
||||
```shell
|
||||
vllm serve answerdotai/answerai-colbert-small-v1
|
||||
```
|
||||
|
||||
Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`).
|
||||
|
||||
ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`).
|
||||
|
||||
If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with:
|
||||
|
||||
```shell
|
||||
vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}'
|
||||
```
|
||||
|
||||
Then you can use the rerank endpoint:
|
||||
|
||||
```shell
|
||||
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"query": "What is machine learning?",
|
||||
"documents": [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a programming language.",
|
||||
"Deep learning uses neural networks."
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Or the score endpoint:
|
||||
|
||||
```shell
|
||||
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"text_1": "What is machine learning?",
|
||||
"text_2": ["Machine learning is a subset of AI.", "The weather is sunny."]
|
||||
}'
|
||||
```
|
||||
|
||||
You can also get the raw token embeddings using the pooling endpoint with `token_embed` task:
|
||||
|
||||
```shell
|
||||
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"input": "What is machine learning?",
|
||||
"task": "token_embed"
|
||||
}'
|
||||
```
|
||||
|
||||
An example can be found here: [examples/pooling/score/colbert_rerank_online.py](../../examples/pooling/score/colbert_rerank_online.py)
|
||||
|
||||
### BAAI/bge-m3
|
||||
|
||||
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
|
||||
|
||||
57
examples/pooling/score/colbert_rerank_online.py
Normal file
57
examples/pooling/score/colbert_rerank_online.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Example of using ColBERT late interaction model for reranking.
|
||||
|
||||
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
|
||||
and MaxSim scoring for document reranking, providing better accuracy than
|
||||
single-vector models while being more efficient than cross-encoders.
|
||||
|
||||
Start the server with:
|
||||
vllm serve answerdotai/answerai-colbert-small-v1
|
||||
|
||||
Then run this script:
|
||||
python colbert_rerank_online.py
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
url = "http://127.0.0.1:8000/rerank"
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
data = {
|
||||
"model": "answerdotai/answerai-colbert-small-v1",
|
||||
"query": "What is machine learning?",
|
||||
"documents": [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a programming language.",
|
||||
"Deep learning uses neural networks for complex tasks.",
|
||||
"The weather today is sunny.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
if response.status_code == 200:
|
||||
print("ColBERT Rerank Request successful!")
|
||||
result = response.json()
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
# Show ranked results
|
||||
print("\nRanked documents (most relevant first):")
|
||||
for item in result["results"]:
|
||||
doc_idx = item["index"]
|
||||
score = item["relevance_score"]
|
||||
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
154
tests/entrypoints/pooling/score/test_online_colbert.py
Normal file
154
tests/entrypoints/pooling/score/test_online_colbert.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Online API tests for ColBERT late interaction scoring."""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
|
||||
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
|
||||
MODEL_NAME = "answerdotai/answerai-colbert-small-v1"
|
||||
COLBERT_DIM = 96 # This model uses 96-dimensional output
|
||||
DTYPE = "half"
|
||||
MAX_MODEL_LEN = 512
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
str(MAX_MODEL_LEN),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT rerank endpoint."""
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
# The relevant document (Paris) should have higher score
|
||||
paris_result = next(r for r in rerank.results if r.index == 1)
|
||||
brazil_result = next(r for r in rerank.results if r.index == 0)
|
||||
|
||||
assert paris_result.relevance_score > brazil_result.relevance_score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT rerank with top_n parameter."""
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
"Machine learning is a field of AI.",
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": 2,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert len(rerank.results) == 2
|
||||
# Top result should be about Paris (index 1)
|
||||
assert rerank.results[0].index == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_score(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT score endpoint."""
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = ["The capital of France is Paris.", "Python is a language."]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
score = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert score.id is not None
|
||||
assert score.data is not None
|
||||
assert len(score.data) == 2
|
||||
|
||||
# The relevant document should have higher score
|
||||
assert score.data[0].score > score.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test ColBERT token_embed task via pooling endpoint."""
|
||||
text = "What is the capital of France?"
|
||||
|
||||
pooling_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": text,
|
||||
"task": "token_embed",
|
||||
},
|
||||
)
|
||||
pooling_response.raise_for_status()
|
||||
pooling = pooling_response.json()
|
||||
|
||||
assert "data" in pooling
|
||||
assert len(pooling["data"]) == 1
|
||||
|
||||
# Token embeddings should be 2D
|
||||
embeddings = pooling["data"][0]["data"]
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) > 0 # Should have tokens
|
||||
assert len(embeddings[0]) == COLBERT_DIM
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
|
||||
"""Test that ColBERT model does not support 'embed' task."""
|
||||
text = "What is the capital of France?"
|
||||
|
||||
pooling_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": text,
|
||||
"task": "embed",
|
||||
},
|
||||
)
|
||||
|
||||
# Should return error
|
||||
assert pooling_response.status_code == 400
|
||||
assert "Task embed is not supported" in pooling_response.text
|
||||
247
tests/models/language/pooling/test_colbert.py
Normal file
247
tests/models/language/pooling/test_colbert.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ColBERT late interaction scoring."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
|
||||
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
|
||||
# suitable for testing (based on BERT-base)
|
||||
COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1"
|
||||
COLBERT_DIM = 96 # This model uses 96-dimensional output
|
||||
|
||||
TEXTS_1 = [
|
||||
"What is the capital of France?",
|
||||
"What is the capital of Germany?",
|
||||
]
|
||||
|
||||
TEXTS_2 = [
|
||||
"The capital of France is Paris.",
|
||||
"The capital of Germany is Berlin.",
|
||||
]
|
||||
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def colbert_model_name():
|
||||
return COLBERT_MODEL
|
||||
|
||||
|
||||
def test_colbert_token_embed(vllm_runner, colbert_model_name):
|
||||
"""Test that ColBERT model produces token embeddings."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
# Get token embeddings for a single text
|
||||
outputs = vllm_model.token_embed([TEXTS_1[0]])
|
||||
|
||||
assert len(outputs) == 1
|
||||
# Token embeddings should be 2D: [num_tokens, colbert_dim]
|
||||
emb = torch.tensor(outputs[0])
|
||||
assert emb.dim() == 2
|
||||
assert emb.shape[1] == COLBERT_DIM
|
||||
# Should have at least a few tokens
|
||||
assert emb.shape[0] > 1
|
||||
|
||||
|
||||
def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name):
|
||||
"""Test ColBERT late interaction scoring with 1:1 query-document pair."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
# Get token embeddings
|
||||
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
|
||||
d_outputs = vllm_model.token_embed([TEXTS_2[0]])
|
||||
|
||||
q_emb = torch.tensor(q_outputs[0])
|
||||
d_emb = torch.tensor(d_outputs[0])
|
||||
|
||||
# Compute MaxSim manually
|
||||
manual_score = compute_maxsim_score(q_emb, d_emb).item()
|
||||
|
||||
# Use the score API (which should internally use _late_interaction_score)
|
||||
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])
|
||||
|
||||
assert len(vllm_scores) == 1
|
||||
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
|
||||
|
||||
|
||||
def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
|
||||
"""Test ColBERT late interaction scoring with 1:N query-documents."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
# Get token embeddings
|
||||
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
|
||||
d_outputs = vllm_model.token_embed(TEXTS_2)
|
||||
|
||||
q_emb = torch.tensor(q_outputs[0])
|
||||
|
||||
# Compute MaxSim manually for each document
|
||||
manual_scores = []
|
||||
for d_out in d_outputs:
|
||||
d_emb = torch.tensor(d_out)
|
||||
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
|
||||
|
||||
# Use the score API
|
||||
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
||||
|
||||
assert len(vllm_scores) == 2
|
||||
for i in range(2):
|
||||
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
|
||||
|
||||
|
||||
def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
|
||||
"""Test ColBERT late interaction scoring with N:N query-documents."""
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
# Get token embeddings
|
||||
q_outputs = vllm_model.token_embed(TEXTS_1)
|
||||
d_outputs = vllm_model.token_embed(TEXTS_2)
|
||||
|
||||
# Compute MaxSim manually for each pair
|
||||
manual_scores = []
|
||||
for q_out, d_out in zip(q_outputs, d_outputs):
|
||||
q_emb = torch.tensor(q_out)
|
||||
d_emb = torch.tensor(d_out)
|
||||
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
|
||||
|
||||
# Use the score API
|
||||
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
|
||||
|
||||
assert len(vllm_scores) == 2
|
||||
for i in range(2):
|
||||
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
|
||||
|
||||
|
||||
def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
|
||||
"""Test that ColBERT scores relevant documents higher than irrelevant ones."""
|
||||
query = "What is machine learning?"
|
||||
documents = [
|
||||
"Machine learning is a subset of artificial intelligence.",
|
||||
"Python is a programming language.",
|
||||
"Deep learning uses neural networks.",
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
scores = vllm_model.score(query, documents)
|
||||
|
||||
assert len(scores) == 3
|
||||
# ML-related documents should score higher than unrelated Python doc
|
||||
# Document 0 (ML definition) should be most relevant
|
||||
# Document 2 (Deep learning) should also be relevant
|
||||
# Document 1 (Python) should be least relevant
|
||||
assert scores[0] > scores[1], "ML doc should score higher than Python doc"
|
||||
assert scores[2] > scores[1], "DL doc should score higher than Python doc"
|
||||
|
||||
|
||||
def test_colbert_embed_not_supported(vllm_runner, colbert_model_name):
|
||||
"""Test that ColBERT model does not support 'embed' task."""
|
||||
with (
|
||||
vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype=DTYPE,
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model,
|
||||
pytest.raises(ValueError, match="Embedding API is not supported"),
|
||||
):
|
||||
vllm_model.embed([TEXTS_1[0]])
|
||||
|
||||
|
||||
def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
|
||||
"""Test that vLLM ColBERT produces same embeddings as HuggingFace."""
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoTokenizer, BertModel
|
||||
|
||||
test_texts = [TEXTS_1[0], TEXTS_2[0]]
|
||||
|
||||
# Get vLLM embeddings first (to avoid GPU memory contention)
|
||||
# Use fp32 to match HuggingFace default precision for fair comparison
|
||||
with vllm_runner(
|
||||
colbert_model_name,
|
||||
runner="pooling",
|
||||
dtype="float32",
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.token_embed(test_texts)
|
||||
|
||||
# Get HuggingFace reference embeddings on CPU
|
||||
# Load the base BERT model and manually apply the ColBERT linear projection
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name)
|
||||
hf_bert = BertModel.from_pretrained(colbert_model_name)
|
||||
hf_bert.eval()
|
||||
|
||||
# Load the ColBERT linear weights from safetensors
|
||||
weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors")
|
||||
weights = load_file(weights_path)
|
||||
linear_weight = weights["linear.weight"] # [96, 384]
|
||||
|
||||
hf_embeddings = []
|
||||
for text in test_texts:
|
||||
inputs = hf_tokenizer(text, return_tensors="pt")
|
||||
with torch.no_grad():
|
||||
outputs = hf_bert(**inputs)
|
||||
# Get last hidden state: [1, seq_len, 384]
|
||||
hidden_states = outputs.last_hidden_state
|
||||
# Apply ColBERT linear projection: [1, seq_len, 96]
|
||||
token_emb = F.linear(hidden_states, linear_weight)
|
||||
# L2 normalize
|
||||
token_emb = F.normalize(token_emb, p=2, dim=-1)
|
||||
hf_embeddings.append(token_emb.squeeze(0).float())
|
||||
|
||||
# Compare embeddings
|
||||
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
|
||||
vllm_emb = torch.tensor(vllm_out).float()
|
||||
|
||||
# Print first few components for debugging
|
||||
print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===")
|
||||
print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}")
|
||||
print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}")
|
||||
print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}")
|
||||
print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}")
|
||||
print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}")
|
||||
|
||||
# Should have same shape
|
||||
assert hf_emb.shape == vllm_emb.shape, (
|
||||
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
|
||||
)
|
||||
|
||||
# Should have same values (with tolerance for fp16)
|
||||
torch.testing.assert_close(
|
||||
vllm_emb,
|
||||
hf_emb,
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
msg=f"Embedding mismatch for text {i}",
|
||||
)
|
||||
@@ -520,6 +520,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
|
||||
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
|
||||
|
||||
@@ -1411,6 +1411,11 @@ class ModelConfig:
|
||||
self._model_info.supports_cross_encoding or self.convert_type == "classify"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_late_interaction(self) -> bool:
|
||||
"""Check if model uses late interaction (ColBERT-style) scoring."""
|
||||
return self._model_info.supports_late_interaction
|
||||
|
||||
@property
|
||||
def is_pp_supported(self) -> bool:
|
||||
return self._model_info.supports_pp
|
||||
|
||||
@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
validate_score_input,
|
||||
)
|
||||
@@ -1368,6 +1369,87 @@ class LLM:
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Extract text from ScoreData
|
||||
text_1: list[str] = []
|
||||
for text in data_1:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
text_1.append(text)
|
||||
|
||||
text_2: list[str] = []
|
||||
for text in data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
text_2.append(text)
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
text_1 + text_2,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
pooling_task="token_embed",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
return [ScoringRequestOutput.from_base(item) for item in items]
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
@@ -1497,7 +1579,11 @@ class LLM:
|
||||
)
|
||||
|
||||
supported_tasks = self.supported_tasks
|
||||
if all(t not in supported_tasks for t in ("embed", "classify")):
|
||||
# Late interaction models (e.g., ColBERT) use token_embed for scoring
|
||||
is_late_interaction = model_config.is_late_interaction
|
||||
if not is_late_interaction and all(
|
||||
t not in supported_tasks for t in ("embed", "classify")
|
||||
):
|
||||
raise ValueError(
|
||||
"Score API is not supported by this model. "
|
||||
"Try converting the model using "
|
||||
@@ -1538,6 +1624,15 @@ class LLM:
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
score_template=chat_template,
|
||||
)
|
||||
elif is_late_interaction:
|
||||
return self._late_interaction_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
score_data_1,
|
||||
|
||||
@@ -37,7 +37,11 @@ def register_pooling_api_routers(
|
||||
|
||||
app.include_router(embed_router)
|
||||
|
||||
if "score" in supported_tasks or "embed" in supported_tasks:
|
||||
# Score/rerank endpoints are available for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
|
||||
app.include_router(score_router)
|
||||
@@ -101,6 +105,10 @@ def init_pooling_state(
|
||||
if "classify" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# ServingScores handles score/rerank for:
|
||||
# - "score" task (cross-encoder models)
|
||||
# - "embed" task (bi-encoder models)
|
||||
# - "token_embed" task (late interaction models like ColBERT)
|
||||
state.openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
@@ -109,6 +117,6 @@ def init_pooling_state(
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if ("embed" in supported_tasks or "score" in supported_tasks)
|
||||
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
validate_score_input,
|
||||
)
|
||||
@@ -68,9 +69,12 @@ class ServingScores(OpenAIServing):
|
||||
self.is_cross_encoder = self.model_config.is_cross_encoder
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_late_interaction = self.model_config.is_late_interaction
|
||||
|
||||
if self.is_cross_encoder:
|
||||
self._score_func = self._cross_encoding_score
|
||||
elif self.is_late_interaction:
|
||||
self._score_func = self._late_interaction_score
|
||||
else:
|
||||
self._score_func = self._embedding_score
|
||||
|
||||
@@ -172,6 +176,142 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
return final_res_batch
|
||||
|
||||
async def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
input_texts: list[str] = []
|
||||
for text in data_1 + data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Late interaction scores currently do not support multimodal input."
|
||||
)
|
||||
input_texts.append(text)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
encode_async = make_async(
|
||||
tokenizer.encode,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
# Use token_embed task for late interaction models
|
||||
from vllm import PoolingParams
|
||||
|
||||
pooling_params = PoolingParams(
|
||||
task="token_embed",
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
use_activation=request.use_activation,
|
||||
)
|
||||
|
||||
try:
|
||||
pooling_params.verify("token_embed", self.model_config)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
input_texts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Collect token embeddings
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
# Split into query and document embeddings
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import CosineSimilarity
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
@@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute ColBERT MaxSim score.
|
||||
|
||||
Args:
|
||||
q_emb: Query token embeddings [query_len, dim]
|
||||
d_emb: Document token embeddings [doc_len, dim]
|
||||
|
||||
Returns:
|
||||
MaxSim score (sum over query tokens of max similarity to any doc token)
|
||||
"""
|
||||
# [query_len, doc_len]
|
||||
token_scores = torch.matmul(q_emb, d_emb.T)
|
||||
# Max over document tokens, sum over query tokens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
152
vllm/model_executor/models/colbert.py
Normal file
152
vllm/model_executor/models/colbert.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
ColBERT late interaction model for retrieval and reranking.
|
||||
|
||||
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
|
||||
instead of single-vector representations or cross-encoder concatenation.
|
||||
|
||||
Reference: https://arxiv.org/abs/2004.12832
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import PoolerConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
|
||||
from .bert import BertEmbeddingModel, BertModel
|
||||
from .interfaces_base import default_pooling_type
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
||||
class ColBERTModel(BertEmbeddingModel):
|
||||
"""ColBERT late interaction model for retrieval/reranking.
|
||||
|
||||
This model extends BertEmbeddingModel with a ColBERT-style linear
|
||||
projection layer for per-token embeddings. It supports only:
|
||||
- "token_embed" task: Per-token embeddings for late interaction
|
||||
|
||||
ColBERT is fundamentally a per-token embedding model - the linear
|
||||
projection is trained for per-token representations, not for CLS
|
||||
pooling. Use a dedicated dense embedding model if you need single-
|
||||
vector representations.
|
||||
|
||||
The ColBERT scoring (MaxSim) is computed externally, either client-side
|
||||
or via the late interaction scoring path in ServingScores.
|
||||
|
||||
Attributes:
|
||||
colbert_linear: Linear projection from hidden_size to colbert_dim
|
||||
supports_late_interaction: Flag indicating this model uses late
|
||||
interaction scoring
|
||||
"""
|
||||
|
||||
# Mark this model as supporting late interaction scoring
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Get config before calling super().__init__
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
# ColBERT dimension - check various config field names used by different
|
||||
# ColBERT implementations. If not found in config, will be inferred
|
||||
# from loaded weights in load_weights()
|
||||
self.colbert_dim: int | None = (
|
||||
getattr(config, "colbert_dim", None)
|
||||
or getattr(config, "dim", None)
|
||||
or getattr(config, "projection_dim", None)
|
||||
)
|
||||
|
||||
# Initialize parent (this will call _build_pooler)
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
|
||||
return BertModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def _build_colbert_linear(self) -> nn.Linear:
|
||||
"""Build the ColBERT linear projection layer."""
|
||||
if self.colbert_dim is None:
|
||||
raise ValueError("colbert_dim must be set before building the linear layer")
|
||||
return nn.Linear(
|
||||
self.hidden_size,
|
||||
self.colbert_dim,
|
||||
bias=False,
|
||||
dtype=self.head_dtype,
|
||||
)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
# ColBERT linear projection: hidden_size -> colbert_dim
|
||||
# Original ColBERT uses bias=False
|
||||
# If colbert_dim is not set from config, it will be inferred during
|
||||
# load_weights and the linear layer will be created there
|
||||
if self.colbert_dim is not None:
|
||||
self.colbert_linear = self._build_colbert_linear()
|
||||
else:
|
||||
# Placeholder - will be created when weights are loaded
|
||||
self.colbert_linear = None
|
||||
|
||||
# ColBERT only supports token_embed - it's fundamentally a per-token
|
||||
# embedding model.
|
||||
return pooler_for_token_embed(
|
||||
pooler_config,
|
||||
projector=self.colbert_linear,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
def _strip(name: str) -> str:
|
||||
for p in ("model.", "bert."):
|
||||
if name.startswith(p):
|
||||
name = name[len(p) :]
|
||||
return name
|
||||
|
||||
weights_list = list(weights)
|
||||
model_side: list[tuple[str, torch.Tensor]] = []
|
||||
colbert_side: list[tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, weight in weights_list:
|
||||
stripped = _strip(name)
|
||||
# Handle different checkpoint naming conventions for ColBERT linear
|
||||
if stripped in ("linear.weight", "colbert_linear.weight"):
|
||||
colbert_side.append(("colbert_linear.weight", weight))
|
||||
elif stripped.startswith("linear.") or stripped.startswith(
|
||||
"colbert_linear."
|
||||
):
|
||||
new_name = stripped.replace("linear.", "colbert_linear.")
|
||||
colbert_side.append((new_name, weight))
|
||||
else:
|
||||
model_side.append((stripped, weight))
|
||||
|
||||
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
|
||||
loaded: set[str] = set()
|
||||
loaded_model = self.model.load_weights(model_side)
|
||||
loaded.update({"model." + n for n in loaded_model})
|
||||
|
||||
# Load ColBERT linear weights
|
||||
if colbert_side:
|
||||
for name, weight in colbert_side:
|
||||
if name == "colbert_linear.weight":
|
||||
# Infer colbert_dim from weights if not set in config
|
||||
if self.colbert_dim is None:
|
||||
# Weight shape is [colbert_dim, hidden_size]
|
||||
self.colbert_dim = weight.shape[0]
|
||||
# Create the linear layer now that we know the dimension
|
||||
self.colbert_linear = self._build_colbert_linear()
|
||||
# Move to the same device as the model's existing parameters
|
||||
device = next(self.model.parameters()).device
|
||||
self.colbert_linear.to(device)
|
||||
# Update the pooler's projector to use the new linear layer
|
||||
self.pooler.head.projector = self.colbert_linear
|
||||
|
||||
# Load weights directly into the pooler's projector
|
||||
weight = weight.to(self.pooler.head.projector.weight.device)
|
||||
self.pooler.head.projector.weight.data.copy_(weight)
|
||||
loaded.add("pooler.head.projector.weight")
|
||||
break
|
||||
|
||||
return loaded
|
||||
@@ -981,6 +981,40 @@ def supports_cross_encoding(
|
||||
return is_pooling_model(model) and _supports_cross_encoding(model)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsLateInteraction(Protocol):
|
||||
"""The interface required for all models that support late interaction.
|
||||
|
||||
Late interaction models (like ColBERT) encode queries and documents
|
||||
separately into per-token embeddings, then compute similarity via
|
||||
MaxSim (max over document tokens, sum over query tokens).
|
||||
"""
|
||||
|
||||
supports_late_interaction: ClassVar[Literal[True]] = True
|
||||
|
||||
|
||||
@overload
|
||||
def supports_late_interaction(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsLateInteraction]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ...
|
||||
|
||||
|
||||
def _supports_late_interaction(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
|
||||
return getattr(model, "supports_late_interaction", False)
|
||||
|
||||
|
||||
def supports_late_interaction(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
|
||||
return is_pooling_model(model) and _supports_late_interaction(model)
|
||||
|
||||
|
||||
class SupportsQuant:
|
||||
"""The interface required for all models that support quantization."""
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ from .interfaces import (
|
||||
is_hybrid,
|
||||
requires_raw_input_tokens,
|
||||
supports_cross_encoding,
|
||||
supports_late_interaction,
|
||||
supports_mamba_prefix_caching,
|
||||
supports_multimodal,
|
||||
supports_multimodal_encoder_tp_data,
|
||||
@@ -205,6 +206,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
||||
"HF_ColBERT": ("colbert", "ColBERTModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||
@@ -593,6 +595,7 @@ class _ModelInfo:
|
||||
default_seq_pooling_type: SequencePoolingType
|
||||
default_tok_pooling_type: TokenPoolingType
|
||||
supports_cross_encoding: bool
|
||||
supports_late_interaction: bool
|
||||
supports_multimodal: bool
|
||||
supports_multimodal_raw_input_only: bool
|
||||
requires_raw_input_tokens: bool
|
||||
@@ -616,6 +619,7 @@ class _ModelInfo:
|
||||
default_tok_pooling_type=get_default_tok_pooling_type(model),
|
||||
attn_type=get_attn_type(model),
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_late_interaction=supports_late_interaction(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
|
||||
model
|
||||
|
||||
Reference in New Issue
Block a user