[Model] Let more models to support the score template. (#31335)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -3,13 +3,16 @@
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from mteb.models import ModelMeta
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from tests.models.utils import (
|
||||
RerankModelInfo,
|
||||
get_vllm_extra_kwargs,
|
||||
@@ -67,6 +70,12 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
# Hoping to discover potential scheduling
|
||||
# issues by randomizing the order.
|
||||
r = self.rng.permutation(len(queries))
|
||||
queries = [queries[i] for i in r]
|
||||
corpus = [corpus[i] for i in r]
|
||||
|
||||
outputs = self.llm.score(
|
||||
queries,
|
||||
corpus,
|
||||
@@ -75,6 +84,7 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
|
||||
chat_template=self.chat_template,
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
scores = scores[np.argsort(r)]
|
||||
return scores
|
||||
|
||||
|
||||
@@ -84,7 +94,6 @@ class ScoreClientMtebEncoder(MtebCrossEncoderMixin):
|
||||
def __init__(self, model_name: str, url):
|
||||
self.model_name = model_name
|
||||
self.url = url
|
||||
self.rng = np.random.default_rng(seed=42)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
@@ -130,6 +139,50 @@ class RerankClientMtebEncoder(ScoreClientMtebEncoder):
|
||||
return response["results"][0]["relevance_score"]
|
||||
|
||||
|
||||
class HFMtebCrossEncoder(MtebCrossEncoderMixin, HfRunner):
|
||||
chat_template: str | None = None
|
||||
|
||||
def __init__(self, model_name: str, dtype: str = "auto", **kwargs: Any) -> None:
|
||||
HfRunner.__init__(
|
||||
self, model_name=model_name, is_cross_encoder=True, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
if self.chat_template is not None:
|
||||
tokenizer = self.model.tokenizer
|
||||
prompts = []
|
||||
for query, document in zip(queries, corpus):
|
||||
conversation = [
|
||||
{"role": "query", "content": query},
|
||||
{"role": "document", "content": document},
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
tools=None,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
outputs_list = HfRunner.classify(self, prompts)
|
||||
scores = np.array(outputs_list).squeeze(-1)
|
||||
return scores
|
||||
else:
|
||||
prompts = list(zip(queries, corpus))
|
||||
outputs_tensor = HfRunner.predict(self, prompts, show_progress_bar=False)
|
||||
return outputs_tensor.cpu().numpy()
|
||||
|
||||
|
||||
def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
|
||||
with tempfile.TemporaryDirectory() as prediction_folder:
|
||||
bm25s = mteb.get_model("bm25s")
|
||||
@@ -168,31 +221,21 @@ def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
|
||||
return main_score
|
||||
|
||||
|
||||
def mteb_test_rerank_models_hf(
|
||||
hf_runner, model_name, hf_dtype="float32", hf_model_callback=None
|
||||
):
|
||||
with hf_runner(model_name, is_cross_encoder=True, dtype=hf_dtype) as hf_model:
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_rerank(
|
||||
hf_model, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS
|
||||
)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
return st_main_score, st_dtype
|
||||
|
||||
|
||||
def mteb_test_rerank_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model_info: RerankModelInfo,
|
||||
hf_runner=HFMtebCrossEncoder,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None,
|
||||
vllm_mteb_encoder=VllmMtebCrossEncoder,
|
||||
atol=MTEB_RERANK_TOL,
|
||||
):
|
||||
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
|
||||
|
||||
# Maybe load chat_template.
|
||||
chat_template: str | None = None
|
||||
if model_info.chat_template_name is not None:
|
||||
chat_template = (template_home / model_info.chat_template_name).read_text()
|
||||
|
||||
with vllm_runner(
|
||||
model_info.name,
|
||||
runner="pooling",
|
||||
@@ -201,6 +244,7 @@ def mteb_test_rerank_models(
|
||||
**vllm_extra_kwargs,
|
||||
) as vllm_model:
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
vllm_model.chat_template = chat_template
|
||||
|
||||
# Confirm whether vllm is using the correct architecture
|
||||
if model_info.architecture:
|
||||
@@ -209,12 +253,6 @@ def mteb_test_rerank_models(
|
||||
# Score API is only enabled for num_labels == 1
|
||||
assert model_config.hf_config.num_labels == 1
|
||||
|
||||
# Maybe load chat_template.
|
||||
chat_template: str | None = None
|
||||
if model_info.chat_template_name is not None:
|
||||
chat_template = (template_home / model_info.chat_template_name).read_text()
|
||||
vllm_model.chat_template = chat_template
|
||||
|
||||
# Confirm whether the important configs in model_config are correct.
|
||||
if model_info.pooling_type is not None:
|
||||
assert model_config.pooler_config.pooling_type == model_info.pooling_type
|
||||
@@ -242,9 +280,14 @@ def mteb_test_rerank_models(
|
||||
# Accelerate mteb test by setting
|
||||
# SentenceTransformers mteb score to a constant
|
||||
if model_info.mteb_score is None:
|
||||
st_main_score, st_dtype = mteb_test_rerank_models_hf(
|
||||
hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback
|
||||
)
|
||||
with hf_runner(model_info.name, dtype=model_info.hf_dtype) as hf_model:
|
||||
hf_model.chat_template = chat_template
|
||||
st_main_score = run_mteb_rerank(
|
||||
hf_model,
|
||||
tasks=MTEB_RERANK_TASKS,
|
||||
languages=MTEB_RERANK_LANGS,
|
||||
)
|
||||
st_dtype = next(hf_model.model.model.parameters()).dtype
|
||||
else:
|
||||
st_main_score = model_info.mteb_score
|
||||
st_dtype = "Constant"
|
||||
|
||||
@@ -112,7 +112,5 @@ def test_embed_models_correctness(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(vllm_runner, model_info)
|
||||
|
||||
@@ -11,40 +11,60 @@ from torch.utils.data import DataLoader
|
||||
from tests.conftest import HfRunner
|
||||
from tests.models.utils import RerankModelInfo
|
||||
|
||||
from .mteb_score_utils import VllmMtebCrossEncoder, mteb_test_rerank_models
|
||||
from .mteb_score_utils import (
|
||||
MtebCrossEncoderMixin,
|
||||
mteb_test_rerank_models,
|
||||
)
|
||||
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo(
|
||||
"BAAI/bge-reranker-v2-gemma",
|
||||
architecture="GemmaForSequenceClassification",
|
||||
mteb_score=0.33757,
|
||||
hf_overrides={
|
||||
"architectures": ["GemmaForSequenceClassification"],
|
||||
"classifier_from_token": ["Yes"],
|
||||
"method": "no_post_processing",
|
||||
},
|
||||
mteb_score=0.33757,
|
||||
pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
chat_template_name="bge-reranker-v2-gemma.jinja",
|
||||
),
|
||||
]
|
||||
|
||||
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
|
||||
|
||||
|
||||
class GemmaRerankerHfRunner(HfRunner):
|
||||
class GemmaRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
|
||||
def __init__(
|
||||
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
HfRunner.__init__(
|
||||
self,
|
||||
model_name=model_name,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||
@torch.no_grad
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
def get_inputs(pairs, tokenizer, prompt=None):
|
||||
if prompt is None:
|
||||
prompt = PROMPT
|
||||
@@ -89,8 +109,8 @@ class GemmaRerankerHfRunner(HfRunner):
|
||||
)
|
||||
|
||||
scores = []
|
||||
for query, doc, *_ in prompts:
|
||||
pairs = [(query, doc)]
|
||||
for query, document in zip(queries, corpus):
|
||||
pairs = [(query, document)]
|
||||
inputs = get_inputs(pairs, self.tokenizer)
|
||||
inputs = inputs.to(self.model.device)
|
||||
_n_tokens = inputs["input_ids"].shape[1]
|
||||
@@ -107,41 +127,10 @@ class GemmaRerankerHfRunner(HfRunner):
|
||||
return torch.Tensor(scores)
|
||||
|
||||
|
||||
class GemmaMtebEncoder(VllmMtebCrossEncoder):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.query_template = "A: {query}\n"
|
||||
self.document_template = "B: {doc}\n{prompt}"
|
||||
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [
|
||||
self.query_template.format(query=text)
|
||||
for batch in inputs1
|
||||
for text in batch["text"]
|
||||
]
|
||||
corpus = [
|
||||
self.document_template.format(doc=text, prompt=PROMPT)
|
||||
for batch in inputs2
|
||||
for text in batch["text"]
|
||||
]
|
||||
outputs = self.llm.score(
|
||||
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
return scores
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(
|
||||
GemmaRerankerHfRunner,
|
||||
vllm_runner,
|
||||
model_info,
|
||||
vllm_mteb_encoder=GemmaMtebEncoder,
|
||||
hf_runner=GemmaRerankerHfRunner,
|
||||
)
|
||||
|
||||
@@ -11,27 +11,26 @@ from .mteb_score_utils import mteb_test_rerank_models
|
||||
RERANK_MODELS = [
|
||||
RerankModelInfo(
|
||||
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
mteb_score=0.32898,
|
||||
architecture="BertForSequenceClassification",
|
||||
pooling_type="CLS",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
mteb_score=0.32898,
|
||||
),
|
||||
RerankModelInfo(
|
||||
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
|
||||
mteb_score=0.25736,
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
chat_template_name="qwen3_reranker.jinja",
|
||||
mteb_score=0.33459,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(vllm_runner, model_info)
|
||||
|
||||
@@ -143,7 +143,5 @@ def test_embed_models_correctness(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(vllm_runner, model_info)
|
||||
|
||||
@@ -72,10 +72,8 @@ def test_embed_models_correctness(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(vllm_runner, model_info)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from tests.models.utils import RerankModelInfo
|
||||
|
||||
from .mteb_score_utils import mteb_test_rerank_models
|
||||
from .mteb_score_utils import MtebCrossEncoderMixin, mteb_test_rerank_models
|
||||
|
||||
mxbai_rerank_hf_overrides = {
|
||||
"architectures": ["Qwen2ForSequenceClassification"],
|
||||
@@ -21,50 +24,69 @@ RERANK_MODELS = [
|
||||
"mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
mteb_score=0.273,
|
||||
pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
chat_template_name="mxbai_rerank_v2.jinja",
|
||||
mteb_score=0.33651,
|
||||
enable_test=True,
|
||||
),
|
||||
RerankModelInfo(
|
||||
"mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
chat_template_name="mxbai_rerank_v2.jinja",
|
||||
enable_test=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class MxbaiRerankerHfRunner(HfRunner):
|
||||
class MxbaiRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
|
||||
def __init__(
|
||||
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
HfRunner.__init__(
|
||||
self,
|
||||
model_name=model_name,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
|
||||
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
|
||||
|
||||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||
def process_inputs(pairs):
|
||||
inputs = self.tokenizer(
|
||||
pairs,
|
||||
padding=False,
|
||||
truncation="longest_first",
|
||||
return_attention_mask=False,
|
||||
)
|
||||
for i, ele in enumerate(inputs["input_ids"]):
|
||||
inputs["input_ids"][i] = ele
|
||||
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
@torch.no_grad
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
prompts = []
|
||||
for query, document in zip(queries, corpus):
|
||||
conversation = [
|
||||
{"role": "query", "content": query},
|
||||
{"role": "document", "content": document},
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
tools=None,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs):
|
||||
logits = self.model(**inputs).logits[:, -1, :]
|
||||
yes_logits = logits[:, self.yes_loc]
|
||||
@@ -74,9 +96,9 @@ class MxbaiRerankerHfRunner(HfRunner):
|
||||
return scores
|
||||
|
||||
scores = []
|
||||
for query, doc, *_ in prompts:
|
||||
pairs = [(query, doc)]
|
||||
inputs = process_inputs(pairs)
|
||||
for prompt in prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = self.wrap_device(inputs)
|
||||
score = compute_logits(inputs)
|
||||
scores.append(score[0].item())
|
||||
return torch.Tensor(scores)
|
||||
@@ -84,4 +106,4 @@ class MxbaiRerankerHfRunner(HfRunner):
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
|
||||
mteb_test_rerank_models(vllm_runner, model_info, hf_runner=MxbaiRerankerHfRunner)
|
||||
|
||||
@@ -46,7 +46,5 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(
|
||||
hf_runner, vllm_runner, model_info: RerankModelInfo
|
||||
) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(vllm_runner, model_info)
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
from typing import Any
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from tests.models.utils import RerankModelInfo
|
||||
from tests.utils import multi_gpu_test
|
||||
|
||||
from .mteb_score_utils import mteb_test_rerank_models
|
||||
from .mteb_score_utils import MtebCrossEncoderMixin, mteb_test_rerank_models
|
||||
|
||||
qwen3_reranker_hf_overrides = {
|
||||
"architectures": ["Qwen3ForSequenceClassification"],
|
||||
@@ -21,51 +25,71 @@ RERANK_MODELS = [
|
||||
RerankModelInfo(
|
||||
"Qwen/Qwen3-Reranker-0.6B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
mteb_score=0.25736,
|
||||
hf_overrides=qwen3_reranker_hf_overrides,
|
||||
chat_template_name="qwen3_reranker.jinja",
|
||||
pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
mteb_score=0.33459,
|
||||
enable_test=True,
|
||||
),
|
||||
RerankModelInfo(
|
||||
"Qwen/Qwen3-Reranker-4B",
|
||||
architecture="Qwen3ForSequenceClassification",
|
||||
chat_template_name="qwen3_reranker.jinja",
|
||||
hf_overrides=qwen3_reranker_hf_overrides,
|
||||
enable_test=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Qwen3RerankerHfRunner(HfRunner):
|
||||
class Qwen3RerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
|
||||
def __init__(
|
||||
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
HfRunner.__init__(
|
||||
self,
|
||||
model_name=model_name,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
||||
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
||||
self.max_length = 40960
|
||||
|
||||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||
def process_inputs(pairs):
|
||||
inputs = self.tokenizer(
|
||||
pairs,
|
||||
padding=False,
|
||||
truncation="longest_first",
|
||||
return_attention_mask=False,
|
||||
@torch.no_grad
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
prompts = []
|
||||
for query, document in zip(queries, corpus):
|
||||
conversation = [
|
||||
{"role": "query", "content": query},
|
||||
{"role": "document", "content": document},
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
tools=None,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
for i, ele in enumerate(inputs["input_ids"]):
|
||||
inputs["input_ids"][i] = ele
|
||||
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
prompts.append(prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs):
|
||||
batch_scores = self.model(**inputs).logits[:, -1, :]
|
||||
true_vector = batch_scores[:, self.token_true_id]
|
||||
@@ -76,9 +100,9 @@ class Qwen3RerankerHfRunner(HfRunner):
|
||||
return scores
|
||||
|
||||
scores = []
|
||||
for query, doc, *_ in prompts:
|
||||
pairs = [(query, doc)]
|
||||
inputs = process_inputs(pairs)
|
||||
for prompt in prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = self.wrap_device(inputs)
|
||||
score = compute_logits(inputs)
|
||||
scores.append(score[0].item())
|
||||
return torch.Tensor(scores)
|
||||
@@ -86,7 +110,7 @@ class Qwen3RerankerHfRunner(HfRunner):
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
|
||||
mteb_test_rerank_models(vllm_runner, model_info, hf_runner=Qwen3RerankerHfRunner)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
@@ -99,5 +123,8 @@ def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None
|
||||
}
|
||||
|
||||
mteb_test_rerank_models(
|
||||
Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs
|
||||
vllm_runner,
|
||||
model_info,
|
||||
vllm_extra_kwargs=vllm_extra_kwargs,
|
||||
hf_runner=Qwen3RerankerHfRunner,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user