Signed-off-by: Ilya Boytsov <ilyaboytsov1805@gmail.com> Signed-off-by: Ilya Boytsov <boytsovpanamera@mail.ru> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
153 lines
6.4 KiB
Python
153 lines
6.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
ColBERT late interaction model for retrieval and reranking.
|
|
|
|
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
|
|
instead of single-vector representations or cross-encoder concatenation.
|
|
|
|
Reference: https://arxiv.org/abs/2004.12832
|
|
"""
|
|
|
|
from collections.abc import Iterable
|
|
from typing import ClassVar, Literal
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.config import PoolerConfig, VllmConfig
|
|
from vllm.model_executor.layers.pooler import Pooler
|
|
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
|
|
|
from .bert import BertEmbeddingModel, BertModel
|
|
from .interfaces_base import default_pooling_type
|
|
|
|
|
|
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
|
|
class ColBERTModel(BertEmbeddingModel):
|
|
"""ColBERT late interaction model for retrieval/reranking.
|
|
|
|
This model extends BertEmbeddingModel with a ColBERT-style linear
|
|
projection layer for per-token embeddings. It supports only:
|
|
- "token_embed" task: Per-token embeddings for late interaction
|
|
|
|
ColBERT is fundamentally a per-token embedding model - the linear
|
|
projection is trained for per-token representations, not for CLS
|
|
pooling. Use a dedicated dense embedding model if you need single-
|
|
vector representations.
|
|
|
|
The ColBERT scoring (MaxSim) is computed externally, either client-side
|
|
or via the late interaction scoring path in ServingScores.
|
|
|
|
Attributes:
|
|
colbert_linear: Linear projection from hidden_size to colbert_dim
|
|
supports_late_interaction: Flag indicating this model uses late
|
|
interaction scoring
|
|
"""
|
|
|
|
# Mark this model as supporting late interaction scoring
|
|
supports_late_interaction: ClassVar[Literal[True]] = True
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
# Get config before calling super().__init__
|
|
config = vllm_config.model_config.hf_config
|
|
self.hidden_size = config.hidden_size
|
|
self.head_dtype = vllm_config.model_config.head_dtype
|
|
|
|
# ColBERT dimension - check various config field names used by different
|
|
# ColBERT implementations. If not found in config, will be inferred
|
|
# from loaded weights in load_weights()
|
|
self.colbert_dim: int | None = (
|
|
getattr(config, "colbert_dim", None)
|
|
or getattr(config, "dim", None)
|
|
or getattr(config, "projection_dim", None)
|
|
)
|
|
|
|
# Initialize parent (this will call _build_pooler)
|
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
|
|
return BertModel(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
def _build_colbert_linear(self) -> nn.Linear:
|
|
"""Build the ColBERT linear projection layer."""
|
|
if self.colbert_dim is None:
|
|
raise ValueError("colbert_dim must be set before building the linear layer")
|
|
return nn.Linear(
|
|
self.hidden_size,
|
|
self.colbert_dim,
|
|
bias=False,
|
|
dtype=self.head_dtype,
|
|
)
|
|
|
|
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
|
# ColBERT linear projection: hidden_size -> colbert_dim
|
|
# Original ColBERT uses bias=False
|
|
# If colbert_dim is not set from config, it will be inferred during
|
|
# load_weights and the linear layer will be created there
|
|
if self.colbert_dim is not None:
|
|
self.colbert_linear = self._build_colbert_linear()
|
|
else:
|
|
# Placeholder - will be created when weights are loaded
|
|
self.colbert_linear = None
|
|
|
|
# ColBERT only supports token_embed - it's fundamentally a per-token
|
|
# embedding model.
|
|
return pooler_for_token_embed(
|
|
pooler_config,
|
|
projector=self.colbert_linear,
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
def _strip(name: str) -> str:
|
|
for p in ("model.", "bert."):
|
|
if name.startswith(p):
|
|
name = name[len(p) :]
|
|
return name
|
|
|
|
weights_list = list(weights)
|
|
model_side: list[tuple[str, torch.Tensor]] = []
|
|
colbert_side: list[tuple[str, torch.Tensor]] = []
|
|
|
|
for name, weight in weights_list:
|
|
stripped = _strip(name)
|
|
# Handle different checkpoint naming conventions for ColBERT linear
|
|
if stripped in ("linear.weight", "colbert_linear.weight"):
|
|
colbert_side.append(("colbert_linear.weight", weight))
|
|
elif stripped.startswith("linear.") or stripped.startswith(
|
|
"colbert_linear."
|
|
):
|
|
new_name = stripped.replace("linear.", "colbert_linear.")
|
|
colbert_side.append((new_name, weight))
|
|
else:
|
|
model_side.append((stripped, weight))
|
|
|
|
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
|
|
loaded: set[str] = set()
|
|
loaded_model = self.model.load_weights(model_side)
|
|
loaded.update({"model." + n for n in loaded_model})
|
|
|
|
# Load ColBERT linear weights
|
|
if colbert_side:
|
|
for name, weight in colbert_side:
|
|
if name == "colbert_linear.weight":
|
|
# Infer colbert_dim from weights if not set in config
|
|
if self.colbert_dim is None:
|
|
# Weight shape is [colbert_dim, hidden_size]
|
|
self.colbert_dim = weight.shape[0]
|
|
# Create the linear layer now that we know the dimension
|
|
self.colbert_linear = self._build_colbert_linear()
|
|
# Move to the same device as the model's existing parameters
|
|
device = next(self.model.parameters()).device
|
|
self.colbert_linear.to(device)
|
|
# Update the pooler's projector to use the new linear layer
|
|
self.pooler.head.projector = self.colbert_linear
|
|
|
|
# Load weights directly into the pooler's projector
|
|
weight = weight.to(self.pooler.head.projector.weight.device)
|
|
self.pooler.head.projector.weight.data.copy_(weight)
|
|
loaded.add("pooler.head.projector.weight")
|
|
break
|
|
|
|
return loaded
|