[CI] Add mteb testing for rerank models (#19344)

This commit is contained in:
wang.yuqi
2025-06-16 16:36:43 +08:00
committed by GitHub
parent 26bc46ef89
commit f40f763f12
15 changed files with 418 additions and 246 deletions

View File

@@ -1,87 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import pytest
import torch
model_name = "Qwen/Qwen3-Reranker-4B"
from tests.conftest import HfRunner
text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
from .mteb_utils import RerankModelInfo, mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo("Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification",
dtype="float32",
enable_test=True),
RerankModelInfo("Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification",
dtype="float32",
enable_test=False)
]
def vllm_reranker(model_name):
from vllm import LLM
class Qwen3RerankerHfRunner(HfRunner):
model = LLM(model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
dtype="float32")
def __init__(self,
model_name: str,
dtype: str = "auto",
*args: Any,
**kwargs: Any) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
text_1 = "What is the capital of France?"
texts_2 = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
self.tokenizer = AutoTokenizer.from_pretrained(model_name,
padding_side='left')
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
outputs = model.score(text_1, texts_2)
def predict(self, prompts: list[list[str]], *args,
**kwargs) -> torch.Tensor:
return [output.outputs.score for output in outputs]
def process_inputs(pairs):
inputs = self.tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = self.tokenizer.pad(inputs,
padding=True,
return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def compute_logits(inputs):
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
false_vector = batch_scores[:, self.token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp()
return scores
scores = []
for prompt in prompts:
inputs = process_inputs([prompt])
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)
def hf_reranker(model_name):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
assert model_info.architecture == "Qwen3ForSequenceClassification"
token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
vllm_extra_kwargs: dict[str, Any] = {
"hf_overrides": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
}
max_length = 8192
if model_info.name == "Qwen/Qwen3-Reranker-4B":
vllm_extra_kwargs["max_num_seqs"] = 1
def process_inputs(pairs):
inputs = tokenizer(pairs,
padding=False,
truncation='longest_first',
return_attention_mask=False,
max_length=max_length)
for i, ele in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = ele
inputs = tokenizer.pad(inputs,
padding=True,
return_tensors="pt",
max_length=max_length)
for key in inputs:
inputs[key] = inputs[key].to(model.device)
return inputs
@torch.no_grad()
def compute_logits(inputs, **kwargs):
batch_scores = model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, token_true_id]
false_vector = batch_scores[:, token_false_id]
batch_scores = torch.stack([false_vector, true_vector], dim=1)
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
scores = batch_scores[:, 1].exp().tolist()
return scores
pairs = [(text_1, texts_2[0]), (text_1, texts_2[1])]
inputs = process_inputs(pairs)
scores = compute_logits(inputs)
return scores
@pytest.mark.parametrize("model_name", [model_name])
def test_model(model_name):
hf_outputs = hf_reranker(model_name)
vllm_outputs = vllm_reranker(model_name)
assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.01)
assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01)
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info,
vllm_extra_kwargs)