[Frontend] Use new Renderer for Completions and Tokenize API (#32863)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-31 20:51:15 +08:00
committed by GitHub
parent 8980001c93
commit f0a1c8453a
64 changed files with 2116 additions and 2003 deletions

View File

@@ -60,9 +60,7 @@ def main():
completion = client.completions.create(
model=model_name,
# NOTE: The OpenAI client does not allow `None` as an input to
# `prompt`. Use an empty string if you have no text prompts.
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
# NOTE: The OpenAI client allows passing in extra JSON body via the

View File

@@ -22,7 +22,11 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
with pytest.raises(ValueError, match="longer than the maximum model length"):
vllm_model = vllm_runner(
model,
max_model_len=128, # LLaVA has a feature size of 576
# LLaVA has a feature size of 576
# For the HF processor to execute successfully but still
# failing the overall context length check, we need the
# max_model_len to at least contain all image tokens
max_model_len=579,
enforce_eager=True,
load_format="dummy",
)

View File

@@ -205,7 +205,7 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
valid_msg,
]
sampling_params = SamplingParams(temperature=0, max_tokens=10)
with pytest.raises(ValueError, match="longer than the maximum model length"):
with pytest.raises(ValueError, match="context length is only"):
llm.chat(batch_1, sampling_params=sampling_params)
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
assert len(outputs_2) == len(batch_2)

View File

@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
@@ -57,6 +58,15 @@ class MockModelConfig:
return self.diff_sampling_param or {}
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
@@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
chat_template_content_format="auto",
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, engine_prompts
return (
@@ -90,7 +88,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
[{"prompt_token_ids": [1, 2, 3]}],
)
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
return serving_chat
@@ -99,11 +96,11 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
async def test_chat_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)
@@ -153,11 +150,11 @@ async def test_chat_error_non_stream():
async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_chat = _build_serving_chat(mock_engine)

View File

