[Frontend] Support using chat template as custom score template for reranking models (#30550)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
Jakub Zakrzewski
2025-12-23 12:19:16 +01:00
committed by GitHub
parent 27c6c2f98c
commit 23daef548d
19 changed files with 663 additions and 46 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from pathlib import Path
import mteb
import numpy as np
@@ -19,6 +20,11 @@ from tests.models.utils import (
get_vllm_extra_kwargs,
)
template_home = (
Path(__file__).parent.parent.parent.parent.parent
/ "examples/pooling/score/template"
)
# Most embedding models on the STS12 task (See #17175):
# - Model implementation and minor changes in tensor dtype
# results in differences less than 1e-4
@@ -102,30 +108,6 @@ class VllmMtebEncoder(mteb.EncoderProtocol):
return sim
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
mteb_model_meta = _empty_model_meta
def __init__(self, vllm_model):
self.llm = vllm_model
self.rng = np.random.default_rng(seed=42)
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
outputs = self.llm.score(
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
)
scores = np.array(outputs)
return scores
class OpenAIClientMtebEncoder(VllmMtebEncoder):
def __init__(self, model_name: str, client):
self.model_name = model_name
@@ -153,6 +135,35 @@ class OpenAIClientMtebEncoder(VllmMtebEncoder):
return embeds
class VllmMtebCrossEncoder(mteb.CrossEncoderProtocol):
mteb_model_meta = _empty_model_meta
def __init__(self, vllm_model):
self.llm = vllm_model
self.rng = np.random.default_rng(seed=42)
self.chat_template: str | None = getattr(vllm_model, "chat_template", None)
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
outputs = self.llm.score(
queries,
corpus,
truncate_prompt_tokens=-1,
use_tqdm=False,
chat_template=self.chat_template,
)
scores = np.array(outputs)
return scores
class ScoreClientMtebEncoder(mteb.CrossEncoderProtocol):
mteb_model_meta = _empty_model_meta
@@ -387,6 +398,11 @@ def mteb_test_rerank_models(
== model_info.default_pooling_type
)
chat_template: str | None = None
if model_info.chat_template_name is not None:
chat_template = (template_home / model_info.chat_template_name).read_text()
vllm_model.chat_template = chat_template
vllm_main_score = run_mteb_rerank(
vllm_mteb_encoder(vllm_model),
tasks=MTEB_RERANK_TASKS,

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.models.utils import (
EmbedModelInfo,
LASTPoolingEmbedModelInfo,
LASTPoolingRerankModelInfo,
RerankModelInfo,
)
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
EMBEDDING_MODELS = [
LASTPoolingEmbedModelInfo(
"nvidia/llama-nemotron-embed-1b-v2",
architecture="LlamaBidirectionalModel",
mteb_score=0.689164662128673,
)
]
RERANK_MODELS = [
LASTPoolingRerankModelInfo(
"nvidia/llama-nemotron-rerank-1b-v2",
architecture="LlamaBidirectionalForSequenceClassification",
chat_template_name="nemotron-rerank.jinja",
mteb_score=0.33994,
),
]
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)