[Model] Let more models to support the score template. (#31335)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
wang.yuqi
2026-01-05 19:54:26 +08:00
committed by GitHub
parent caaa482aca
commit 911d38ed99
23 changed files with 764 additions and 334 deletions

View File

@@ -540,21 +540,28 @@ If your model is not in the above list, we will try to automatically convert the
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | |
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
| Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|---------------------------|-----------------------------|-----------------------------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | |
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ |
| `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2`(see note), etc. | [mxbai_rerank_v2.jinja](../../examples/pooling/score/template/mxbai_rerank_v2.jinja) | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B`(see note), etc. | [qwen3_reranker.jinja](../../examples/pooling/score/template/qwen3_reranker.jinja) | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | N/A | | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | N/A | | |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | N/A | \* | \* |
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model.
!!! note
Some models require a specific prompt format to work correctly.
You can find Example HF Models's corresponding score template in [examples/pooling/score/template/](../../examples/pooling/score/template)
Examples : [examples/pooling/score/using_template_offline.py](../../examples/pooling/score/using_template_offline.py) [examples/pooling/score/using_template_online.py](../../examples/pooling/score/using_template_online.py)
!!! note
Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command.
@@ -565,11 +572,6 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
!!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
!!! note
`nvidia/llama-nemotron-rerank-1b-v2` require a specific prompt format to work correctly.
Examples : [offline_using_template.py](../../examples/pooling/score/offline_using_template.py) [online_using_template.py](../../examples/pooling/score/online_using_template.py)
!!! note
Load the official original `mxbai-rerank-v2` by using the following command.
@@ -578,7 +580,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
```
!!! note
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py).
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker_offline.py](../../examples/pooling/score/qwen3_reranker_offline.py) [examples/pooling/score/qwen3_reranker_online.py](../../examples/pooling/score/qwen3_reranker_online.py).
```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'

View File

@@ -2,35 +2,70 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Script to convert Large Language Models (LLMs) to Sequence Classification models.
This is particularly useful for converting reranker models that use next-token
prediction to a sequence classification format for compatibility with standard
classification and rerank pipelines.
Usage examples:
- For BAAI/bge-reranker-v2-gemma:
python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma \
--classifier_from_tokens '["Yes"]' --method no_post_processing \
--path ./bge-reranker-v2-gemma-seq-cls
- For mxbai-rerank-v2:
python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 \
--classifier_from_tokens '["0", "1"]' --method from_2_way_softmax \
--path ./mxbai-rerank-base-v2-seq-cls
- For Qwen3-Reranker:
python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B \
--classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax \
--path ./Qwen3-Reranker-0.6B-seq-cls
Note: For BAAI/bge-reranker-v2-gemma, "Yes" and "yes" are different tokens.
"""
import argparse
import json
import torch
import transformers
# Usage:
# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
assert len(tokens) == 2
"""
This method extracts the difference between weights for 'true' and 'false' tokens
from the language model head to create a single classification weight vector.
Args:
causal_lm: The original causal language model
seq_cls_model: The target sequence classification model
tokenizer: Model tokenizer
tokens: List of two tokens representing [false_token, true_token]
device: Target device (cpu/cuda)
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
"""
assert len(tokens) == 2, (
"Method requires exactly two tokens for binary classification"
)
# Get the language model head weights (vocabulary_size x hidden_size)
lm_head_weights = causal_lm.lm_head.weight
# Convert token strings to their corresponding token IDs
false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1])
# Compute the classification weight as the difference between true and false token weights
# This follows the approach in: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
score_weight = lm_head_weights[true_id].to(device).to(
torch.float32
) - lm_head_weights[false_id].to(device).to(torch.float32)
# Copy the computed weights to the sequence classification model
with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
if seq_cls_model.score.bias is not None:
@@ -38,12 +73,29 @@ def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
"""
Directly use token weights from the language model head for classification.
This method maps each classification label directly to a corresponding token
in the vocabulary without additional transformation.
Args:
causal_lm: The original causal language model
seq_cls_model: The target sequence classification model
tokenizer: Model tokenizer
tokens: List of tokens representing class labels
device: Target device (cpu/cuda)
"""
# Get the language model head weights (vocabulary_size x hidden_size)
lm_head_weights = causal_lm.lm_head.weight
# Convert all tokens to their corresponding token IDs
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
# Extract weights for the specific tokens (num_tokens x hidden_size)
score_weight = lm_head_weights[token_ids].to(device)
# Copy the weights to the sequence classification model
with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight)
if seq_cls_model.score.bias is not None:
@@ -58,19 +110,33 @@ method_map = {
def converting(
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
):
assert method in method_map
"""
Main conversion function to transform a CausalLM model to SequenceClassification.
Args:
model_name: Name or path of the pretrained model
classifier_from_tokens: List of tokens used for classification
path: Output path to save the converted model
method: Conversion method ('from_2_way_softmax' or 'no_post_processing')
use_pad_token: Whether to use padding token in the sequence classification model
device: Device to load the model on ('cpu' or 'cuda')
"""
assert method in method_map, f"Unknown method: {method}"
# Determine number of labels based on conversion method
if method == "from_2_way_softmax":
assert len(classifier_from_tokens) == 2
num_labels = 1
else:
num_labels = len(classifier_from_tokens)
# Load tokenizer and original causal language model
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
model_name, device_map=device
)
# Load an empty sequence classification model with the same architecture
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
@@ -78,14 +144,17 @@ def converting(
device_map=device,
)
# Apply the selected conversion method to transfer weights
method_map[method](
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
)
# `llm as reranker` defaults to not using pad_token
# Configure padding token settings
# Note: Reranker models typically don't use padding tokens by default
seq_cls_model.config.use_pad_token = use_pad_token
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
# Save the converted model and tokenizer
seq_cls_model.save_pretrained(path)
tokenizer.save_pretrained(path)
@@ -99,25 +168,30 @@ def parse_args():
"--model_name",
type=str,
default="BAAI/bge-reranker-v2-gemma",
help="Model name",
help="HuggingFace model name or local path",
)
parser.add_argument(
"--classifier_from_tokens",
type=str,
default='["Yes"]',
help="classifier from tokens",
help="JSON string of tokens used for classification labels",
)
parser.add_argument(
"--method", type=str, default="no_post_processing", help="Converting converting"
"--method",
type=str,
default="no_post_processing",
help="Conversion method to use",
)
parser.add_argument(
"--use-pad-token", action="store_true", help="Whether to use pad_token"
"--use-pad-token",
action="store_true",
help="Enable padding token in the sequence classification model",
)
parser.add_argument(
"--path",
type=str,
default="./bge-reranker-v2-gemma-seq-cls",
help="Path to save converted model",
help="Output directory to save the converted model",
)
return parser.parse_args()

View File

@@ -1,89 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from vllm import LLM
model_name = "Qwen/Qwen3-Reranker-0.6B"
# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. Seehttps://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")
# If you want to load the official original version, the init parameters are
# as follows.
def get_llm() -> LLM:
"""Initializes and returns the LLM model for Qwen3-Reranker."""
return LLM(
model=model_name,
runner="pooling",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
)
# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector. The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
# Please use the query_template and document_template to format the query and
# document for better reranker results.
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"
def main() -> None:
instruction = (
"Given a web search query, retrieve relevant passages that answer the query"
)
queries = [
"What is the capital of China?",
"Explain gravity",
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
queries = [
query_template.format(prefix=prefix, instruction=instruction, query=query)
for query in queries
]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
llm = get_llm()
outputs = llm.score(queries, documents)
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
main()

View File

@@ -1,27 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from pathlib import Path
from vllm import LLM
model_name = "nvidia/llama-nemotron-rerank-1b-v2"
# Path to template file
template_path = Path(__file__).parent / "template" / "nemotron-rerank.jinja"
chat_template = template_path.read_text()
llm = LLM(model=model_name, runner="pooling", trust_remote_code=True)
query = "how much protein should a female eat?"
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
outputs = llm.score(query, documents, chat_template=chat_template)
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)

View File

@@ -1,46 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example of using the rerank API with template.
run:
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
"""
import json
import requests
url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"}
query = "how much protein should a female eat?"
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
data = {
"model": "nvidia/llama-nemotron-rerank-1b-v2",
"query": query,
"documents": documents,
}
def main():
response = requests.post(url, headers=headers, json=data)
# Check the response
if response.status_code == 200:
print("Request successful!")
print(json.dumps(response.json(), indent=2))
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
What is the difference between the official original version and one
that has been converted into a sequence classification model?
Qwen3-Reranker is a language model that doing reranker by using the
logits of "no" and "yes" tokens.
This requires computing logits for all 151,669 tokens in the vocabulary,
making it inefficient and incompatible with vLLM's score() API.
A conversion method has been proposed to transform the original model into a
sequence classification model. This converted model:
1. Is significantly more efficient
2. Fully supports vLLM's score() API
3. Simplifies initialization parameters
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
For the converted model, initialization would simply be:
llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")
This example demonstrates loading the ORIGINAL model with special overrides
to make it compatible with vLLM's score API.
"""
from pathlib import Path
from vllm import LLM
model_name = "Qwen/Qwen3-Reranker-0.6B"
def get_llm() -> LLM:
"""
Initializes and returns the LLM model for Qwen3-Reranker.
Returns:
LLM: Configured vLLM instance for reranking tasks.
Note:
This function loads the ORIGINAL Qwen3-Reranker model with specific
overrides to make it compatible with vLLM's score API.
"""
return LLM(
# Specify the original model from HuggingFace
model=model_name,
# Use pooling runner for score task
runner="pooling",
# HuggingFace model configuration overrides required for compatibility
hf_overrides={
# Manually route to sequence classification architecture
# This tells vLLM to use Qwen3ForSequenceClassification instead of
# the default Qwen3ForCausalLM
"architectures": ["Qwen3ForSequenceClassification"],
# Specify which token logits to extract from the language model head
# The original reranker uses "no" and "yes" token logits for scoring
"classifier_from_token": ["no", "yes"],
# Enable special handling for original Qwen3-Reranker models
# This flag triggers conversion logic that transforms the two token
# vectors into a single classification vector
"is_original_qwen3_reranker": True,
},
)
def main() -> None:
# Load the Jinja template for formatting query-document pairs
# The template ensures proper formatting for the reranker model
template_home = Path(__file__).parent / "template"
template_path = "qwen3_reranker.jinja"
chat_template = (template_home / template_path).read_text()
# Sample queries for testing the reranker
queries = [
"What is the capital of China?",
"Explain gravity",
]
# Corresponding documents to be scored against each query
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
# Initialize the LLM model with the original Qwen3-Reranker configuration
llm = get_llm()
# Compute relevance scores for each query-document pair
# The score() method returns a relevance score for each pair
# Higher scores indicate better relevance
outputs = llm.score(queries, documents, chat_template=chat_template)
# Extract and print the relevance scores from the outputs
# Each output contains a score representing query-document relevance
print("-" * 30)
print("Relevance scores:", [output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
What is the difference between the official original version and one
that has been converted into a sequence classification model?
Qwen3-Reranker is a language model that doing reranker by using the
logits of "no" and "yes" tokens.
This requires computing logits for all 151,669 tokens in the vocabulary,
making it inefficient and incompatible with vLLM's score() API.
A conversion method has been proposed to transform the original model into a
sequence classification model. This converted model:
1. Is significantly more efficient
2. Fully supports vLLM's score() API
3. Simplifies initialization parameters
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
For the converted model, initialization would simply be:
vllm serve tomaarsen/Qwen3-Reranker-0.6B-seq-cls --runner pooling --chat-template examples/pooling/score/template/qwen3_reranker.jinja
This example demonstrates loading the ORIGINAL model with special overrides
to make it compatible with vLLM's score API.
vllm serve Qwen/Qwen3-Reranker-0.6B --runner pooling --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' --chat-template examples/pooling/score/template/qwen3_reranker.jinja
"""
import json
import requests
# URL of the vLLM server's score endpoint
# Default vLLM server runs on localhost port 8000
url = "http://127.0.0.1:8000/score"
# HTTP headers for the request
headers = {"accept": "application/json", "Content-Type": "application/json"}
# Example queries & documents
queries = [
"What is the capital of China?",
"Explain gravity",
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
# Request payload for the score API
data = {
"model": "Qwen/Qwen3-Reranker-0.6B",
"text_1": queries,
"text_2": documents,
}
def main():
"""Main function to send a score request to the vLLM server.
This function sends a POST request to the /score endpoint with
the query and documents, then prints the relevance scores.
"""
# Send POST request to the vLLM server's score endpoint
response = requests.post(url, headers=headers, json=data)
# Check if the request was successful
if response.status_code == 200:
print("Request successful!")
# Pretty print the JSON response containing relevance scores
# The response includes scores for each document's relevance to the query
print(json.dumps(response.json(), indent=2))
else:
# Handle request failure
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
A: {{ (messages | selectattr("role", "eq", "query") | first).content }}
B: {{ (messages | selectattr("role", "eq", "document") | first).content }}
Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.

View File

@@ -0,0 +1,8 @@
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
query: {{ (messages | selectattr("role", "eq", "query") | first).content }}
document: {{ (messages | selectattr("role", "eq", "document") | first).content }}
You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).
Relevance:<|im_end|>
<|im_start|>assistant

View File

@@ -0,0 +1,11 @@
<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: {{ messages | selectattr("role", "eq", "system") | map(attribute="content") | first | default("Given a web search query, retrieve relevant passages that answer the query") }}
<Query>: {{ messages | selectattr("role", "eq", "query") | map(attribute="content") | first }}
<Document>: {{ messages | selectattr("role", "eq", "document") | map(attribute="content") | first }}<|im_end|>
<|im_start|>assistant
<think>
</think>

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from argparse import Namespace
from pathlib import Path
from typing import Any
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
"""Parse command line arguments for the reranking example.
This function sets up the argument parser with default values
specific to reranking models, including the model name and
runner type.
"""
parser = FlexibleArgumentParser()
# Add all EngineArgs command line arguments to the parser
parser = EngineArgs.add_cli_args(parser)
# Set default values specific to this reranking example
# These defaults ensure the script works out-of-the-box for reranking tasks
parser.set_defaults(
model="nvidia/llama-nemotron-rerank-1b-v2", # Default reranking model
runner="pooling", # Required for cross-encoder/reranking models
trust_remote_code=True, # Allow loading models with custom code
)
return parser.parse_args()
def get_chat_template(model: str) -> str:
"""Load the appropriate chat template for the specified model.
Reranking models require specific prompt templates to format
query-document pairs correctly. This function maps model names
to their corresponding template files.
"""
# Directory containing all chat template files
template_home = Path(__file__).parent / "template"
# Mapping from model names to their corresponding template files
# Each reranking model has its own specific prompt format
model_name_to_template_path_map = {
"BAAI/bge-reranker-v2-gemma": "bge-reranker-v2-gemma.jinja",
"Qwen/Qwen3-Reranker-0.6B": "qwen3_reranker.jinja",
"Qwen/Qwen3-Reranker-4B": "qwen3_reranker.jinja",
"Qwen/Qwen3-Reranker-8B": "qwen3_reranker.jinja",
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls": "qwen3_reranker.jinja",
"tomaarsen/Qwen3-Reranker-4B-seq-cls": "qwen3_reranker.jinja",
"tomaarsen/Qwen3-Reranker-8B-seq-cls": "qwen3_reranker.jinja",
"mixedbread-ai/mxbai-rerank-base-v2": "mxbai_rerank_v2.jinja",
"mixedbread-ai/mxbai-rerank-large-v2": "mxbai_rerank_v2.jinja",
"nvidia/llama-nemotron-rerank-1b-v2": "nemotron-rerank.jinja",
}
# Get the template filename for the specified model
template_path = model_name_to_template_path_map.get(model)
if template_path is None:
raise ValueError(f"This demo does not support model name: {model}.")
# Read and return the template content
return (template_home / template_path).read_text()
def get_hf_overrides(model: str) -> dict[str, Any]:
"""Convert Large Language Models (LLMs) to Sequence Classification models.
note:
Some reranking models require special configuration overrides to work
correctly with vLLM's score API.
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/qwen3_reranker_offline.py
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
"""
model_name_to_hf_overrides_map = {
"BAAI/bge-reranker-v2-gemma": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
},
"Qwen/Qwen3-Reranker-0.6B": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"Qwen/Qwen3-Reranker-4B": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"Qwen/Qwen3-Reranker-8B": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls": {},
"tomaarsen/Qwen3-Reranker-4B-seq-cls": {},
"tomaarsen/Qwen3-Reranker-8B-seq-cls": {},
"mixedbread-ai/mxbai-rerank-base-v2": {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
},
"mixedbread-ai/mxbai-rerank-large-v2": {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
},
"nvidia/llama-nemotron-rerank-1b-v2": {},
}
hf_overrides = model_name_to_hf_overrides_map.get(model)
if hf_overrides is None:
raise ValueError(f"This demo does not support model name: {model}.")
return hf_overrides
def main(args: Namespace):
"""Main execution function for the reranking example."""
# Get the overrides for the specified model
args.hf_overrides = get_hf_overrides(args.model)
# Initialize the LLM with all provided arguments
llm = LLM(**vars(args))
# Example query for demonstration
query = "how much protein should a female eat?"
# Example documents to be reranked based on relevance to the query
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
# Load the appropriate chat template for the selected model
# The template formats query-document pairs for the reranking model
chat_template = get_chat_template(args.model)
# Score documents based on relevance to the query
# The score method returns relevance scores for each document
outputs = llm.score(query, documents, chat_template=chat_template)
# Display the relevance scores
# Higher scores indicate more relevant documents
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Example of using the rerank API with template.
This script demonstrates how to interact with a vLLM server running
a reranking model via the REST API.
Before running this script, start the vLLM server with one of the
supported reranking models using the commands below.
note:
Some reranking models require special configuration overrides to work correctly
with vLLM's score API.
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/qwen3_reranker_online.py
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
run:
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' --chat-template examples/pooling/score/template/bge-reranker-v2-gemma.jinja
vllm serve tomaarsen/Qwen3-Reranker-0.6B-seq-cls --chat-template examples/pooling/score/template/qwen3_reranker.jinja
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}' --chat-template examples/pooling/score/template/mxbai_rerank_v2.jinja
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
vllm serve Qwen/Qwen3-Reranker-0.6B --runner pooling --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' --chat-template examples/pooling/score/template/qwen3_reranker.jinja
"""
import json
import requests
# URL of the vLLM server's rerank endpoint
# Default vLLM server runs on localhost port 8000
url = "http://127.0.0.1:8000/rerank"
# HTTP headers for the request
headers = {"accept": "application/json", "Content-Type": "application/json"}
# Example query & documents
query = "how much protein should a female eat?"
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
# Request payload for the rerank API
data = {
"model": "nvidia/llama-nemotron-rerank-1b-v2", # Model to use for reranking
"query": query, # The query to score documents against
"documents": documents, # List of documents to be scored
}
def main():
"""Main function to send a rerank request to the vLLM server.
This function sends a POST request to the /rerank endpoint with
the query and documents, then prints the relevance scores.
"""
# Send POST request to the vLLM server's rerank endpoint
response = requests.post(url, headers=headers, json=data)
# Check if the request was successful
if response.status_code == 200:
print("Request successful!")
# Pretty print the JSON response containing relevance scores
# The response includes scores for each document's relevance to the query
print(json.dumps(response.json(), indent=2))
else:
# Handle request failure
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()

View File

@@ -45,7 +45,11 @@ from transformers import (
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
from tests.models.utils import (
TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
softmax,
)
from vllm import LLM, SamplingParams, envs
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
@@ -513,7 +517,7 @@ class HfRunner:
elif problem_type == "multi_label_classification":
logits = output.logits.sigmoid()[0].tolist()
else:
logits = output.logits.softmax(dim=-1)[0].tolist()
logits = softmax(output.logits)[0].tolist()
outputs.append(logits)
return outputs

View File

@@ -3,13 +3,16 @@
import tempfile
from pathlib import Path
from typing import Any
import mteb
import numpy as np
import requests
import torch
from mteb.models import ModelMeta
from torch.utils.data import DataLoader
from tests.conftest import HfRunner
from tests.models.utils import (
RerankModelInfo,
get_vllm_extra_kwargs,
@@ -67,6 +70,12 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
# Hoping to discover potential scheduling
# issues by randomizing the order.
r = self.rng.permutation(len(queries))
queries = [queries[i] for i in r]
corpus = [corpus[i] for i in r]
outputs = self.llm.score(
queries,
corpus,
@@ -75,6 +84,7 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
chat_template=self.chat_template,
)
scores = np.array(outputs)
scores = scores[np.argsort(r)]
return scores
@@ -84,7 +94,6 @@ class ScoreClientMtebEncoder(MtebCrossEncoderMixin):
def __init__(self, model_name: str, url):
self.model_name = model_name
self.url = url
self.rng = np.random.default_rng(seed=42)
def predict(
self,
@@ -130,6 +139,50 @@ class RerankClientMtebEncoder(ScoreClientMtebEncoder):
return response["results"][0]["relevance_score"]
class HFMtebCrossEncoder(MtebCrossEncoderMixin, HfRunner):
chat_template: str | None = None
def __init__(self, model_name: str, dtype: str = "auto", **kwargs: Any) -> None:
HfRunner.__init__(
self, model_name=model_name, is_cross_encoder=True, dtype=dtype, **kwargs
)
@torch.no_grad
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
if self.chat_template is not None:
tokenizer = self.model.tokenizer
prompts = []
for query, document in zip(queries, corpus):
conversation = [
{"role": "query", "content": query},
{"role": "document", "content": document},
]
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tools=None,
chat_template=self.chat_template,
tokenize=False,
)
prompts.append(prompt)
outputs_list = HfRunner.classify(self, prompts)
scores = np.array(outputs_list).squeeze(-1)
return scores
else:
prompts = list(zip(queries, corpus))
outputs_tensor = HfRunner.predict(self, prompts, show_progress_bar=False)
return outputs_tensor.cpu().numpy()
def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
with tempfile.TemporaryDirectory() as prediction_folder:
bm25s = mteb.get_model("bm25s")
@@ -168,31 +221,21 @@ def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
return main_score
def mteb_test_rerank_models_hf(
hf_runner, model_name, hf_dtype="float32", hf_model_callback=None
):
with hf_runner(model_name, is_cross_encoder=True, dtype=hf_dtype) as hf_model:
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_rerank(
hf_model, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS
)
st_dtype = next(hf_model.model.model.parameters()).dtype
return st_main_score, st_dtype
def mteb_test_rerank_models(
hf_runner,
vllm_runner,
model_info: RerankModelInfo,
hf_runner=HFMtebCrossEncoder,
vllm_extra_kwargs=None,
hf_model_callback=None,
vllm_mteb_encoder=VllmMtebCrossEncoder,
atol=MTEB_RERANK_TOL,
):
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
# Maybe load chat_template.
chat_template: str | None = None
if model_info.chat_template_name is not None:
chat_template = (template_home / model_info.chat_template_name).read_text()
with vllm_runner(
model_info.name,
runner="pooling",
@@ -201,6 +244,7 @@ def mteb_test_rerank_models(
**vllm_extra_kwargs,
) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
vllm_model.chat_template = chat_template
# Confirm whether vllm is using the correct architecture
if model_info.architecture:
@@ -209,12 +253,6 @@ def mteb_test_rerank_models(
# Score API is only enabled for num_labels == 1
assert model_config.hf_config.num_labels == 1
# Maybe load chat_template.
chat_template: str | None = None
if model_info.chat_template_name is not None:
chat_template = (template_home / model_info.chat_template_name).read_text()
vllm_model.chat_template = chat_template
# Confirm whether the important configs in model_config are correct.
if model_info.pooling_type is not None:
assert model_config.pooler_config.pooling_type == model_info.pooling_type
@@ -242,9 +280,14 @@ def mteb_test_rerank_models(
# Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant
if model_info.mteb_score is None:
st_main_score, st_dtype = mteb_test_rerank_models_hf(
hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback
)
with hf_runner(model_info.name, dtype=model_info.hf_dtype) as hf_model:
hf_model.chat_template = chat_template
st_main_score = run_mteb_rerank(
hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS,
)
st_dtype = next(hf_model.model.model.parameters()).dtype
else:
st_main_score = model_info.mteb_score
st_dtype = "Constant"

View File

@@ -112,7 +112,5 @@ def test_embed_models_correctness(
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(vllm_runner, model_info)

View File

@@ -11,40 +11,60 @@ from torch.utils.data import DataLoader
from tests.conftest import HfRunner
from tests.models.utils import RerankModelInfo
from .mteb_score_utils import VllmMtebCrossEncoder, mteb_test_rerank_models
from .mteb_score_utils import (
MtebCrossEncoderMixin,
mteb_test_rerank_models,
)
RERANK_MODELS = [
RerankModelInfo(
"BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification",
mteb_score=0.33757,
hf_overrides={
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
},
mteb_score=0.33757,
pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
chat_template_name="bge-reranker-v2-gemma.jinja",
),
]
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
class GemmaRerankerHfRunner(HfRunner):
class GemmaRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
def __init__(
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
HfRunner.__init__(
self,
model_name=model_name,
auto_cls=AutoModelForCausalLM,
dtype=dtype,
**kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")
@torch.no_grad()
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
@torch.no_grad
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
def get_inputs(pairs, tokenizer, prompt=None):
if prompt is None:
prompt = PROMPT
@@ -89,8 +109,8 @@ class GemmaRerankerHfRunner(HfRunner):
)
scores = []
for query, doc, *_ in prompts:
pairs = [(query, doc)]
for query, document in zip(queries, corpus):
pairs = [(query, document)]
inputs = get_inputs(pairs, self.tokenizer)
inputs = inputs.to(self.model.device)
_n_tokens = inputs["input_ids"].shape[1]
@@ -107,41 +127,10 @@ class GemmaRerankerHfRunner(HfRunner):
return torch.Tensor(scores)
class GemmaMtebEncoder(VllmMtebCrossEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.query_template = "A: {query}\n"
self.document_template = "B: {doc}\n{prompt}"
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [
self.query_template.format(query=text)
for batch in inputs1
for text in batch["text"]
]
corpus = [
self.document_template.format(doc=text, prompt=PROMPT)
for batch in inputs2
for text in batch["text"]
]
outputs = self.llm.score(
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
)
scores = np.array(outputs)
return scores
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(
GemmaRerankerHfRunner,
vllm_runner,
model_info,
vllm_mteb_encoder=GemmaMtebEncoder,
hf_runner=GemmaRerankerHfRunner,
)

View File

@@ -11,27 +11,26 @@ from .mteb_score_utils import mteb_test_rerank_models
RERANK_MODELS = [
RerankModelInfo(
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
mteb_score=0.32898,
architecture="BertForSequenceClassification",
pooling_type="CLS",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
mteb_score=0.32898,
),
RerankModelInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
mteb_score=0.25736,
architecture="Qwen3ForSequenceClassification",
pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
chat_template_name="qwen3_reranker.jinja",
mteb_score=0.33459,
),
]
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(vllm_runner, model_info)

View File

@@ -143,7 +143,5 @@ def test_embed_models_correctness(
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(vllm_runner, model_info)

View File

@@ -72,10 +72,8 @@ def test_embed_models_correctness(
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(vllm_runner, model_info)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)

View File

@@ -2,13 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import mteb
import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from tests.conftest import HfRunner
from tests.models.utils import RerankModelInfo
from .mteb_score_utils import mteb_test_rerank_models
from .mteb_score_utils import MtebCrossEncoderMixin, mteb_test_rerank_models
mxbai_rerank_hf_overrides = {
"architectures": ["Qwen2ForSequenceClassification"],
@@ -21,50 +24,69 @@ RERANK_MODELS = [
"mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
mteb_score=0.273,
pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
chat_template_name="mxbai_rerank_v2.jinja",
mteb_score=0.33651,
enable_test=True,
),
RerankModelInfo(
"mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides,
chat_template_name="mxbai_rerank_v2.jinja",
enable_test=False,
),
]
class MxbaiRerankerHfRunner(HfRunner):
class MxbaiRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
def __init__(
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
HfRunner.__init__(
self,
model_name=model_name,
auto_cls=AutoModelForCausalLM,
dtype=dtype,
**kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
def process_inputs(pairs):
inputs = self.tokenizer(
pairs,
padding=False,
truncation="longest_first",
return_attention_mask=False,
)
for i, ele in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = ele
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
tokenizer = self.tokenizer
prompts = []
for query, document in zip(queries, corpus):
conversation = [
{"role": "query", "content": query},
{"role": "document", "content": document},
]
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tools=None,
chat_template=self.chat_template,
tokenize=False,
)
prompts.append(prompt)
@torch.no_grad()
def compute_logits(inputs):
logits = self.model(**inputs).logits[:, -1, :]
yes_logits = logits[:, self.yes_loc]
@@ -74,9 +96,9 @@ class MxbaiRerankerHfRunner(HfRunner):
return scores
scores = []
for query, doc, *_ in prompts:
pairs = [(query, doc)]
inputs = process_inputs(pairs)
for prompt in prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = self.wrap_device(inputs)
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)
@@ -84,4 +106,4 @@ class MxbaiRerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info)
mteb_test_rerank_models(vllm_runner, model_info, hf_runner=MxbaiRerankerHfRunner)

View File

@@ -46,7 +46,5 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(
hf_runner, vllm_runner, model_info: RerankModelInfo
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(vllm_runner, model_info)

View File

@@ -1,15 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from typing import Any
import mteb
import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
from tests.conftest import HfRunner
from tests.models.utils import RerankModelInfo
from tests.utils import multi_gpu_test
from .mteb_score_utils import mteb_test_rerank_models
from .mteb_score_utils import MtebCrossEncoderMixin, mteb_test_rerank_models
qwen3_reranker_hf_overrides = {
"architectures": ["Qwen3ForSequenceClassification"],
@@ -21,51 +25,71 @@ RERANK_MODELS = [
RerankModelInfo(
"Qwen/Qwen3-Reranker-0.6B",
architecture="Qwen3ForSequenceClassification",
mteb_score=0.25736,
hf_overrides=qwen3_reranker_hf_overrides,
chat_template_name="qwen3_reranker.jinja",
pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
mteb_score=0.33459,
enable_test=True,
),
RerankModelInfo(
"Qwen/Qwen3-Reranker-4B",
architecture="Qwen3ForSequenceClassification",
chat_template_name="qwen3_reranker.jinja",
hf_overrides=qwen3_reranker_hf_overrides,
enable_test=False,
),
]
class Qwen3RerankerHfRunner(HfRunner):
class Qwen3RerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
def __init__(
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM)
HfRunner.__init__(
self,
model_name=model_name,
auto_cls=AutoModelForCausalLM,
dtype=dtype,
**kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
self.max_length = 40960
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
def process_inputs(pairs):
inputs = self.tokenizer(
pairs,
padding=False,
truncation="longest_first",
return_attention_mask=False,
@torch.no_grad
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
tokenizer = self.tokenizer
prompts = []
for query, document in zip(queries, corpus):
conversation = [
{"role": "query", "content": query},
{"role": "document", "content": document},
]
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tools=None,
chat_template=self.chat_template,
tokenize=False,
)
for i, ele in enumerate(inputs["input_ids"]):
inputs["input_ids"][i] = ele
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
prompts.append(prompt)
@torch.no_grad()
def compute_logits(inputs):
batch_scores = self.model(**inputs).logits[:, -1, :]
true_vector = batch_scores[:, self.token_true_id]
@@ -76,9 +100,9 @@ class Qwen3RerankerHfRunner(HfRunner):
return scores
scores = []
for query, doc, *_ in prompts:
pairs = [(query, doc)]
inputs = process_inputs(pairs)
for prompt in prompts:
inputs = tokenizer([prompt], return_tensors="pt")
inputs = self.wrap_device(inputs)
score = compute_logits(inputs)
scores.append(score[0].item())
return torch.Tensor(scores)
@@ -86,7 +110,7 @@ class Qwen3RerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info)
mteb_test_rerank_models(vllm_runner, model_info, hf_runner=Qwen3RerankerHfRunner)
@pytest.mark.parametrize("model_info", RERANK_MODELS)
@@ -99,5 +123,8 @@ def test_rerank_models_mteb_tp(vllm_runner, model_info: RerankModelInfo) -> None
}
mteb_test_rerank_models(
Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs
vllm_runner,
model_info,
vllm_extra_kwargs=vllm_extra_kwargs,
hf_runner=Qwen3RerankerHfRunner,
)

View File

@@ -251,7 +251,7 @@ class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
tokens = getattr(config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, (
"Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/offline_reranker.py"
"https://github.com/vllm-project/vllm/tree/main/examples/pooling/score/qwen3_reranker_offline.py"
)
model_config.hf_config.method = "from_2_way_softmax"