@@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
import pytest
@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
@@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_completion = OpenAIServingCompletion(
return OpenAIServingCompletion(
engine,
models,
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_completion
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.mark.asyncio
async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)
@@ -141,11 +136,11 @@ async def test_completion_error_non_stream():
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
serving_completion = _build_serving_completion(mock_engine)

View File

@@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds(
# Test case: Single prompt embeds input
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
@@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds(
# Test case: batch completion with prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
@@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds(
# Test case: streaming with prompt_embeds
single_completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
@@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds(
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
stream=True,
@@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds(
# Test case: batch streaming with prompt_embeds
stream = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
stream=True,
@@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds(
)
completion_embeds_only = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds},
@@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds(
# Test error case: invalid prompt_embeds
with pytest.raises(BadRequestError):
await client_with_prompt_embeds.completions.create(
prompt="",
prompt=None,
model=model_name,
max_tokens=5,
temperature=0.0,
@@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
# Test case: Logprobs using prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
echo=False,
@@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
# Test case: Log probs with batch completion and prompt_embeds
completion = await client_with_prompt_embeds.completions.create(
model=model_name,
prompt="", # Add empty prompt as required parameter
prompt=None,
max_tokens=5,
temperature=0.0,
echo=False,
@@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error(
with pytest.raises(BadRequestError, match="not compatible"):
await client_with_prompt_embeds.completions.create(
model=MODEL_NAME,
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},

View File

@@ -7,7 +7,7 @@ Tests verify that embeddings with correct ndim but incorrect hidden_size
are rejected before they can cause crashes during model inference.
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
classes, not by CompletionRenderer or MediaIO classes.
classes, not by MediaIO classes.
"""
import pytest

View File

@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.tokenizers import get_tokenizer
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
@@ -35,6 +36,7 @@ class MockModelConfig:
"""Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME
runner_type = "generate"
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
tokenizer_mode: str = "auto"
@@ -85,15 +87,21 @@ def register_mock_resolver():
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer(
model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
@pytest.fixture
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.errored = False
tokenizer = get_tokenizer(MODEL_NAME)
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)
async def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora":
@@ -118,6 +126,7 @@ def mock_serving_setup():
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
engine_client=mock_engine,
@@ -128,10 +137,6 @@ def mock_serving_setup():
mock_engine, models, request_logger=None
)
serving_completion._process_inputs = AsyncMock(
return_value=(MagicMock(name="engine_request"), {})
)
return mock_engine, serving_completion

View File

@@ -12,7 +12,7 @@ import regex as re
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.renderers.embed_utils import safe_load_prompt_embeds
from ...utils import RemoteOpenAIServer
@@ -30,7 +30,7 @@ async def test_empty_prompt():
):
await client.completions.create(
model=model_name,
prompt="",
prompt=None,
max_tokens=5,
temperature=0.0,
extra_body={"prompt_embeds": []},
@@ -63,7 +63,6 @@ def test_load_prompt_embeds(
):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = True
renderer = CompletionRenderer(model_config, tokenizer=None)
# construct arbitrary tensors of various dtypes, layouts, and sizes.
# We need to check against different layouts to make sure that if a user
@@ -89,9 +88,7 @@ def test_load_prompt_embeds(
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor)
assert loaded_tensor.device.type == "cpu"
assert loaded_tensor.layout == torch.strided
torch.testing.assert_close(
@@ -105,7 +102,6 @@ def test_load_prompt_embeds(
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
model_config = Mock(spec=ModelConfig)
model_config.enable_prompt_embeds = False
renderer = CompletionRenderer(model_config, tokenizer=None)
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
@@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in
encoded_tensor = pybase64.b64encode(buffer.getvalue())
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
renderer.load_prompt_embeds(encoded_tensor)
safe_load_prompt_embeds(model_config, encoded_tensor)

View File

@@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
data_parallel_rank,
):
return dict(engine_prompt), {}
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_chat
@@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "max_tokens" in resp.error.message
assert "context length is only" in resp.error.message
@pytest.mark.asyncio
@@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
resp = await serving_chat.create_chat_completion(req)
assert isinstance(resp, ErrorResponse)
assert "maximum context length" in resp.error.message
assert "context length is only" in resp.error.message
@pytest.mark.asyncio
@@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
serving_chat = _build_serving_chat(mock_engine)
orig_render_chat_request = serving_chat.render_chat_request
captured_prompts = []
async def render_chat_request(request):
result = await orig_render_chat_request(request)
assert isinstance(result, tuple)
conversation, engine_prompts = result
captured_prompts.extend(engine_prompts)
return result
serving_chat.render_chat_request = render_chat_request
# Test cache_salt
req = ChatCompletionRequest(
model=MODEL_NAME,
@@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
# By default, cache_salt in the engine prompt is not set
with suppress(Exception):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
assert "cache_salt" not in engine_prompt
assert len(captured_prompts) == 1
assert "cache_salt" not in captured_prompts[0]
captured_prompts.clear()
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"
assert len(captured_prompts) == 1
assert captured_prompts[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
@@ -1007,11 +1012,11 @@ class TestServingChatWithHarmony:
@pytest.fixture()
def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
return mock_engine
@pytest.fixture()
@@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser():
"""Test that tool_choice='required' or named tool without tool_parser
returns an appropriate error message."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
mock_engine.renderer = _build_renderer(mock_engine.model_config)
models = OpenAIServingModels(
engine_client=mock_engine,

View File

@@ -67,20 +67,6 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_zero_truncation_size(client: openai.AsyncOpenAI):
truncation_size = 0
kwargs: dict[str, Any] = {
"model": MODEL_NAME,
"input": input,
"truncate_prompt_tokens": truncation_size,
}
response = await client.post(path="embeddings", cast_to=object, body={**kwargs})
assert response["usage"]["prompt_tokens"] == truncation_size
@pytest.mark.asyncio
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
truncation_size = max_model_len + 1

View File

@@ -128,12 +128,10 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
server.url_for("classify"),
json={"model": model_name, "input": []},
)
classification_response.raise_for_status()
output = ClassificationResponse.model_validate(classification_response.json())
assert output.object == "list"
assert isinstance(output.data, list)
assert len(output.data) == 0
error = classification_response.json()
assert classification_response.status_code == 400
assert "error" in error
@pytest.mark.parametrize("model_name", [MODEL_NAME])

View File

@@ -247,7 +247,7 @@ class TestModel:
},
)
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in score_response.text
assert "Please request a smaller truncation size." in score_response.text
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]):
queries = "What is the capital of France?"

View File

@@ -1,325 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock
import pybase64
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
from vllm.inputs.data import is_embeds_prompt
@dataclass
class MockModelConfig:
max_model_len: int = 100
encoder_config: dict | None = None
enable_prompt_embeds: bool = True
class MockTokenizerResult:
def __init__(self, input_ids):
self.input_ids = input_ids
@pytest.fixture
def mock_model_config():
return MockModelConfig()
@pytest.fixture
def mock_tokenizer():
tokenizer = MagicMock()
return tokenizer
@pytest.fixture
def mock_async_tokenizer():
async_tokenizer = AsyncMock()
return async_tokenizer
@pytest.fixture
def renderer(mock_model_config, mock_tokenizer):
return CompletionRenderer(
model_config=mock_model_config,
tokenizer=mock_tokenizer,
async_tokenizer_pool={},
)
class TestRenderPrompt:
"""Test Category A: Basic Functionality Tests"""
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(
prompt_or_prompts=tokens, config=RenderConfig(max_length=100)
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)
)
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.assert_called_once()
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100)
)
assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.call_count == 3
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]
) # Truncated
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100, truncate_prompt_tokens=50),
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]
) # Truncated to max_model_len
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=200, truncate_prompt_tokens=-1),
)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
results = await renderer.render_prompt(
prompt_or_prompts=long_tokens,
config=RenderConfig(max_length=100, truncate_prompt_tokens=5),
)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100
with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(
prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100)
)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
renderer_no_tokenizer = CompletionRenderer(
model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={}
)
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
tokens = [1, 2, 3, 4]
results = await renderer.render_prompt(
prompt_or_prompts=tokens,
config=RenderConfig(needs_detokenization=True),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text"
mock_async_tokenizer.decode.assert_awaited_once()
class TestRenderEmbedPrompt:
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
"""Helper to create base64-encoded tensor bytes"""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio
async def test_single_prompt_embed(self, renderer):
# Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(cache_salt="test_salt"),
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
assert results[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors
test_tensors = [
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list,
config=RenderConfig(),
)
assert len(results) == 2
for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(truncate_prompt_tokens=10),
)
assert len(results) == 1
# Should keep last 10 tokens
expected = test_tensor[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected)
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
# Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
# Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103])
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_or_prompts="Hello world",
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 2
# First should be embed prompt
assert is_embeds_prompt(results[0])
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]

View File

@@ -96,7 +96,7 @@ def test_gemma_multimodal(
dtype="bfloat16",
) as vllm_model:
llm = vllm_model.get_llm()
prompts = llm.preprocess_chat(messages)
prompts = llm._preprocess_chat([messages])
result = llm.classify(prompts)
assert result[0].outputs.probs[0] > 0.95

View File

@@ -29,7 +29,8 @@ def test_smaller_truncation_size(
model_name, runner="pooling", max_model_len=max_model_len
) as vllm_model:
vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens
input_str,
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
prompt_tokens = vllm_output[0].prompt_token_ids
@@ -44,7 +45,8 @@ def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input
model_name, runner="pooling", max_model_len=max_model_len
) as vllm_model:
vllm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens
input_str,
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
prompt_tokens = vllm_output[0].prompt_token_ids
@@ -64,7 +66,8 @@ def test_bigger_truncation_size(
) as vllm_model,
):
llm_output = vllm_model.llm.embed(
input_str, truncate_prompt_tokens=truncate_prompt_tokens
input_str,
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
assert (

View File

@@ -187,7 +187,10 @@ def mteb_test_embed_models(
head_dtype = model_config.head_dtype
# Test embedding_size, isnan and whether to use normalize
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1)
vllm_outputs = vllm_model.embed(
example_prompts,
tokenization_kwargs=dict(truncate_prompt_tokens=-1),
)
outputs_tensor = torch.tensor(vllm_outputs)
assert not torch.any(torch.isnan(outputs_tensor))
embedding_size = model_config.embedding_size

View File

@@ -79,9 +79,9 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
outputs = self.llm.score(
queries,
corpus,
truncate_prompt_tokens=-1,
use_tqdm=False,
chat_template=self.chat_template,
tokenization_kwargs={"truncate_prompt_tokens": -1},
)
scores = np.array(outputs)
scores = scores[np.argsort(r)]

View File

@@ -0,0 +1,426 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock
import pybase64
import pytest
import torch
from vllm.inputs.data import is_embeds_prompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "auto"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
@pytest.fixture
def mock_model_config():
return MockModelConfig()
@pytest.fixture
def mock_async_tokenizer():
return AsyncMock()
@pytest.fixture
def renderer(mock_model_config):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config)
return HfRenderer(
mock_model_config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)
class TestValidatePrompt:
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self, renderer):
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
def test_invalid_type(self, renderer):
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, renderer, string_input: str):
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, renderer, token_input: list[int]):
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, renderer, inputs_slice: slice):
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
prompts = await renderer.render_completions_async(tokens)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = await renderer.render_completions_async(token_lists)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
assert results[2]["prompt_token_ids"] == [103, 4567]
@pytest.mark.asyncio
async def test_text_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
mock_async_tokenizer.encode.assert_called_once()
@pytest.mark.asyncio
async def test_text_list_input(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
text_list_input = ["Hello world", "How are you?", "Good morning"]
prompts = await renderer.render_completions_async(text_list_input)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 3
for result in results:
assert result["prompt_token_ids"] == [101, 7592, 2088]
assert mock_async_tokenizer.encode.call_count == 3
@pytest.mark.asyncio
async def test_no_truncation(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert (
"truncation" not in call_args.kwargs
or call_args.kwargs["truncation"] is False
)
@pytest.mark.asyncio
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=50,
),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.encode.return_value = [
101,
7592,
2088,
] # Truncated to max_model_len
renderer._async_tokenizer = mock_async_tokenizer
prompts = await renderer.render_completions_async("Hello world")
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=200,
truncate_prompt_tokens=-1,
),
)
assert len(results) == 1
call_args = mock_async_tokenizer.encode.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 200
@pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = await renderer.render_completions_async(long_tokens)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=100,
truncate_prompt_tokens=5,
),
)
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
@pytest.mark.asyncio
async def test_max_length_exceeded(self, renderer):
long_tokens = list(range(150)) # Exceeds max_model_len=100
prompts = await renderer.render_completions_async(long_tokens)
with pytest.raises(ValueError, match="context length is only"):
await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, renderer):
renderer_no_tokenizer = HfRenderer.from_config(
MockModelConfig(skip_tokenizer_init=True),
tokenizer_kwargs={},
)
prompts = await renderer_no_tokenizer.render_completions_async("Hello world")
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
await renderer_no_tokenizer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=100),
)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer
):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer._async_tokenizer = mock_async_tokenizer
tokens = [1, 2, 3, 4]
prompts = await renderer.render_completions_async(tokens)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
needs_detokenization=True,
),
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text"
mock_async_tokenizer.decode.assert_awaited_once()
class TestRenderEmbedPrompt:
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
"""Helper to create base64-encoded tensor bytes"""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio
async def test_single_prompt_embed(self, renderer):
# Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors
test_tensors = [
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
prompts = await renderer.render_completions_async(
prompt_embeds=embed_bytes_list
)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 2
for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(
max_total_tokens=renderer.config.max_model_len,
truncate_prompt_tokens=10,
),
)
assert len(results) == 1
# Should keep last 10 tokens
expected = test_tensor[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected)
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
# Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 1
# Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.encode.return_value = [101, 102, 103]
renderer._async_tokenizer = mock_async_tokenizer
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
prompts = await renderer.render_completions_async(
"Hello world",
prompt_embeds=embed_bytes,
)
results = await renderer.tokenize_prompts_async(
prompts,
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
)
assert len(results) == 2
# First should be embed prompt
assert is_embeds_prompt(results[0])
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]

View File

@@ -9,6 +9,7 @@ import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer
@@ -27,7 +28,7 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([])
task = mock_renderer.render_messages_async([], ChatParams())
# Ensure the event loop is not blocked
blocked_count = 0

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Sparse tensor validation in embedding APIs.
Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations.
"""
@@ -13,8 +11,24 @@ import io
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO
from vllm.renderers.embed_utils import safe_load_prompt_embeds
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig
return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)
def _encode_tensor(tensor: torch.Tensor) -> bytes:
@@ -63,15 +77,12 @@ class TestPromptEmbedsValidation:
def test_valid_dense_tensor_accepted(self, model_config):
"""Baseline: Valid dense tensors should work normally."""
renderer = CompletionRenderer(model_config)
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = renderer.load_prompt_embeds(encoded)
assert len(result) == 1
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
result = safe_load_prompt_embeds(model_config, encoded)
assert result.shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
@@ -86,14 +97,12 @@ class TestPromptEmbedsValidation:
def test_malicious_sparse_tensor_rejected(self, model_config):
"""Security: Malicious sparse tensors should be rejected."""
renderer = CompletionRenderer(model_config)
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
renderer.load_prompt_embeds(encoded)
safe_load_prompt_embeds(model_config, encoded)
# Error should indicate sparse tensor validation failure
error_msg = str(exc_info.value).lower()
@@ -101,8 +110,6 @@ class TestPromptEmbedsValidation:
def test_extremely_large_indices_rejected(self, model_config):
"""Security: Sparse tensors with extremely large indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with indices far beyond reasonable bounds
indices = torch.tensor([[999999], [999999]])
values = torch.tensor([1.0])
@@ -114,12 +121,10 @@ class TestPromptEmbedsValidation:
encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)
safe_load_prompt_embeds(model_config, encoded)
def test_negative_indices_rejected(self, model_config):
"""Security: Sparse tensors with negative indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with negative indices
indices = torch.tensor([[-1], [-1]])
values = torch.tensor([1.0])
@@ -131,7 +136,7 @@ class TestPromptEmbedsValidation:
encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)
safe_load_prompt_embeds(model_config, encoded)
class TestImageEmbedsValidation:
@@ -253,14 +258,12 @@ class TestSparseTensorValidationIntegration:
3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs
"""
renderer = CompletionRenderer(model_config)
# Step 1-2: Attacker creates malicious payload
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
# Step 3-4: Server processes and should reject
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(attack_payload)
safe_load_prompt_embeds(model_config, attack_payload)
def test_attack_scenario_chat_api_image(self):
"""
@@ -285,57 +288,3 @@ class TestSparseTensorValidationIntegration:
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))
def test_multiple_valid_embeddings_in_batch(self, model_config):
"""
Regression test: Multiple valid embeddings should still work.
Ensures the fix doesn't break legitimate batch processing.
"""
renderer = CompletionRenderer(model_config)
valid_tensors = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
]
# Should process all without error
result = renderer.load_prompt_embeds(valid_tensors)
assert len(result) == 3
def test_mixed_valid_and_malicious_rejected(self, model_config):
"""
Security: Batch with one malicious tensor should be rejected.
Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer = CompletionRenderer(model_config)
mixed_batch = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
_encode_tensor(_create_valid_dense_tensor()),
]
# Should fail on the malicious tensor
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(mixed_batch)
# Pytest fixtures
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig
return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)

View File

@@ -5,65 +5,10 @@ import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.parse import parse_raw_prompts
from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
def test_invalid_input_raise_type_error(invalid_input):
with pytest.raises(TypeError):
parse_raw_prompts(invalid_input)
def test_parse_raw_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_raw_prompts([[]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_parse_raw_single_batch_string_consistent(string_input: str):
assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
STRING_INPUTS[inputs_slice]
)
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",

View File

@@ -768,7 +768,7 @@ class ModelConfig:
)
self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self):
def _get_encoder_config(self) -> dict[str, Any] | None:
model = self.model
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window: int | None,
spec_target_max_model_len: int | None = None,
encoder_config: Any | None = None,
encoder_config: dict[str, Any] | None = None,
) -> int:
"""Get and verify the model's maximum length."""
(derived_max_model_len, max_len_key) = (

View File

@@ -72,14 +72,9 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove this argument in v0.15.
"""
"""Generate outputs for a request from a pooling model."""
...
@abstractmethod

View File

@@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, TypeAlias, cast
import cloudpickle
import torch.nn as nn
@@ -46,15 +47,17 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids,
get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import (
DataPrompt,
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.inputs.parse import get_prompt_components
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -67,6 +70,7 @@ from vllm.outputs import (
)
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
@@ -74,7 +78,6 @@ from vllm.tokenizers.mistral import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
@@ -85,6 +88,9 @@ logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
@@ -372,6 +378,7 @@ class LLM:
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]:
"""Generates the completions for the input prompts.
@@ -398,15 +405,11 @@ class LLM:
If provided, must be a list of integers matching the length
of `prompts`, where each priority value corresponds to the prompt
at the same index.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
model_config = self.model_config
runner_type = model_config.runner_type
@@ -418,17 +421,14 @@ class LLM:
)
if sampling_params is None:
# Use default sampling params.
sampling_params = self.get_default_sampling_params()
# Add any modality specific loras to the corresponding prompts
lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
self._validate_and_add_requests(
prompts=prompts,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
@@ -771,65 +771,169 @@ class LLM:
return outputs
def preprocess_chat(
def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _normalize_prompts(
self,
messages: list[ChatCompletionMessageParam]
prompts: PromptType | Sequence[PromptType],
) -> list[EnginePrompt | EngineEncDecPrompt]:
if isinstance(prompts, str):
prompts = TextPrompt(prompt=prompts)
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
def _preprocess_cmpl_singleton(
self,
prompt: SingletonPrompt,
tok_params: TokenizeParams,
*,
tokenize: bool,
) -> EnginePrompt:
renderer = self.llm_engine.renderer
if not isinstance(prompt, dict):
prompt = renderer.render_completion(prompt)
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
def _preprocess_cmpl_enc_dec(
self,
prompt: ExplicitEncoderDecoderPrompt,
tok_params: TokenizeParams,
) -> EngineEncDecPrompt:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return EngineEncDecPrompt(
encoder_prompt=self._preprocess_cmpl_singleton(
enc_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
),
decoder_prompt=(
None
if dec_prompt is None
else self._preprocess_cmpl_singleton(
dec_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
)
),
)
def _preprocess_completion(
self,
prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt | EngineEncDecPrompt]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
Refer to [LLM.generate][] for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
for prompt in self._normalize_prompts(prompts):
if is_explicit_encoder_decoder_prompt(prompt):
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
else:
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
self._preprocess_cmpl_singleton(
prompt,
tok_params,
tokenize=not self.model_config.is_multimodal_model,
)
)
return engine_prompts
def _normalize_conversations(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
) -> list[list[ChatCompletionMessageParam]]:
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
).with_kwargs(tokenization_kwargs)
def _preprocess_chat(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[TextPrompt | TokensPrompt]:
) -> list[EnginePrompt]:
"""
Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
Refer to [LLM.chat][] for a complete description of the arguments.
Refer to `chat` for a complete description of the arguments.
Returns:
A list of `TokensPrompts` objects containing the tokenized
prompt after chat template interpolation, and the
pre-processed multi-modal inputs.
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
list_of_messages: list[list[ChatCompletionMessageParam]]
# Handle multi and single conversations
if is_list_of(messages, list):
# messages is list[list[...]]
list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
else:
# messages is list[...]
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
renderer = self.llm_engine.renderer
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tools,
**(chat_template_kwargs or {}),
}
chat_params = ChatParams(
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=merge_kwargs(
chat_template_kwargs,
dict(
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
),
),
)
tok_params = self._get_chat_tok_params(tokenization_kwargs)
prompts = list[TextPrompt | TokensPrompt]()
for msgs in list_of_messages:
# NOTE: renderer.render_messages() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
_, prompt = renderer.render_messages(
msgs,
chat_template_content_format=chat_template_content_format,
**chat_template_kwargs,
)
engine_prompts = list[EnginePrompt]()
for conversation in self._normalize_conversations(conversations):
_, in_prompt = renderer.render_messages(conversation, chat_params)
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
prompts.append(prompt)
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
return prompts
return engine_prompts
def chat(
self,
@@ -844,6 +948,7 @@ class LLM:
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[RequestOutput]:
"""
@@ -889,22 +994,22 @@ class LLM:
`True` if `add_generation_prompt` is also `True`.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
tokenization_kwargs: Overrides for `tokenizer.encode`.
mm_processor_kwargs: Overrides for `processor.__call__`.
Returns:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts = self.preprocess_chat(
messages=messages,
prompts = self._preprocess_chat(
messages,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
chat_template_kwargs=chat_template_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
@@ -913,6 +1018,7 @@ class LLM:
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
def encode(
@@ -945,37 +1051,29 @@ class LLM:
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_task: Override the pooling task to use.
tokenization_kwargs: overrides tokenization_kwargs set in
pooling_params
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
Note:
Using `prompts` and `prompt_token_ids` as keyword parameters is
considered legacy and may be deprecated in the future. You should
instead pass them via the `inputs` parameter.
"""
error_str = (
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
if pooling_task is None:
raise ValueError(error_str)
raise ValueError(
"pooling_task required for `LLM.encode`\n"
"Please use one of the more specific methods or set the "
"pooling_task when using `LLM.encode`:\n"
" - For embeddings, use `LLM.embed(...)` "
'or `pooling_task="embed"`.\n'
" - For classification logits, use `LLM.classify(...)` "
'or `pooling_task="classify"`.\n'
" - For similarity scores, use `LLM.score(...)`.\n"
" - For rewards, use `LLM.reward(...)` "
'or `pooling_task="token_classify"`\n'
" - For token classification, "
'use `pooling_task="token_classify"`\n'
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
)
model_config = self.model_config
runner_type = model_config.runner_type
@@ -986,6 +1084,20 @@ class LLM:
"pooling model."
)
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
@@ -1017,19 +1129,16 @@ class LLM:
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
)
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
self._validate_and_add_requests(
prompts=prompts,
@@ -1094,6 +1203,7 @@ class LLM:
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
@@ -1105,9 +1215,14 @@ class LLM:
"Try converting the model using `--convert embed`."
)
if truncate_prompt_tokens is not None:
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
items = self.encode(
prompts,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
@@ -1121,8 +1236,8 @@ class LLM:
self,
prompts: PromptType | Sequence[PromptType],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[ClassificationRequestOutput]:
@@ -1137,13 +1252,15 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `ClassificationRequestOutput` objects containing the
embedding vectors in the same order as the input prompts.
@@ -1170,9 +1287,9 @@ class LLM:
prompts: PromptType | Sequence[PromptType],
/,
*,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[PoolingRequestOutput]:
@@ -1183,13 +1300,15 @@ class LLM:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See [PromptType][vllm.inputs.PromptType]
for more details about the format of each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
@@ -1207,18 +1326,18 @@ class LLM:
def _embedding_score(
self,
tokenizer: TokenizerLike,
text_1: list[str | TextPrompt | TokensPrompt],
text_2: list[str | TextPrompt | TokensPrompt],
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
text_1: list[SingletonPrompt],
text_2: list[SingletonPrompt],
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
) -> list[ScoringRequestOutput]:
encoded_output: list[PoolingRequestOutput] = self.encode(
tokenizer = self.get_tokenizer()
encoded_output = self.encode(
text_1 + text_2,
truncate_prompt_tokens=truncate_prompt_tokens,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
@@ -1226,14 +1345,16 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
encoded_output_1 = encoded_output[0 : len(text_1)]
encoded_output_2 = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
scores = _cosine_similarity(
tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
tokenizer=tokenizer,
embed_1=encoded_output_1,
embed_2=encoded_output_2,
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
@@ -1241,17 +1362,17 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
score_template: str | None = None,
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
score_template: str | None,
) -> list[ScoringRequestOutput]:
model_config = self.model_config
tokenizer = self.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
raise ValueError("Score API is not supported for Mistral tokenizer")
@@ -1265,13 +1386,6 @@ class LLM:
pooling_params.verify("score", model_config)
pooling_params_list = list[PoolingParams]()
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size(
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
)
prompts = list[PromptType]()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
@@ -1314,10 +1428,10 @@ class LLM:
data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
/,
*,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
chat_template: str | None = None,
) -> list[ScoringRequestOutput]:
"""Generate similarity scores for all pairs `<text,text_pair>` or
@@ -1344,20 +1458,22 @@ class LLM:
the LLM. Can be text or multi-modal data. See [PromptType]
[vllm.inputs.PromptType] for more details about the format of
each prompt.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: If `True`, shows a tqdm progress bar.
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
it is used to create the progress bar.
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
chat_template: The chat template to use for the scoring. If None, we
use the model's default chat template.
tokenization_kwargs: Overrides for `tokenizer.encode`.
Returns:
A list of `ScoringRequestOutput` objects containing the
generated scores in the same order as the input prompts.
"""
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "pooling":
raise ValueError(
@@ -1445,26 +1561,27 @@ class LLM:
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder:
return self._cross_encoding_score(
tokenizer,
data_1, # type: ignore[arg-type]
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
pooling_params,
lora_request,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
score_template=chat_template,
)
else:
return self._embedding_score(
tokenizer,
data_1, # type: ignore[arg-type]
data_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
pooling_params,
lora_request,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
)
def start_profile(self) -> None:
@@ -1530,42 +1647,79 @@ class LLM:
def _validate_and_add_requests(
self,
prompts: PromptType | Sequence[PromptType] | DataPrompt,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| Sequence[SamplingParams]
| PoolingParams
| Sequence[PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest] | LoRARequest | None,
priority: list[int] | None = None,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts] # type: ignore[list-item]
in_prompts = self._normalize_prompts(prompts)
num_requests = len(in_prompts)
num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params must be the same.")
if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
raise ValueError(
"The lengths of prompts and lora_request must be the same."
)
if priority is not None and len(priority) != num_requests:
raise ValueError(
"The lengths of prompts "
f"({num_requests}) and priority ({len(priority)}) "
"must be the same."
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({params}) "
f"and lora_request ({len(params)}) must be the same."
)
engine_params = params
else:
engine_params = [params] * num_requests
if isinstance(lora_request, Sequence):
if len(lora_request) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and lora_request ({len(lora_request)}) must be the same."
)
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
else:
engine_lora_requests = [lora_request] * num_requests
if priority is not None:
if len(priority) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and priority ({len(priority)}) must be the same."
)
else:
priority = [0] * num_requests
if any(param.truncate_prompt_tokens is not None for param in engine_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
engine_prompts = [
engine_prompt
for in_prompt, param in zip(in_prompts, engine_params)
for engine_prompt in self._preprocess_completion(
[in_prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_completion(
in_prompts,
tokenization_kwargs=tokenization_kwargs,
)
for sp in params if isinstance(params, Sequence) else (params,):
for sp in engine_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
it = engine_prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
@@ -1576,12 +1730,10 @@ class LLM:
for i, prompt in enumerate(it):
request_id = self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i]
if isinstance(lora_request, Sequence)
else lora_request,
priority=priority[i] if priority else 0,
engine_params[i],
lora_request=engine_lora_requests[i],
tokenization_kwargs=tokenization_kwargs,
priority=priority[i],
)
added_request_ids.append(request_id)
except Exception as e:
@@ -1589,54 +1741,42 @@ class LLM:
self.llm_engine.abort_request(added_request_ids, internal=True)
raise e
def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: SamplingParams | PoolingParams,
*,
lora_request: LoRARequest | None,
priority: int,
tokenization_kwargs: dict[str, Any] | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for LLMEngine."""
local_kwargs = tokenization_kwargs or {}
tokenization_kwargs = local_kwargs.copy()
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return engine_request, tokenization_kwargs
def _add_request(
self,
prompt: PromptType,
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
priority: int = 0,
tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str:
prompt_text, _, _ = get_prompt_components(prompt)
request_id = str(next(self.request_counter))
engine_request, tokenization_kwargs = self._process_inputs(
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id,
prompt,
params,
lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
self.llm_engine.add_request(

View File

@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import (
Field,
model_validator,
)
from pydantic import Field, model_validator
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
DeltaMessage,
@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel):
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
documents=self.documents,
reasoning_effort=self.reasoning_effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
max_output_tokens_param = "max_completion_tokens"
else:
max_output_tokens = self.max_tokens
max_output_tokens_param = "max_tokens"
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=max_output_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param=max_output_tokens_param,
)
# Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,

View File

@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing):
start_time = time.perf_counter()
try:
renderer = self.engine_client.renderer
# Create a minimal dummy request
dummy_request = ChatCompletionRequest(
messages=[{"role": "user", "content": "warmup"}],
@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing):
# 3. Tokenizer initialization for chat
await self._preprocess_chat(
dummy_request,
renderer,
dummy_request.messages,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=True,
continue_final_message=False,
tool_dicts=None,
documents=None,
chat_template_kwargs=None,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=None,
add_special_tokens=False,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
)
elapsed = (time.perf_counter() - start_time) * 1000
@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse:
) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
"""
render chat request by validating and preprocessing inputs.
@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
chat_template_kwargs = request.chat_template_kwargs or {}
chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
conversation, engine_prompts = await self._preprocess_chat(
request,
renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=chat_template_kwargs,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=tool_parser,
add_special_tokens=request.add_special_tokens,
)
else:
# For GPT-OSS.
@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,

View File

@@ -9,11 +9,9 @@ from dataclasses import replace
from typing import Annotated, Any, Literal
import torch
from pydantic import (
Field,
model_validator,
)
from pydantic import Field, model_validator
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import (
AnyResponseFormat,
LegacyStructuralTagResponseFormat,
@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RequestOutputKind,
@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel):
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_tokens or 0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
add_special_tokens=self.add_special_tokens,
needs_detokenization=bool(self.echo and not self.return_token_ids),
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_tokens",
)
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,

View File

@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import (
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing):
)
try:
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
config=self._build_render_config(request),
)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
# Mypy inconsistently requires this second cast in different
# environments. It shouldn't be necessary (redundant from above)
# but pre-commit in CI fails without it.
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search(
prompt=engine_prompt,
@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing):
trace_headers=trace_headers,
)
else:
engine_request, tokenization_kwargs = await self._process_inputs(
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
request_id_item,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: int | None = None,
) -> RenderConfig:
# Validate max_tokens before using it
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
raise VLLMValidationError(
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
f"the model's maximum context length ({self.max_model_len}).",
parameter="max_tokens",
value=request.max_tokens,
)
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo and not request.return_token_ids),
)

View File

@@ -16,9 +16,7 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger
from vllm.sampling_params import (
SamplingParams,
)
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname

View File

@@ -5,10 +5,10 @@ import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
from collections.abc import AsyncGenerator, Callable, Mapping
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
import numpy as np
from fastapi import Request
@@ -20,6 +20,7 @@ from starlette.datastructures import Headers
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
@@ -86,7 +87,6 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreResponse,
ScoreTextRequest,
)
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
@@ -94,13 +94,9 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import (
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
@@ -112,7 +108,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.renderers import RendererLike
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
@@ -123,11 +119,9 @@ from vllm.tracing import (
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
merge_async_iterators,
)
from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception):
@@ -140,6 +134,21 @@ class GenerationError(Exception):
logger = init_logger(__name__)
class RendererRequest(Protocol):
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
raise NotImplementedError
class RendererChatRequest(RendererRequest, Protocol):
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
raise NotImplementedError
CompletionLikeRequest: TypeAlias = (
CompletionRequest
| TokenizeCompletionRequest
@@ -158,7 +167,9 @@ ChatLikeRequest: TypeAlias = (
| ClassificationChatRequest
| PoolingChatRequest
)
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
AnyRequest: TypeAlias = (
CompletionLikeRequest
| ChatLikeRequest
@@ -193,7 +204,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt] | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
@@ -227,7 +238,6 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor
@@ -519,41 +529,6 @@ class OpenAIServing:
prompt_logprobs=None,
)
def _get_completion_renderer(self) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=self.renderer.tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool,
)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess(
self,
ctx: ServeContext,
@@ -912,71 +887,6 @@ class OpenAIServing:
message_types.add(content_dict["type"].split("_")[0])
return message_types
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
prompt: str,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (
self.model_config.encoder_config is not None
and self.model_config.encoder_config.get("do_lower_case", False)
):
prompt = prompt.lower()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens
)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len,
)
else:
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens,
)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
async def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: TokenizerLike | None,
) -> TokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
input_ids = prompt_ids[-self.max_model_len :]
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
if tokenizer is None:
input_text = ""
else:
async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: object,
@@ -1061,50 +971,6 @@ class OpenAIServing:
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TokensPrompt:
"""
A simpler implementation that tokenizes a single prompt input.
"""
async for result in self._tokenize_prompt_inputs_async(
request,
tokenizer,
[prompt_input],
add_special_tokens=add_special_tokens,
):
return result
raise ValueError("No results yielded from tokenization")
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TokensPrompt, None]:
"""
A simpler implementation that tokenizes multiple prompt inputs.
"""
for prompt in prompt_inputs:
if isinstance(prompt, str):
yield await self._normalize_prompt_text_to_input(
request,
prompt=prompt,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
)
else:
yield await self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt,
tokenizer=tokenizer,
)
def _validate_chat_template(
self,
request_chat_template: str | None,
@@ -1137,131 +1003,94 @@ class OpenAIServing:
# Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs
async def _preprocess_completion(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]:
renderer = self.renderer
tok_params = request.build_tok_params(self.model_config)
in_prompts = await renderer.render_completions_async(
prompt_input, prompt_embeds
)
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
for prompt in engine_prompts:
prompt.update(extra_items) # type: ignore
return engine_prompts
async def _preprocess_chat(
self,
request: ChatLikeRequest | ResponsesRequest,
renderer: RendererLike,
request: RendererChatRequest,
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
default_chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tool_dicts,
"documents": documents,
**(chat_template_kwargs or {}),
}
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
chat_template_kwargs,
default_chat_template_kwargs,
)
# Use the async tokenizer in `OpenAIServing` if possible.
# Later we can move it into the renderer so that we can return both
# text and token IDs in the same prompt from `render_messages_async`
# which is used for logging and `enable_response_messages`.
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer
conversation, engine_prompt = await renderer.render_messages_async(
messages,
chat_template_content_format=chat_template_content_format,
tokenize=(
chat_template_kwargs.pop("tokenize", False)
or isinstance(renderer.tokenizer, MistralTokenizer)
renderer = self.renderer
default_template_kwargs = merge_kwargs(
default_template_kwargs,
dict(
tools=tool_dicts,
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
),
**chat_template_kwargs,
)
if "prompt_token_ids" not in engine_prompt:
extra_data = engine_prompt
engine_prompt = await self._tokenize_prompt_input_async(
request,
renderer.get_tokenizer(),
engine_prompt["prompt"],
add_special_tokens=add_special_tokens,
)
# Fill in other keys like MM data
engine_prompt.update(extra_data) # type: ignore
else:
self._validate_input(
request=request,
input_ids=engine_prompt["prompt_token_ids"], # type: ignore
input_text="",
)
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
engine_prompt = cast(TokensPrompt, engine_prompt)
conversation, prompt = await renderer.render_messages_async(
messages, chat_params
)
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if (cache_salt := getattr(request, "cache_salt", None)) is not None:
engine_prompt["cache_salt"] = cache_salt
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
engine_prompt.update(extra_items) # type: ignore
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
should_parse_tools = tool_parser is not None and (
hasattr(request, "tool_choice") and request.tool_choice != "none"
)
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported for Chat Completions API "
"or Responses API requests."
)
raise NotImplementedError(msg)
if should_parse_tools:
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported for Chat Completions API "
"or Responses API requests."
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore
# TODO: Update adjust_request to accept ResponsesRequest
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
return conversation, [engine_prompt]
async def _process_inputs(
self,
request_id: str,
engine_prompt: PromptType,
params: SamplingParams | PoolingParams,
*,
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(
self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
)
engine_request = self.input_processor.process_inputs(
request_id,
engine_prompt,
params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs
async def _render_next_turn(
self,
request: ResponsesRequest,
renderer: RendererLike,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
@@ -1271,24 +1100,25 @@ class OpenAIServing:
_, engine_prompts = await self._preprocess_chat(
request,
renderer,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokensPrompt,
engine_prompt: TokensPrompt | EmbedsPrompt,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
**kwargs,
trace_headers: Mapping[str, str] | None = None,
):
prompt_text, _, _ = get_prompt_components(engine_prompt)
@@ -1297,18 +1127,21 @@ class OpenAIServing:
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
trace_headers = kwargs.get("trace_headers")
engine_request, tokenization_kwargs = await self._process_inputs(
tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs(
sub_request_id,
engine_prompt,
sampling_params,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
)
@@ -1318,10 +1151,10 @@ class OpenAIServing:
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
**kwargs,
)
async for res in generator:
@@ -1350,7 +1183,6 @@ class OpenAIServing:
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
context.request,
context.renderer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,

View File

@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import (
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import RendererLike
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid
@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext):
self,
*,
response_messages: list[ResponseInputOutputItem],
renderer: RendererLike,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext):
if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.")
tokenizer = renderer.get_tokenizer()
self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext):
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.renderer = renderer
self.tokenizer = tokenizer
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}

View File

@@ -59,12 +59,15 @@ from pydantic import (
model_validator,
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
RequestOutputKind,
SamplingParams,
@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel):
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
# --8<-- [end:responses-extra-params]
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
from .utils import should_continue_final_message
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(self.input)
reasoning = self.reasoning
return ChatParams(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=self.max_output_tokens or 0,
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_output_tokens",
)
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,

View File

@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import (
construct_input_messages,
construct_tool_dicts,
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server
def _validate_generator_input(
self, engine_prompt: TokensPrompt
self,
engine_prompt: TokensPrompt | EmbedsPrompt,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
prompt_len = get_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = (
"The engine prompt length"
f" {len(engine_prompt['prompt_token_ids'])} "
f"The engine prompt length {prompt_len} "
f"exceeds the max_model_len {self.max_model_len}. "
"Please reduce prompt."
)
@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing):
status_code=HTTPStatus.BAD_REQUEST,
param="input",
)
return None
def _validate_create_responses_input(
@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing):
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response, renderer
request, prev_response
)
except (
@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0
available_tools = []
try:
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None:
@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)
tok_params = request.build_tok_params(self.model_config)
trace_headers = (
None
@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing):
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
renderer=renderer,
tokenizer=tokenizer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing):
request_id=request.request_id,
engine_prompt=engine_prompt,
sampling_params=sampling_params,
tok_params=tok_params,
context=context,
lora_request=lora_request,
priority=request.priority,
@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
renderer: RendererLike,
):
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
# Construct the input messages.
@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None,
)
# Check if we should continue the final message (partial completion)
# This enables Anthropic-style partial message completion where the
# user provides an incomplete assistant message to continue from.
continue_final = should_continue_final_message(request.input)
chat_template_kwargs = dict(
reasoning_effort=None
if request.reasoning is None
else request.reasoning.effort
)
_, engine_prompts = await self._preprocess_chat(
request,
renderer,
messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=self.tool_parser,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# When continuing a partial message, we set continue_final_message=True
# and add_generation_prompt=False so the model continues the message
# rather than starting a new one.
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
chat_template_kwargs=chat_template_kwargs,
)
return messages, engine_prompts

View File

@@ -8,8 +8,12 @@ from pydantic import Field, model_validator
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel):
)
return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
class EncodingRequestMixin(OpenAIBaseModel):
# --8<-- [start:encoding-params]

View File

@@ -4,10 +4,9 @@
import time
from typing import Any, TypeAlias
from pydantic import (
Field,
)
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import (
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
class ClassificationCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
):
pass
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
class ClassificationChatRequest(
@@ -33,6 +43,18 @@ class ClassificationChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
ClassificationRequest: TypeAlias = (
ClassificationCompletionRequest | ClassificationChatRequest

View File

@@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from typing import Final, cast
from typing import Final, TypeAlias
import jinja2
import numpy as np
@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.outputs import ClassificationOutput
from vllm.pooling_params import PoolingParams
logger = init_logger(__name__)
ClassificationServeContext = ServeContext[ClassificationRequest]
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
class ServingClassification(OpenAIServing):
@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing):
if error_check_ret:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
ctx.engine_prompts = engine_prompts
elif isinstance(ctx.request, ClassificationCompletionRequest):
input_data = ctx.request.input
if input_data in (None, ""):
return self.create_error_response(
"Input or messages must be provided",
status_code=HTTPStatus.BAD_REQUEST,
)
if isinstance(input_data, list) and not input_data:
ctx.engine_prompts = []
return None
renderer = self._get_completion_renderer()
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,
config=self._build_render_config(ctx.request),
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing):
items: list[ClassificationData] = []
num_prompt_tokens = 0
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
final_res_batch_checked = ctx.final_res_batch
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing):
usage=usage,
)
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
async def create_classify(
self,
request: ClassificationRequest,

View File

@@ -3,10 +3,9 @@
import time
from typing import Any, TypeAlias
from pydantic import (
Field,
)
from pydantic import Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
def _get_max_total_output_tokens(
model_config: ModelConfig,
) -> tuple[int | None, int]:
max_total_tokens = model_config.max_model_len
pooler_config = model_config.pooler_config
if pooler_config is None:
return max_total_tokens, 0
if pooler_config.enable_chunked_processing:
return None, 0
max_embed_len = pooler_config.max_embed_len or max_total_tokens
max_output_tokens = max_total_tokens - max_embed_len
return max_total_tokens, max_output_tokens
class EmbeddingCompletionRequest(
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
pass
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
class EmbeddingChatRequest(
@@ -33,6 +64,24 @@ class EmbeddingChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
(
max_total_tokens,
max_output_tokens,
) = _get_max_total_output_tokens(model_config)
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=max_output_tokens,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
max_output_tokens_param="max_model_len - max_embed_len",
)
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast
from typing import Any, Final, TypeAlias
import torch
from fastapi import Request
@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponse,
EmbeddingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import (
logger = init_logger(__name__)
EmbeddingServeContext = ServeContext[EmbeddingRequest]
EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
class OpenAIServingEmbedding(OpenAIServing):
@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing):
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=ctx.request.add_generation_prompt,
continue_final_message=ctx.request.continue_final_message,
add_special_tokens=ctx.request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(ctx.request, EmbeddingCompletionRequest):
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),
ctx.engine_prompts = await self._preprocess_completion(
ctx.request,
prompt_input=ctx.request.input,
prompt_embeds=None,
)
else:
return self.create_error_response("Invalid classification request type")
@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
def _build_response(
self,
ctx: EmbeddingServeContext,
@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Create generator for this chunk and wrap it to return indices
original_generator = self.engine_client.encode(
chunk_engine_prompt,
pooling_params,
chunk_request_id,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
priority=ctx.request.priority,
)
generators.append(original_generator)
@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt,
engine_prompt: TokensPrompt | EmbedsPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_request=ctx.lora_request,
)
tok_params = ctx.request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
priority=ctx.request.priority,
)
async def _prepare_generators(
self,
ctx: ServeContext,
ctx: EmbeddingServeContext,
) -> ErrorResponse | None:
"""Override to support chunked processing."""
ctx = cast(EmbeddingServeContext, ctx)
# Check if we should use chunked processing
use_chunked = self._should_use_chunked_processing(ctx.request)
@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing
if "prompt_token_ids" in engine_prompt:
prompt_token_ids = engine_prompt["prompt_token_ids"]
prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
if len(prompt_token_ids) > max_pos_embeddings:
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"token IDs"
)
original_token_ids = original_prompt["prompt_token_ids"]
original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"],

View File

@@ -3,11 +3,10 @@
import time
from typing import Any, Generic, TypeAlias, TypeVar
from pydantic import (
Field,
)
from pydantic import Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
@@ -30,6 +30,18 @@ class PoolingCompletionRequest(
):
task: PoolingTask | None = None
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
@@ -48,6 +60,18 @@ class PoolingChatRequest(
description=("Additional kwargs to pass to the HF processor."),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=self.add_special_tokens,
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,

View File

@@ -5,7 +5,7 @@ import asyncio
import json
import time
from collections.abc import AsyncGenerator, Sequence
from typing import Final, cast
from typing import Any, Final, cast
import jinja2
from fastapi import Request
@@ -14,10 +14,7 @@ from typing_extensions import assert_never
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
UsageInfo,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.pooling.protocol import (
@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
PoolingResponse,
PoolingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import PoolingTask, SupportedTask
@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens
)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
self.renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
config=self._build_render_config(request),
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,
)
else:
raise ValueError(f"Unsupported request of type {type(request)}")
@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
if is_io_processor_request:
tokenization_kwargs: dict[str, Any] = {}
else:
tok_params = request.build_tok_params(self.model_config) # type: ignore
tokenization_kwargs = tok_params.get_encode_kwargs()
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)
@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing):
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)

View File

@@ -3,12 +3,10 @@
import time
from typing import Any, TypeAlias
from pydantic import (
BaseModel,
Field,
)
from pydantic import BaseModel, Field
from vllm import PoolingParams
from vllm.config import ModelConfig
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
@@ -19,6 +17,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
ScoreMultiModalParam,
)
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
@@ -30,6 +29,17 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
# --8<-- [end:score-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
@@ -85,6 +95,17 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
)
# --8<-- [end:rerank-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
max_output_tokens=0,
truncate_prompt_tokens=self.truncate_prompt_tokens,
do_lower_case=encoder_config.get("do_lower_case", False),
max_total_tokens_param="max_model_len",
)
class RerankDocument(BaseModel):
text: str | None = None

View File

@@ -34,7 +34,6 @@ from vllm.entrypoints.pooling.score.utils import (
compress_token_type_ids,
get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -68,31 +67,31 @@ class ServingScores(OpenAIServing):
async def _embedding_score(
self,
tokenizer: TokenizerLike,
data_1: list[str],
data_2: list[str],
request: RerankRequest | ScoreRequest,
request_id: str,
tokenization_kwargs: dict[str, Any] | None = None,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
encode_async = make_async(
tokenizer.encode,
executor=self._tokenizer_executor,
)
input_texts = data_1 + data_2
engine_prompts: list[TokensPrompt] = []
tokenize_async = make_async(
tokenizer.__call__, executor=self._tokenizer_executor
)
tokenization_kwargs = tokenization_kwargs or {}
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
tokenized_prompts = await asyncio.gather(
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts)
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(
request, tok_result["input_ids"], input_text
)
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
@@ -184,15 +183,16 @@ class ServingScores(OpenAIServing):
async def _cross_encoding_score(
self,
tokenizer: TokenizerLike,
data_1: list[str] | list[ScoreContentPartParam],
data_2: list[str] | list[ScoreContentPartParam],
request: RerankRequest | ScoreRequest,
request_id: str,
tokenization_kwargs: dict[str, Any] | None = None,
lora_request: LoRARequest | None | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
request_prompts: list[str] = []
engine_prompts: list[TokensPrompt] = []
@@ -202,12 +202,13 @@ class ServingScores(OpenAIServing):
if isinstance(tokenizer, MistralTokenizer):
raise ValueError("MistralTokenizer not supported for cross-encoding")
tokenization_kwargs = tokenization_kwargs or {}
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
preprocess_async = make_async(
self._preprocess_score, executor=self._tokenizer_executor
self._preprocess_score,
executor=self._tokenizer_executor,
)
preprocessed_prompts = await asyncio.gather(
@@ -215,7 +216,7 @@ class ServingScores(OpenAIServing):
preprocess_async(
request=request,
tokenizer=tokenizer,
tokenization_kwargs=tokenization_kwargs,
tokenization_kwargs=tok_kwargs,
data_1=t1,
data_2=t2,
)
@@ -286,14 +287,6 @@ class ServingScores(OpenAIServing):
raw_request: Request | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
lora_request = self._maybe_get_adapters(request)
tokenizer = self.renderer.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(
self.max_model_len, truncate_prompt_tokens, tokenization_kwargs
)
trace_headers = (
None
@@ -322,24 +315,20 @@ class ServingScores(OpenAIServing):
if self.model_config.is_cross_encoder:
return await self._cross_encoding_score(
tokenizer=tokenizer,
data_1=data_1, # type: ignore[arg-type]
data_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers,
)
else:
return await self._embedding_score(
tokenizer=tokenizer,
data_1=data_1, # type: ignore[arg-type]
data_2=data_2, # type: ignore[arg-type]
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
trace_headers=trace_headers,
)

View File

@@ -1,411 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated
import pybase64
import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
@dataclass(frozen=True)
class RenderConfig:
"""Configuration to control how prompts are prepared."""
max_length: int | None = None
"""Maximum allowable total input token length. If provided,
token inputs longer than this raise `ValueError`."""
truncate_prompt_tokens: int | None = None
"""Number of tokens to keep. `None` means no truncation.
`0` yields an empty list (and skips embeds).
`-1` maps to `model_config.max_model_len`."""
add_special_tokens: bool = True
"""Whether to add model-specific special tokens during tokenization."""
cache_salt: str | None = None
"""String to disambiguate prefix cache entries."""
needs_detokenization: bool | None = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None:
"""Validate and normalize `truncate_prompt_tokens` parameter."""
truncate_prompt_tokens = self.truncate_prompt_tokens
if truncate_prompt_tokens is None or truncate_prompt_tokens == 0:
return truncate_prompt_tokens
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = model_config.max_model_len
max_length = self.max_length
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
raise ValueError(
f"{truncate_prompt_tokens=} cannot be greater than "
f"{max_length=}. Please select a smaller truncation size."
)
return truncate_prompt_tokens
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
The Renderer serves as a unified input processor that consolidates
tokenization, chat template formatting, and multimodal input handling
into a single component.
It converts high-level API requests (OpenAI-style JSON) into token IDs and
multimodal features ready for engine consumption.
Key responsibilities:
- Convert text prompts to token sequences with proper special tokens
- Apply chat templates and format conversations
- Handle multimodal inputs (images, audio, etc.) when applicable
- Manage prompt truncation and length validation
- Provide clean separation between API layer and engine core
"""
def __init__(
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None = None,
):
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
@abstractmethod
async def render_prompt(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig,
) -> list[TokensPrompt]:
"""
Convert text or token inputs into engine-ready TokensPrompt objects.
This method accepts text or token inputs and produces a
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
for the engine.
Args:
prompt_or_prompts: One of:
- `str`: Single text prompt.
- `list[str]`: Batch of text prompts.
- `list[int]`: Single pre-tokenized sequence.
- `list[list[int]]`: Batch of pre-tokenized sequences.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[TokensPrompt]: Engine-ready token prompts.
Raises:
ValueError: If input formats are invalid or length limits exceeded.
"""
raise NotImplementedError
@abstractmethod
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig,
) -> list[TokensPrompt | EmbedsPrompt]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects using a unified RenderConfig.
At least one of `prompt_or_prompts` or `prompt_embeds` must be
provided and non-empty. If both are omitted or empty (e.g., empty
string and empty list), a `ValueError` is raised.
Args:
prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[Union[TokensPrompt, EmbedsPrompt]]:
Engine-ready prompt objects.
Raises:
ValueError: If both `prompt_or_prompts` and `prompt_embeds`
are omitted or empty (decoder prompt cannot be empty), or if
length limits are exceeded.
"""
raise NotImplementedError
def load_prompt_embeds(
self,
prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None,
) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
)
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None,
):
super().__init__(model_config, tokenizer)
self.async_tokenizer_pool = async_tokenizer_pool
self.async_tokenizer: AsyncMicrobatchTokenizer | None = None
async def render_prompt(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig,
) -> list[TokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
if truncate_prompt_tokens == 0:
return []
tasks = (
self._create_prompt(
prompt_input,
config=config,
truncate_prompt_tokens=truncate_prompt_tokens,
)
for prompt_input in parse_raw_prompts(prompt_or_prompts)
)
return await asyncio.gather(*tasks)
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig,
) -> list[TokensPrompt | EmbedsPrompt]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
if truncate_prompt_tokens == 0:
return []
rendered: list[TokensPrompt | EmbedsPrompt] = []
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(
prompt_embeds, truncate_prompt_tokens, config.cache_salt
)
)
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
config=config,
)
rendered.extend(token_prompts)
return rendered
def _maybe_apply_truncation(
self, token_ids: list[int], truncate_prompt_tokens: int | None
) -> list[int]:
"""Apply truncation to token sequence."""
if truncate_prompt_tokens is None:
return token_ids
if truncate_prompt_tokens >= len(token_ids):
return token_ids
return token_ids[-truncate_prompt_tokens:]
async def _create_prompt(
self,
prompt_input: TextPrompt | TokensPrompt,
config: RenderConfig,
truncate_prompt_tokens: int | None,
) -> TokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None:
# NOTE: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
return await self._create_prompt_from_token_ids(
prompt_token_ids,
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization,
)
if prompt is not None:
return await self._create_prompt_from_text(
prompt,
config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt,
)
# TODO: Also handle embeds prompt using this method
raise NotImplementedError
async def _create_prompt_from_text(
self,
text: str,
max_length: int | None,
truncate_prompt_tokens: int | None,
add_special_tokens: bool,
cache_salt: str | None,
) -> TokensPrompt:
"""Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer()
# Handle encoder-specific preprocessing
if (
self.model_config.encoder_config is not None
and self.model_config.encoder_config.get("do_lower_case", False)
):
text = text.lower()
# Tokenize texts
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens)
else:
encoded = await async_tokenizer(
text,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens,
)
return self._create_tokens_prompt(
encoded.input_ids, max_length, cache_salt, text
)
async def _create_prompt_from_token_ids(
self,
token_ids: list[int],
max_length: int | None,
truncate_prompt_tokens: int | None,
cache_salt: str | None,
needs_detokenization: bool | None = False,
) -> TokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
prompt = None
if needs_detokenization:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)
return self._create_tokens_prompt(
token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt,
)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
async_tokenizer = self.async_tokenizer
if async_tokenizer is not None:
return async_tokenizer
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
else:
async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self.async_tokenizer_pool[tokenizer] = async_tokenizer
self.async_tokenizer = async_tokenizer
return async_tokenizer
def _create_tokens_prompt(
self,
token_ids: list[int],
max_length: int | None = None,
cache_salt: str | None = None,
prompt: str | None = None,
) -> TokensPrompt:
"""Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
raise VLLMValidationError(
f"This model's maximum context length is {max_length} tokens. "
f"However, your request has {len(token_ids)} input tokens. "
"Please reduce the length of the input messages.",
parameter="input_tokens",
value=len(token_ids),
)
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
if prompt is not None:
tokens_prompt["prompt"] = prompt
return tokens_prompt

View File

@@ -4,12 +4,14 @@ from typing import Any
from pydantic import BaseModel, Field
from vllm.config import ModelConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
from vllm.entrypoints.openai.engine.protocol import (
SamplingParams,
StreamOptions,
)
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
@@ -62,6 +64,12 @@ class GenerateRequest(BaseModel):
description="KVTransfer parameters used for disaggregated serving.",
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
)
class GenerateResponseChoice(BaseModel):
index: int

View File

@@ -101,12 +101,13 @@ class ServingTokens(OpenAIServing):
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.token_ids,
prompt_embeds=None,
)
assert len(engine_prompts) == 1
engine_prompt = engine_prompts[0]
# Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None
@@ -128,11 +129,15 @@ class ServingTokens(OpenAIServing):
else await self._get_trace_headers(raw_request.headers)
)
tok_params = request.build_tok_params(self.model_config)
tokenization_kwargs = tok_params.get_encode_kwargs()
result_generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=request.priority,
)

View File

@@ -6,8 +6,10 @@ from typing import Any, TypeAlias
from pydantic import ConfigDict, Field, model_validator
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionToolsParam,
@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
)
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
class TokenizeCompletionRequest(OpenAIBaseModel):
@@ -35,6 +38,13 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
add_special_tokens=self.add_special_tokens,
)
class TokenizeChatRequest(OpenAIBaseModel):
model: str | None = None
@@ -109,6 +119,30 @@ class TokenizeChatRequest(OpenAIBaseModel):
)
return data
def build_chat_params(
self,
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs(
self.chat_template_kwargs,
dict(
add_generation_prompt=self.add_generation_prompt,
continue_final_message=self.continue_final_message,
),
),
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
add_special_tokens=self.add_special_tokens,
)
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
@@ -124,6 +158,13 @@ class DetokenizeRequest(OpenAIBaseModel):
model: str | None = None
tokens: list[int]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
return TokenizeParams(
max_total_tokens=None,
max_output_tokens=0,
needs_detokenization=True,
)
class DetokenizeResponse(OpenAIBaseModel):
prompt: str

View File

@@ -9,12 +9,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
)
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
DetokenizeResponse,
@@ -83,21 +80,17 @@ class OpenAIServingTokenization(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
self.renderer,
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
chat_template_kwargs=request.chat_template_kwargs,
add_special_tokens=request.add_special_tokens,
)
else:
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
config=self._build_render_config(request),
engine_prompts = await self._preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=None,
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
@@ -106,11 +99,14 @@ class OpenAIServingTokenization(OpenAIServing):
input_ids: list[int] = []
for engine_prompt in engine_prompts:
self._log_inputs(
request_id, engine_prompt, params=None, lora_request=lora_request
request_id,
engine_prompt,
params=None,
lora_request=lora_request,
)
if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"])
if "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
token_strs = None
if request.return_token_strs:
@@ -136,7 +132,6 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokenize-{self._base_request_id(raw_request)}"
lora_request = self._maybe_get_adapters(request)
tokenizer = self.renderer.get_tokenizer()
self._log_inputs(
request_id,
@@ -145,14 +140,13 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request,
)
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,
engine_prompt = await self.renderer.tokenize_prompt_async(
TokensPrompt(prompt_token_ids=request.tokens),
request.build_tok_params(self.model_config),
)
input_text = prompt_input["prompt"]
prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item]
return DetokenizeResponse(prompt=input_text)
return DetokenizeResponse(prompt=prompt_text)
async def get_tokenizer_info(
self,
@@ -165,9 +159,6 @@ class OpenAIServingTokenization(OpenAIServing):
except Exception as e:
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass
class TokenizerInfo:

View File

@@ -8,7 +8,7 @@ import os
from argparse import Namespace
from logging import Logger
from string import Template
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
import regex as re
from fastapi import Request
@@ -18,9 +18,9 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
@@ -34,9 +34,7 @@ if TYPE_CHECKING:
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
else:
ChatCompletionRequest = object
CompletionRequest = object
@@ -188,33 +186,6 @@ def cli_env_setup():
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def _validate_truncation_size(
max_model_len: int,
truncate_prompt_tokens: int | None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> int | None:
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens <= -1:
truncate_prompt_tokens = max_model_len
if truncate_prompt_tokens > max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({max_model_len})."
f" Please, select a smaller truncation size."
)
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
else:
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = False
return truncate_prompt_tokens
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
@@ -233,10 +204,7 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)

View File

@@ -21,12 +21,7 @@ else:
MultiModalUUIDDict = object
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
class _CommonKeys(TypedDict):
multi_modal_data: NotRequired[MultiModalDataDict | None]
"""
Optional multi-modal data to pass to the model,
@@ -56,7 +51,14 @@ class TextPrompt(TypedDict):
"""
class TokensPrompt(TypedDict):
class TextPrompt(_CommonKeys):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
class TokensPrompt(_CommonKeys):
"""Schema for a tokenized prompt."""
prompt_token_ids: list[int]
@@ -68,47 +70,15 @@ class TokensPrompt(TypedDict):
token_type_ids: NotRequired[list[int]]
"""A list of token type IDs to pass to the cross encoder model."""
multi_modal_data: NotRequired[MultiModalDataDict | None]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
mm_processor_kwargs: NotRequired[dict[str, Any] | None]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired[MultiModalUUIDDict]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class EmbedsPrompt(TypedDict):
class EmbedsPrompt(_CommonKeys):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class DataPrompt(TypedDict):
class DataPrompt(_CommonKeys):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
@@ -197,7 +167,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
mm_processor_kwargs: NotRequired[dict[str, Any]]
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:

View File

@@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
from typing_extensions import TypeIs
from vllm.utils.collection_utils import is_list_of
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .data import (
EmbedsPrompt,
@@ -22,50 +21,6 @@ if TYPE_CHECKING:
import torch
def parse_raw_prompts(
prompt: str | list[str] | list[int] | list[list[int]],
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
if isinstance(prompt, str):
# case 1: a string
return [TextPrompt(prompt=prompt)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
# case 2: array of strings
if is_list_of(prompt, str):
prompt = cast(list[str], prompt)
return [TextPrompt(prompt=elem) for elem in prompt]
# case 3: array of tokens
if is_list_of(prompt, int):
prompt = cast(list[int], prompt)
return [TokensPrompt(prompt_token_ids=prompt)]
# case 4: array of token arrays
if is_list_of(prompt, list):
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
for elem in prompt:
if not isinstance(elem, list):
raise TypeError(
"prompt must be a list of lists, but found a non-list element."
)
if not is_list_of(elem, int):
raise TypeError(
"Nested lists of tokens must contain only integers."
)
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError(
"prompt must be a string, array of strings, "
"array of tokens, or array of token arrays"
)
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
@@ -145,3 +100,10 @@ def get_prompt_components(prompt: PromptType) -> PromptComponents:
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
return length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)

View File

@@ -209,6 +209,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo
item_processor_data = dict(**mm_data, audios=audios)
# some tokenizer kwargs are incompatible with UltravoxProcessor
tok_kwargs.pop("add_special_tokens", None)
tok_kwargs.pop("padding", None)
tok_kwargs.pop("truncation", None)

View File

@@ -1,7 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .params import ChatParams, TokenizeParams, merge_kwargs
from .protocol import RendererLike
from .registry import RendererRegistry, renderer_from_config
__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"]
__all__ = [
"RendererLike",
"RendererRegistry",
"renderer_from_config",
"ChatParams",
"TokenizeParams",
"merge_kwargs",
]

View File

@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
@@ -61,8 +62,8 @@ class DeepseekV32Renderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
@@ -74,26 +75,22 @@ class DeepseekV32Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
@@ -105,17 +102,13 @@ class DeepseekV32Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt

View File

@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from io import BytesIO
from typing import TYPE_CHECKING
import pybase64
import torch
from vllm.exceptions import VLLMValidationError
if TYPE_CHECKING:
from vllm.config import ModelConfig
def safe_load_prompt_embeds(
model_config: "ModelConfig",
embed: bytes,
) -> torch.Tensor:
if not model_config.enable_prompt_embeds:
raise VLLMValidationError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
parameter="prompt_embeds",
)
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
return tensor

View File

@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
@@ -61,8 +62,8 @@ class Grok2Renderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
@@ -74,26 +75,22 @@ class Grok2Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
@@ -105,17 +102,13 @@ class Grok2Renderer(RendererLike):
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt

View File

@@ -25,7 +25,7 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
@@ -33,6 +33,7 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .params import ChatParams
from .protocol import RendererLike
if TYPE_CHECKING:
@@ -632,9 +633,8 @@ class HfRenderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
@@ -642,9 +642,9 @@ class HfRenderer(RendererLike):
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=kwargs.get("chat_template"),
tools=kwargs.get("tools"),
given_format=chat_template_content_format,
chat_template=params.chat_template,
tools=params.chat_template_kwargs.get("tools"),
given_format=params.chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
@@ -654,7 +654,7 @@ class HfRenderer(RendererLike):
model_config,
tokenizer,
conversation,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
@@ -666,7 +666,7 @@ class HfRenderer(RendererLike):
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placehoder, replace it with runtime video-chunk prompts
# get video placeholder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
@@ -676,24 +676,19 @@ class HfRenderer(RendererLike):
video_placeholder,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
@@ -701,9 +696,9 @@ class HfRenderer(RendererLike):
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=kwargs.get("chat_template"),
tools=kwargs.get("tools"),
given_format=chat_template_content_format,
chat_template=params.chat_template,
tools=params.chat_template_kwargs.get("tools"),
given_format=params.chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
@@ -713,7 +708,7 @@ class HfRenderer(RendererLike):
model_config,
tokenizer,
conversation,
**kwargs,
**params.get_apply_chat_template_kwargs(),
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
@@ -723,9 +718,7 @@ class HfRenderer(RendererLike):
and mm_uuids is not None
and mm_data is not None
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placehoder, replace it with runtime video-chunk prompts
# get video placeholder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
@@ -735,14 +728,10 @@ class HfRenderer(RendererLike):
video_placeholder,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt

View File

@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
@@ -95,8 +96,8 @@ class MistralRenderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
@@ -104,25 +105,25 @@ class MistralRenderer(RendererLike):
content_format="string",
)
prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
prompt_raw = safe_apply_chat_template(
tokenizer,
messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
@@ -131,17 +132,15 @@ class MistralRenderer(RendererLike):
)
prompt_raw = await self._apply_chat_template_async(
tokenizer, messages, **kwargs
tokenizer,
messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
prompt = self.render_completion(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
return conversation, prompt

351
vllm/renderers/params.py Normal file
View File

@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
_S = TypeVar("_S", list[int], "torch.Tensor")
def merge_kwargs(
defaults: dict[str, Any] | None,
overrides: dict[str, Any] | None,
/,
*,
unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
if defaults is None:
defaults = {}
if overrides is None:
overrides = {}
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
@dataclass(frozen=True)
class ChatParams:
"""Configuration to control how to parse chat messages."""
chat_template: str | None = None
"""The chat template to apply."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format of the chat template."""
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
"""The kwargs to pass to the chat template."""
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
if not default_chat_template_kwargs:
return self
return ChatParams(
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
chat_template_kwargs=merge_kwargs(
default_chat_template_kwargs,
self.chat_template_kwargs,
),
)
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.apply_chat_template`."""
return merge_kwargs(
self.chat_template_kwargs,
dict(chat_template=self.chat_template),
)
@dataclass(frozen=True)
class TokenizeParams:
"""Configuration to control how prompts are tokenized."""
max_total_tokens: int | None
"""
Maximum allowed number of input + output tokens.
Usually, this refers to the model's context length.
"""
max_output_tokens: int = 0
"""Maximum requested number of output tokens."""
pad_prompt_tokens: int | None = None
"""
Number of tokens to pad to:
- `None` means no padding.
- `-1` maps to `max_input_tokens`.
"""
truncate_prompt_tokens: int | None = None
"""
Number of tokens to keep:
- `None` means no truncation.
- `-1` maps to `max_input_tokens`.
"""
do_lower_case: bool = False
"""Whether to normalize text to lower case before tokenization."""
add_special_tokens: bool = True
"""Whether to add special tokens."""
needs_detokenization: bool = False
"""
Whether the tokenized prompt needs to contain the original text.
Not to be confused with `SamplingParams.detokenize` which deals
with the output generated by the model.
"""
max_total_tokens_param: str = "max_total_tokens"
"""Override this to edit the message for validation errors."""
max_output_tokens_param: str = "max_output_tokens"
"""Override this to edit the message for validation errors."""
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
"""Override this to edit the message for validation errors."""
@property
def max_input_tokens(self) -> int | None:
"""Maximum allowed number of input tokens."""
if self.max_total_tokens is None:
return None
return self.max_total_tokens - self.max_output_tokens
def __post_init__(self) -> None:
max_total_tokens = self.max_total_tokens
max_output_tokens = self.max_output_tokens
max_input_tokens = self.max_input_tokens
truncate_prompt_tokens = self.truncate_prompt_tokens
if (
max_output_tokens is not None
and max_total_tokens is not None
and max_output_tokens > max_total_tokens
):
raise VLLMValidationError(
f"{self.max_output_tokens_param}={max_output_tokens}"
f"cannot be greater than "
f"{self.max_total_tokens_param}={max_total_tokens=}. "
f"Please request fewer output tokens.",
parameter=self.max_output_tokens_param,
value=max_output_tokens,
)
if (
max_input_tokens is not None
and truncate_prompt_tokens is not None
and truncate_prompt_tokens > max_input_tokens
):
raise VLLMValidationError(
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
f"cannot be greater than {self.max_total_tokens_param} - "
f"{self.max_output_tokens_param} = {max_input_tokens}. "
f"Please request a smaller truncation size.",
parameter=self.truncate_prompt_tokens_param,
value=truncate_prompt_tokens,
)
def with_kwargs(self, tokenization_kwargs: dict[str, Any] | None):
if tokenization_kwargs is None:
tokenization_kwargs = {}
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
pad_prompt_tokens = tokenization_kwargs.pop(
"pad_prompt_tokens", self.pad_prompt_tokens
)
truncate_prompt_tokens = tokenization_kwargs.pop(
"truncate_prompt_tokens", self.truncate_prompt_tokens
)
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
add_special_tokens = tokenization_kwargs.pop(
"add_special_tokens", self.add_special_tokens
)
needs_detokenization = tokenization_kwargs.pop(
"needs_detokenization", self.needs_detokenization
)
# https://huggingface.co/docs/transformers/en/pad_truncation
if padding := tokenization_kwargs.pop("padding", None):
if padding == "max_length":
pad_prompt_tokens = max_length
elif padding in (False, "do_not_pad"):
pad_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["padding"] = padding
if truncation := tokenization_kwargs.pop("truncation", None):
if truncation in (True, "longest_first"):
truncate_prompt_tokens = max_length
elif truncation in (False, "do_not_truncate"):
truncate_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["truncation"] = truncation
if tokenization_kwargs:
logger.warning(
"The following tokenization arguments are not supported "
"by vLLM Renderer and will be ignored: %s",
tokenization_kwargs,
)
max_total_tokens = self.max_total_tokens
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=(
0
if max_total_tokens is None or max_length is None
else max_total_tokens - max_length
),
pad_prompt_tokens=pad_prompt_tokens,
truncate_prompt_tokens=truncate_prompt_tokens,
do_lower_case=do_lower_case,
add_special_tokens=add_special_tokens,
needs_detokenization=needs_detokenization,
)
def get_encode_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.encode`."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
return dict(
truncation=self.truncate_prompt_tokens is not None,
max_length=max_length,
add_special_tokens=self.add_special_tokens,
)
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
if self.do_lower_case:
text = text.lower()
return text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text."""
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366
for validator in (self._apply_lowercase,):
text = validator(tokenizer, text)
return text
def apply_pre_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TextPrompt,
) -> TextPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run before tokenization occurs.
"""
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
return prompt
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to a token sequence."""
pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens
if pad_length is None or pad_length <= len(tokens):
return tokens
if tokenizer is None:
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
if not isinstance(tokens, list):
raise ValueError("Cannot pad tokens for embedding inputs")
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to a token sequence."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
if max_length is None or max_length >= len(tokens):
return tokens
if max_length == 0:
return tokens[:0]
if getattr(tokenizer, "truncation_side", "left") == "left":
return tokens[-max_length:]
return tokens[:max_length]
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to a token sequence."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is not None and len(tokens) > max_input_tokens:
raise VLLMValidationError(
f"You passed {len(tokens)} input tokens and "
f"requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens}, resulting in a maximum "
f"input length of {max_input_tokens}. "
f"Please reduce the length of the input messages.",
parameter="input_tokens",
value=len(tokens),
)
return tokens
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence."""
for validator in (
self._apply_padding,
self._apply_truncation,
self._apply_length_check,
):
tokens = validator(tokenizer, tokens)
return tokens
def apply_post_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TokensPrompt | EmbedsPrompt,
) -> TokensPrompt | EmbedsPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run after tokenization occurs.
"""
if "prompt_token_ids" in prompt:
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
)
if "prompt_embeds" in prompt:
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_embeds"], # type: ignore[typeddict-item]
)
return prompt

View File

@@ -1,9 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from typing import TYPE_CHECKING, Any, Protocol
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.collection_utils import is_list_of
from .embed_utils import safe_load_prompt_embeds
from .params import ChatParams, TokenizeParams
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -14,6 +20,9 @@ if TYPE_CHECKING:
class RendererLike(Protocol):
config: "ModelConfig"
_async_tokenizer: AsyncMicrobatchTokenizer
@classmethod
def from_config(
cls,
@@ -33,16 +42,147 @@ class RendererLike(Protocol):
return tokenizer
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
# Lazy initialization since offline LLM doesn't use async
if not hasattr(self, "_async_tokenizer"):
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
return self._async_tokenizer
# Step 1: Convert raw inputs to prompts
def render_completion(
self,
prompt_raw: str | list[int] | bytes,
) -> TextPrompt | TokensPrompt | EmbedsPrompt:
error_msg = "Each prompt must be a string or an array of tokens"
if isinstance(prompt_raw, str):
return TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, list):
if not is_list_of(prompt_raw, int):
raise TypeError(error_msg)
return TokensPrompt(prompt_token_ids=prompt_raw)
if isinstance(prompt_raw, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
return EmbedsPrompt(prompt_embeds=embeds)
raise TypeError(error_msg)
def render_completions(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
prompts_raw = list[str | list[int] | bytes]()
if prompt_embeds is not None: # embeds take higher priority
if isinstance(prompt_embeds, bytes):
prompts_raw.append(prompt_embeds)
else:
prompts_raw.extend(prompt_embeds)
if prompt_input is not None:
if isinstance(prompt_input, str) or (
len(prompt_input) > 0 and is_list_of(prompt_input, int)
):
prompts_raw.append(prompt_input) # type: ignore[arg-type]
else:
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
if len(prompts_raw) == 0:
raise ValueError("You must pass at least one prompt")
return [self.render_completion(prompt) for prompt in prompts_raw]
async def render_completions_async(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_completions(prompt_input, prompt_embeds)
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
return self.render_messages(messages, **kwargs)
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_tokenizer()
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def tokenize_prompts(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_async_tokenizer()
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
async def tokenize_prompts_async(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
)

View File

@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .params import ChatParams
from .protocol import RendererLike
logger = init_logger(__name__)
@@ -45,8 +46,8 @@ class TerratorchRenderer(RendererLike):
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages(
@@ -55,7 +56,7 @@ class TerratorchRenderer(RendererLike):
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1])
prompt = self.render_completion([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
@@ -66,8 +67,8 @@ class TerratorchRenderer(RendererLike):
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
@@ -76,7 +77,7 @@ class TerratorchRenderer(RendererLike):
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs
prompt = self.render_completion([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:

View File

@@ -834,7 +834,7 @@ def parse_pooling_type(pooling_name: str):
@cache
def get_sentence_transformer_tokenizer_config(
model: str | Path, revision: str | None = "main"
):
) -> dict[str, Any] | None:
"""
Returns the tokenization configuration dictionary for a
given Sentence Transformer BERT model.

View File

@@ -50,14 +50,17 @@ class AsyncMicrobatchTokenizer:
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
async def __call__(self, prompt, **kwargs) -> BatchEncoding:
result_future: Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
async def encode(self, prompt, **kwargs) -> list[int]:
return (await self(prompt, **kwargs)).input_ids
async def decode(self, token_ids, **kwargs) -> str:
result_future: Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)

View File

@@ -16,7 +16,6 @@ from vllm import TokensPrompt
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.inputs.data import StreamingInput
from vllm.logger import init_logger
@@ -25,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.renderers import RendererLike, merge_kwargs
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
@@ -304,13 +303,20 @@ class AsyncLLM(EngineClient):
"prompt logprobs"
)
if tokenization_kwargs is None:
tokenization_kwargs = {}
_validate_truncation_size(
self.model_config.max_model_len,
params.truncate_prompt_tokens,
tokenization_kwargs,
)
if params.truncate_prompt_tokens is not None:
params_type = type(params).__name__
warnings.warn(
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
)
if isinstance(prompt, AsyncGenerator):
# Streaming input case.
@@ -344,12 +350,12 @@ class AsyncLLM(EngineClient):
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
arrival_time=arrival_time,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
prompt_text = get_prompt_text(prompt)
@@ -757,7 +763,6 @@ class AsyncLLM(EngineClient):
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""
@@ -772,22 +777,10 @@ class AsyncLLM(EngineClient):
The caller of generate() iterates the returned AsyncGenerator,
returning the RequestOutput back to the caller.
NOTE: truncate_prompt_tokens is deprecated in v0.14.
TODO: Remove truncate_prompt_tokens in v0.15.
"""
q: RequestOutputCollector | None = None
try:
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
"is deprecated and will be removed in v0.15. "
"Please use `pooling_params.truncate_prompt_tokens` instead.",
DeprecationWarning,
stacklevel=2,
)
q = await self.add_request(
request_id,
prompt,