[Frontend][3/n] Improve pooling entrypoints | scoring. (#28631)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -10,9 +10,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
|
||||
|
||||
async def accumulate_streaming_response(
|
||||
|
||||
@@ -105,7 +105,7 @@ def test_pooling_params(llm: LLM):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_score_api(llm: LLM):
|
||||
err_msg = "Score API is only enabled for num_labels == 1."
|
||||
err_msg = "Scoring API is only enabled for num_labels == 1."
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_score(server: RemoteOpenAIServer, model_name: str):
|
||||
# score api is only enabled for num_labels == 1.
|
||||
# Scoring API is only enabled for num_labels == 1.
|
||||
response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
@@ -405,7 +405,7 @@ async def test_score(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||
# rerank api is only enabled for num_labels == 1.
|
||||
# Scoring API is only enabled for num_labels == 1.
|
||||
response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
|
||||
@@ -7,7 +7,7 @@ import requests
|
||||
from tests.entrypoints.pooling.scoring.util import EncoderScoringHfRunner
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "BAAI/bge-base-en-v1.5"
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.multimodal.utils import encode_image_url, fetch_image
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .util import make_base64_image, make_image_mm_param
|
||||
|
||||
MODEL_NAME = "vidore/colpali-v1.3-hf"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
|
||||
# that supports encoder-only models on ROCm.
|
||||
attention_config = None
|
||||
if current_platform.is_rocm():
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
max_num_batched_tokens=32768,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.75,
|
||||
enforce_eager=True,
|
||||
seed=0,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_query_text_vs_docs_image(llm):
|
||||
"""Score a text query against image documents via the multimodal path."""
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
query = "Describe the red object"
|
||||
image_docs = [
|
||||
make_image_mm_param(red_image),
|
||||
make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
scores = llm.score(query, image_docs)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_query_text_vs_docs_mix(llm) -> None:
|
||||
"""Score a text query against a mix of text and image documents."""
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
scores = llm.score(query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_query_image_vs_docs_text(llm) -> None:
|
||||
"""Score an image query against text documents."""
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"Describe the red object.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
scores = llm.score(image_query, documents)
|
||||
|
||||
assert len(scores) == 2
|
||||
assert scores[0].outputs.score > scores[1].outputs.score
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
|
||||
from .util import ColBERTScoringHfRunner
|
||||
|
||||
|
||||
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.entrypoints.pooling.scoring.util import (
|
||||
make_base64_image,
|
||||
make_image_mm_param,
|
||||
)
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankResponse, ScoreResponse
|
||||
|
||||
MODEL_NAME = "vidore/colpali-v1.3-hf"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with RemoteOpenAIServer(MODEL_NAME, []) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_api_query_text_vs_docs_image(server: RemoteOpenAIServer):
|
||||
query = "Describe the red object"
|
||||
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
documents = [
|
||||
make_image_mm_param(red_image),
|
||||
make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"queries": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
scores = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert scores.id is not None
|
||||
assert scores.data is not None
|
||||
assert len(scores.data) == 2
|
||||
assert scores.data[0].score > scores.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_api_query_text_vs_docs_mix(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"queries": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
scores = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert scores.id is not None
|
||||
assert scores.data is not None
|
||||
assert len(scores.data) == 2
|
||||
assert scores.data[0].score > scores.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_score_api_query_image_vs_docs_text(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"Describe the red object.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
score_response = requests.post(
|
||||
server.url_for("score"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"queries": image_query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
score_response.raise_for_status()
|
||||
scores = ScoreResponse.model_validate(score_response.json())
|
||||
|
||||
assert scores.id is not None
|
||||
assert scores.data is not None
|
||||
assert len(scores.data) == 2
|
||||
assert scores.data[0].score > scores.data[1].score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_api_query_text_vs_docs_image(server: RemoteOpenAIServer):
|
||||
query = "Describe the red object"
|
||||
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
blue_image = make_base64_image(64, 64, color=(0, 0, 255))
|
||||
|
||||
documents = [
|
||||
make_image_mm_param(red_image),
|
||||
make_image_mm_param(blue_image),
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={"model": MODEL_NAME, "query": query, "documents": documents},
|
||||
)
|
||||
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
red_result = next(r for r in rerank.results if r.index == 0)
|
||||
blue_result = next(r for r in rerank.results if r.index == 1)
|
||||
|
||||
assert red_result.relevance_score > blue_result.relevance_score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_api_query_text_vs_docs_mix(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
query = "What is the capital of France?"
|
||||
documents: list = [
|
||||
"The capital of France is Paris.",
|
||||
make_image_mm_param(red_image),
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
result0 = next(r for r in rerank.results if r.index == 0)
|
||||
result1 = next(r for r in rerank.results if r.index == 1)
|
||||
|
||||
assert result0.relevance_score > result1.relevance_score
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_api_query_image_vs_docs_text(server: RemoteOpenAIServer):
|
||||
red_image = make_base64_image(64, 64, color=(255, 0, 0))
|
||||
image_query = make_image_mm_param(red_image, text="red color")
|
||||
|
||||
documents = [
|
||||
"Describe the red object.",
|
||||
"The capital of France is Paris.",
|
||||
]
|
||||
|
||||
rerank_response = requests.post(
|
||||
server.url_for("rerank"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"query": image_query,
|
||||
"documents": documents,
|
||||
},
|
||||
)
|
||||
rerank_response.raise_for_status()
|
||||
rerank = RerankResponse.model_validate(rerank_response.json())
|
||||
|
||||
assert rerank.id is not None
|
||||
assert rerank.results is not None
|
||||
assert len(rerank.results) == 2
|
||||
|
||||
result0 = next(r for r in rerank.results if r.index == 0)
|
||||
result1 = next(r for r in rerank.results if r.index == 1)
|
||||
|
||||
assert result0.relevance_score > result1.relevance_score
|
||||
@@ -1,353 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
get_score_prompt,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
# A cross-encoder model for testing
|
||||
CROSS_ENCODER_MODEL_ID = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
def assert_prompt_tokenization_consistent(
|
||||
tokenizer, full_prompt, engine_prompt, add_special_tokens=True
|
||||
):
|
||||
"""Verify that engine_prompt token_ids match tokenizing full_prompt."""
|
||||
expected_ids = tokenizer(full_prompt, add_special_tokens=add_special_tokens)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert actual_ids == expected_ids, (
|
||||
f"Token IDs don't match.\nExpected: {expected_ids}\nActual: {actual_ids}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_model_config():
|
||||
return ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cross_encoder_tokenizer(cross_encoder_model_config):
|
||||
return get_tokenizer(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
trust_remote_code=cross_encoder_model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm_reranker_model_config():
|
||||
"""Model config for LLM-as-reranker style (no pad token)."""
|
||||
config = ModelConfig(
|
||||
CROSS_ENCODER_MODEL_ID,
|
||||
runner="pooling",
|
||||
)
|
||||
# use_sep_token is a property that reads from hf_config,
|
||||
# so we set it there to override the default (True)
|
||||
config.hf_config.use_sep_token = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tokenization_kwargs():
|
||||
"""Common tokenization kwargs used across tests."""
|
||||
return {"add_special_tokens": True, "return_tensors": None}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_with_score_template():
|
||||
"""Mock model class that supports score template and tracks post_process calls."""
|
||||
|
||||
class MockModelWithScoreTemplate:
|
||||
supports_score_template = True
|
||||
post_process_called: list[TokensPrompt] = []
|
||||
|
||||
@staticmethod
|
||||
def get_score_template(p1: str, p2: str) -> str:
|
||||
return f"[QUERY]{p1}[SEP][DOC]{p2}"
|
||||
|
||||
@staticmethod
|
||||
def post_process_tokens(prompt: TokensPrompt) -> None:
|
||||
MockModelWithScoreTemplate.post_process_called.append(prompt)
|
||||
|
||||
return MockModelWithScoreTemplate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_no_score_template():
|
||||
"""Mock model class that does not support score template."""
|
||||
|
||||
class MockModelNoScoreTemplate:
|
||||
supports_score_template = False
|
||||
|
||||
return MockModelNoScoreTemplate
|
||||
|
||||
|
||||
class TestGetScorePrompt:
|
||||
"""Tests for the get_score_prompt function."""
|
||||
|
||||
def test_tokenization_kwargs_passed_through(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
):
|
||||
"""Test that tokenization kwargs are properly passed through."""
|
||||
data_1 = "Query text"
|
||||
data_2 = "Document text"
|
||||
|
||||
# Test with truncation - custom kwargs for this test
|
||||
custom_tokenization_kwargs = {
|
||||
"add_special_tokens": True,
|
||||
"return_tensors": None,
|
||||
"truncation": True,
|
||||
"max_length": 20,
|
||||
}
|
||||
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
custom_tokenization_kwargs,
|
||||
data_1,
|
||||
data_2,
|
||||
)
|
||||
|
||||
assert isinstance(full_prompt, str)
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# With max_length=20 and truncation, should not exceed this
|
||||
assert len(engine_prompt["prompt_token_ids"]) <= 20
|
||||
# Since truncation was applied, token_ids should be a prefix of full encoding
|
||||
full_ids = cross_encoder_tokenizer(full_prompt, add_special_tokens=True)[
|
||||
"input_ids"
|
||||
]
|
||||
actual_ids = engine_prompt["prompt_token_ids"]
|
||||
assert full_ids[: len(actual_ids)] == actual_ids, (
|
||||
f"Token IDs are not a prefix of full encoding.\n"
|
||||
f"Full IDs: {full_ids}\n"
|
||||
f"Actual IDs: {actual_ids}"
|
||||
)
|
||||
|
||||
def test_model_supports_score_template(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template (no score_template arg)."""
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query text",
|
||||
"document text",
|
||||
)
|
||||
|
||||
assert full_prompt == "[QUERY]query text[SEP][DOC]document text"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert len(engine_prompt["prompt_token_ids"]) > 0
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_model_supports_score_template_but_custom_template_provided(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test when model supports score template but custom template is provided."""
|
||||
template = (
|
||||
'TEMPLATE_USED {{ messages[0]["content"] }} {{ messages[1]["content"] }}'
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
score_template=template, # Providing a template
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert full_prompt == "TEMPLATE_USED query doc"
|
||||
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_not_using_default_template(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
return_value="test querytest doc",
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"test query",
|
||||
"test doc",
|
||||
)
|
||||
|
||||
assert full_prompt == "test querytest doc"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_fallback_with_sep_token(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_sep_token=True."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config, # use_sep_token=True
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
# Should have token_type_ids from text_pair encoding
|
||||
assert "token_type_ids" in engine_prompt
|
||||
assert "query" in full_prompt
|
||||
assert "document" in full_prompt
|
||||
assert full_prompt != "querydocument"
|
||||
assert (
|
||||
engine_prompt["prompt_token_ids"]
|
||||
== cross_encoder_tokenizer(
|
||||
"query", text_pair="document", add_special_tokens=True
|
||||
)["input_ids"]
|
||||
)
|
||||
|
||||
# FIXME(?): add_special_tokens=False is needed because in this case
|
||||
# full_prompt is obtained by decoding the tokenized prompt, which includes
|
||||
# special tokens and we would get duplicated special tokens otherwise.
|
||||
# This is inconsistent with other cases.
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer,
|
||||
full_prompt,
|
||||
engine_prompt,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
def test_fallback_without_sep_token(
|
||||
self,
|
||||
llm_reranker_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_no_score_template,
|
||||
):
|
||||
"""Test fallback path when ChatTemplateResolutionError
|
||||
and use_sep_token=False."""
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
llm_reranker_model_config, # use_sep_token=False
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"document",
|
||||
)
|
||||
|
||||
assert full_prompt == "querydocument"
|
||||
assert "prompt_token_ids" in engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
|
||||
def test_post_process_tokens_called(
|
||||
self,
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
mock_model_with_score_template,
|
||||
):
|
||||
"""Test that post_process_tokens is called on the engine prompt."""
|
||||
# Reset the call tracker
|
||||
mock_model_with_score_template.post_process_called.clear()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.model_executor.model_loader.get_model_cls",
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.pooling.score.utils.safe_apply_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
full_prompt, engine_prompt = get_score_prompt(
|
||||
cross_encoder_model_config,
|
||||
cross_encoder_tokenizer,
|
||||
tokenization_kwargs,
|
||||
"query",
|
||||
"doc",
|
||||
)
|
||||
|
||||
# post_process_tokens should have been called once
|
||||
assert len(mock_model_with_score_template.post_process_called) == 1
|
||||
assert mock_model_with_score_template.post_process_called[0] is engine_prompt
|
||||
assert_prompt_tokenization_consistent(
|
||||
cross_encoder_tokenizer, full_prompt, engine_prompt
|
||||
)
|
||||
@@ -1,14 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import pybase64 as base64
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from tests.conftest import HfRunner
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
|
||||
class ColBERTScoringHfRunner(torch.nn.Module):
|
||||
@@ -67,3 +76,32 @@ class EncoderScoringHfRunner(HfRunner):
|
||||
for pair in hf_embeddings
|
||||
]
|
||||
return torch.as_tensor(hf_outputs)
|
||||
|
||||
|
||||
def make_base64_image(
|
||||
width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
|
||||
) -> str:
|
||||
"""Create a small solid-color PNG image and return its base64 data URI."""
|
||||
img = Image.new("RGB", (width, height), color)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
|
||||
def make_image_mm_param(
|
||||
image_uri: str,
|
||||
text: str | None = None,
|
||||
) -> ScoreMultiModalParam:
|
||||
"""Build a ScoreMultiModalParam containing an image (and optional text)."""
|
||||
content: list = [
|
||||
ChatCompletionContentPartImageParam(
|
||||
type="image_url",
|
||||
image_url={"url": image_uri},
|
||||
),
|
||||
]
|
||||
if text is not None:
|
||||
content.append(
|
||||
ChatCompletionContentPartTextParam(type="text", text=text),
|
||||
)
|
||||
return ScoreMultiModalParam(content=content)
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_token_ids_prompts(llm: LLM):
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_score_api(llm: LLM):
|
||||
err_msg = "Score API is only enabled for num_labels == 1."
|
||||
err_msg = "Scoring API is only enabled for num_labels == 1."
|
||||
with pytest.raises(ValueError, match=err_msg):
|
||||
llm.score("ping", "pong", use_tqdm=False)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ generic ColBERT support works with different encoder architectures.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
|
||||
|
||||
@@ -10,7 +10,7 @@ embeddings for visual document retrieval.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
MODEL_NAME = "ModernVBERT/colmodernvbert-merged"
|
||||
COLBERT_DIM = 128
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
|
||||
@@ -114,7 +114,7 @@ def _run_late_interaction_test(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Verify MaxSim scoring matches manual computation."""
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
|
||||
@@ -125,7 +125,7 @@ def _run_late_interaction_test(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Verify MaxSim scoring matches manual computation."""
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
|
||||
@@ -73,7 +73,7 @@ def _run_late_interaction_test(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
"""Verify MaxSim scoring matches manual computation."""
|
||||
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
|
||||
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
|
||||
|
||||
@@ -46,22 +46,16 @@ from vllm.entrypoints.chat_utils import (
|
||||
load_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
get_score_prompt,
|
||||
score_data_to_prompts,
|
||||
validate_score_input,
|
||||
from vllm.entrypoints.pooling.scoring.io_processor import (
|
||||
ScoringIOProcessor,
|
||||
)
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoreInput
|
||||
from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext
|
||||
from vllm.entrypoints.utils import log_non_default_args
|
||||
from vllm.inputs import (
|
||||
DataPrompt,
|
||||
EngineInput,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
@@ -1161,7 +1155,9 @@ class LLM:
|
||||
if pooling_task in self.pooling_io_processors:
|
||||
io_processor = self.pooling_io_processors[pooling_task]
|
||||
processor_inputs = io_processor.pre_process_offline(
|
||||
prompts_seq, tokenization_kwargs
|
||||
ctx=OfflineInputsContext(
|
||||
prompts=prompts_seq, tokenization_kwargs=tokenization_kwargs
|
||||
)
|
||||
)
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(prompts_seq)
|
||||
@@ -1178,7 +1174,9 @@ class LLM:
|
||||
outputs = self._run_engine(
|
||||
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
|
||||
)
|
||||
outputs = io_processor.post_process_offline(outputs)
|
||||
outputs = io_processor.post_process_offline(
|
||||
ctx=OfflineOutputsContext(outputs=outputs)
|
||||
)
|
||||
else:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
@@ -1378,188 +1376,10 @@ class LLM:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
def _embedding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> list[ScoringRequestOutput]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
input_texts: list[str] = []
|
||||
for text in data_1 + data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Embedding scores currently do not support multimodal input."
|
||||
)
|
||||
input_texts.append(text)
|
||||
|
||||
encoded_output = self.encode(
|
||||
input_texts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
pooling_task="embed",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1 = encoded_output[0 : len(data_1)]
|
||||
encoded_output_2 = encoded_output[len(data_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
scores = _cosine_similarity(
|
||||
tokenizer=tokenizer,
|
||||
embed_1=encoded_output_1,
|
||||
embed_2=encoded_output_2,
|
||||
)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in scores]
|
||||
|
||||
def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
from vllm.outputs import PoolingOutput
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Convert ScoreData to PromptType (handles both text and multimodal)
|
||||
model_config = self.model_config
|
||||
prompts_1 = score_data_to_prompts(data_1, "query", model_config)
|
||||
prompts_2 = score_data_to_prompts(data_2, "document", model_config)
|
||||
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
prompts_1 + prompts_2,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
pooling_task="token_embed",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
|
||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
# Compute MaxSim scores
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=PoolingOutput(data=maxsim_score),
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in scores]
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
score_template: str | None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError("Score API is not supported for Mistral tokenizer")
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
if pooling_params is None:
|
||||
pooling_params = PoolingParams(task="classify")
|
||||
elif pooling_params.task is None:
|
||||
pooling_params.task = "classify"
|
||||
|
||||
pooling_params_list = list[PoolingParams]()
|
||||
|
||||
prompts = list[PromptType]()
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
for q, d in input_pairs:
|
||||
_, engine_prompt = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=q,
|
||||
data_2=d,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=score_template,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
params = pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
pooling_params_list.append(params)
|
||||
else:
|
||||
pooling_params_list.append(pooling_params)
|
||||
|
||||
prompts.append(engine_prompt)
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts,
|
||||
params=pooling_params_list,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
data_1: SingletonPrompt
|
||||
| Sequence[SingletonPrompt]
|
||||
| ScoreMultiModalParam
|
||||
| list[ScoreMultiModalParam],
|
||||
data_2: SingletonPrompt
|
||||
| Sequence[SingletonPrompt]
|
||||
| ScoreMultiModalParam
|
||||
| list[ScoreMultiModalParam],
|
||||
data_1: ScoreInput | list[ScoreInput],
|
||||
data_2: ScoreInput | list[ScoreInput],
|
||||
/,
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
@@ -1606,83 +1426,71 @@ class LLM:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
if self.runner_type != "pooling":
|
||||
raise ValueError(
|
||||
"LLM.score() is only supported for pooling models. "
|
||||
"Try passing `--runner pooling` to use the model as a "
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
supported_tasks = self.supported_tasks
|
||||
score_type = self.model_config.score_type
|
||||
is_late_interaction = score_type == "late-interaction"
|
||||
is_cross_encoder = score_type == "cross-encoder"
|
||||
|
||||
# Late interaction models (e.g., ColBERT) use token_embed for scoring
|
||||
if not is_late_interaction and all(
|
||||
t not in supported_tasks for t in ("embed", "classify")
|
||||
if (
|
||||
score_type == "cross-encoder"
|
||||
and getattr(self.model_config.hf_config, "num_labels", 0) != 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Score API is not supported by this model. "
|
||||
"Try converting the model using "
|
||||
"`--convert embed` or `--convert classify`."
|
||||
)
|
||||
raise ValueError("Scoring API is only enabled for num_labels == 1.")
|
||||
|
||||
if is_cross_encoder and getattr(model_config.hf_config, "num_labels", 0) != 1:
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
if score_type is None or score_type not in self.pooling_io_processors:
|
||||
raise ValueError("This model does not support the Scoring API.")
|
||||
|
||||
if not is_cross_encoder and chat_template is not None:
|
||||
raise ValueError(
|
||||
"chat_template is only supported for cross-encoder models."
|
||||
)
|
||||
io_processor = self.pooling_io_processors[score_type]
|
||||
assert isinstance(io_processor, ScoringIOProcessor)
|
||||
|
||||
is_multimodal_model = model_config.is_multimodal_model
|
||||
architecture = model_config.architecture
|
||||
pooling_task = io_processor.pooling_task
|
||||
scoring_data = io_processor.valid_inputs(data_1, data_2)
|
||||
offset = len(scoring_data.data_1)
|
||||
|
||||
score_data_1, score_data_2 = validate_score_input(
|
||||
data_1, # type: ignore[arg-type]
|
||||
data_2, # type: ignore[arg-type]
|
||||
is_multimodal_model=is_multimodal_model,
|
||||
architecture=architecture,
|
||||
ctx = OfflineInputsContext(
|
||||
prompts=scoring_data,
|
||||
pooling_params=pooling_params,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
chat_template=chat_template,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
renderer = self.renderer
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
encode_kwargs = tok_params.get_encode_kwargs()
|
||||
processor_inputs = io_processor.pre_process_offline(ctx)
|
||||
|
||||
if is_cross_encoder:
|
||||
return self._cross_encoding_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
score_template=chat_template,
|
||||
)
|
||||
elif is_late_interaction:
|
||||
return self._late_interaction_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
score_data_1,
|
||||
score_data_2,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
)
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(processor_inputs)
|
||||
)
|
||||
|
||||
if ctx.pooling_params is None:
|
||||
ctx.pooling_params = PoolingParams()
|
||||
params_seq = self._params_to_seq(ctx.pooling_params, len(processor_inputs))
|
||||
|
||||
for param in params_seq:
|
||||
if param.task is None:
|
||||
param.task = pooling_task
|
||||
elif param.task != pooling_task:
|
||||
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
seq_priority = self._priority_to_seq(None, len(processor_inputs))
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=processor_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
|
||||
outputs = io_processor.post_process_offline(
|
||||
ctx=OfflineOutputsContext(outputs=outputs, offset=offset),
|
||||
)
|
||||
|
||||
return [ScoringRequestOutput.from_base(item) for item in outputs]
|
||||
|
||||
def start_profile(self, profile_prefix: str | None = None) -> None:
|
||||
"""Start profiling with optional custom trace prefix.
|
||||
|
||||
@@ -11,9 +11,7 @@ from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from openai.types.responses import (
|
||||
ToolChoiceFunction,
|
||||
)
|
||||
from openai.types.responses import ToolChoiceFunction
|
||||
from pydantic import ConfigDict, TypeAdapter, ValidationError
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -21,9 +19,7 @@ import vllm.envs as envs
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
BatchChatCompletionRequest,
|
||||
@@ -42,9 +38,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
GenerationError,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
@@ -56,14 +50,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreDataRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreTextRequest,
|
||||
)
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
@@ -72,8 +58,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EngineInput, PromptType, TokensPrompt
|
||||
from vllm.inputs import EngineInput, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -119,8 +104,6 @@ CompletionLikeRequest: TypeAlias = (
|
||||
CompletionRequest
|
||||
| TokenizeCompletionRequest
|
||||
| DetokenizeRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
)
|
||||
|
||||
@@ -148,7 +131,6 @@ AnyResponse: TypeAlias = (
|
||||
| TranscriptionResponse
|
||||
| TokenizeResponse
|
||||
| PoolingResponse
|
||||
| ScoreResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
@@ -692,88 +674,6 @@ class OpenAIServing:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: object,
|
||||
input_ids: list[int],
|
||||
input_text: str,
|
||||
) -> TokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
# Note: ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
(
|
||||
ScoreDataRequest,
|
||||
ScoreTextRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
RerankRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
# since these requests don't generate tokens.
|
||||
if token_num > max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreDataRequest: "score",
|
||||
ScoreTextRequest: "score",
|
||||
ScoreQueriesDocumentsRequest: "score",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input prompt.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(
|
||||
request,
|
||||
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
|
||||
):
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
# Note: input length can be up to model context length - 1 for
|
||||
# completion-like requests.
|
||||
if token_num >= max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{max_model_len} tokens. However, your request has "
|
||||
f"{token_num} input tokens. Please reduce the length of "
|
||||
"the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=token_num,
|
||||
)
|
||||
|
||||
if max_tokens is not None and token_num + max_tokens > max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is "
|
||||
f"{max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens} output tokens and your prompt contains "
|
||||
f"{token_num} input tokens, for a total of "
|
||||
f"{token_num + max_tokens} tokens "
|
||||
f"({token_num} + {max_tokens} = "
|
||||
f"{token_num + max_tokens} > {max_model_len}). "
|
||||
f"Please reduce the length of the input prompt or the "
|
||||
f"number of requested output tokens.",
|
||||
parameter="max_tokens",
|
||||
value=max_tokens,
|
||||
)
|
||||
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import sys
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
@@ -13,12 +15,14 @@ from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import pybase64 as base64
|
||||
import pydantic
|
||||
import torch
|
||||
from fastapi import UploadFile
|
||||
from prometheus_client import start_http_server
|
||||
from pydantic import Field, TypeAdapter, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
from starlette.datastructures import State
|
||||
from starlette.responses import JSONResponse
|
||||
from tqdm import tqdm
|
||||
from urllib3.util import parse_url
|
||||
|
||||
@@ -49,7 +53,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
from vllm.entrypoints.pooling.scoring.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
@@ -180,6 +184,18 @@ class BatchRequestInput(OpenAIBaseModel):
|
||||
return TypeAdapter(BatchRequestInputBody).validate_python(value)
|
||||
|
||||
|
||||
AllResponse: TypeAlias = (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
)
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
# HTTP status code of the response.
|
||||
status_code: int = 200
|
||||
@@ -188,17 +204,7 @@ class BatchResponseData(OpenAIBaseModel):
|
||||
request_id: str
|
||||
|
||||
# The body of the response.
|
||||
body: (
|
||||
ChatCompletionResponse
|
||||
| EmbeddingResponse
|
||||
| ScoreResponse
|
||||
| RerankResponse
|
||||
| TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| None
|
||||
) = None
|
||||
body: AllResponse | None = None
|
||||
|
||||
|
||||
class BatchRequestOutput(OpenAIBaseModel):
|
||||
@@ -536,19 +542,13 @@ async def run_request(
|
||||
except Exception as e:
|
||||
response = create_error_response(e)
|
||||
|
||||
if isinstance(
|
||||
response,
|
||||
(
|
||||
ChatCompletionResponse,
|
||||
EmbeddingResponse,
|
||||
ScoreResponse,
|
||||
RerankResponse,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseVerbose,
|
||||
TranslationResponse,
|
||||
TranslationResponseVerbose,
|
||||
),
|
||||
):
|
||||
if isinstance(response, JSONResponse):
|
||||
with contextlib.suppress(pydantic.ValidationError):
|
||||
response = TypeAdapter(AllResponse | ErrorResponse).validate_python(
|
||||
json.loads(response.body)
|
||||
)
|
||||
|
||||
if isinstance(response, AllResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
@@ -745,14 +745,14 @@ async def build_endpoint_registry(
|
||||
"score": {
|
||||
"url_matcher": lambda url: url.endswith("/score"),
|
||||
"handler_getter": lambda: (
|
||||
serving_scores.create_score if serving_scores is not None else None
|
||||
serving_scores if serving_scores is not None else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
"rerank": {
|
||||
"url_matcher": lambda url: url.endswith("/rerank"),
|
||||
"handler_getter": lambda: (
|
||||
serving_scores.do_rerank if serving_scores is not None else None
|
||||
serving_scores if serving_scores is not None else None
|
||||
),
|
||||
"wrapper_fn": None,
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
||||
from fastapi import FastAPI
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.pooling.utils import enable_scoring_api
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -23,23 +24,6 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def enable_scoring_api(
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> bool:
|
||||
if any(t in supported_tasks for t in ("embed", "token_embed")):
|
||||
return True
|
||||
|
||||
if model_config is not None and "classify" in supported_tasks:
|
||||
num_labels = getattr(model_config.hf_config, "num_labels", 0)
|
||||
if num_labels != 1:
|
||||
logger.debug_once("Score API is only enabled for num_labels == 1.")
|
||||
return False
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def register_pooling_api_routers(
|
||||
app: FastAPI,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
@@ -68,7 +52,7 @@ def register_pooling_api_routers(
|
||||
app.include_router(embed_router)
|
||||
|
||||
if enable_scoring_api(supported_tasks, model_config):
|
||||
from vllm.entrypoints.pooling.score.api_router import router as score_router
|
||||
from vllm.entrypoints.pooling.scoring.api_router import router as score_router
|
||||
|
||||
app.include_router(score_router)
|
||||
|
||||
@@ -84,7 +68,7 @@ def init_pooling_state(
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import ServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.pooling.scoring.serving import ServingScores
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
model_config = engine_client.model_config
|
||||
@@ -136,8 +120,9 @@ def init_pooling_state(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
)
|
||||
if enable_scoring_api(supported_tasks, model_config)
|
||||
else None
|
||||
|
||||
@@ -13,13 +13,16 @@ from vllm.entrypoints.chat_utils import (
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoringData
|
||||
from vllm.entrypoints.pooling.typing import (
|
||||
OfflineInputsContext,
|
||||
OfflineOutputsContext,
|
||||
PoolingChatLikeRequest,
|
||||
PoolingCompletionLikeRequest,
|
||||
PoolingServeContext,
|
||||
)
|
||||
from vllm.inputs import EngineInput, SingletonPrompt
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers import BaseRenderer, TokenizeParams, merge_kwargs
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
@@ -96,29 +99,29 @@ class PoolingIOProcessor:
|
||||
#######################################
|
||||
# offline APIs
|
||||
|
||||
def pre_process_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[EngineInput]:
|
||||
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
|
||||
assert not isinstance(ctx.prompts, ScoringData)
|
||||
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(ctx.tokenization_kwargs or {})
|
||||
)
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=prompts, tokenization_kwargs=tokenization_kwargs
|
||||
prompts=ctx.prompts, tok_params=tok_params
|
||||
)
|
||||
|
||||
async def pre_process_offline_async(self, *args, **kwargs):
|
||||
return self.pre_process_offline(*args, **kwargs)
|
||||
async def pre_process_offline_async(self, ctx: OfflineInputsContext):
|
||||
return self.pre_process_offline(ctx)
|
||||
|
||||
def post_process_offline(
|
||||
self,
|
||||
outputs: list[PoolingRequestOutput],
|
||||
ctx: OfflineOutputsContext,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return outputs
|
||||
return ctx.outputs
|
||||
|
||||
async def post_process_offline_async(
|
||||
self,
|
||||
outputs: list[PoolingRequestOutput],
|
||||
ctx: OfflineOutputsContext,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
return self.post_process_offline(outputs)
|
||||
return self.post_process_offline(ctx)
|
||||
|
||||
#######################################
|
||||
# helpers
|
||||
@@ -204,28 +207,21 @@ class PoolingIOProcessor:
|
||||
def _preprocess_completion_offline(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
tok_params: TokenizeParams,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
) -> Sequence[EngineInput]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
prompts = prompt_to_seq(prompts)
|
||||
|
||||
parsed_prompts = [
|
||||
(
|
||||
prompt
|
||||
if isinstance(prompt, bytes)
|
||||
else parse_model_prompt(model_config, prompt)
|
||||
else parse_model_prompt(self.model_config, prompt)
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
return renderer.render_cmpl(
|
||||
parsed_prompts,
|
||||
tok_params,
|
||||
return self.renderer.render_cmpl(
|
||||
parsed_prompts, tok_params, prompt_extras=prompt_extras
|
||||
)
|
||||
|
||||
def _validate_chat_template(
|
||||
|
||||
@@ -117,8 +117,16 @@ class PoolingServing:
|
||||
else await self._get_trace_headers(ctx.raw_request.headers)
|
||||
)
|
||||
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
pooling_params.verify(self.model_config)
|
||||
if ctx.pooling_params is None:
|
||||
pooling_params = self.io_processor.create_pooling_params(ctx.request)
|
||||
else:
|
||||
pooling_params = ctx.pooling_params
|
||||
|
||||
if isinstance(pooling_params, list):
|
||||
for params in pooling_params:
|
||||
params.verify(self.model_config)
|
||||
else:
|
||||
pooling_params.verify(self.model_config)
|
||||
|
||||
for i, engine_input in enumerate(ctx.engine_inputs):
|
||||
prompt_request_id = (
|
||||
@@ -127,16 +135,22 @@ class PoolingServing:
|
||||
else ctx.prompt_request_ids[i]
|
||||
)
|
||||
|
||||
params = (
|
||||
pooling_params[i]
|
||||
if isinstance(pooling_params, list)
|
||||
else pooling_params
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
prompt_request_id,
|
||||
engine_input,
|
||||
params=pooling_params,
|
||||
params=params,
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_input,
|
||||
pooling_params,
|
||||
params,
|
||||
prompt_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
trace_headers=trace_headers,
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessors
|
||||
from vllm.entrypoints.pooling.utils import enable_scoring_api
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.tasks import SupportedTask
|
||||
|
||||
@@ -25,6 +27,11 @@ def init_pooling_io_processors(
|
||||
|
||||
processors.append(("embed", EmbedIOProcessor))
|
||||
|
||||
if enable_scoring_api(supported_tasks, model_config):
|
||||
score_type = model_config.score_type
|
||||
if score_type is not None and score_type in ScoringIOProcessors:
|
||||
processors.append((score_type, ScoringIOProcessors[score_type]))
|
||||
|
||||
return {
|
||||
task: processor_cls(
|
||||
model_config=model_config,
|
||||
|
||||
@@ -1,667 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankDocument,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResult,
|
||||
RerankUsage,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreInputs,
|
||||
_cosine_similarity,
|
||||
compress_token_type_ids,
|
||||
get_score_prompt,
|
||||
parse_score_data_single,
|
||||
validate_score_input,
|
||||
)
|
||||
from vllm.inputs import EngineInput, TokensPrompt, tokens_input
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import make_async, merge_async_iterators
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.v1.pool.late_interaction import (
|
||||
build_late_interaction_doc_params,
|
||||
build_late_interaction_query_params,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
score_template: str | None = None,
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
self.score_template = score_template
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
self.score_type = self.model_config.score_type
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
|
||||
if self.score_type == "cross-encoder":
|
||||
self._score_func = self._cross_encoding_score
|
||||
elif self.score_type == "late-interaction":
|
||||
self._score_func = self._late_interaction_score
|
||||
else: # "bi-encoder"
|
||||
self._score_func = self._embedding_score
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
input_texts: list[str] = []
|
||||
for text in data_1 + data_2:
|
||||
if not isinstance(text, str):
|
||||
raise NotImplementedError(
|
||||
"Embedding scores currently do not support multimodal input."
|
||||
)
|
||||
input_texts.append(text)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
encode_async = make_async(
|
||||
tokenizer.encode,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_inputs: list[EngineInput] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_inputs.append(
|
||||
tokens_input(
|
||||
text_token_prompt["prompt_token_ids"],
|
||||
prompt=input_text,
|
||||
)
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
pooling_params = request.to_pooling_params("embed")
|
||||
|
||||
for i, engine_input in enumerate(engine_inputs):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_input,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(
|
||||
self.engine_client.encode(
|
||||
engine_input,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
|
||||
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_inputs)
|
||||
|
||||
async for i, res in result_generator:
|
||||
embeddings[i] = res
|
||||
|
||||
emb_data_1: list[PoolingRequestOutput] = []
|
||||
emb_data_2: list[PoolingRequestOutput] = []
|
||||
|
||||
for i in range(0, len(data_1)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_1.append(emb)
|
||||
|
||||
for i in range(len(data_1), len(embeddings)):
|
||||
assert (emb := embeddings[i]) is not None
|
||||
emb_data_2.append(emb)
|
||||
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
final_res_batch = _cosine_similarity(
|
||||
tokenizer=tokenizer, embed_1=emb_data_1, embed_2=emb_data_2
|
||||
)
|
||||
|
||||
return final_res_batch
|
||||
|
||||
def _preprocess_late_interaction_item(
|
||||
self,
|
||||
data: ScoreData,
|
||||
role: str,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> TokensPrompt:
|
||||
"""Parse a single ScoreData into a text + optional multimodal
|
||||
TokensPrompt for late-interaction encoding.
|
||||
|
||||
For plain strings, tokenises directly.
|
||||
For multimodal content parts, extracts text and multi_modal_data.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
|
||||
if isinstance(data, str):
|
||||
text, mm_data, mm_uuids = data, None, None
|
||||
else:
|
||||
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
|
||||
|
||||
prompt_ids = tokenizer.encode(text, **tokenization_kwargs)
|
||||
self._validate_input(request, prompt_ids, text)
|
||||
|
||||
tok_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_ids,
|
||||
prompt=text,
|
||||
)
|
||||
|
||||
if mm_data is not None:
|
||||
tok_prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
tok_prompt["multi_modal_uuids"] = mm_uuids
|
||||
if request.mm_processor_kwargs is not None:
|
||||
tok_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return tok_prompt
|
||||
|
||||
async def _late_interaction_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
"""
|
||||
Late interaction scoring (ColBERT MaxSim).
|
||||
|
||||
Encodes queries and documents into per-token embeddings, then computes
|
||||
MaxSim: sum over query tokens of max similarity to any document token.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
|
||||
all_data = data_1 + data_2
|
||||
roles = ["query"] * len(data_1) + ["document"] * len(data_2)
|
||||
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_late_interaction_item,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
tok_prompts = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
data=d,
|
||||
role=r,
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
for d, r in zip(all_data, roles)
|
||||
)
|
||||
)
|
||||
|
||||
query_prompts = tok_prompts[: len(data_1)]
|
||||
doc_prompts = tok_prompts[len(data_1) :]
|
||||
|
||||
default_pooling_params = request.to_pooling_params("token_embed")
|
||||
|
||||
# stage 1: encode queries and cache token embeddings on workers.
|
||||
query_keys = [f"{request_id}-query-{i}" for i in range(len(query_prompts))]
|
||||
query_uses = [len(doc_prompts) if len(query_prompts) == 1 else 1] * len(
|
||||
query_prompts
|
||||
)
|
||||
query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
for i, tok_prompt in enumerate(query_prompts):
|
||||
request_id_item = f"{request_id}-query-{i}"
|
||||
pooling_params = default_pooling_params.clone()
|
||||
pooling_params.late_interaction_params = (
|
||||
build_late_interaction_query_params(
|
||||
query_key=query_keys[i],
|
||||
query_uses=query_uses[i],
|
||||
)
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
tok_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
query_generators.append(
|
||||
self.engine_client.encode(
|
||||
tok_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
query_outputs: list[PoolingRequestOutput | None] = [None] * len(query_prompts)
|
||||
if query_generators:
|
||||
async for i, res in merge_async_iterators(*query_generators):
|
||||
query_outputs[i] = res
|
||||
|
||||
assert all(res is not None for res in query_outputs)
|
||||
query_results = [res for res in query_outputs if res is not None]
|
||||
|
||||
# stage 2: encode docs and return scalar scores from workers.
|
||||
doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
for i, tok_prompt in enumerate(doc_prompts):
|
||||
request_id_item = f"{request_id}-doc-{i}"
|
||||
query_idx = 0 if len(query_prompts) == 1 else i
|
||||
pooling_params = default_pooling_params.clone()
|
||||
pooling_params.late_interaction_params = build_late_interaction_doc_params(
|
||||
query_key=query_keys[query_idx]
|
||||
)
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
tok_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
doc_generators.append(
|
||||
self.engine_client.encode(
|
||||
tok_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
doc_outputs: list[PoolingRequestOutput | None] = [None] * len(doc_prompts)
|
||||
if doc_generators:
|
||||
async for i, res in merge_async_iterators(*doc_generators):
|
||||
doc_outputs[i] = res
|
||||
|
||||
assert all(res is not None for res in doc_outputs)
|
||||
doc_results = [res for res in doc_outputs if res is not None]
|
||||
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
if len(query_results) == 1:
|
||||
query_results = query_results * len(doc_results)
|
||||
|
||||
for query_result, doc_result in zip(query_results, doc_results):
|
||||
tokens = (
|
||||
query_result.prompt_token_ids + padding + doc_result.prompt_token_ids
|
||||
)
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{query_result.request_id}_{doc_result.request_id}",
|
||||
outputs=doc_result.outputs,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=(
|
||||
query_result.num_cached_tokens + doc_result.num_cached_tokens
|
||||
),
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
model_config = self.model_config
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_score,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
*(
|
||||
preprocess_async(
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2,
|
||||
)
|
||||
for t1, t2 in input_pairs
|
||||
)
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
default_pooling_params = request.to_pooling_params("classify")
|
||||
|
||||
for i, (full_prompt, tok_prompt) in enumerate(preprocessed_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
full_prompt,
|
||||
params=default_pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if token_type_ids := tok_prompt.pop("token_type_ids", None):
|
||||
pooling_params = default_pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
else:
|
||||
pooling_params = default_pooling_params
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
tok_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
|
||||
preprocessed_prompts
|
||||
)
|
||||
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
return [out for out in final_res_batch if out is not None]
|
||||
|
||||
def _preprocess_score(
|
||||
self,
|
||||
request: RerankRequest | ScoreRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
model_config = self.model_config
|
||||
full_prompt, engine_input = get_score_prompt(
|
||||
model_config=model_config,
|
||||
data_1=data_1,
|
||||
data_2=data_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
score_template=self.score_template,
|
||||
)
|
||||
self._validate_input(request, engine_input["prompt_token_ids"], full_prompt)
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_input["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
return full_prompt, engine_input
|
||||
|
||||
async def _run_scoring(
|
||||
self,
|
||||
data_1: ScoreInputs,
|
||||
data_2: ScoreInputs,
|
||||
request: ScoreRequest | RerankRequest,
|
||||
request_id: str,
|
||||
raw_request: Request | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
score_data_1, score_data_2 = validate_score_input(
|
||||
data_1,
|
||||
data_2,
|
||||
is_multimodal_model=self.is_multimodal_model,
|
||||
architecture=self.architecture,
|
||||
)
|
||||
|
||||
return await self._score_func(
|
||||
data_1=score_data_1,
|
||||
data_2=score_data_2,
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> ScoreResponse | ErrorResponse:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"score-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.data_1,
|
||||
request.data_2,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
return self.request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
self.models.model_name(),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
async def do_rerank(
|
||||
self, request: RerankRequest, raw_request: Request | None = None
|
||||
) -> RerankResponse | ErrorResponse:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
documents = request.documents
|
||||
|
||||
try:
|
||||
final_res_batch = await self._run_scoring(
|
||||
request.query,
|
||||
documents,
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
||||
top_n = request.top_n if request.top_n > 0 else len(final_res_batch)
|
||||
|
||||
return self.request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
self.models.model_name(),
|
||||
documents,
|
||||
top_n,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
def request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> ScoreResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
documents: ScoreInputs,
|
||||
top_n: int,
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
|
||||
if not isinstance(documents, list):
|
||||
documents = [documents]
|
||||
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
document = documents[idx]
|
||||
if isinstance(document, str):
|
||||
rerank_document = RerankDocument(text=document)
|
||||
else:
|
||||
rerank_document = RerankDocument(
|
||||
multi_modal=document.get("content", [])
|
||||
)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=rerank_document,
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(
|
||||
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
|
||||
),
|
||||
)
|
||||
@@ -3,21 +3,15 @@
|
||||
from http import HTTPStatus
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .protocol import RerankRequest, ScoreRequest
|
||||
from .serving import ServingScores
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -46,16 +40,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Score API")
|
||||
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -92,16 +77,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
if handler is None:
|
||||
raise NotImplementedError("The model does not support Rerank (Score) API")
|
||||
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
419
vllm/entrypoints/pooling/scoring/io_processor.py
Normal file
419
vllm/entrypoints/pooling/scoring/io_processor.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import PoolingParams, PoolingRequestOutput, TokensPrompt
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.typing import (
|
||||
OfflineInputsContext,
|
||||
OfflineOutputsContext,
|
||||
PoolingServeContext,
|
||||
)
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tasks import PoolingTask, ScoreType
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
from ...chat_utils import ChatTemplateResolutionError
|
||||
from .protocol import RerankRequest, ScoreRequest, ScoringRequest
|
||||
from .typing import ScoreData, ScoreInput, ScoringData
|
||||
from .utils import (
|
||||
compress_token_type_ids,
|
||||
compute_maxsim_score,
|
||||
parse_score_data,
|
||||
score_data_to_prompts,
|
||||
validate_score_input,
|
||||
)
|
||||
|
||||
ScoringServeContext: TypeAlias = PoolingServeContext[ScoringRequest]
|
||||
|
||||
|
||||
class ScoringIOProcessor(PoolingIOProcessor):
|
||||
name: ScoreType
|
||||
pooling_task: PoolingTask
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.tokenizer = self.renderer.get_tokenizer()
|
||||
self.architecture = self.model_config.architecture
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
def create_pooling_params(self, request):
|
||||
return request.to_pooling_params(self.pooling_task)
|
||||
|
||||
def valid_inputs(
|
||||
self,
|
||||
data_1: ScoreInput | list[ScoreInput],
|
||||
data_2: ScoreInput | list[ScoreInput],
|
||||
) -> ScoringData:
|
||||
scoring_data = validate_score_input(
|
||||
data_1,
|
||||
data_2,
|
||||
is_multimodal_model=self.is_multimodal_model,
|
||||
architecture=self.architecture,
|
||||
)
|
||||
return scoring_data
|
||||
|
||||
|
||||
class BiEncoderIOProcessor(ScoringIOProcessor):
|
||||
name: ScoreType = "bi-encoder"
|
||||
pooling_task: PoolingTask = "embed"
|
||||
|
||||
#######################################
|
||||
# online APIs
|
||||
|
||||
def pre_process_online(self, ctx: ScoringServeContext):
|
||||
request = ctx.request
|
||||
|
||||
if isinstance(request, ScoreRequest):
|
||||
data_1 = request.data_1
|
||||
data_2 = request.data_2
|
||||
elif isinstance(request, RerankRequest):
|
||||
data_1 = request.query
|
||||
data_2 = request.documents
|
||||
else:
|
||||
raise ValueError(f"Invalid {self.name} request type")
|
||||
|
||||
scoring_data = self.valid_inputs(data_1, data_2)
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
engine_inputs = self._pre_process(
|
||||
scoring_data,
|
||||
tok_params,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
ctx.engine_inputs = engine_inputs
|
||||
ctx.intermediates = len(scoring_data.data_1)
|
||||
|
||||
def post_process_online(
|
||||
self,
|
||||
ctx: ScoringServeContext,
|
||||
):
|
||||
if ctx.final_res_batch is None:
|
||||
raise ValueError("Final response batch not available")
|
||||
|
||||
if ctx.intermediates is None:
|
||||
raise ValueError("data_1 len not available")
|
||||
|
||||
ctx.final_res_batch = self._post_process(
|
||||
outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates)
|
||||
)
|
||||
|
||||
#######################################
|
||||
# offline APIs
|
||||
|
||||
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
|
||||
assert isinstance(ctx.prompts, ScoringData)
|
||||
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(ctx.tokenization_kwargs or {})
|
||||
)
|
||||
return self._pre_process(ctx.prompts, tok_params)
|
||||
|
||||
def post_process_offline(
|
||||
self,
|
||||
ctx: OfflineOutputsContext,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
assert ctx.offset is not None
|
||||
return self._post_process(outputs=ctx.outputs, offset=ctx.offset)
|
||||
|
||||
#######################################
|
||||
# helpers
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
scoring_data: ScoringData,
|
||||
tok_params: TokenizeParams,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
) -> Sequence[EngineInput]:
|
||||
data_1 = score_data_to_prompts(scoring_data.data_1, "query", self.model_config)
|
||||
data_2 = score_data_to_prompts(
|
||||
scoring_data.data_2, "document", self.model_config
|
||||
)
|
||||
|
||||
return self._preprocess_completion_offline(
|
||||
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
|
||||
)
|
||||
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
|
||||
emb_data_1 = outputs[:offset]
|
||||
emb_data_2 = outputs[offset:]
|
||||
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
pair_score = F.cosine_similarity(
|
||||
emb_1.outputs.data.float(), emb_2.outputs.data.float(), dim=0
|
||||
)
|
||||
|
||||
padding: list[int] = []
|
||||
if self.pad_token_id is not None:
|
||||
padding = [self.pad_token_id]
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
final_res_batch.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
return final_res_batch
|
||||
|
||||
|
||||
class LateInteractionIOProcessor(BiEncoderIOProcessor):
|
||||
name: ScoreType = "late-interaction"
|
||||
pooling_task: PoolingTask = "token_embed"
|
||||
|
||||
def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
|
||||
# Split into query and document embeddings
|
||||
emb_data_1 = outputs[:offset]
|
||||
emb_data_2 = outputs[offset:]
|
||||
|
||||
# Expand queries if 1:N scoring
|
||||
if len(emb_data_1) == 1:
|
||||
emb_data_1 = emb_data_1 * len(emb_data_2)
|
||||
|
||||
final_res_batch: list[PoolingRequestOutput] = []
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := self.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
# Compute MaxSim scores
|
||||
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
|
||||
# emb_1.outputs.data: [query_len, dim]
|
||||
# emb_2.outputs.data: [doc_len, dim]
|
||||
q_emb = emb_1.outputs.data
|
||||
d_emb = emb_2.outputs.data
|
||||
|
||||
maxsim_score = compute_maxsim_score(q_emb, d_emb)
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
final_res_batch.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=maxsim_score,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
return final_res_batch
|
||||
|
||||
|
||||
class CrossEncoderIOProcessor(ScoringIOProcessor):
|
||||
name: ScoreType = "cross-encoder"
|
||||
pooling_task: PoolingTask = "classify"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if is_mistral_tokenizer(self.tokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
|
||||
model = get_model_cls(self.model_config)
|
||||
self.supports_score_template = supports_score_template(model)
|
||||
self.model = model if self.supports_score_template else None
|
||||
self.use_sep_token = self.model_config.use_sep_token
|
||||
|
||||
#######################################
|
||||
# online APIs
|
||||
|
||||
def pre_process_online(self, ctx: ScoringServeContext):
|
||||
request = ctx.request
|
||||
|
||||
if isinstance(request, ScoreRequest):
|
||||
data_1 = request.data_1
|
||||
data_2 = request.data_2
|
||||
elif isinstance(request, RerankRequest):
|
||||
data_1 = request.query
|
||||
data_2 = request.documents
|
||||
else:
|
||||
raise ValueError(f"Invalid {self.name} request type")
|
||||
|
||||
scoring_data = self.valid_inputs(data_1, data_2)
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
pooling_params = self.create_pooling_params(request)
|
||||
|
||||
engine_inputs, pooling_params_list = self._pre_process(
|
||||
scoring_data,
|
||||
tok_params,
|
||||
pooling_params,
|
||||
chat_template=self.chat_template,
|
||||
prompt_extras={
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
},
|
||||
)
|
||||
|
||||
ctx.engine_inputs = engine_inputs
|
||||
ctx.pooling_params = pooling_params_list
|
||||
|
||||
#######################################
|
||||
# offline APIs
|
||||
|
||||
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
|
||||
assert isinstance(ctx.prompts, ScoringData)
|
||||
assert not isinstance(ctx.pooling_params, list)
|
||||
|
||||
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(ctx.tokenization_kwargs or {})
|
||||
)
|
||||
engine_inputs, pooling_params_list = self._pre_process(
|
||||
ctx.prompts, tok_params, ctx.pooling_params, ctx.chat_template
|
||||
)
|
||||
ctx.pooling_params = pooling_params_list
|
||||
return engine_inputs
|
||||
|
||||
#######################################
|
||||
# helpers
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
scoring_data: ScoringData,
|
||||
tok_params: TokenizeParams,
|
||||
pooling_params: PoolingParams | None,
|
||||
chat_template: str | None = None,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
) -> tuple[Sequence[EngineInput], list[PoolingParams]]:
|
||||
# todo: support prompt_extras
|
||||
arrival_time = time.time()
|
||||
|
||||
data_1 = scoring_data.data_1
|
||||
data_2 = scoring_data.data_2
|
||||
|
||||
if len(data_1) == 1:
|
||||
data_1 = data_1 * len(data_2)
|
||||
|
||||
if pooling_params is None:
|
||||
pooling_params = PoolingParams(task="classify")
|
||||
|
||||
pooling_params_list = list[PoolingParams]()
|
||||
engine_inputs = list[EngineInput]()
|
||||
for q, d in zip(data_1, data_2):
|
||||
_, engine_prompt = self.get_score_prompt(
|
||||
data_1=q,
|
||||
data_2=d,
|
||||
encode_kwargs=tok_params.get_encode_kwargs(),
|
||||
chat_template=chat_template,
|
||||
)
|
||||
|
||||
if token_type_ids := engine_prompt.pop("token_type_ids", None):
|
||||
params = pooling_params.clone()
|
||||
compressed = compress_token_type_ids(token_type_ids)
|
||||
params.extra_kwargs = {"compressed_token_type_ids": compressed}
|
||||
pooling_params_list.append(params)
|
||||
else:
|
||||
pooling_params_list.append(pooling_params)
|
||||
|
||||
tok_params.apply_post_tokenization(self.tokenizer, engine_prompt)
|
||||
engine_inputs.append(
|
||||
self.renderer.process_for_engine(engine_prompt, arrival_time)
|
||||
)
|
||||
return engine_inputs, pooling_params_list
|
||||
|
||||
def get_score_prompt(
|
||||
self,
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
encode_kwargs: dict[str, Any],
|
||||
chat_template: str | None = None,
|
||||
):
|
||||
model_config = self.model_config
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
|
||||
data_1,
|
||||
data_2,
|
||||
model_config,
|
||||
)
|
||||
|
||||
def default_tokenizer_encode():
|
||||
if self.supports_score_template:
|
||||
assert self.model is not None
|
||||
full_prompt = self.model.get_score_template(prompt_1, prompt_2)
|
||||
if full_prompt is None:
|
||||
raise ValueError("Get empty score template from model")
|
||||
|
||||
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
|
||||
else:
|
||||
if self.use_sep_token:
|
||||
# cross_encoder models defaults to using separating token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **encode_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
else:
|
||||
# `llm as reranker` defaults to not using separating token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **encode_kwargs)
|
||||
return full_prompt, prompt_inputs
|
||||
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
if chat_template is None:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
else:
|
||||
# FIXME:
|
||||
# Try applying a score template from the CLI arg or tokenizer_config.json
|
||||
# If that fails because there is no such template,
|
||||
# fall back to the default implementation.
|
||||
try:
|
||||
full_prompt = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
[
|
||||
{"role": "query", "content": prompt_1},
|
||||
{"role": "document", "content": prompt_2},
|
||||
],
|
||||
chat_template=chat_template,
|
||||
tools=None,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_inputs = tokenizer(full_prompt, **encode_kwargs)
|
||||
except ChatTemplateResolutionError:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
||||
|
||||
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
|
||||
engine_prompt["token_type_ids"] = token_type_ids
|
||||
|
||||
if self.model is not None:
|
||||
self.model.post_process_tokens(engine_prompt)
|
||||
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
|
||||
|
||||
ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = {
|
||||
"bi-encoder": BiEncoderIOProcessor,
|
||||
"late-interaction": LateInteractionIOProcessor,
|
||||
"cross-encoder": CrossEncoderIOProcessor,
|
||||
}
|
||||
@@ -12,15 +12,12 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
ClassifyRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreInput,
|
||||
ScoreInputs,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
from .typing import ScoreContentPartParam, ScoreInput
|
||||
|
||||
|
||||
class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
@@ -43,13 +40,13 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
|
||||
|
||||
class ScoreDataRequest(ScoreRequestMixin):
|
||||
data_1: ScoreInputs
|
||||
data_2: ScoreInputs
|
||||
data_1: ScoreInput | list[ScoreInput]
|
||||
data_2: ScoreInput | list[ScoreInput]
|
||||
|
||||
|
||||
class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
|
||||
queries: ScoreInputs
|
||||
documents: ScoreInputs
|
||||
queries: ScoreInput | list[ScoreInput]
|
||||
documents: ScoreInput | list[ScoreInput]
|
||||
|
||||
@property
|
||||
def data_1(self):
|
||||
@@ -61,8 +58,8 @@ class ScoreQueriesDocumentsRequest(ScoreRequestMixin):
|
||||
|
||||
|
||||
class ScoreQueriesItemsRequest(ScoreRequestMixin):
|
||||
queries: ScoreInputs
|
||||
items: ScoreInputs
|
||||
queries: ScoreInput | list[ScoreInput]
|
||||
items: ScoreInput | list[ScoreInput]
|
||||
|
||||
@property
|
||||
def data_1(self):
|
||||
@@ -74,8 +71,8 @@ class ScoreQueriesItemsRequest(ScoreRequestMixin):
|
||||
|
||||
|
||||
class ScoreTextRequest(ScoreRequestMixin):
|
||||
text_1: ScoreInputs
|
||||
text_2: ScoreInputs
|
||||
text_1: ScoreInput | list[ScoreInput]
|
||||
text_2: ScoreInput | list[ScoreInput]
|
||||
|
||||
@property
|
||||
def data_1(self):
|
||||
@@ -96,7 +93,7 @@ ScoreRequest: TypeAlias = (
|
||||
|
||||
class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
query: ScoreInput
|
||||
documents: ScoreInputs
|
||||
documents: ScoreInput | list[ScoreInput]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
@@ -118,6 +115,9 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
)
|
||||
|
||||
|
||||
ScoringRequest: TypeAlias = ScoreRequest | RerankRequest
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str | None = None
|
||||
multi_modal: list[ScoreContentPartParam] | None = None
|
||||
@@ -154,3 +154,6 @@ class ScoreResponse(OpenAIBaseModel):
|
||||
model: str
|
||||
data: list[ScoreResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
ScoringResponse: TypeAlias = RerankResponse | ScoreResponse
|
||||
160
vllm/entrypoints/pooling/scoring/serving.py
Normal file
160
vllm/entrypoints/pooling/scoring/serving.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import UsageInfo
|
||||
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
from .io_processor import ScoringIOProcessors, ScoringServeContext
|
||||
from .protocol import (
|
||||
RerankDocument,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResult,
|
||||
RerankUsage,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
)
|
||||
from .typing import ScoreInput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServingScores(PoolingServing):
|
||||
request_id_prefix = "score"
|
||||
|
||||
def init_io_processor(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
chat_template_config: ChatTemplateConfig,
|
||||
) -> PoolingIOProcessor:
|
||||
score_type = model_config.score_type
|
||||
assert score_type in ScoringIOProcessors
|
||||
processor_cls = ScoringIOProcessors[score_type]
|
||||
return processor_cls(
|
||||
model_config=model_config,
|
||||
renderer=renderer,
|
||||
chat_template_config=chat_template_config,
|
||||
)
|
||||
|
||||
async def _build_response(
|
||||
self,
|
||||
ctx: ScoringServeContext,
|
||||
) -> JSONResponse:
|
||||
final_res_batch = ctx.final_res_batch
|
||||
request_id = ctx.request_id
|
||||
created_time = ctx.created_time
|
||||
model_name = self.models.model_name()
|
||||
|
||||
if isinstance(ctx.request, ScoreRequest):
|
||||
return self._request_output_to_score_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
)
|
||||
elif isinstance(ctx.request, RerankRequest):
|
||||
return self._request_output_to_rerank_response(
|
||||
final_res_batch,
|
||||
request_id,
|
||||
model_name,
|
||||
ctx.request.documents,
|
||||
ctx.request.top_n if ctx.request.top_n > 0 else len(final_res_batch),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("")
|
||||
|
||||
def _request_output_to_score_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> JSONResponse:
|
||||
items: list[ScoreResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
item = ScoreResponseData(
|
||||
index=idx,
|
||||
score=classify_res.outputs.score,
|
||||
)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
|
||||
items.append(item)
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
response = ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
def _request_output_to_rerank_response(
|
||||
self,
|
||||
final_res_batch: list[PoolingRequestOutput],
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
documents: ScoreInput | list[ScoreInput],
|
||||
top_n: int,
|
||||
) -> JSONResponse:
|
||||
if not isinstance(documents, list):
|
||||
documents = [documents]
|
||||
|
||||
results: list[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
document = documents[idx]
|
||||
if isinstance(document, str):
|
||||
rerank_document = RerankDocument(text=document)
|
||||
else:
|
||||
rerank_document = RerankDocument(
|
||||
multi_modal=document.get("content", [])
|
||||
)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=rerank_document,
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
response = RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(
|
||||
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
|
||||
),
|
||||
)
|
||||
|
||||
return JSONResponse(content=response.model_dump())
|
||||
46
vllm/entrypoints/pooling/scoring/typing.py
Normal file
46
vllm/entrypoints/pooling/scoring/typing.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionContentPartVideoParam,
|
||||
)
|
||||
|
||||
ScoreContentPartParam: TypeAlias = (
|
||||
ChatCompletionContentPartImageParam
|
||||
| ChatCompletionContentPartImageEmbedsParam
|
||||
| ChatCompletionContentPartTextParam
|
||||
| ChatCompletionContentPartVideoParam
|
||||
)
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
|
||||
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
|
||||
2. Including chat-specific fields would confuse users about their purpose in scoring
|
||||
3. This is a more focused interface that only exposes what's needed for scoring
|
||||
""" # noqa: E501
|
||||
|
||||
content: Required[list[ScoreContentPartParam]]
|
||||
"""The multimodal contents"""
|
||||
|
||||
|
||||
# Raw input data with content key in ScoreMultiModalParam.
|
||||
ScoreInput = str | ScoreMultiModalParam
|
||||
# Score data without content key.
|
||||
ScoreData = str | list[ScoreContentPartParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringData:
|
||||
data_1: list[ScoreData]
|
||||
data_2: list[ScoreData]
|
||||
@@ -1,42 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, TypeAlias, cast
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch.nn import CosineSimilarity
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from vllm import PromptType, TextPrompt
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
BaseMultiModalItemTracker,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionContentPartVideoParam,
|
||||
ChatTemplateResolutionError,
|
||||
ConversationMessage,
|
||||
MultiModalItemTracker,
|
||||
_parse_chat_message_content_parts,
|
||||
)
|
||||
from vllm.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalUUIDDict,
|
||||
PromptType,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
|
||||
ScoreContentPartParam: TypeAlias = (
|
||||
ChatCompletionContentPartImageParam
|
||||
| ChatCompletionContentPartImageEmbedsParam
|
||||
| ChatCompletionContentPartTextParam
|
||||
| ChatCompletionContentPartVideoParam
|
||||
from .typing import (
|
||||
ScoreContentPartParam,
|
||||
ScoreData,
|
||||
ScoreInput,
|
||||
ScoringData,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,72 +42,6 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
|
||||
return token_scores.amax(dim=-1).sum()
|
||||
|
||||
|
||||
class ScoreMultiModalParam(TypedDict, total=False):
|
||||
"""
|
||||
A specialized parameter type for scoring multimodal content
|
||||
|
||||
The reasons why don't reuse `CustomChatCompletionMessageParam` directly:
|
||||
1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions
|
||||
2. Including chat-specific fields would confuse users about their purpose in scoring
|
||||
3. This is a more focused interface that only exposes what's needed for scoring
|
||||
""" # noqa: E501
|
||||
|
||||
content: Required[list[ScoreContentPartParam]]
|
||||
"""The multimodal contents"""
|
||||
|
||||
|
||||
# Raw input data with content key in ScoreMultiModalParam.
|
||||
ScoreInput = str | ScoreMultiModalParam
|
||||
ScoreInputs = ScoreInput | list[ScoreInput]
|
||||
# Score data without content key.
|
||||
ScoreData = str | list[ScoreContentPartParam]
|
||||
|
||||
|
||||
def _cosine_similarity(
|
||||
tokenizer: TokenizerLike,
|
||||
embed_1: list[PoolingRequestOutput],
|
||||
embed_2: list[PoolingRequestOutput],
|
||||
) -> list[PoolingRequestOutput]:
|
||||
scorer = CosineSimilarity(0)
|
||||
scores: list[PoolingRequestOutput] = []
|
||||
|
||||
for emb_1, emb_2 in zip(embed_1, embed_2):
|
||||
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
|
||||
|
||||
padding: list[int] = []
|
||||
if (pad_token_id := tokenizer.pad_token_id) is not None:
|
||||
padding = [pad_token_id]
|
||||
|
||||
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
|
||||
|
||||
scores.append(
|
||||
PoolingRequestOutput(
|
||||
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
|
||||
outputs=pair_score,
|
||||
prompt_token_ids=tokens,
|
||||
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
|
||||
finished=True,
|
||||
)
|
||||
)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _validate_score_input_lens(
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
):
|
||||
len_1 = len(data_1)
|
||||
len_2 = len(data_2)
|
||||
|
||||
if len_1 > 1 and len_1 != len_2:
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len_1 == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len_2 == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
|
||||
def _validate_mm_score_input(
|
||||
data: list[ScoreInput],
|
||||
is_multimodal_model: bool,
|
||||
@@ -140,12 +59,27 @@ def _validate_mm_score_input(
|
||||
return out
|
||||
|
||||
|
||||
def _validate_score_input_lens(
|
||||
data_1: list[ScoreData],
|
||||
data_2: list[ScoreData],
|
||||
):
|
||||
len_1 = len(data_1)
|
||||
len_2 = len(data_2)
|
||||
|
||||
if len_1 > 1 and len_1 != len_2:
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len_1 == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len_2 == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
|
||||
def validate_score_input(
|
||||
data_1: ScoreInputs,
|
||||
data_2: ScoreInputs,
|
||||
data_1: ScoreInput | list[ScoreInput],
|
||||
data_2: ScoreInput | list[ScoreInput],
|
||||
is_multimodal_model: bool,
|
||||
architecture: str,
|
||||
) -> tuple[list[ScoreData], list[ScoreData]]:
|
||||
) -> ScoringData:
|
||||
if not isinstance(data_1, list):
|
||||
data_1 = [data_1]
|
||||
|
||||
@@ -155,62 +89,7 @@ def validate_score_input(
|
||||
score_input_1 = _validate_mm_score_input(data_1, is_multimodal_model, architecture)
|
||||
score_input_2 = _validate_mm_score_input(data_2, is_multimodal_model, architecture)
|
||||
_validate_score_input_lens(score_input_1, score_input_2)
|
||||
return score_input_1, score_input_2
|
||||
|
||||
|
||||
def _ensure_str(content: list[ConversationMessage]) -> str:
|
||||
"""Extract a single string prompt from parsed conversation content."""
|
||||
assert len(content) == 1
|
||||
prompt = content[0]["content"]
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
return cast(str, prompt)
|
||||
raise ValueError(f"Only string content is supported, but got {content}.")
|
||||
|
||||
|
||||
def parse_score_data(
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
"""Parse a query-document pair into text prompts and shared multi-modal
|
||||
data.
|
||||
|
||||
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
|
||||
items from both inputs are merged into one ``mm_data`` dict. This is
|
||||
the correct behaviour for cross-encoder scoring, where query and
|
||||
document are concatenated into a single model prompt.
|
||||
"""
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
content_1 = _parse_score_content("query", data_1, mm_tracker)
|
||||
content_2 = _parse_score_content("document", data_2, mm_tracker)
|
||||
|
||||
prompt_1 = _ensure_str(content_1)
|
||||
prompt_2 = _ensure_str(content_2)
|
||||
mm_items, mm_uuids = mm_tracker.resolve_items()
|
||||
|
||||
return prompt_1, prompt_2, mm_items, mm_uuids
|
||||
|
||||
|
||||
def parse_score_data_single(
|
||||
data: ScoreData,
|
||||
role: str,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
"""Parse **one** ScoreData into a text prompt and its own multi-modal
|
||||
data.
|
||||
|
||||
Unlike :func:`parse_score_data`, each call creates an **independent**
|
||||
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
|
||||
This is the correct behaviour for late-interaction scoring, where
|
||||
query and document are encoded independently.
|
||||
"""
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
content = _parse_score_content(role, data, mm_tracker)
|
||||
|
||||
prompt = _ensure_str(content)
|
||||
mm_items, mm_uuids = mm_tracker.resolve_items()
|
||||
return prompt, mm_items, mm_uuids
|
||||
return ScoringData(data_1=score_input_1, data_2=score_input_2)
|
||||
|
||||
|
||||
def score_data_to_prompts(
|
||||
@@ -243,6 +122,15 @@ def score_data_to_prompts(
|
||||
return prompts
|
||||
|
||||
|
||||
def _ensure_str(content: list[ConversationMessage]) -> str:
|
||||
"""Extract a single string prompt from parsed conversation content."""
|
||||
assert len(content) == 1
|
||||
prompt = content[0]["content"]
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
return cast(str, prompt)
|
||||
raise ValueError(f"Only string content is supported, but got {content}.")
|
||||
|
||||
|
||||
def _parse_score_content(
|
||||
role: str,
|
||||
data: ScoreData,
|
||||
@@ -278,113 +166,50 @@ def _parse_score_content(
|
||||
return next(iter(mm_placeholder_storage.values()))[0]
|
||||
|
||||
|
||||
def _apply_model_score_template(
|
||||
model_config: ModelConfig, prompt_1: str, prompt_2: str
|
||||
) -> str:
|
||||
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
if supports_score_template(model):
|
||||
full_prompt = model.get_score_template(prompt_1, prompt_2)
|
||||
if full_prompt is None:
|
||||
raise ValueError("Get empty score template from model")
|
||||
return full_prompt
|
||||
|
||||
raise ValueError(f"Unsupported model architecture: {model_config.architecture}")
|
||||
|
||||
|
||||
def post_process_tokens(
|
||||
def parse_score_data_single(
|
||||
data: ScoreData,
|
||||
role: str,
|
||||
model_config: ModelConfig,
|
||||
prompt: TokensPrompt,
|
||||
) -> None:
|
||||
) -> tuple[str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
"""Parse **one** ScoreData into a text prompt and its own multi-modal
|
||||
data.
|
||||
|
||||
Unlike :func:`parse_score_data`, each call creates an **independent**
|
||||
:class:`MultiModalItemTracker` so multi-modal items are kept separate.
|
||||
This is the correct behaviour for late-interaction scoring, where
|
||||
query and document are encoded independently.
|
||||
"""
|
||||
Perform architecture-specific manipulations on the input tokens.
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
content = _parse_score_content(role, data, mm_tracker)
|
||||
|
||||
Note:
|
||||
This is an in-place operation.
|
||||
"""
|
||||
# NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
if supports_score_template(model):
|
||||
model.post_process_tokens(prompt)
|
||||
prompt = _ensure_str(content)
|
||||
mm_items, mm_uuids = mm_tracker.resolve_items()
|
||||
return prompt, mm_items, mm_uuids
|
||||
|
||||
|
||||
def get_score_prompt(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
def parse_score_data(
|
||||
data_1: ScoreData,
|
||||
data_2: ScoreData,
|
||||
score_template: str | None = None,
|
||||
) -> tuple[str, TokensPrompt]:
|
||||
prompt_1, prompt_2, mm_data, mm_uuids = parse_score_data(
|
||||
data_1,
|
||||
data_2,
|
||||
model_config,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[str, str, MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
"""Parse a query-document pair into text prompts and shared multi-modal
|
||||
data.
|
||||
|
||||
model = get_model_cls(model_config)
|
||||
Uses a **single** :class:`MultiModalItemTracker` so that multi-modal
|
||||
items from both inputs are merged into one ``mm_data`` dict. This is
|
||||
the correct behaviour for cross-encoder scoring, where query and
|
||||
document are concatenated into a single model prompt.
|
||||
"""
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
def default_tokenizer_encode():
|
||||
if supports_score_template(model):
|
||||
full_prompt = _apply_model_score_template(model_config, prompt_1, prompt_2)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
else:
|
||||
if model_config.use_sep_token:
|
||||
# cross_encoder models defaults to using separating token.
|
||||
prompt_inputs = tokenizer(
|
||||
text=prompt_1, text_pair=prompt_2, **tokenization_kwargs
|
||||
)
|
||||
full_prompt = tokenizer.decode(prompt_inputs["input_ids"])
|
||||
else:
|
||||
# `llm as reranker` defaults to not using separating token.
|
||||
full_prompt = prompt_1 + prompt_2
|
||||
prompt_inputs = tokenizer(text=full_prompt, **tokenization_kwargs)
|
||||
return full_prompt, prompt_inputs
|
||||
content_1 = _parse_score_content("query", data_1, mm_tracker)
|
||||
content_2 = _parse_score_content("document", data_2, mm_tracker)
|
||||
|
||||
# FIXME: For now, we only apply a template when one is explicitly provided.
|
||||
# We cannot rely on the tokenizer's chat template because many models
|
||||
# inherit junk templates from their base LLM, which breaks both the models
|
||||
# and the tests that use them.
|
||||
if score_template is None:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
else:
|
||||
# FIXME: Try applying a score template from the CLI arg or tokenizer_config.json
|
||||
# If that fails because there is no such template,
|
||||
# fall back to the default implementation.
|
||||
try:
|
||||
full_prompt = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
[
|
||||
{"role": "query", "content": prompt_1},
|
||||
{"role": "document", "content": prompt_2},
|
||||
],
|
||||
chat_template=score_template,
|
||||
tools=None,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
except ChatTemplateResolutionError:
|
||||
full_prompt, prompt_inputs = default_tokenizer_encode()
|
||||
prompt_1 = _ensure_str(content_1)
|
||||
prompt_2 = _ensure_str(content_2)
|
||||
mm_items, mm_uuids = mm_tracker.resolve_items()
|
||||
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"])
|
||||
|
||||
if (token_type_ids := prompt_inputs.get("token_type_ids")) is not None:
|
||||
engine_prompt["token_type_ids"] = token_type_ids
|
||||
|
||||
post_process_tokens(model_config, engine_prompt)
|
||||
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return full_prompt, engine_prompt
|
||||
return prompt_1, prompt_2, mm_items, mm_uuids
|
||||
|
||||
|
||||
def compress_token_type_ids(token_type_ids: list[int]) -> int:
|
||||
@@ -1,14 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, TypeAlias, TypeVar
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from vllm import PoolingRequestOutput
|
||||
from vllm import PoolingParams, PoolingRequestOutput, PromptType
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
@@ -23,15 +23,13 @@ from vllm.entrypoints.pooling.embed.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
PoolingBytesResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.protocol import (
|
||||
RerankRequest,
|
||||
ScoreRequest,
|
||||
ScoreResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.scoring.protocol import ScoringRequest, ScoringResponse
|
||||
from vllm.entrypoints.pooling.scoring.typing import ScoringData
|
||||
from vllm.inputs import EngineInput
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
@@ -49,8 +47,7 @@ AnyPoolingRequest: TypeAlias = (
|
||||
PoolingCompletionLikeRequest
|
||||
| PoolingChatLikeRequest
|
||||
| IOProcessorRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| ScoringRequest
|
||||
| CohereEmbedRequest
|
||||
)
|
||||
|
||||
@@ -59,7 +56,8 @@ AnyPoolingResponse: TypeAlias = (
|
||||
| EmbeddingResponse
|
||||
| EmbeddingBytesResponse
|
||||
| PoolingResponse
|
||||
| ScoreResponse
|
||||
| PoolingBytesResponse
|
||||
| ScoringResponse
|
||||
)
|
||||
|
||||
PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
|
||||
@@ -73,8 +71,8 @@ class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
|
||||
engine_inputs: list[EngineInput] | None = None
|
||||
pooling_params: PoolingParams | list[PoolingParams] | None = None
|
||||
engine_inputs: Sequence[EngineInput] | None = None
|
||||
prompt_request_ids: list[str] | None = None
|
||||
intermediates: Any | None = None
|
||||
|
||||
@@ -84,3 +82,22 @@ class PoolingServeContext(Generic[PoolingRequestT]):
|
||||
final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineInputsContext:
|
||||
prompts: PromptType | Sequence[PromptType] | ScoringData
|
||||
pooling_params: PoolingParams | list[PoolingParams] | None = None
|
||||
tokenization_kwargs: dict[str, Any] | None = None
|
||||
chat_template: str | None = None
|
||||
|
||||
## for bi-encoder & late-interaction
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineOutputsContext:
|
||||
outputs: list[PoolingRequestOutput]
|
||||
|
||||
## for bi-encoder & late-interaction
|
||||
offset: int | None = None
|
||||
|
||||
@@ -11,8 +11,10 @@ import pybase64
|
||||
import torch
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPES,
|
||||
EmbedDType,
|
||||
@@ -133,3 +135,20 @@ def get_json_response_cls() -> type[JSONResponse]:
|
||||
"To make v1/embeddings API fast, please install orjson by `pip install orjson`"
|
||||
)
|
||||
return JSONResponse
|
||||
|
||||
|
||||
def enable_scoring_api(
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> bool:
|
||||
if any(t in supported_tasks for t in ("embed", "token_embed")):
|
||||
return True
|
||||
|
||||
if model_config is not None and "classify" in supported_tasks:
|
||||
num_labels = getattr(model_config.hf_config, "num_labels", 0)
|
||||
if num_labels != 1:
|
||||
logger.debug_once("Scoring API is only enabled for num_labels == 1.")
|
||||
return False
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -14,8 +14,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.pooling import enable_scoring_api
|
||||
from vllm.entrypoints.pooling.base.serving import PoolingServing
|
||||
from vllm.entrypoints.pooling.utils import enable_scoring_api
|
||||
from vllm.entrypoints.serve.instrumentator.basic import base
|
||||
from vllm.entrypoints.serve.instrumentator.health import health
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
@@ -76,15 +76,15 @@ def get_invocation_types(
|
||||
]
|
||||
|
||||
if enable_scoring_api(supported_tasks, model_config):
|
||||
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
|
||||
from vllm.entrypoints.pooling.score.protocol import RerankRequest
|
||||
from vllm.entrypoints.pooling.scoring.api_router import do_rerank, rerank
|
||||
from vllm.entrypoints.pooling.scoring.protocol import RerankRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(RerankRequest, (rerank, do_rerank)),
|
||||
]
|
||||
|
||||
from vllm.entrypoints.pooling.score.api_router import create_score, score
|
||||
from vllm.entrypoints.pooling.score.protocol import ScoreRequest
|
||||
from vllm.entrypoints.pooling.scoring.api_router import create_score, score
|
||||
from vllm.entrypoints.pooling.scoring.protocol import ScoreRequest
|
||||
|
||||
INVOCATION_TYPES += [
|
||||
(ScoreRequest, (score, create_score)),
|
||||
|
||||
Reference in New Issue
Block a user