[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -10,9 +10,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
|
||||
|
||||
async def accumulate_streaming_response(
|
||||
|
||||
@@ -105,7 +105,7 @@ def test_pooling_params(llm: LLM):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_score_api(llm: LLM):
|
||||
err_msg = "Score API is only enabled for num_labels == 1."
|
||||
err_msg = "Scoring API is only enabled for num_labels == 1."
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_score(server: RemoteOpenAIServer, model_name: str):
|
||||
# score api is only enabled for num_labels == 1.
|
||||
# Scoring API is only enabled for num_labels == 1.
|
||||
response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
@@ -405,7 +405,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||
# rerank api is only enabled for num_labels == 1.
|
||||
# Scoring API is only enabled for num_labels == 1.
|
||||
response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
|
||||
@@ -7,7 +7,7 @@ import requests
|
||||
from tests.entrypoints.pooling.scoring.util import EncoderScoringHfRunner
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.multimodal.utils import encode_image_url, fetch_image
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .util import make_base64_image, make_image_mm_param
|
||||
|
||||
MODEL_NAME = "vidore/colpali-v1.3-hf"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
|
||||
# that supports encoder-only models on ROCm.
|
||||
attention_config = None
|
||||
if current_platform.is_rocm():
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
max_num_batched_tokens=32768,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.75,
|
||||
enforce_eager=True,
|
||||
seed=0,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_query_text_vs_docs_image(llm):
|
||||
"""Score a text query against image documents via the multimodal path."""
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
query = "Describe the red object"
|
||||
image_docs = [
|
||||
make_image_mm_param(red_image),
|
||||
make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
scores = llm.score(query, image_docs)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_query_text_vs_docs_mix(llm) -> None:
|
||||
"""Score a text query against a mix of text and image documents."""
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
scores = llm.score(query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_query_image_vs_docs_text(llm) -> None:
|
||||
"""Score an image query against text documents."""
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"Describe the red object.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
scores = llm.score(image_query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
|
||||
from .util import ColBERTScoringHfRunner
|
||||
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.entrypoints.pooling.scoring.util import (
|
||||
make_base64_image,
|
||||
make_image_mm_param,
|
||||
)
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
|
||||
MODEL_NAME = "vidore/colpali-v1.3-hf"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, []) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_api_query_text_vs_docs_image(server: RemoteOpenAIServer):
|
||||
query = "Describe the red object"
|
||||
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
documents = [
|
||||
make_image_mm_param(red_image),
|
||||
make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"queries": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
scores = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert scores.id is not None
|
||||
assert scores.data is not None
|
||||
assert len(scores.data) == 2
|
||||
assert scores.data[0].score > scores.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_api_query_text_vs_docs_mix(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"queries": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
scores = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert scores.id is not None
|
||||
assert scores.data is not None
|
||||
assert len(scores.data) == 2
|
||||
assert scores.data[0].score > scores.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_api_query_image_vs_docs_text(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"Describe the red object.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"queries": image_query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
scores = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert scores.id is not None
|
||||
assert scores.data is not None
|
||||
assert len(scores.data) == 2
|
||||
assert scores.data[0].score > scores.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_api_query_text_vs_docs_image(server: RemoteOpenAIServer):
|
||||
query = "Describe the red object"
|
||||
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
documents = [
|
||||
make_image_mm_param(red_image),
|
||||
make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={"model": MODEL_NAME, "query": query, "documents": documents},
|
||||
)
|
||||
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
red_result = next(r for r in rerank.results if r.index == 0)
|
||||
blue_result = next(r for r in rerank.results if r.index == 1)
|
||||
|
||||
assert red_result.relevance_score > blue_result.relevance_score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_api_query_text_vs_docs_mix(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
result0 = next(r for r in rerank.results if r.index == 0)
|
||||
result1 = next(r for r in rerank.results if r.index == 1)
|
||||
|
||||
assert result0.relevance_score > result1.relevance_score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_api_query_image_vs_docs_text(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"Describe the red object.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"query": image_query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
result0 = next(r for r in rerank.results if r.index == 0)
|
||||
result1 = next(r for r in rerank.results if r.index == 1)
|
||||
|
||||
assert result0.relevance_score > result1.relevance_score
|
||||
@@ -1,353 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
get_score_prompt,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
# A cross-encoder model for testing
|
||||
CROSS_ENCODER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
def assert_prompt_tokenization_consistent(
|
||||
tokenizer, full_prompt, engine_prompt, add_special_tokens=True
|
||||
):
|
||||
"""Verify that engine_prompt token_ids match tokenizing full_prompt."""
|
||||
expected_ids = tokenizer(full_prompt, add_special_tokens=add_special_tokens)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert actual_ids == expected_ids, (
|
||||
f"Token IDs don't match.\nExpected: {expected_ids}\nActual: {actual_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_model_config():
|
||||
return ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_tokenizer(cross_encoder_model_config):
|
||||
return get_tokenizer(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
trust_remote_code=cross_encoder_model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_reranker_model_config():
|
||||
"""Model config for LLM-as-reranker style (no pad token)."""
|
||||
config = ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
# use_sep_token is a property that reads from hf_config,
|
||||
# so we set it there to override the default (True)
|
||||
config.hf_config.use_sep_token = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenization_kwargs():
|
||||
"""Common tokenization kwargs used across tests."""
|
||||
return {"add_special_tokens": True, "return_tensors": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_with_score_template():
|
||||
"""Mock model class that supports score template and tracks post_process calls."""
|
||||
|
||||
class MockModelWithScoreTemplate:
|
||||
supports_score_template = True
|
||||
post_process_called: list[TokensPrompt] = []
|
||||
|
||||
@staticmethod
|
||||
def get_score_template(p1: str, p2: str) -> str:
|
||||
return f"[QUERY]{p1}[SEP][DOC]{p2}"
|
||||
|
||||
@staticmethod
|
||||
def post_process_tokens(prompt: TokensPrompt) -> None:
|
||||
MockModelWithScoreTemplate.post_process_called.append(prompt)
|
||||
|
||||
return MockModelWithScoreTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_no_score_template():
|
||||
"""Mock model class that does not support score template."""
|
||||
|
||||
class MockModelNoScoreTemplate:
|
||||
supports_score_template = False
|
||||
|
||||
return MockModelNoScoreTemplate
|
||||
|
||||
|
||||
class TestGetScorePrompt:
|
||||
"""Tests for the get_score_prompt function."""
|
||||
|
||||
def test_tokenization_kwargs_passed_through(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
):
|
||||
"""Test that tokenization kwargs are properly passed through."""
|
||||
data_1 = "Query text"
|
||||
data_2 = "Document text"
|
||||
|
||||
# Test with truncation - custom kwargs for this test
|
||||
custom_tokenization_kwargs = {
|
||||
"add_special_tokens": True,
|
||||
"return_tensors": None,
|
||||
"truncation": True,
|
||||
"max_length": 20,
|
||||
}
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
custom_tokenization_kwargs,
|
||||
data_1,
|
||||
data_2,
|
||||
)
|
||||
|
||||
assert isinstance(full_prompt, str)
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# With max_length=20 and truncation, should not exceed this
|
||||
assert len(engine_prompt["prompt_token_ids"]) <= 20
|
||||
# Since truncation was applied, token_ids should be a prefix of full encoding
|
||||
full_ids = cross_encoder_tokenizer(full_prompt, add_special_tokens=True)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert full_ids[: len(actual_ids)] == actual_ids, (
|
||||
f"Token IDs are not a prefix of full encoding.\n"
|
||||
f"Full IDs: {full_ids}\n"
|
||||
f"Actual IDs: {actual_ids}"
|
||||
)
|
||||
|
||||
def test_model_supports_score_template(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template (no score_template arg)."""
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query text",
|
||||
"document text",
|
||||
)
|
||||
|
||||
assert full_prompt == "[QUERY]query text[SEP][DOC]document text"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert len(engine_prompt["prompt_token_ids"]) > 0
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_model_supports_score_template_but_custom_template_provided(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template but custom template is provided."""
|
||||
template = (
|
||||
'TEMPLATE_USED {{ messages[0]["content"] }} {{ messages[1]["content"] }}'
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
score_template=template, # Providing a template
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert full_prompt == "TEMPLATE_USED query doc"
|
||||
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_not_using_default_template(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
return_value="test querytest doc",
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"test query",
|
||||
"test doc",
|
||||
)
|
||||
|
||||
assert full_prompt == "test querytest doc"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_fallback_with_sep_token(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_sep_token=True."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config, # use_sep_token=True
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# Should have token_type_ids from text_pair encoding
|
||||
assert "token_type_ids" in engine_prompt
|
||||
assert "query" in full_prompt
|
||||
assert "document" in full_prompt
|
||||
assert full_prompt != "querydocument"
|
||||
assert (
|
||||
engine_prompt["prompt_token_ids"]
|
||||
== cross_encoder_tokenizer(
|
||||
"query", text_pair="document", add_special_tokens=True
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
# FIXME(?): add_special_tokens=False is needed because in this case
|
||||
# full_prompt is obtained by decoding the tokenized prompt, which includes
|
||||
# special tokens and we would get duplicated special tokens otherwise.
|
||||
# This is inconsistent with other cases.
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer,
|
||||
full_prompt,
|
||||
engine_prompt,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
def test_fallback_without_sep_token(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_sep_token=False."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config, # use_sep_token=False
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert full_prompt == "querydocument"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_post_process_tokens_called(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test that post_process_tokens is called on the engine prompt."""
|
||||
# Reset the call tracker
|
||||
mock_model_with_score_template.post_process_called.clear()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
)
|
||||
|
||||
# post_process_tokens should have been called once
|
||||
assert len(mock_model_with_score_template.post_process_called) == 1
|
||||
assert mock_model_with_score_template.post_process_called[0] is engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
@@ -1,14 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import pybase64 as base64
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
|
||||
class ColBERTScoringHfRunner(torch.nn.Module):
|
||||
@@ -67,3 +76,32 @@ class EncoderScoringHfRunner(HfRunner):
|
||||
for pair in hf_embeddings
|
||||
]
|
||||
return torch.as_tensor(hf_outputs)
|
||||
|
||||
|
||||
def make_base64_image(
|
||||
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
|
||||
) -> str:
|
||||
"""Create a small solid-color PNG image and return its base64 data URI."""
|
||||
img = Image.new("RGB", (width, height), color)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def make_image_mm_param(
|
||||
image_uri: str,
|
||||
text: str | None = None,
|
||||
) -> ScoreMultiModalParam:
|
||||
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
|
||||
content: list = [
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url={"url": image_uri},
|
||||
),
|
||||
]
|
||||
if text is not None:
|
||||
content.append(
|
||||
ChatCompletionContentPartTextParam(type="text", text=text),
|
||||
)
|
||||
return ScoreMultiModalParam(content=content)
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_token_ids_prompts(llm: LLM):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_score_api(llm: LLM):
|
||||
err_msg = "Score API is only enabled for num_labels == 1."
|
||||
err_msg = "Scoring API is only enabled for num_labels == 1."
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user