[CI] Add mteb testing for rerank models (#19344)
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.models.utils import EmbedModelInfo
|
||||
from tests.models.utils import EmbedModelInfo, RerankModelInfo
|
||||
|
||||
# Most models on the STS12 task (See #17175):
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
# results in differences less than 1e-4
|
||||
# - Different model results in differences more than 1e-3
|
||||
@@ -16,6 +20,11 @@ from tests.models.utils import EmbedModelInfo
|
||||
MTEB_EMBED_TASKS = ["STS12"]
|
||||
MTEB_EMBED_TOL = 1e-4
|
||||
|
||||
# See #19344
|
||||
MTEB_RERANK_TASKS = ["NFCorpus"]
|
||||
MTEB_RERANK_LANGS = ["en"]
|
||||
MTEB_RERANK_TOL = 1e-3
|
||||
|
||||
|
||||
class VllmMtebEncoder(mteb.Encoder):
|
||||
|
||||
@@ -39,6 +48,27 @@ class VllmMtebEncoder(mteb.Encoder):
|
||||
embeds = embeds[np.argsort(r)]
|
||||
return embeds
|
||||
|
||||
def predict(
|
||||
self,
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
r = self.rng.permutation(len(sentences))
|
||||
sentences = [sentences[i] for i in r]
|
||||
|
||||
queries = [s[0] for s in sentences]
|
||||
corpus = [s[1] for s in sentences]
|
||||
|
||||
outputs = self.model.score(queries,
|
||||
corpus,
|
||||
truncate_prompt_tokens=-1,
|
||||
use_tqdm=False)
|
||||
scores = np.array(outputs)
|
||||
scores = scores[np.argsort(r)]
|
||||
return scores
|
||||
|
||||
|
||||
class OpenAIClientMtebEncoder(mteb.Encoder):
|
||||
|
||||
@@ -62,21 +92,72 @@ class OpenAIClientMtebEncoder(mteb.Encoder):
|
||||
return embeds
|
||||
|
||||
|
||||
class ScoreClientMtebEncoder(mteb.Encoder):
|
||||
|
||||
def __init__(self, model_name: str, url):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.url = url
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
r = self.rng.permutation(len(sentences))
|
||||
sentences = [sentences[i] for i in r]
|
||||
|
||||
outputs = []
|
||||
for query, corpus, prompt in sentences:
|
||||
outputs.append(self.get_score(query, corpus))
|
||||
|
||||
scores = np.array(outputs)
|
||||
scores = scores[np.argsort(r)]
|
||||
return scores
|
||||
|
||||
def get_score(self, query, corpus):
|
||||
response = requests.post(self.url,
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"text_1": query,
|
||||
"text_2": corpus,
|
||||
"truncate_prompt_tokens": -1,
|
||||
}).json()
|
||||
return response['data'][0]["score"]
|
||||
|
||||
|
||||
class RerankClientMtebEncoder(ScoreClientMtebEncoder):
|
||||
|
||||
def get_score(self, query, corpus):
|
||||
response = requests.post(self.url,
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": [corpus],
|
||||
"truncate_prompt_tokens": -1,
|
||||
}).json()
|
||||
return response['results'][0]["relevance_score"]
|
||||
|
||||
|
||||
def run_mteb_embed_task(encoder, tasks):
|
||||
tasks = mteb.get_tasks(tasks=tasks)
|
||||
evaluation = mteb.MTEB(tasks=tasks)
|
||||
results = evaluation.run(encoder, verbosity=0, output_folder=None)
|
||||
results = evaluation.run(
|
||||
encoder,
|
||||
verbosity=0,
|
||||
output_folder=None,
|
||||
encode_kwargs={
|
||||
"show_progress_bar": False,
|
||||
},
|
||||
)
|
||||
|
||||
main_score = results[0].scores["test"][0]["main_score"]
|
||||
return main_score
|
||||
|
||||
|
||||
def run_mteb_embed_task_st(model_name, tasks):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(model_name)
|
||||
return run_mteb_embed_task(model, tasks)
|
||||
|
||||
|
||||
def mteb_test_embed_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
@@ -118,3 +199,96 @@ def mteb_test_embed_models(hf_runner,
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
|
||||
|
||||
|
||||
def run_mteb_rerank(cross_encoder, tasks, languages):
|
||||
with tempfile.TemporaryDirectory() as results_folder:
|
||||
bm25s = mteb.get_model("bm25s")
|
||||
tasks = mteb.get_tasks(tasks=tasks, languages=languages)
|
||||
|
||||
subset = "default"
|
||||
eval_splits = ["test"]
|
||||
|
||||
evaluation = mteb.MTEB(tasks=tasks)
|
||||
evaluation.run(
|
||||
bm25s,
|
||||
verbosity=0,
|
||||
eval_splits=eval_splits,
|
||||
save_predictions=True,
|
||||
output_folder=f"{results_folder}/stage1",
|
||||
encode_kwargs={"show_progress_bar": False},
|
||||
)
|
||||
|
||||
results = evaluation.run(
|
||||
cross_encoder,
|
||||
verbosity=0,
|
||||
eval_splits=eval_splits,
|
||||
top_k=10,
|
||||
save_predictions=True,
|
||||
output_folder=f"{results_folder}/stage2",
|
||||
previous_results=
|
||||
f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json",
|
||||
encode_kwargs={"show_progress_bar": False},
|
||||
)
|
||||
main_score = results[0].scores["test"][0]["main_score"]
|
||||
return main_score
|
||||
|
||||
|
||||
def mteb_test_rerank_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: RerankModelInfo,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None):
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
task="score",
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture
|
||||
in vllm_model.model.llm_engine.model_config.architectures)
|
||||
|
||||
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
||||
|
||||
with hf_runner(model_info.name, is_cross_encoder=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
original_predict = hf_model.predict
|
||||
|
||||
def _predict(
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# vllm and st both remove the prompt, fair comparison.
|
||||
prompts = [(s[0], s[1]) for s in sentences]
|
||||
return original_predict(prompts, *args, **kwargs, batch_size=8)
|
||||
|
||||
hf_model.predict = _predict
|
||||
hf_model.original_predict = original_predict
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
from ...utils import EmbedModelInfo, RerankModelInfo
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
MODELS = [
|
||||
########## BertModel
|
||||
@@ -57,6 +58,20 @@ MODELS = [
|
||||
enable_test=True),
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
########## XLMRobertaForSequenceClassification
|
||||
RerankModelInfo("BAAI/bge-reranker-base",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=True),
|
||||
RerankModelInfo("BAAI/bge-reranker-large",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
enable_test=False),
|
||||
RerankModelInfo("BAAI/bge-reranker-v2-m3",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
@@ -70,3 +85,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
|
||||
example_prompts) -> None:
|
||||
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
||||
example_prompts)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
18
tests/models/language/pooling/test_cross_encoder.py
Normal file
18
tests/models/language/pooling/test_cross_encoder.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
architecture="BertForSequenceClassification"),
|
||||
RerankModelInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
||||
architecture="Qwen3ForSequenceClassification")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
@@ -6,28 +6,10 @@ import pytest
|
||||
|
||||
from vllm import PoolingParams
|
||||
|
||||
from .embed_utils import (EmbedModelInfo, check_embeddings_close,
|
||||
from ...utils import EmbedModelInfo, RerankModelInfo
|
||||
from .embed_utils import (check_embeddings_close,
|
||||
correctness_test_embed_models, matryoshka_fy)
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
SCORING_MODELS = [
|
||||
"jinaai/jina-reranker-v2-base-multilingual", # Roberta
|
||||
]
|
||||
|
||||
TEXTS_1 = ["Organic skincare products for sensitive skin"]
|
||||
|
||||
TEXTS_2 = [
|
||||
"Organic skincare for sensitive skin with aloe vera and chamomile.",
|
||||
"New makeup trends focus on bold colors and innovative techniques",
|
||||
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
|
||||
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken", # noqa: E501
|
||||
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla", # noqa: E501
|
||||
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras", # noqa: E501
|
||||
"针对敏感肌专门设计的天然有机护肤产品",
|
||||
"新的化妆趋势注重鲜艳的颜色和创新的技巧",
|
||||
"敏感肌のために特別に設計された天然有機スキンケア製品",
|
||||
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
|
||||
]
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
EMBEDDING_MODELS = [
|
||||
EmbedModelInfo("jinaai/jina-embeddings-v3",
|
||||
@@ -35,47 +17,13 @@ EMBEDDING_MODELS = [
|
||||
is_matryoshka=True)
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=SCORING_MODELS)
|
||||
def model_name(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
|
||||
text_pair = [TEXTS_1[0], TEXTS_2[0]]
|
||||
|
||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||
hf_outputs = hf_model.predict([text_pair]).tolist()
|
||||
|
||||
with vllm_runner(model_name, task="score", dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.score(text_pair[0], text_pair[1])
|
||||
|
||||
assert len(vllm_outputs) == 1
|
||||
assert len(hf_outputs) == 1
|
||||
|
||||
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str):
|
||||
|
||||
text_pairs = [[TEXTS_1[0], text] for text in TEXTS_2]
|
||||
|
||||
with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model:
|
||||
hf_outputs = hf_model.predict(text_pairs).tolist()
|
||||
|
||||
with vllm_runner(model_name, task="score", dtype=dtype,
|
||||
max_model_len=None) as vllm_model:
|
||||
vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2)
|
||||
|
||||
assert len(vllm_outputs) == 10
|
||||
assert len(hf_outputs) == 10
|
||||
|
||||
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo(
|
||||
"jinaai/jina-reranker-v2-base-multilingual",
|
||||
architecture="XLMRobertaForSequenceClassification",
|
||||
dtype="float32",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
@@ -106,6 +54,12 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
|
||||
hf_model_callback=hf_model_callback)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("dimensions", [16, 32])
|
||||
|
||||
@@ -1,87 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
model_name = "Qwen/Qwen3-Reranker-4B"
|
||||
from tests.conftest import HfRunner
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=True),
|
||||
RerankModelInfo("Qwen/Qwen3-Reranker-4B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
dtype="float32",
|
||||
enable_test=False)
|
||||
]
|
||||
|
||||
|
||||
def vllm_reranker(model_name):
|
||||
from vllm import LLM
|
||||
class Qwen3RerankerHfRunner(HfRunner):
|
||||
|
||||
model = LLM(model=model_name,
|
||||
task="score",
|
||||
hf_overrides={
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
},
|
||||
dtype="float32")
|
||||
def __init__(self,
|
||||
model_name: str,
|
||||
dtype: str = "auto",
|
||||
*args: Any,
|
||||
**kwargs: Any) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
padding_side='left')
|
||||
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
||||
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
||||
|
||||
outputs = model.score(text_1, texts_2)
|
||||
def predict(self, prompts: list[list[str]], *args,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
return [output.outputs.score for output in outputs]
|
||||
def process_inputs(pairs):
|
||||
inputs = self.tokenizer(pairs,
|
||||
padding=False,
|
||||
truncation='longest_first',
|
||||
return_attention_mask=False)
|
||||
for i, ele in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = ele
|
||||
inputs = self.tokenizer.pad(inputs,
|
||||
padding=True,
|
||||
return_tensors="pt")
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs):
|
||||
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, self.token_true_id]
|
||||
false_vector = batch_scores[:, self.token_false_id]
|
||||
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp()
|
||||
return scores
|
||||
|
||||
scores = []
|
||||
for prompt in prompts:
|
||||
inputs = process_inputs([prompt])
|
||||
score = compute_logits(inputs)
|
||||
scores.append(score[0].item())
|
||||
return torch.Tensor(scores)
|
||||
|
||||
|
||||
def hf_reranker(model_name):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
|
||||
assert model_info.architecture == "Qwen3ForSequenceClassification"
|
||||
|
||||
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
vllm_extra_kwargs: dict[str, Any] = {
|
||||
"hf_overrides": {
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
"classifier_from_token": ["no", "yes"],
|
||||
"is_original_qwen3_reranker": True,
|
||||
}
|
||||
}
|
||||
|
||||
max_length = 8192
|
||||
if model_info.name == "Qwen/Qwen3-Reranker-4B":
|
||||
vllm_extra_kwargs["max_num_seqs"] = 1
|
||||
|
||||
def process_inputs(pairs):
|
||||
inputs = tokenizer(pairs,
|
||||
padding=False,
|
||||
truncation='longest_first',
|
||||
return_attention_mask=False,
|
||||
max_length=max_length)
|
||||
for i, ele in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = ele
|
||||
inputs = tokenizer.pad(inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=max_length)
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(model.device)
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs, **kwargs):
|
||||
batch_scores = model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, token_true_id]
|
||||
false_vector = batch_scores[:, token_false_id]
|
||||
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
|
||||
inputs = process_inputs(pairs)
|
||||
scores = compute_logits(inputs)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [model_name])
|
||||
def test_model(model_name):
|
||||
hf_outputs = hf_reranker(model_name)
|
||||
vllm_outputs = vllm_reranker(model_name)
|
||||
|
||||
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
||||
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
|
||||
vllm_extra_kwargs)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
model_name = "tomaarsen/Qwen3-Reranker-0.6B-seq-cls"
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
texts_2 = [
|
||||
"The capital of Brazil is Brasilia.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
|
||||
def vllm_reranker(model_name):
|
||||
from vllm import LLM
|
||||
|
||||
model = LLM(model=model_name, task="score")
|
||||
outputs = model.score(text_1, texts_2)
|
||||
|
||||
return [output.outputs.score for output in outputs]
|
||||
|
||||
|
||||
def hf_reranker(model_name):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
|
||||
|
||||
token_false_id = tokenizer.convert_tokens_to_ids("no")
|
||||
token_true_id = tokenizer.convert_tokens_to_ids("yes")
|
||||
|
||||
max_length = 8192
|
||||
|
||||
def process_inputs(pairs):
|
||||
inputs = tokenizer(pairs,
|
||||
padding=False,
|
||||
truncation='longest_first',
|
||||
return_attention_mask=False,
|
||||
max_length=max_length)
|
||||
for i, ele in enumerate(inputs['input_ids']):
|
||||
inputs['input_ids'][i] = ele
|
||||
inputs = tokenizer.pad(inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
max_length=max_length)
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(model.device)
|
||||
return inputs
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs, **kwargs):
|
||||
batch_scores = model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, token_true_id]
|
||||
false_vector = batch_scores[:, token_false_id]
|
||||
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
||||
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
||||
scores = batch_scores[:, 1].exp().tolist()
|
||||
return scores
|
||||
|
||||
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
|
||||
inputs = process_inputs(pairs)
|
||||
scores = compute_logits(inputs)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [model_name])
|
||||
def test_model(model_name):
|
||||
hf_outputs = hf_reranker(model_name)
|
||||
vllm_outputs = vllm_reranker(model_name)
|
||||
|
||||
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
|
||||
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
|
||||
Reference in New Issue
Block a user