[CI] Add mteb testing for rerank models (#19344)
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.models.utils import EmbedModelInfo
|
||||
from tests.models.utils import EmbedModelInfo, RerankModelInfo
|
||||
|
||||
# Most models on the STS12 task (See #17175):
|
||||
# Most embedding models on the STS12 task (See #17175):
|
||||
# - Model implementation and minor changes in tensor dtype
|
||||
# results in differences less than 1e-4
|
||||
# - Different model results in differences more than 1e-3
|
||||
@@ -16,6 +20,11 @@ from tests.models.utils import EmbedModelInfo
|
||||
MTEB_EMBED_TASKS = ["STS12"]
|
||||
MTEB_EMBED_TOL = 1e-4
|
||||
|
||||
# See #19344
|
||||
MTEB_RERANK_TASKS = ["NFCorpus"]
|
||||
MTEB_RERANK_LANGS = ["en"]
|
||||
MTEB_RERANK_TOL = 1e-3
|
||||
|
||||
|
||||
class VllmMtebEncoder(mteb.Encoder):
|
||||
|
||||
@@ -39,6 +48,27 @@ class VllmMtebEncoder(mteb.Encoder):
|
||||
embeds = embeds[np.argsort(r)]
|
||||
return embeds
|
||||
|
||||
def predict(
|
||||
self,
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
r = self.rng.permutation(len(sentences))
|
||||
sentences = [sentences[i] for i in r]
|
||||
|
||||
queries = [s[0] for s in sentences]
|
||||
corpus = [s[1] for s in sentences]
|
||||
|
||||
outputs = self.model.score(queries,
|
||||
corpus,
|
||||
truncate_prompt_tokens=-1,
|
||||
use_tqdm=False)
|
||||
scores = np.array(outputs)
|
||||
scores = scores[np.argsort(r)]
|
||||
return scores
|
||||
|
||||
|
||||
class OpenAIClientMtebEncoder(mteb.Encoder):
|
||||
|
||||
@@ -62,21 +92,72 @@ class OpenAIClientMtebEncoder(mteb.Encoder):
|
||||
return embeds
|
||||
|
||||
|
||||
class ScoreClientMtebEncoder(mteb.Encoder):
|
||||
|
||||
def __init__(self, model_name: str, url):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.url = url
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
r = self.rng.permutation(len(sentences))
|
||||
sentences = [sentences[i] for i in r]
|
||||
|
||||
outputs = []
|
||||
for query, corpus, prompt in sentences:
|
||||
outputs.append(self.get_score(query, corpus))
|
||||
|
||||
scores = np.array(outputs)
|
||||
scores = scores[np.argsort(r)]
|
||||
return scores
|
||||
|
||||
def get_score(self, query, corpus):
|
||||
response = requests.post(self.url,
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"text_1": query,
|
||||
"text_2": corpus,
|
||||
"truncate_prompt_tokens": -1,
|
||||
}).json()
|
||||
return response['data'][0]["score"]
|
||||
|
||||
|
||||
class RerankClientMtebEncoder(ScoreClientMtebEncoder):
|
||||
|
||||
def get_score(self, query, corpus):
|
||||
response = requests.post(self.url,
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"query": query,
|
||||
"documents": [corpus],
|
||||
"truncate_prompt_tokens": -1,
|
||||
}).json()
|
||||
return response['results'][0]["relevance_score"]
|
||||
|
||||
|
||||
def run_mteb_embed_task(encoder, tasks):
|
||||
tasks = mteb.get_tasks(tasks=tasks)
|
||||
evaluation = mteb.MTEB(tasks=tasks)
|
||||
results = evaluation.run(encoder, verbosity=0, output_folder=None)
|
||||
results = evaluation.run(
|
||||
encoder,
|
||||
verbosity=0,
|
||||
output_folder=None,
|
||||
encode_kwargs={
|
||||
"show_progress_bar": False,
|
||||
},
|
||||
)
|
||||
|
||||
main_score = results[0].scores["test"][0]["main_score"]
|
||||
return main_score
|
||||
|
||||
|
||||
def run_mteb_embed_task_st(model_name, tasks):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer(model_name)
|
||||
return run_mteb_embed_task(model, tasks)
|
||||
|
||||
|
||||
def mteb_test_embed_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
@@ -118,3 +199,96 @@ def mteb_test_embed_models(hf_runner,
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
|
||||
|
||||
|
||||
def run_mteb_rerank(cross_encoder, tasks, languages):
|
||||
with tempfile.TemporaryDirectory() as results_folder:
|
||||
bm25s = mteb.get_model("bm25s")
|
||||
tasks = mteb.get_tasks(tasks=tasks, languages=languages)
|
||||
|
||||
subset = "default"
|
||||
eval_splits = ["test"]
|
||||
|
||||
evaluation = mteb.MTEB(tasks=tasks)
|
||||
evaluation.run(
|
||||
bm25s,
|
||||
verbosity=0,
|
||||
eval_splits=eval_splits,
|
||||
save_predictions=True,
|
||||
output_folder=f"{results_folder}/stage1",
|
||||
encode_kwargs={"show_progress_bar": False},
|
||||
)
|
||||
|
||||
results = evaluation.run(
|
||||
cross_encoder,
|
||||
verbosity=0,
|
||||
eval_splits=eval_splits,
|
||||
top_k=10,
|
||||
save_predictions=True,
|
||||
output_folder=f"{results_folder}/stage2",
|
||||
previous_results=
|
||||
f"{results_folder}/stage1/NFCorpus_{subset}_predictions.json",
|
||||
encode_kwargs={"show_progress_bar": False},
|
||||
)
|
||||
main_score = results[0].scores["test"][0]["main_score"]
|
||||
return main_score
|
||||
|
||||
|
||||
def mteb_test_rerank_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: RerankModelInfo,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None):
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
pytest.skip("Skipping test.")
|
||||
|
||||
vllm_extra_kwargs = vllm_extra_kwargs or {}
|
||||
vllm_extra_kwargs["dtype"] = model_info.dtype
|
||||
|
||||
with vllm_runner(model_info.name,
|
||||
task="score",
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
|
||||
if model_info.architecture:
|
||||
assert (model_info.architecture
|
||||
in vllm_model.model.llm_engine.model_config.architectures)
|
||||
|
||||
vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model),
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
||||
|
||||
with hf_runner(model_info.name, is_cross_encoder=True,
|
||||
dtype="float32") as hf_model:
|
||||
|
||||
original_predict = hf_model.predict
|
||||
|
||||
def _predict(
|
||||
sentences: list[tuple[str, str,
|
||||
Optional[str]]], # query, corpus, prompt
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# vllm and st both remove the prompt, fair comparison.
|
||||
prompts = [(s[0], s[1]) for s in sentences]
|
||||
return original_predict(prompts, *args, **kwargs, batch_size=8)
|
||||
|
||||
hf_model.predict = _predict
|
||||
hf_model.original_predict = original_predict
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL)
|
||||
|
||||
Reference in New Issue
Block a user