[CI] Fix ColBERT HF comparison tests on AMD CI + refactor (#34567)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-20 22:12:05 -06:00
committed by GitHub
parent a0fe7ea2f0
commit 89358f0d35

View File

@@ -20,6 +20,12 @@ COLBERT_MODELS = {
"colbert_dim": 96,
"max_model_len": 512,
"extra_kwargs": {},
"hf_comparison": {
"weights_file": "model.safetensors",
"weights_key": "linear.weight",
"trust_remote_code": False,
"model_cls": "BertModel",
},
},
"modernbert": {
"model": "lightonai/GTE-ModernColBERT-v1",
@@ -30,6 +36,12 @@ COLBERT_MODELS = {
"architectures": ["ColBERTModernBertModel"],
},
},
"hf_comparison": {
"weights_file": "1_Dense/model.safetensors",
"weights_key": "linear.weight",
"trust_remote_code": False,
"model_cls": "AutoModel",
},
},
"jina": {
"model": "jinaai/jina-colbert-v2",
@@ -40,9 +52,16 @@ COLBERT_MODELS = {
"architectures": ["ColBERTJinaRobertaModel"],
},
},
"hf_comparison": {
"weights_file": "model.safetensors",
"weights_key": "linear.weight",
"trust_remote_code": True,
"model_cls": "AutoModel",
},
},
}
TEXTS_1 = [
"What is the capital of France?",
"What is the capital of Germany?",
@@ -56,9 +75,68 @@ TEXTS_2 = [
DTYPE = "half"
# -----------------------------------------------------------------------
# Fixtures
# -----------------------------------------------------------------------
def _load_hf_model(model_name: str, hf_spec: dict, device: torch.device):
"""Load HF model on the given device with a compatible attention impl."""
from transformers import AutoModel, BertModel
cls = BertModel if hf_spec["model_cls"] == "BertModel" else AutoModel
trust = hf_spec.get("trust_remote_code", False)
# Flash / Triton kernels require GPU tensors; fall back to eager on CPU.
extra = {}
if device.type == "cpu":
extra["attn_implementation"] = "eager"
model = cls.from_pretrained(
model_name,
trust_remote_code=trust,
**extra,
).to(device)
model.eval()
return model
def _load_projection_weight(model_name: str, hf_spec: dict, device: torch.device):
"""Download and return the ColBERT linear projection weight."""
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
path = hf_hub_download(model_name, filename=hf_spec["weights_file"])
weights = load_file(path)
return weights[hf_spec["weights_key"]].to(device)
def _compute_hf_colbert_embeddings(model, tokenizer, linear_weight, texts, device):
"""Run HF model + projection and return L2-normalised token embeddings."""
import torch.nn.functional as F
embeddings = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
hidden = model(**inputs).last_hidden_state.float()
projected = F.linear(hidden, linear_weight.float())
normalised = F.normalize(projected, p=2, dim=-1)
embeddings.append(normalised.squeeze(0).cpu())
return embeddings
def _assert_embeddings_close(vllm_outputs, hf_embeddings):
"""Assert that vLLM and HuggingFace embeddings match."""
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.as_tensor(vllm_out).float()
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)
@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
@@ -87,11 +165,6 @@ def colbert_extra_kwargs(colbert_spec):
return colbert_spec["extra_kwargs"]
# -----------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------
def test_colbert_token_embed(
vllm_runner,
colbert_model_name,
@@ -111,7 +184,7 @@ def test_colbert_token_embed(
outputs = vllm_model.token_embed([TEXTS_1[0]])
assert len(outputs) == 1
emb = torch.tensor(outputs[0])
emb = torch.as_tensor(outputs[0])
assert emb.dim() == 2
assert emb.shape[1] == colbert_dim
assert emb.shape[0] > 1
@@ -135,8 +208,8 @@ def test_colbert_late_interaction_1_to_1(
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed([TEXTS_2[0]])
q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0])
q_emb = torch.as_tensor(q_outputs[0])
d_emb = torch.as_tensor(d_outputs[0])
manual_score = compute_maxsim_score(q_emb, d_emb).item()
@@ -164,11 +237,11 @@ def test_colbert_late_interaction_1_to_N(
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed(TEXTS_2)
q_emb = torch.tensor(q_outputs[0])
q_emb = torch.as_tensor(q_outputs[0])
manual_scores = []
for d_out in d_outputs:
d_emb = torch.tensor(d_out)
d_emb = torch.as_tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
@@ -198,8 +271,8 @@ def test_colbert_late_interaction_N_to_N(
manual_scores = []
for q_out, d_out in zip(q_outputs, d_outputs):
q_emb = torch.tensor(q_out)
d_emb = torch.tensor(d_out)
q_emb = torch.as_tensor(q_out)
d_emb = torch.as_tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
@@ -259,79 +332,16 @@ def test_colbert_embed_not_supported(
vllm_model.embed([TEXTS_1[0]])
# -----------------------------------------------------------------------
# Per-model HuggingFace comparison tests
# -----------------------------------------------------------------------
@pytest.mark.parametrize("backend", list(COLBERT_MODELS.keys()))
def test_colbert_hf_comparison(vllm_runner, backend):
"""Test that vLLM ColBERT embeddings match HuggingFace for each backend."""
from transformers import AutoTokenizer
def _assert_embeddings_close(vllm_outputs, hf_embeddings):
"""Assert that vLLM and HuggingFace embeddings match."""
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)
def test_colbert_hf_comparison_bert(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace (BERT)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, BertModel
model_name = COLBERT_MODELS["bert"]["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]]
with vllm_runner(
model_name,
runner="pooling",
dtype="float32",
max_model_len=512,
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_bert = BertModel.from_pretrained(model_name)
hf_bert.eval()
weights_path = hf_hub_download(model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [96, 384]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_bert(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states, linear_weight)
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)
def test_colbert_hf_comparison_modernbert(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
(ModernBERT)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
spec = COLBERT_MODELS["modernbert"]
spec = COLBERT_MODELS[backend]
hf_spec = spec["hf_comparison"]
model_name = spec["model"]
assert isinstance(model_name, str)
assert isinstance(hf_spec, dict)
test_texts = [TEXTS_1[0], TEXTS_2[0]]
with vllm_runner(
@@ -344,73 +354,21 @@ def test_colbert_hf_comparison_modernbert(vllm_runner):
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModel.from_pretrained(model_name)
hf_model.eval()
# Load projection from sentence-transformers 1_Dense layer
dense_path = hf_hub_download(model_name, filename="1_Dense/model.safetensors")
dense_weights = load_file(dense_path)
linear_weight = dense_weights["linear.weight"] # [128, 768]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_model(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states, linear_weight)
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)
def test_colbert_hf_comparison_jina(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
(Jina XLM-RoBERTa)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
spec = COLBERT_MODELS["jina"]
model_name = spec["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]]
with vllm_runner(
model_name,
runner="pooling",
dtype="float32",
max_model_len=spec["max_model_len"],
enforce_eager=True,
**spec["extra_kwargs"],
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
trust_remote_code=hf_spec.get("trust_remote_code", False),
)
hf_model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
hf_model = _load_hf_model(model_name, hf_spec, device)
linear_weight = _load_projection_weight(model_name, hf_spec, device)
hf_embeddings = _compute_hf_colbert_embeddings(
hf_model,
hf_tokenizer,
linear_weight,
test_texts,
device,
)
hf_model.eval()
# Load projection from main checkpoint
weights_path = hf_hub_download(model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [128, 1024]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_model(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states.float(), linear_weight.float())
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)