[Model] Let more models to support the score template. (#31335)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -2,13 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import mteb
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from tests.models.utils import RerankModelInfo
|
||||
|
||||
from .mteb_score_utils import mteb_test_rerank_models
|
||||
from .mteb_score_utils import MtebCrossEncoderMixin, mteb_test_rerank_models
|
||||
|
||||
mxbai_rerank_hf_overrides = {
|
||||
"architectures": ["Qwen2ForSequenceClassification"],
|
||||
@@ -21,50 +24,69 @@ RERANK_MODELS = [
|
||||
"mixedbread-ai/mxbai-rerank-base-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
mteb_score=0.273,
|
||||
pooling_type="LAST",
|
||||
attn_type="decoder",
|
||||
is_prefix_caching_supported=True,
|
||||
is_chunked_prefill_supported=True,
|
||||
chat_template_name="mxbai_rerank_v2.jinja",
|
||||
mteb_score=0.33651,
|
||||
enable_test=True,
|
||||
),
|
||||
RerankModelInfo(
|
||||
"mixedbread-ai/mxbai-rerank-large-v2",
|
||||
architecture="Qwen2ForSequenceClassification",
|
||||
hf_overrides=mxbai_rerank_hf_overrides,
|
||||
chat_template_name="mxbai_rerank_v2.jinja",
|
||||
enable_test=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class MxbaiRerankerHfRunner(HfRunner):
|
||||
class MxbaiRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
|
||||
def __init__(
|
||||
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
|
||||
HfRunner.__init__(
|
||||
self,
|
||||
model_name=model_name,
|
||||
auto_cls=AutoModelForCausalLM,
|
||||
dtype=dtype,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
|
||||
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
|
||||
|
||||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||||
def process_inputs(pairs):
|
||||
inputs = self.tokenizer(
|
||||
pairs,
|
||||
padding=False,
|
||||
truncation="longest_first",
|
||||
return_attention_mask=False,
|
||||
)
|
||||
for i, ele in enumerate(inputs["input_ids"]):
|
||||
inputs["input_ids"][i] = ele
|
||||
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")
|
||||
for key in inputs:
|
||||
inputs[key] = inputs[key].to(self.model.device)
|
||||
return inputs
|
||||
@torch.no_grad
|
||||
def predict(
|
||||
self,
|
||||
inputs1: DataLoader[mteb.types.BatchedInput],
|
||||
inputs2: DataLoader[mteb.types.BatchedInput],
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
queries = [text for batch in inputs1 for text in batch["text"]]
|
||||
corpus = [text for batch in inputs2 for text in batch["text"]]
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
prompts = []
|
||||
for query, document in zip(queries, corpus):
|
||||
conversation = [
|
||||
{"role": "query", "content": query},
|
||||
{"role": "document", "content": document},
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
tools=None,
|
||||
chat_template=self.chat_template,
|
||||
tokenize=False,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_logits(inputs):
|
||||
logits = self.model(**inputs).logits[:, -1, :]
|
||||
yes_logits = logits[:, self.yes_loc]
|
||||
@@ -74,9 +96,9 @@ class MxbaiRerankerHfRunner(HfRunner):
|
||||
return scores
|
||||
|
||||
scores = []
|
||||
for query, doc, *_ in prompts:
|
||||
pairs = [(query, doc)]
|
||||
inputs = process_inputs(pairs)
|
||||
for prompt in prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = self.wrap_device(inputs)
|
||||
score = compute_logits(inputs)
|
||||
scores.append(score[0].item())
|
||||
return torch.Tensor(scores)
|
||||
@@ -84,4 +106,4 @@ class MxbaiRerankerHfRunner(HfRunner):
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
|
||||
mteb_test_rerank_models(vllm_runner, model_info, hf_runner=MxbaiRerankerHfRunner)
|
||||
|
||||
Reference in New Issue
Block a user