[Frontend] Use new Renderer for Completions and Tokenize API (#32863)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -60,9 +60,7 @@ def main():
|
||||
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
# NOTE: The OpenAI client does not allow `None` as an input to
|
||||
# `prompt`. Use an empty string if you have no text prompts.
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# NOTE: The OpenAI client allows passing in extra JSON body via the
|
||||
|
||||
@@ -22,7 +22,11 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
with pytest.raises(ValueError, match="longer than the maximum model length"):
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
# LLaVA has a feature size of 576
|
||||
# For the HF processor to execute successfully but still
|
||||
# failing the overall context length check, we need the
|
||||
# max_model_len to at least contain all image tokens
|
||||
max_model_len=579,
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
@@ -205,7 +205,7 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
|
||||
valid_msg,
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
with pytest.raises(ValueError, match="longer than the maximum model length"):
|
||||
with pytest.raises(ValueError, match="context length is only"):
|
||||
llm.chat(batch_1, sampling_params=sampling_params)
|
||||
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
|
||||
assert len(outputs_2) == len(batch_2)
|
||||
|
||||
@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -57,6 +58,15 @@ class MockModelConfig:
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
@@ -71,18 +81,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
chat_template_content_format="auto",
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
async def _fake_preprocess_chat(*args, **kwargs):
|
||||
# return conversation, engine_prompts
|
||||
return (
|
||||
@@ -90,7 +88,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
[{"prompt_token_ids": [1, 2, 3]}],
|
||||
)
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
||||
return serving_chat
|
||||
|
||||
@@ -99,11 +96,11 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
async def test_chat_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
@@ -153,11 +150,11 @@ async def test_chat_error_non_stream():
|
||||
async def test_chat_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -61,37 +62,31 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
)
|
||||
serving_completion = OpenAIServingCompletion(
|
||||
return OpenAIServingCompletion(
|
||||
engine,
|
||||
models,
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_completion
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_error_non_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
@@ -141,11 +136,11 @@ async def test_completion_error_non_stream():
|
||||
async def test_completion_error_stream():
|
||||
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_completion = _build_serving_completion(mock_engine)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: Single prompt embeds input
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
@@ -121,7 +121,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: batch completion with prompt_embeds
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]},
|
||||
@@ -133,7 +133,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: streaming with prompt_embeds
|
||||
single_completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
@@ -142,7 +142,7 @@ async def test_completions_with_prompt_embeds(
|
||||
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
@@ -162,7 +162,7 @@ async def test_completions_with_prompt_embeds(
|
||||
# Test case: batch streaming with prompt_embeds
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
@@ -197,7 +197,7 @@ async def test_completions_with_prompt_embeds(
|
||||
)
|
||||
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds},
|
||||
@@ -215,7 +215,7 @@ async def test_completions_errors_with_prompt_embeds(
|
||||
# Test error case: invalid prompt_embeds
|
||||
with pytest.raises(BadRequestError):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
prompt="",
|
||||
prompt=None,
|
||||
model=model_name,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
@@ -237,7 +237,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
# Test case: Logprobs using prompt_embeds
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
@@ -257,7 +257,7 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
# Test case: Log probs with batch completion and prompt_embeds
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
@@ -287,7 +287,7 @@ async def test_prompt_logprobs_raises_error(
|
||||
with pytest.raises(BadRequestError, match="not compatible"):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds, "prompt_logprobs": True},
|
||||
|
||||
@@ -7,7 +7,7 @@ Tests verify that embeddings with correct ndim but incorrect hidden_size
|
||||
are rejected before they can cause crashes during model inference.
|
||||
|
||||
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
|
||||
classes, not by CompletionRenderer or MediaIO classes.
|
||||
classes, not by MediaIO classes.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -35,6 +36,7 @@ class MockModelConfig:
|
||||
"""Minimal mock ModelConfig for testing."""
|
||||
|
||||
model: str = MODEL_NAME
|
||||
runner_type = "generate"
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
tokenizer_mode: str = "auto"
|
||||
@@ -85,15 +87,21 @@ def register_mock_resolver():
|
||||
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_serving_setup():
|
||||
"""Provides a mocked engine and serving completion instance."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
|
||||
tokenizer = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.get_tokenizer = AsyncMock(return_value=tokenizer)
|
||||
|
||||
async def mock_add_lora_side_effect(lora_request: LoRARequest):
|
||||
"""Simulate engine behavior when adding LoRAs."""
|
||||
if lora_request.lora_name == "test-lora":
|
||||
@@ -118,6 +126,7 @@ def mock_serving_setup():
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
@@ -128,10 +137,6 @@ def mock_serving_setup():
|
||||
mock_engine, models, request_logger=None
|
||||
)
|
||||
|
||||
serving_completion._process_inputs = AsyncMock(
|
||||
return_value=(MagicMock(name="engine_request"), {})
|
||||
)
|
||||
|
||||
return mock_engine, serving_completion
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.renderers.embed_utils import safe_load_prompt_embeds
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@@ -30,7 +30,7 @@ async def test_empty_prompt():
|
||||
):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
prompt=None,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": []},
|
||||
@@ -63,7 +63,6 @@ def test_load_prompt_embeds(
|
||||
):
|
||||
model_config = Mock(spec=ModelConfig)
|
||||
model_config.enable_prompt_embeds = True
|
||||
renderer = CompletionRenderer(model_config, tokenizer=None)
|
||||
|
||||
# construct arbitrary tensors of various dtypes, layouts, and sizes.
|
||||
# We need to check against different layouts to make sure that if a user
|
||||
@@ -89,9 +88,7 @@ def test_load_prompt_embeds(
|
||||
buffer.seek(0)
|
||||
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
||||
|
||||
loaded_prompt_embeds = renderer.load_prompt_embeds(encoded_tensor)
|
||||
assert len(loaded_prompt_embeds) == 1
|
||||
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
|
||||
loaded_tensor = safe_load_prompt_embeds(model_config, encoded_tensor)
|
||||
assert loaded_tensor.device.type == "cpu"
|
||||
assert loaded_tensor.layout == torch.strided
|
||||
torch.testing.assert_close(
|
||||
@@ -105,7 +102,6 @@ def test_load_prompt_embeds(
|
||||
def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: int):
|
||||
model_config = Mock(spec=ModelConfig)
|
||||
model_config.enable_prompt_embeds = False
|
||||
renderer = CompletionRenderer(model_config, tokenizer=None)
|
||||
|
||||
tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
@@ -115,4 +111,4 @@ def test_disable_prompt_embeds(dtype: torch.dtype, seq_len: int, hidden_size: in
|
||||
encoded_tensor = pybase64.b64encode(buffer.getvalue())
|
||||
|
||||
with pytest.raises(ValueError, match="--enable-prompt-embeds"):
|
||||
renderer.load_prompt_embeds(encoded_tensor)
|
||||
safe_load_prompt_embeds(model_config, encoded_tensor)
|
||||
|
||||
@@ -556,19 +556,6 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
request_logger=None,
|
||||
)
|
||||
|
||||
async def _fake_process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
*,
|
||||
lora_request,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
):
|
||||
return dict(engine_prompt), {}
|
||||
|
||||
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||
return serving_chat
|
||||
|
||||
|
||||
@@ -784,7 +771,7 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
|
||||
|
||||
resp = await serving_chat.create_chat_completion(req)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert "max_tokens" in resp.error.message
|
||||
assert "context length is only" in resp.error.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -824,7 +811,7 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
|
||||
|
||||
resp = await serving_chat.create_chat_completion(req)
|
||||
assert isinstance(resp, ErrorResponse)
|
||||
assert "maximum context length" in resp.error.message
|
||||
assert "context length is only" in resp.error.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -890,6 +877,20 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
orig_render_chat_request = serving_chat.render_chat_request
|
||||
captured_prompts = []
|
||||
|
||||
async def render_chat_request(request):
|
||||
result = await orig_render_chat_request(request)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
conversation, engine_prompts = result
|
||||
captured_prompts.extend(engine_prompts)
|
||||
|
||||
return result
|
||||
|
||||
serving_chat.render_chat_request = render_chat_request
|
||||
|
||||
# Test cache_salt
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -899,15 +900,19 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
# By default, cache_salt in the engine prompt is not set
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
|
||||
assert "cache_salt" not in engine_prompt
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert "cache_salt" not in captured_prompts[0]
|
||||
|
||||
captured_prompts.clear()
|
||||
|
||||
# Test with certain cache_salt
|
||||
req.cache_salt = "test_salt"
|
||||
with suppress(Exception):
|
||||
await serving_chat.create_chat_completion(req)
|
||||
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
|
||||
assert engine_prompt.get("cache_salt") == "test_salt"
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert captured_prompts[0]["cache_salt"] == "test_salt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1007,11 +1012,11 @@ class TestServingChatWithHarmony:
|
||||
@pytest.fixture()
|
||||
def mock_engine(self) -> AsyncLLM:
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
return mock_engine
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -1618,11 +1623,11 @@ async def test_tool_choice_validation_without_parser():
|
||||
"""Test that tool_choice='required' or named tool without tool_parser
|
||||
returns an appropriate error message."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
models = OpenAIServingModels(
|
||||
engine_client=mock_engine,
|
||||
|
||||
@@ -67,20 +67,6 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
|
||||
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = 0
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input,
|
||||
"truncate_prompt_tokens": truncation_size,
|
||||
}
|
||||
|
||||
response = await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
||||
|
||||
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = max_model_len + 1
|
||||
|
||||
@@ -128,12 +128,10 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": []},
|
||||
)
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(classification_response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert "error" in error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
|
||||
@@ -247,7 +247,7 @@ class TestModel:
|
||||
},
|
||||
)
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in score_response.text
|
||||
assert "Please request a smaller truncation size." in score_response.text
|
||||
|
||||
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, Any]):
|
||||
queries = "What is the capital of France?"
|
||||
|
||||
@@ -1,325 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pybase64
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
|
||||
from vllm.inputs.data import is_embeds_prompt
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
max_model_len: int = 100
|
||||
encoder_config: dict | None = None
|
||||
enable_prompt_embeds: bool = True
|
||||
|
||||
|
||||
class MockTokenizerResult:
|
||||
def __init__(self, input_ids):
|
||||
self.input_ids = input_ids
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_config():
|
||||
return MockModelConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
tokenizer = MagicMock()
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_tokenizer():
|
||||
async_tokenizer = AsyncMock()
|
||||
return async_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def renderer(mock_model_config, mock_tokenizer):
|
||||
return CompletionRenderer(
|
||||
model_config=mock_model_config,
|
||||
tokenizer=mock_tokenizer,
|
||||
async_tokenizer_pool={},
|
||||
)
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
"""Test Category A: Basic Functionality Tests"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input(self, renderer):
|
||||
tokens = [101, 7592, 2088]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens, config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_list_input(self, renderer):
|
||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
|
||||
assert results[2]["prompt_token_ids"] == [103, 4567]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_input(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
mock_async_tokenizer.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_list_input(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
text_list_input = ["Hello world", "How are you?", "Good morning"]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=text_list_input, config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert result["prompt_token_ids"] == [101, 7592, 2088]
|
||||
assert mock_async_tokenizer.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_truncation(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult([101, 7592, 2088])
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
assert (
|
||||
"truncation" not in call_args.kwargs
|
||||
or call_args.kwargs["truncation"] is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||
[101, 7592, 2088]
|
||||
) # Truncated
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100, truncate_prompt_tokens=50),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
|
||||
# Test that negative truncation uses model's max_model_len
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||
[101, 7592, 2088]
|
||||
) # Truncated to max_model_len
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=200, truncate_prompt_tokens=-1),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_truncation_last_elements(self, renderer):
|
||||
# Test that token truncation keeps the last N elements
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=long_tokens,
|
||||
config=RenderConfig(max_length=100, truncate_prompt_tokens=5),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
|
||||
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_length_exceeded(self, renderer):
|
||||
long_tokens = list(range(150)) # Exceeds max_model_len=100
|
||||
|
||||
with pytest.raises(ValueError, match="maximum context length"):
|
||||
await renderer.render_prompt(
|
||||
prompt_or_prompts=long_tokens, config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tokenizer_for_text(self, mock_model_config):
|
||||
renderer_no_tokenizer = CompletionRenderer(
|
||||
model_config=mock_model_config, tokenizer=None, async_tokenizer_pool={}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No tokenizer available"):
|
||||
await renderer_no_tokenizer.render_prompt(
|
||||
prompt_or_prompts="Hello world", config=RenderConfig(max_length=100)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input_with_needs_detokenization(
|
||||
self, renderer, mock_async_tokenizer
|
||||
):
|
||||
# When needs_detokenization=True for token inputs, renderer should
|
||||
# use the async tokenizer to decode and include the original text
|
||||
# in the returned prompt object.
|
||||
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens,
|
||||
config=RenderConfig(needs_detokenization=True),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
assert results[0]["prompt"] == "decoded text"
|
||||
mock_async_tokenizer.decode.assert_awaited_once()
|
||||
|
||||
|
||||
class TestRenderEmbedPrompt:
|
||||
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
|
||||
"""Helper to create base64-encoded tensor bytes"""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
return pybase64.b64encode(buffer.read())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_prompt_embed(self, renderer):
|
||||
# Create a test tensor
|
||||
test_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(cache_salt="test_salt"),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert is_embeds_prompt(results[0])
|
||||
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
|
||||
assert results[0]["cache_salt"] == "test_salt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompt_embeds(self, renderer):
|
||||
# Create multiple test tensors
|
||||
test_tensors = [
|
||||
torch.randn(8, 512, dtype=torch.float32),
|
||||
torch.randn(12, 512, dtype=torch.float32),
|
||||
]
|
||||
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes_list,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
for i, result in enumerate(results):
|
||||
assert is_embeds_prompt(result)
|
||||
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_truncation(self, renderer):
|
||||
# Create tensor with more tokens than truncation limit
|
||||
test_tensor = torch.randn(20, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(truncate_prompt_tokens=10),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep last 10 tokens
|
||||
expected = test_tensor[-10:]
|
||||
assert torch.allclose(results[0]["prompt_embeds"], expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_different_dtypes(self, renderer):
|
||||
# Test different supported dtypes
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
test_tensor = torch.randn(5, 256, dtype=dtype)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_embeds"].dtype == dtype
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
|
||||
# Test tensor with batch dimension gets squeezed
|
||||
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should be squeezed to 2D
|
||||
assert results[0]["prompt_embeds"].shape == (10, 768)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
|
||||
# Set up text tokenization
|
||||
mock_async_tokenizer.return_value = MockTokenizerResult([101, 102, 103])
|
||||
renderer.async_tokenizer_pool[renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
# Create embed
|
||||
test_tensor = torch.randn(5, 256, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts="Hello world",
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
assert is_embeds_prompt(results[0])
|
||||
# Second should be tokens prompt
|
||||
assert "prompt_token_ids" in results[1]
|
||||
assert results[1]["prompt_token_ids"] == [101, 102, 103]
|
||||
@@ -96,7 +96,7 @@ def test_gemma_multimodal(
|
||||
dtype="bfloat16",
|
||||
) as vllm_model:
|
||||
llm = vllm_model.get_llm()
|
||||
prompts = llm.preprocess_chat(messages)
|
||||
prompts = llm._preprocess_chat([messages])
|
||||
|
||||
result = llm.classify(prompts)
|
||||
assert result[0].outputs.probs[0] > 0.95
|
||||
|
||||
@@ -29,7 +29,8 @@ def test_smaller_truncation_size(
|
||||
model_name, runner="pooling", max_model_len=max_model_len
|
||||
) as vllm_model:
|
||||
vllm_output = vllm_model.llm.embed(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens
|
||||
input_str,
|
||||
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
prompt_tokens = vllm_output[0].prompt_token_ids
|
||||
@@ -44,7 +45,8 @@ def test_max_truncation_size(vllm_runner, model_name=MODEL_NAME, input_str=input
|
||||
model_name, runner="pooling", max_model_len=max_model_len
|
||||
) as vllm_model:
|
||||
vllm_output = vllm_model.llm.embed(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens
|
||||
input_str,
|
||||
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
prompt_tokens = vllm_output[0].prompt_token_ids
|
||||
@@ -64,7 +66,8 @@ def test_bigger_truncation_size(
|
||||
) as vllm_model,
|
||||
):
|
||||
llm_output = vllm_model.llm.embed(
|
||||
input_str, truncate_prompt_tokens=truncate_prompt_tokens
|
||||
input_str,
|
||||
tokenization_kwargs=dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
assert (
|
||||
|
||||
@@ -187,7 +187,10 @@ def mteb_test_embed_models(
|
||||
head_dtype = model_config.head_dtype
|
||||
|
||||
# Test embedding_size, isnan and whether to use normalize
|
||||
vllm_outputs = vllm_model.embed(example_prompts, truncate_prompt_tokens=-1)
|
||||
vllm_outputs = vllm_model.embed(
|
||||
example_prompts,
|
||||
tokenization_kwargs=dict(truncate_prompt_tokens=-1),
|
||||
)
|
||||
outputs_tensor = torch.tensor(vllm_outputs)
|
||||
assert not torch.any(torch.isnan(outputs_tensor))
|
||||
embedding_size = model_config.embedding_size
|
||||
|
||||
@@ -79,9 +79,9 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
|
||||
outputs = self.llm.score(
|
||||
queries,
|
||||
corpus,
|
||||
truncate_prompt_tokens=-1,
|
||||
use_tqdm=False,
|
||||
chat_template=self.chat_template,
|
||||
tokenization_kwargs={"truncate_prompt_tokens": -1},
|
||||
)
|
||||
scores = np.array(outputs)
|
||||
scores = scores[np.argsort(r)]
|
||||
|
||||
426
tests/renderers/test_completions.py
Normal file
426
tests/renderers/test_completions.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pybase64
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.inputs.data import is_embeds_prompt
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHFConfig:
|
||||
model_type: str = "any"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
runner_type = "generate"
|
||||
model: str = MODEL_NAME
|
||||
tokenizer: str = MODEL_NAME
|
||||
trust_remote_code: bool = False
|
||||
max_model_len: int = 100
|
||||
tokenizer_revision = None
|
||||
tokenizer_mode = "auto"
|
||||
hf_config = MockHFConfig()
|
||||
encoder_config: dict[str, Any] | None = None
|
||||
enable_prompt_embeds: bool = True
|
||||
skip_tokenizer_init: bool = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_config():
|
||||
return MockModelConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_tokenizer():
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def renderer(mock_model_config):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(mock_model_config)
|
||||
|
||||
return HfRenderer(
|
||||
mock_model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
class TestValidatePrompt:
|
||||
STRING_INPUTS = [
|
||||
"",
|
||||
"foo",
|
||||
"foo bar",
|
||||
"foo baz bar",
|
||||
"foo bar qux baz",
|
||||
]
|
||||
|
||||
TOKEN_INPUTS = [
|
||||
[-1],
|
||||
[1],
|
||||
[1, 2],
|
||||
[1, 3, 4],
|
||||
[1, 2, 4, 3],
|
||||
]
|
||||
|
||||
INPUTS_SLICES = [
|
||||
slice(None, None, -1),
|
||||
slice(None, None, 2),
|
||||
slice(None, None, -2),
|
||||
]
|
||||
|
||||
# Test that a nested mixed-type list of lists raises a TypeError.
|
||||
def test_empty_input(self, renderer):
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
renderer.render_completions([])
|
||||
|
||||
def test_invalid_type(self, renderer):
|
||||
with pytest.raises(TypeError, match="string or an array of tokens"):
|
||||
renderer.render_completions([[1, 2], ["foo", "bar"]])
|
||||
|
||||
@pytest.mark.parametrize("string_input", STRING_INPUTS)
|
||||
def test_string_consistent(self, renderer, string_input: str):
|
||||
assert renderer.render_completions(string_input) == renderer.render_completions(
|
||||
[string_input]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
|
||||
def test_token_consistent(self, renderer, token_input: list[int]):
|
||||
assert renderer.render_completions(token_input) == renderer.render_completions(
|
||||
[token_input]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
|
||||
def test_string_slice(self, renderer, inputs_slice: slice):
|
||||
assert renderer.render_completions(self.STRING_INPUTS)[
|
||||
inputs_slice
|
||||
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
|
||||
|
||||
|
||||
class TestRenderPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input(self, renderer):
|
||||
tokens = [101, 7592, 2088]
|
||||
prompts = await renderer.render_completions_async(tokens)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_list_input(self, renderer):
|
||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||
prompts = await renderer.render_completions_async(token_lists)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
assert results[1]["prompt_token_ids"] == [102, 1234, 5678, 9012]
|
||||
assert results[2]["prompt_token_ids"] == [103, 4567]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_input(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
mock_async_tokenizer.encode.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_list_input(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
text_list_input = ["Hello world", "How are you?", "Good morning"]
|
||||
prompts = await renderer.render_completions_async(text_list_input)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert result["prompt_token_ids"] == [101, 7592, 2088]
|
||||
assert mock_async_tokenizer.encode.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_truncation(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.encode.call_args
|
||||
assert (
|
||||
"truncation" not in call_args.kwargs
|
||||
or call_args.kwargs["truncation"] is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_positive(self, renderer, mock_async_tokenizer):
|
||||
mock_async_tokenizer.encode.return_value = [101, 7592, 2088] # Truncated
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=200,
|
||||
truncate_prompt_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.encode.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
|
||||
# Test that negative truncation uses model's max_model_len
|
||||
mock_async_tokenizer.encode.return_value = [
|
||||
101,
|
||||
7592,
|
||||
2088,
|
||||
] # Truncated to max_model_len
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
prompts = await renderer.render_completions_async("Hello world")
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=200,
|
||||
truncate_prompt_tokens=-1,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.encode.call_args
|
||||
assert call_args.kwargs["truncation"] is True
|
||||
assert call_args.kwargs["max_length"] == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_truncation_last_elements(self, renderer):
|
||||
# Test that token truncation keeps the last N elements
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
|
||||
prompts = await renderer.render_completions_async(long_tokens)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=100,
|
||||
truncate_prompt_tokens=5,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
|
||||
assert results[0]["prompt_token_ids"] == [105, 106, 107, 108, 109]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_length_exceeded(self, renderer):
|
||||
long_tokens = list(range(150)) # Exceeds max_model_len=100
|
||||
|
||||
prompts = await renderer.render_completions_async(long_tokens)
|
||||
|
||||
with pytest.raises(ValueError, match="context length is only"):
|
||||
await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tokenizer_for_text(self, renderer):
|
||||
renderer_no_tokenizer = HfRenderer.from_config(
|
||||
MockModelConfig(skip_tokenizer_init=True),
|
||||
tokenizer_kwargs={},
|
||||
)
|
||||
|
||||
prompts = await renderer_no_tokenizer.render_completions_async("Hello world")
|
||||
|
||||
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
|
||||
await renderer_no_tokenizer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=100),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input_with_needs_detokenization(
|
||||
self, renderer, mock_async_tokenizer
|
||||
):
|
||||
# When needs_detokenization=True for token inputs, renderer should
|
||||
# use the async tokenizer to decode and include the original text
|
||||
# in the returned prompt object.
|
||||
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
tokens = [1, 2, 3, 4]
|
||||
prompts = await renderer.render_completions_async(tokens)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=renderer.config.max_model_len,
|
||||
needs_detokenization=True,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
assert results[0]["prompt"] == "decoded text"
|
||||
mock_async_tokenizer.decode.assert_awaited_once()
|
||||
|
||||
|
||||
class TestRenderEmbedPrompt:
|
||||
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
|
||||
"""Helper to create base64-encoded tensor bytes"""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
return pybase64.b64encode(buffer.read())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_prompt_embed(self, renderer):
|
||||
# Create a test tensor
|
||||
test_tensor = torch.randn(10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert is_embeds_prompt(results[0])
|
||||
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_prompt_embeds(self, renderer):
|
||||
# Create multiple test tensors
|
||||
test_tensors = [
|
||||
torch.randn(8, 512, dtype=torch.float32),
|
||||
torch.randn(12, 512, dtype=torch.float32),
|
||||
]
|
||||
embed_bytes_list = [self._create_test_embed_bytes(t) for t in test_tensors]
|
||||
|
||||
prompts = await renderer.render_completions_async(
|
||||
prompt_embeds=embed_bytes_list
|
||||
)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
for i, result in enumerate(results):
|
||||
assert is_embeds_prompt(result)
|
||||
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_truncation(self, renderer):
|
||||
# Create tensor with more tokens than truncation limit
|
||||
test_tensor = torch.randn(20, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(
|
||||
max_total_tokens=renderer.config.max_model_len,
|
||||
truncate_prompt_tokens=10,
|
||||
),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep last 10 tokens
|
||||
expected = test_tensor[-10:]
|
||||
assert torch.allclose(results[0]["prompt_embeds"], expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_different_dtypes(self, renderer):
|
||||
# Test different supported dtypes
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
for dtype in dtypes:
|
||||
test_tensor = torch.randn(5, 256, dtype=dtype)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_embeds"].dtype == dtype
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
|
||||
# Test tensor with batch dimension gets squeezed
|
||||
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
prompts = await renderer.render_completions_async(prompt_embeds=embed_bytes)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should be squeezed to 2D
|
||||
assert results[0]["prompt_embeds"].shape == (10, 768)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_prompts_and_embeds(self, renderer, mock_async_tokenizer):
|
||||
# Set up text tokenization
|
||||
mock_async_tokenizer.encode.return_value = [101, 102, 103]
|
||||
renderer._async_tokenizer = mock_async_tokenizer
|
||||
|
||||
# Create embed
|
||||
test_tensor = torch.randn(5, 256, dtype=torch.float32)
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
prompts = await renderer.render_completions_async(
|
||||
"Hello world",
|
||||
prompt_embeds=embed_bytes,
|
||||
)
|
||||
results = await renderer.tokenize_prompts_async(
|
||||
prompts,
|
||||
TokenizeParams(max_total_tokens=renderer.config.max_model_len),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
assert is_embeds_prompt(results[0])
|
||||
# Second should be tokens prompt
|
||||
assert "prompt_token_ids" in results[1]
|
||||
assert results[1]["prompt_token_ids"] == [101, 102, 103]
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.renderers import ChatParams
|
||||
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
@@ -27,7 +28,7 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
|
||||
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
|
||||
mock_renderer._tokenizer = mock_tokenizer
|
||||
|
||||
task = mock_renderer.render_messages_async([])
|
||||
task = mock_renderer.render_messages_async([], ChatParams())
|
||||
|
||||
# Ensure the event loop is not blocked
|
||||
blocked_count = 0
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Sparse tensor validation in embedding APIs.
|
||||
|
||||
Tests verify that malicious sparse tensors are rejected before they can trigger
|
||||
out-of-bounds memory writes during to_dense() operations.
|
||||
"""
|
||||
@@ -13,8 +11,24 @@ import io
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.multimodal.media import AudioEmbeddingMediaIO, ImageEmbeddingMediaIO
|
||||
from vllm.renderers.embed_utils import safe_load_prompt_embeds
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
"""Mock ModelConfig for testing."""
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
return ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
enable_prompt_embeds=True, # Required for prompt embeds tests
|
||||
)
|
||||
|
||||
|
||||
def _encode_tensor(tensor: torch.Tensor) -> bytes:
|
||||
@@ -63,15 +77,12 @@ class TestPromptEmbedsValidation:
|
||||
|
||||
def test_valid_dense_tensor_accepted(self, model_config):
|
||||
"""Baseline: Valid dense tensors should work normally."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensor = _create_valid_dense_tensor()
|
||||
encoded = _encode_tensor(valid_tensor)
|
||||
|
||||
# Should not raise any exception
|
||||
result = renderer.load_prompt_embeds(encoded)
|
||||
assert len(result) == 1
|
||||
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
|
||||
result = safe_load_prompt_embeds(model_config, encoded)
|
||||
assert result.shape == valid_tensor.shape
|
||||
|
||||
def test_valid_sparse_tensor_accepted(self):
|
||||
"""Baseline: Valid sparse tensors should load successfully."""
|
||||
@@ -86,14 +97,12 @@ class TestPromptEmbedsValidation:
|
||||
|
||||
def test_malicious_sparse_tensor_rejected(self, model_config):
|
||||
"""Security: Malicious sparse tensors should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
malicious_tensor = _create_malicious_sparse_tensor()
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
# Should raise RuntimeError due to invalid sparse tensor
|
||||
with pytest.raises((RuntimeError, ValueError)) as exc_info:
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
safe_load_prompt_embeds(model_config, encoded)
|
||||
|
||||
# Error should indicate sparse tensor validation failure
|
||||
error_msg = str(exc_info.value).lower()
|
||||
@@ -101,8 +110,6 @@ class TestPromptEmbedsValidation:
|
||||
|
||||
def test_extremely_large_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with extremely large indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with indices far beyond reasonable bounds
|
||||
indices = torch.tensor([[999999], [999999]])
|
||||
values = torch.tensor([1.0])
|
||||
@@ -114,12 +121,10 @@ class TestPromptEmbedsValidation:
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
safe_load_prompt_embeds(model_config, encoded)
|
||||
|
||||
def test_negative_indices_rejected(self, model_config):
|
||||
"""Security: Sparse tensors with negative indices should be rejected."""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Create tensor with negative indices
|
||||
indices = torch.tensor([[-1], [-1]])
|
||||
values = torch.tensor([1.0])
|
||||
@@ -131,7 +136,7 @@ class TestPromptEmbedsValidation:
|
||||
encoded = _encode_tensor(malicious_tensor)
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(encoded)
|
||||
safe_load_prompt_embeds(model_config, encoded)
|
||||
|
||||
|
||||
class TestImageEmbedsValidation:
|
||||
@@ -253,14 +258,12 @@ class TestSparseTensorValidationIntegration:
|
||||
3. Sends to /v1/completions with prompt_embeds parameter
|
||||
4. Server should reject before memory corruption occurs
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
# Step 1-2: Attacker creates malicious payload
|
||||
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
|
||||
|
||||
# Step 3-4: Server processes and should reject
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(attack_payload)
|
||||
safe_load_prompt_embeds(model_config, attack_payload)
|
||||
|
||||
def test_attack_scenario_chat_api_image(self):
|
||||
"""
|
||||
@@ -285,57 +288,3 @@ class TestSparseTensorValidationIntegration:
|
||||
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_base64("", attack_payload.decode("utf-8"))
|
||||
|
||||
def test_multiple_valid_embeddings_in_batch(self, model_config):
|
||||
"""
|
||||
Regression test: Multiple valid embeddings should still work.
|
||||
|
||||
Ensures the fix doesn't break legitimate batch processing.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
valid_tensors = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should process all without error
|
||||
result = renderer.load_prompt_embeds(valid_tensors)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_mixed_valid_and_malicious_rejected(self, model_config):
|
||||
"""
|
||||
Security: Batch with one malicious tensor should be rejected.
|
||||
|
||||
Even if most tensors are valid, a single malicious one should
|
||||
cause rejection of the entire batch.
|
||||
"""
|
||||
renderer = CompletionRenderer(model_config)
|
||||
|
||||
mixed_batch = [
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
|
||||
_encode_tensor(_create_valid_dense_tensor()),
|
||||
]
|
||||
|
||||
# Should fail on the malicious tensor
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
renderer.load_prompt_embeds(mixed_batch)
|
||||
|
||||
|
||||
# Pytest fixtures
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
"""Mock ModelConfig for testing."""
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
return ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float32",
|
||||
seed=0,
|
||||
enable_prompt_embeds=True, # Required for prompt embeds tests
|
||||
)
|
||||
@@ -5,65 +5,10 @@ import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import zip_enc_dec_prompts
|
||||
from vllm.inputs.parse import parse_raw_prompts
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
STRING_INPUTS = [
|
||||
"",
|
||||
"foo",
|
||||
"foo bar",
|
||||
"foo baz bar",
|
||||
"foo bar qux baz",
|
||||
]
|
||||
|
||||
TOKEN_INPUTS = [
|
||||
[-1],
|
||||
[1],
|
||||
[1, 2],
|
||||
[1, 3, 4],
|
||||
[1, 2, 4, 3],
|
||||
]
|
||||
|
||||
INPUTS_SLICES = [
|
||||
slice(None, None, -1),
|
||||
slice(None, None, 2),
|
||||
slice(None, None, -2),
|
||||
]
|
||||
|
||||
|
||||
# Test that a nested mixed-type list of lists raises a TypeError.
|
||||
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
|
||||
def test_invalid_input_raise_type_error(invalid_input):
|
||||
with pytest.raises(TypeError):
|
||||
parse_raw_prompts(invalid_input)
|
||||
|
||||
|
||||
def test_parse_raw_single_batch_empty():
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_raw_prompts([])
|
||||
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_raw_prompts([[]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("string_input", STRING_INPUTS)
|
||||
def test_parse_raw_single_batch_string_consistent(string_input: str):
|
||||
assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
|
||||
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
|
||||
assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
|
||||
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
|
||||
assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
|
||||
STRING_INPUTS[inputs_slice]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs,expected_mm_kwargs",
|
||||
|
||||
@@ -768,7 +768,7 @@ class ModelConfig:
|
||||
)
|
||||
self.tokenizer = object_storage_tokenizer.dir
|
||||
|
||||
def _get_encoder_config(self):
|
||||
def _get_encoder_config(self) -> dict[str, Any] | None:
|
||||
model = self.model
|
||||
if is_remote_gguf(model):
|
||||
model, _ = split_remote_gguf(model)
|
||||
@@ -1918,7 +1918,7 @@ def _get_and_verify_max_len(
|
||||
disable_sliding_window: bool,
|
||||
sliding_window: int | None,
|
||||
spec_target_max_model_len: int | None = None,
|
||||
encoder_config: Any | None = None,
|
||||
encoder_config: dict[str, Any] | None = None,
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
(derived_max_model_len, max_len_key) = (
|
||||
|
||||
@@ -72,14 +72,9 @@ class EngineClient(ABC):
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from a pooling model.
|
||||
|
||||
NOTE: truncate_prompt_tokens is deprecated in v0.14.
|
||||
TODO: Remove this argument in v0.15.
|
||||
"""
|
||||
"""Generate outputs for a request from a pooling model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
@@ -46,15 +47,17 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
compress_token_type_ids,
|
||||
get_score_prompt,
|
||||
)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
|
||||
from vllm.entrypoints.utils import log_non_default_args
|
||||
from vllm.inputs import (
|
||||
DataPrompt,
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.inputs.parse import get_prompt_components
|
||||
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -67,6 +70,7 @@ from vllm.outputs import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -74,7 +78,6 @@ from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.collection_utils import as_iter, is_list_of
|
||||
from vllm.utils.counter import Counter
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
@@ -85,6 +88,9 @@ logger = init_logger(__name__)
|
||||
|
||||
_R = TypeVar("_R", default=Any)
|
||||
|
||||
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
|
||||
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
@@ -372,6 +378,7 @@ class LLM:
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@@ -398,15 +405,11 @@ class LLM:
|
||||
If provided, must be a list of integers matching the length
|
||||
of `prompts`, where each priority value corresponds to the prompt
|
||||
at the same index.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the
|
||||
generated completions in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
Using `prompts` and `prompt_token_ids` as keyword parameters is
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
@@ -418,17 +421,14 @@ class LLM:
|
||||
)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
# Add any modality specific loras to the corresponding prompts
|
||||
lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@@ -771,65 +771,169 @@ class LLM:
|
||||
|
||||
return outputs
|
||||
|
||||
def preprocess_chat(
|
||||
def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
|
||||
model_config = self.model_config
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
# For Whisper, special tokens should be provided by the user based
|
||||
# on the task and language of their request. Also needed to avoid
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
add_special_tokens=not model_config.is_encoder_decoder,
|
||||
).with_kwargs(tokenization_kwargs)
|
||||
|
||||
def _normalize_prompts(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
) -> list[EnginePrompt | EngineEncDecPrompt]:
|
||||
if isinstance(prompts, str):
|
||||
prompts = TextPrompt(prompt=prompts)
|
||||
|
||||
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
|
||||
|
||||
def _preprocess_cmpl_singleton(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
tok_params: TokenizeParams,
|
||||
*,
|
||||
tokenize: bool,
|
||||
) -> EnginePrompt:
|
||||
renderer = self.llm_engine.renderer
|
||||
|
||||
if not isinstance(prompt, dict):
|
||||
prompt = renderer.render_completion(prompt)
|
||||
|
||||
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
|
||||
|
||||
def _preprocess_cmpl_enc_dec(
|
||||
self,
|
||||
prompt: ExplicitEncoderDecoderPrompt,
|
||||
tok_params: TokenizeParams,
|
||||
) -> EngineEncDecPrompt:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
return EngineEncDecPrompt(
|
||||
encoder_prompt=self._preprocess_cmpl_singleton(
|
||||
enc_prompt,
|
||||
tok_params,
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
tokenize=not self.model_config.is_multimodal_model,
|
||||
),
|
||||
decoder_prompt=(
|
||||
None
|
||||
if dec_prompt is None
|
||||
else self._preprocess_cmpl_singleton(
|
||||
dec_prompt,
|
||||
tok_params,
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
tokenize=not self.model_config.is_multimodal_model,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def _preprocess_completion(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[EnginePrompt | EngineEncDecPrompt]:
|
||||
"""
|
||||
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
|
||||
a format that can be passed to `_add_request`.
|
||||
|
||||
Refer to [LLM.generate][] for a complete description of the arguments.
|
||||
|
||||
Returns:
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
|
||||
|
||||
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
|
||||
for prompt in self._normalize_prompts(prompts):
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
|
||||
else:
|
||||
# Some MM models have non-default `add_special_tokens`
|
||||
# TODO: Move multi-modal processor into tokenization
|
||||
engine_prompts.append(
|
||||
self._preprocess_cmpl_singleton(
|
||||
prompt,
|
||||
tok_params,
|
||||
tokenize=not self.model_config.is_multimodal_model,
|
||||
)
|
||||
)
|
||||
|
||||
return engine_prompts
|
||||
|
||||
def _normalize_conversations(
|
||||
self,
|
||||
conversations: list[ChatCompletionMessageParam]
|
||||
| list[list[ChatCompletionMessageParam]],
|
||||
) -> list[list[ChatCompletionMessageParam]]:
|
||||
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
|
||||
|
||||
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
|
||||
model_config = self.model_config
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=False,
|
||||
).with_kwargs(tokenization_kwargs)
|
||||
|
||||
def _preprocess_chat(
|
||||
self,
|
||||
conversations: list[ChatCompletionMessageParam]
|
||||
| list[list[ChatCompletionMessageParam]],
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[TextPrompt | TokensPrompt]:
|
||||
) -> list[EnginePrompt]:
|
||||
"""
|
||||
Generate prompt for a chat conversation. The pre-processed
|
||||
prompt can then be used as input for the other LLM methods.
|
||||
Convert a list of conversations into prompts so that they can then
|
||||
be used as input for other LLM APIs.
|
||||
|
||||
Refer to [LLM.chat][] for a complete description of the arguments.
|
||||
|
||||
Refer to `chat` for a complete description of the arguments.
|
||||
Returns:
|
||||
A list of `TokensPrompts` objects containing the tokenized
|
||||
prompt after chat template interpolation, and the
|
||||
pre-processed multi-modal inputs.
|
||||
A list of `TokensPrompts` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
"""
|
||||
list_of_messages: list[list[ChatCompletionMessageParam]]
|
||||
|
||||
# Handle multi and single conversations
|
||||
if is_list_of(messages, list):
|
||||
# messages is list[list[...]]
|
||||
list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
|
||||
else:
|
||||
# messages is list[...]
|
||||
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
|
||||
|
||||
renderer = self.llm_engine.renderer
|
||||
|
||||
chat_template_kwargs = {
|
||||
"chat_template": chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
"continue_final_message": continue_final_message,
|
||||
"tools": tools,
|
||||
**(chat_template_kwargs or {}),
|
||||
}
|
||||
chat_params = ChatParams(
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
|
||||
),
|
||||
),
|
||||
)
|
||||
tok_params = self._get_chat_tok_params(tokenization_kwargs)
|
||||
|
||||
prompts = list[TextPrompt | TokensPrompt]()
|
||||
|
||||
for msgs in list_of_messages:
|
||||
# NOTE: renderer.render_messages() currently doesn't
|
||||
# handle mm_processor_kwargs, since there is no implementation in
|
||||
# the chat message parsing for it.
|
||||
_, prompt = renderer.render_messages(
|
||||
msgs,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
engine_prompts = list[EnginePrompt]()
|
||||
for conversation in self._normalize_conversations(conversations):
|
||||
_, in_prompt = renderer.render_messages(conversation, chat_params)
|
||||
if mm_processor_kwargs is not None:
|
||||
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
prompts.append(prompt)
|
||||
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
|
||||
|
||||
return prompts
|
||||
return engine_prompts
|
||||
|
||||
def chat(
|
||||
self,
|
||||
@@ -844,6 +948,7 @@ class LLM:
|
||||
continue_final_message: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[RequestOutput]:
|
||||
"""
|
||||
@@ -889,22 +994,22 @@ class LLM:
|
||||
`True` if `add_generation_prompt` is also `True`.
|
||||
chat_template_kwargs: Additional kwargs to pass to the chat
|
||||
template.
|
||||
mm_processor_kwargs: Multimodal processor kwarg overrides for this
|
||||
chat request. Only used for offline requests.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
mm_processor_kwargs: Overrides for `processor.__call__`.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the generated
|
||||
responses in the same order as the input messages.
|
||||
"""
|
||||
|
||||
prompts = self.preprocess_chat(
|
||||
messages=messages,
|
||||
prompts = self._preprocess_chat(
|
||||
messages,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
@@ -913,6 +1018,7 @@ class LLM:
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
def encode(
|
||||
@@ -945,37 +1051,29 @@ class LLM:
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_task: Override the pooling task to use.
|
||||
tokenization_kwargs: overrides tokenization_kwargs set in
|
||||
pooling_params
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
|
||||
Returns:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
|
||||
Note:
|
||||
Using `prompts` and `prompt_token_ids` as keyword parameters is
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
|
||||
error_str = (
|
||||
"pooling_task required for `LLM.encode`\n"
|
||||
"Please use one of the more specific methods or set the "
|
||||
"pooling_task when using `LLM.encode`:\n"
|
||||
" - For embeddings, use `LLM.embed(...)` "
|
||||
'or `pooling_task="embed"`.\n'
|
||||
" - For classification logits, use `LLM.classify(...)` "
|
||||
'or `pooling_task="classify"`.\n'
|
||||
" - For similarity scores, use `LLM.score(...)`.\n"
|
||||
" - For rewards, use `LLM.reward(...)` "
|
||||
'or `pooling_task="token_classify"`\n'
|
||||
" - For token classification, "
|
||||
'use `pooling_task="token_classify"`\n'
|
||||
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
|
||||
)
|
||||
|
||||
if pooling_task is None:
|
||||
raise ValueError(error_str)
|
||||
raise ValueError(
|
||||
"pooling_task required for `LLM.encode`\n"
|
||||
"Please use one of the more specific methods or set the "
|
||||
"pooling_task when using `LLM.encode`:\n"
|
||||
" - For embeddings, use `LLM.embed(...)` "
|
||||
'or `pooling_task="embed"`.\n'
|
||||
" - For classification logits, use `LLM.classify(...)` "
|
||||
'or `pooling_task="classify"`.\n'
|
||||
" - For similarity scores, use `LLM.score(...)`.\n"
|
||||
" - For rewards, use `LLM.reward(...)` "
|
||||
'or `pooling_task="token_classify"`\n'
|
||||
" - For token classification, "
|
||||
'use `pooling_task="token_classify"`\n'
|
||||
' - For multi-vector retrieval, use `pooling_task="token_embed"`'
|
||||
)
|
||||
|
||||
model_config = self.model_config
|
||||
runner_type = model_config.runner_type
|
||||
@@ -986,6 +1084,20 @@ class LLM:
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
warnings.warn(
|
||||
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
|
||||
"is deprecated and will be removed in v0.16. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
io_processor_prompt = False
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
io_processor_prompt = True
|
||||
@@ -1017,19 +1129,16 @@ class LLM:
|
||||
pooling_params = self.io_processor.validate_or_generate_params(
|
||||
pooling_params
|
||||
)
|
||||
else:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
if pooling_task not in self.supported_tasks:
|
||||
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
|
||||
|
||||
for param in as_iter(pooling_params):
|
||||
param.verify(pooling_task, model_config)
|
||||
# for backwards compatibility
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
@@ -1094,6 +1203,7 @@ class LLM:
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
|
||||
Returns:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
@@ -1105,9 +1215,14 @@ class LLM:
|
||||
"Try converting the model using `--convert embed`."
|
||||
)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
items = self.encode(
|
||||
prompts,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
@@ -1121,8 +1236,8 @@ class LLM:
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[ClassificationRequestOutput]:
|
||||
@@ -1137,13 +1252,15 @@ class LLM:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See [PromptType][vllm.inputs.PromptType]
|
||||
for more details about the format of each prompt.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: If `True`, shows a tqdm progress bar.
|
||||
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
|
||||
Returns:
|
||||
A list of `ClassificationRequestOutput` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
@@ -1170,9 +1287,9 @@ class LLM:
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
/,
|
||||
*,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
@@ -1183,13 +1300,15 @@ class LLM:
|
||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||
for batch inference. See [PromptType][vllm.inputs.PromptType]
|
||||
for more details about the format of each prompt.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: If `True`, shows a tqdm progress bar.
|
||||
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
|
||||
Returns:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
@@ -1207,18 +1326,18 @@ class LLM:
|
||||
|
||||
def _embedding_score(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
text_1: list[str | TextPrompt | TokensPrompt],
|
||||
text_2: list[str | TextPrompt | TokensPrompt],
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
text_1: list[SingletonPrompt],
|
||||
text_2: list[SingletonPrompt],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
) -> list[ScoringRequestOutput]:
|
||||
encoded_output: list[PoolingRequestOutput] = self.encode(
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
encoded_output = self.encode(
|
||||
text_1 + text_2,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
@@ -1226,14 +1345,16 @@ class LLM:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
|
||||
encoded_output_1 = encoded_output[0 : len(text_1)]
|
||||
encoded_output_2 = encoded_output[len(text_1) :]
|
||||
|
||||
if len(encoded_output_1) == 1:
|
||||
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
|
||||
|
||||
scores = _cosine_similarity(
|
||||
tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
|
||||
tokenizer=tokenizer,
|
||||
embed_1=encoded_output_1,
|
||||
embed_2=encoded_output_2,
|
||||
)
|
||||
|
||||
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
|
||||
@@ -1241,17 +1362,17 @@ class LLM:
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
data_1: list[str] | list[ScoreContentPartParam],
|
||||
data_2: list[str] | list[ScoreContentPartParam],
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
score_template: str | None = None,
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
pooling_params: PoolingParams | None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any],
|
||||
score_template: str | None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError("Score API is not supported for Mistral tokenizer")
|
||||
@@ -1265,13 +1386,6 @@ class LLM:
|
||||
pooling_params.verify("score", model_config)
|
||||
pooling_params_list = list[PoolingParams]()
|
||||
|
||||
local_kwargs = tokenization_kwargs or {}
|
||||
tokenization_kwargs = local_kwargs.copy()
|
||||
|
||||
_validate_truncation_size(
|
||||
model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
|
||||
)
|
||||
|
||||
prompts = list[PromptType]()
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
@@ -1314,10 +1428,10 @@ class LLM:
|
||||
data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
|
||||
/,
|
||||
*,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
chat_template: str | None = None,
|
||||
) -> list[ScoringRequestOutput]:
|
||||
"""Generate similarity scores for all pairs `<text,text_pair>` or
|
||||
@@ -1344,20 +1458,22 @@ class LLM:
|
||||
the LLM. Can be text or multi-modal data. See [PromptType]
|
||||
[vllm.inputs.PromptType] for more details about the format of
|
||||
each prompt.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
use_tqdm: If `True`, shows a tqdm progress bar.
|
||||
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
|
||||
it is used to create the progress bar.
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
chat_template: The chat template to use for the scoring. If None, we
|
||||
use the model's default chat template.
|
||||
tokenization_kwargs: Overrides for `tokenizer.encode`.
|
||||
Returns:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
model_config = self.model_config
|
||||
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
raise ValueError(
|
||||
@@ -1445,26 +1561,27 @@ class LLM:
|
||||
|
||||
_validate_score_input_lens(data_1, data_2) # type: ignore[arg-type]
|
||||
|
||||
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
|
||||
encode_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
if model_config.is_cross_encoder:
|
||||
return self._cross_encoding_score(
|
||||
tokenizer,
|
||||
data_1, # type: ignore[arg-type]
|
||||
data_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
pooling_params,
|
||||
lora_request,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
score_template=chat_template,
|
||||
)
|
||||
else:
|
||||
return self._embedding_score(
|
||||
tokenizer,
|
||||
data_1, # type: ignore[arg-type]
|
||||
data_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
pooling_params,
|
||||
lora_request,
|
||||
use_tqdm=use_tqdm,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=encode_kwargs,
|
||||
)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
@@ -1530,42 +1647,79 @@ class LLM:
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType] | DataPrompt,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
params: SamplingParams
|
||||
| Sequence[SamplingParams]
|
||||
| PoolingParams
|
||||
| Sequence[PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None,
|
||||
priority: list[int] | None = None,
|
||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: list[int] | None = None,
|
||||
) -> None:
|
||||
if isinstance(prompts, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts] # type: ignore[list-item]
|
||||
in_prompts = self._normalize_prompts(prompts)
|
||||
num_requests = len(in_prompts)
|
||||
|
||||
num_requests = len(prompts)
|
||||
if isinstance(params, Sequence) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params must be the same.")
|
||||
if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
|
||||
raise ValueError(
|
||||
"The lengths of prompts and lora_request must be the same."
|
||||
)
|
||||
if priority is not None and len(priority) != num_requests:
|
||||
raise ValueError(
|
||||
"The lengths of prompts "
|
||||
f"({num_requests}) and priority ({len(priority)}) "
|
||||
"must be the same."
|
||||
if isinstance(params, Sequence):
|
||||
if len(params) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({params}) "
|
||||
f"and lora_request ({len(params)}) must be the same."
|
||||
)
|
||||
|
||||
engine_params = params
|
||||
else:
|
||||
engine_params = [params] * num_requests
|
||||
|
||||
if isinstance(lora_request, Sequence):
|
||||
if len(lora_request) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({num_requests}) "
|
||||
f"and lora_request ({len(lora_request)}) must be the same."
|
||||
)
|
||||
|
||||
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
|
||||
else:
|
||||
engine_lora_requests = [lora_request] * num_requests
|
||||
|
||||
if priority is not None:
|
||||
if len(priority) != num_requests:
|
||||
raise ValueError(
|
||||
f"The lengths of prompts ({num_requests}) "
|
||||
f"and priority ({len(priority)}) must be the same."
|
||||
)
|
||||
else:
|
||||
priority = [0] * num_requests
|
||||
|
||||
if any(param.truncate_prompt_tokens is not None for param in engine_params):
|
||||
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
|
||||
# Then, move the code from the `else` block to the top and let
|
||||
# `self._preprocess_completion` handle prompt normalization
|
||||
engine_prompts = [
|
||||
engine_prompt
|
||||
for in_prompt, param in zip(in_prompts, engine_params)
|
||||
for engine_prompt in self._preprocess_completion(
|
||||
[in_prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_completion(
|
||||
in_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
for sp in params if isinstance(params, Sequence) else (params,):
|
||||
for sp in engine_params:
|
||||
if isinstance(sp, SamplingParams):
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
it = prompts
|
||||
it = engine_prompts
|
||||
if use_tqdm:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
@@ -1576,12 +1730,10 @@ class LLM:
|
||||
for i, prompt in enumerate(it):
|
||||
request_id = self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request[i]
|
||||
if isinstance(lora_request, Sequence)
|
||||
else lora_request,
|
||||
priority=priority[i] if priority else 0,
|
||||
engine_params[i],
|
||||
lora_request=engine_lora_requests[i],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority[i],
|
||||
)
|
||||
added_request_ids.append(request_id)
|
||||
except Exception as e:
|
||||
@@ -1589,54 +1741,42 @@ class LLM:
|
||||
self.llm_engine.abort_request(added_request_ids, internal=True)
|
||||
raise e
|
||||
|
||||
def _process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
*,
|
||||
lora_request: LoRARequest | None,
|
||||
priority: int,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||
"""Use the Processor to process inputs for LLMEngine."""
|
||||
|
||||
local_kwargs = tokenization_kwargs or {}
|
||||
tokenization_kwargs = local_kwargs.copy()
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
params.truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: int = 0,
|
||||
) -> str:
|
||||
prompt_text, _, _ = get_prompt_components(prompt)
|
||||
request_id = str(next(self.request_counter))
|
||||
|
||||
engine_request, tokenization_kwargs = self._process_inputs(
|
||||
if params.truncate_prompt_tokens is not None:
|
||||
params_type = type(params).__name__
|
||||
warnings.warn(
|
||||
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
|
||||
"is deprecated and will be removed in v0.16. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
|
||||
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
self.llm_engine.add_request(
|
||||
|
||||
@@ -13,12 +13,13 @@ from openai.types.chat.chat_completion_audio import (
|
||||
ChatCompletionAudio as OpenAIChatCompletionAudio,
|
||||
)
|
||||
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
AnyResponseFormat,
|
||||
DeltaMessage,
|
||||
@@ -36,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
@@ -348,6 +350,43 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [end:chat-completion-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template or default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
continue_final_message=self.continue_final_message,
|
||||
documents=self.documents,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
if self.max_completion_tokens is not None:
|
||||
max_output_tokens: int | None = self.max_completion_tokens
|
||||
max_output_tokens_param = "max_completion_tokens"
|
||||
else:
|
||||
max_output_tokens = self.max_tokens
|
||||
max_output_tokens_param = "max_tokens"
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=max_output_tokens or 0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
needs_detokenization=bool(self.echo and not self.return_token_ids),
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param=max_output_tokens_param,
|
||||
)
|
||||
|
||||
# Default sampling parameters for chat completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
|
||||
@@ -67,7 +67,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
@@ -185,8 +185,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
renderer = self.engine_client.renderer
|
||||
|
||||
# Create a minimal dummy request
|
||||
dummy_request = ChatCompletionRequest(
|
||||
messages=[{"role": "user", "content": "warmup"}],
|
||||
@@ -201,18 +199,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# 3. Tokenizer initialization for chat
|
||||
await self._preprocess_chat(
|
||||
dummy_request,
|
||||
renderer,
|
||||
dummy_request.messages,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=True,
|
||||
continue_final_message=False,
|
||||
tool_dicts=None,
|
||||
documents=None,
|
||||
chat_template_kwargs=None,
|
||||
default_chat_template_kwargs=self.default_chat_template_kwargs,
|
||||
tool_parser=None,
|
||||
add_special_tokens=False,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=self.default_chat_template_kwargs,
|
||||
)
|
||||
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
@@ -225,7 +215,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def render_chat_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse:
|
||||
) -> (
|
||||
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""
|
||||
render chat request by validating and preprocessing inputs.
|
||||
|
||||
@@ -302,23 +295,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
chat_template_kwargs = request.chat_template_kwargs or {}
|
||||
chat_template_kwargs.update(reasoning_effort=request.reasoning_effort)
|
||||
|
||||
conversation, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
renderer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=self.default_chat_template_kwargs,
|
||||
tool_dicts=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
default_chat_template_kwargs=self.default_chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For GPT-OSS.
|
||||
@@ -428,11 +412,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
|
||||
@@ -9,11 +9,9 @@ from dataclasses import replace
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
AnyResponseFormat,
|
||||
LegacyStructuralTagResponseFormat,
|
||||
@@ -27,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RequestOutputKind,
|
||||
@@ -178,6 +177,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# --8<-- [end:completion-extra-params]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=self.max_tokens or 0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
needs_detokenization=bool(self.echo and not self.return_token_ids),
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_tokens",
|
||||
)
|
||||
|
||||
# Default sampling parameters for completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
|
||||
@@ -32,7 +32,6 @@ from vllm.entrypoints.openai.engine.serving import (
|
||||
clamp_prompt_logprobs,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
|
||||
@@ -111,11 +110,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
|
||||
try:
|
||||
renderer = self._get_completion_renderer()
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@@ -203,10 +201,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
# Mypy inconsistently requires this second cast in different
|
||||
# environments. It shouldn't be necessary (redundant from above)
|
||||
# but pre-commit in CI fails without it.
|
||||
engine_prompt = cast(EmbedsPrompt | TokensPrompt, engine_prompt)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
generator = self.beam_search(
|
||||
prompt=engine_prompt,
|
||||
@@ -216,11 +210,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
@@ -709,26 +707,3 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
max_input_length: int | None = None,
|
||||
) -> RenderConfig:
|
||||
# Validate max_tokens before using it
|
||||
if request.max_tokens is not None and request.max_tokens > self.max_model_len:
|
||||
raise VLLMValidationError(
|
||||
f"'max_tokens' ({request.max_tokens}) cannot be greater than "
|
||||
f"the model's maximum context length ({self.max_model_len}).",
|
||||
parameter="max_tokens",
|
||||
value=request.max_tokens,
|
||||
)
|
||||
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo and not request.return_token_ids),
|
||||
)
|
||||
|
||||
@@ -16,9 +16,7 @@ from pydantic import (
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
SamplingParams,
|
||||
)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ import json
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
|
||||
from collections.abc import AsyncGenerator, Callable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
|
||||
from typing import Any, ClassVar, Generic, Protocol, TypeAlias, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
@@ -20,6 +20,7 @@ from starlette.datastructures import Headers
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
@@ -86,7 +87,6 @@ from vllm.entrypoints.pooling.score.protocol import (
|
||||
ScoreResponse,
|
||||
ScoreTextRequest,
|
||||
)
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
@@ -94,13 +94,9 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import (
|
||||
_validate_truncation_size,
|
||||
get_max_tokens,
|
||||
sanitize_message,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import (
|
||||
get_prompt_components,
|
||||
is_explicit_encoder_decoder_prompt,
|
||||
@@ -112,7 +108,7 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
@@ -123,11 +119,9 @@ from vllm.tracing import (
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
merge_async_iterators,
|
||||
)
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
|
||||
class GenerationError(Exception):
|
||||
@@ -140,6 +134,21 @@ class GenerationError(Exception):
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RendererRequest(Protocol):
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RendererChatRequest(RendererRequest, Protocol):
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
CompletionLikeRequest: TypeAlias = (
|
||||
CompletionRequest
|
||||
| TokenizeCompletionRequest
|
||||
@@ -158,7 +167,9 @@ ChatLikeRequest: TypeAlias = (
|
||||
| ClassificationChatRequest
|
||||
| PoolingChatRequest
|
||||
)
|
||||
|
||||
SpeechToTextRequest: TypeAlias = TranscriptionRequest | TranslationRequest
|
||||
|
||||
AnyRequest: TypeAlias = (
|
||||
CompletionLikeRequest
|
||||
| ChatLikeRequest
|
||||
@@ -193,7 +204,7 @@ class ServeContext(Generic[RequestT]):
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[TokensPrompt] | None = None
|
||||
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
@@ -227,7 +238,6 @@ class OpenAIServing:
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
|
||||
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
self.input_processor = self.models.input_processor
|
||||
@@ -519,41 +529,6 @@ class OpenAIServing:
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
def _get_completion_renderer(self) -> BaseRenderer:
|
||||
"""
|
||||
Get a Renderer instance with the provided tokenizer.
|
||||
Uses shared async tokenizer pool for efficiency.
|
||||
"""
|
||||
return CompletionRenderer(
|
||||
model_config=self.model_config,
|
||||
tokenizer=self.renderer.tokenizer,
|
||||
async_tokenizer_pool=self._async_tokenizer_pool,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: Any,
|
||||
) -> RenderConfig:
|
||||
"""
|
||||
Build and return a `RenderConfig` for an endpoint.
|
||||
|
||||
Used by the renderer to control how prompts are prepared
|
||||
(e.g., tokenization and length handling). Endpoints should
|
||||
implement this with logic appropriate to their request type.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||
"""
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
given tokenizer.
|
||||
"""
|
||||
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
|
||||
if async_tokenizer is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
self._async_tokenizer_pool[tokenizer] = async_tokenizer
|
||||
return async_tokenizer
|
||||
|
||||
async def _preprocess(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
@@ -912,71 +887,6 @@ class OpenAIServing:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
async def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt: str,
|
||||
tokenizer: TokenizerLike,
|
||||
add_special_tokens: bool,
|
||||
) -> TokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
|
||||
if (
|
||||
self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get("do_lower_case", False)
|
||||
):
|
||||
prompt = prompt.lower()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = await async_tokenizer(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
elif truncate_prompt_tokens < 0:
|
||||
# Negative means we cap at the model's max length
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=self.max_model_len,
|
||||
)
|
||||
else:
|
||||
encoded = await async_tokenizer(
|
||||
prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
input_ids = encoded.input_ids
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
async def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
prompt_ids: list[int],
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> TokensPrompt:
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
input_ids = prompt_ids[-self.max_model_len :]
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
if tokenizer is None:
|
||||
input_text = ""
|
||||
else:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
input_text = await async_tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: object,
|
||||
@@ -1061,50 +971,6 @@ class OpenAIServing:
|
||||
|
||||
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
async def _tokenize_prompt_input_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
prompt_input: str | list[int],
|
||||
add_special_tokens: bool = True,
|
||||
) -> TokensPrompt:
|
||||
"""
|
||||
A simpler implementation that tokenizes a single prompt input.
|
||||
"""
|
||||
async for result in self._tokenize_prompt_inputs_async(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
add_special_tokens=add_special_tokens,
|
||||
):
|
||||
return result
|
||||
raise ValueError("No results yielded from tokenization")
|
||||
|
||||
async def _tokenize_prompt_inputs_async(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: TokenizerLike,
|
||||
prompt_inputs: Iterable[str | list[int]],
|
||||
add_special_tokens: bool = True,
|
||||
) -> AsyncGenerator[TokensPrompt, None]:
|
||||
"""
|
||||
A simpler implementation that tokenizes multiple prompt inputs.
|
||||
"""
|
||||
for prompt in prompt_inputs:
|
||||
if isinstance(prompt, str):
|
||||
yield await self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
prompt=prompt,
|
||||
tokenizer=tokenizer,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield await self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
request_chat_template: str | None,
|
||||
@@ -1137,131 +1003,94 @@ class OpenAIServing:
|
||||
# Apply server defaults first, then request kwargs override.
|
||||
return default_chat_template_kwargs | request_chat_template_kwargs
|
||||
|
||||
async def _preprocess_completion(
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
renderer = self.renderer
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
|
||||
in_prompts = await renderer.render_completions_async(
|
||||
prompt_input, prompt_embeds
|
||||
)
|
||||
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
|
||||
|
||||
extra_items = {
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
}
|
||||
for prompt in engine_prompts:
|
||||
prompt.update(extra_items) # type: ignore
|
||||
|
||||
return engine_prompts
|
||||
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: ChatLikeRequest | ResponsesRequest,
|
||||
renderer: RendererLike,
|
||||
request: RendererChatRequest,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
documents: list[dict[str, str]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
default_chat_template_kwargs: dict[str, Any] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
|
||||
chat_template_kwargs = {
|
||||
"chat_template": chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
"continue_final_message": continue_final_message,
|
||||
"tools": tool_dicts,
|
||||
"documents": documents,
|
||||
**(chat_template_kwargs or {}),
|
||||
}
|
||||
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
|
||||
chat_template_kwargs,
|
||||
default_chat_template_kwargs,
|
||||
)
|
||||
|
||||
# Use the async tokenizer in `OpenAIServing` if possible.
|
||||
# Later we can move it into the renderer so that we can return both
|
||||
# text and token IDs in the same prompt from `render_messages_async`
|
||||
# which is used for logging and `enable_response_messages`.
|
||||
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
conversation, engine_prompt = await renderer.render_messages_async(
|
||||
messages,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
tokenize=(
|
||||
chat_template_kwargs.pop("tokenize", False)
|
||||
or isinstance(renderer.tokenizer, MistralTokenizer)
|
||||
renderer = self.renderer
|
||||
|
||||
default_template_kwargs = merge_kwargs(
|
||||
default_template_kwargs,
|
||||
dict(
|
||||
tools=tool_dicts,
|
||||
tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
|
||||
),
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" not in engine_prompt:
|
||||
extra_data = engine_prompt
|
||||
engine_prompt = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
renderer.get_tokenizer(),
|
||||
engine_prompt["prompt"],
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
# Fill in other keys like MM data
|
||||
engine_prompt.update(extra_data) # type: ignore
|
||||
else:
|
||||
self._validate_input(
|
||||
request=request,
|
||||
input_ids=engine_prompt["prompt_token_ids"], # type: ignore
|
||||
input_text="",
|
||||
)
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
|
||||
engine_prompt = cast(TokensPrompt, engine_prompt)
|
||||
conversation, prompt = await renderer.render_messages_async(
|
||||
messages, chat_params
|
||||
)
|
||||
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
|
||||
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
if (cache_salt := getattr(request, "cache_salt", None)) is not None:
|
||||
engine_prompt["cache_salt"] = cache_salt
|
||||
extra_items = {
|
||||
k: v
|
||||
for k in ("mm_processor_kwargs", "cache_salt")
|
||||
if (v := getattr(request, k, None)) is not None
|
||||
}
|
||||
engine_prompt.update(extra_items) # type: ignore
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
|
||||
should_parse_tools = tool_parser is not None and (
|
||||
hasattr(request, "tool_choice") and request.tool_choice != "none"
|
||||
)
|
||||
if tool_parser is not None:
|
||||
tool_choice = getattr(request, "tool_choice", "none")
|
||||
if tool_choice != "none":
|
||||
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
|
||||
msg = (
|
||||
"Tool usage is only supported for Chat Completions API "
|
||||
"or Responses API requests."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if should_parse_tools:
|
||||
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
|
||||
msg = (
|
||||
"Tool usage is only supported for Chat Completions API "
|
||||
"or Responses API requests."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore
|
||||
# TODO: Update adjust_request to accept ResponsesRequest
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
async def _process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
*,
|
||||
lora_request: LoRARequest | None,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
priority: int,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||
"""Use the Processor to process inputs for AsyncLLM."""
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(
|
||||
self.max_model_len, params.truncate_prompt_tokens, tokenization_kwargs
|
||||
)
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
renderer: RendererLike,
|
||||
messages: list[ResponseInputOutputItem],
|
||||
tool_dicts: list[dict[str, Any]] | None,
|
||||
tool_parser,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
@@ -1271,24 +1100,25 @@ class OpenAIServing:
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
renderer,
|
||||
new_messages,
|
||||
default_template=chat_template,
|
||||
default_template_content_format=chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
)
|
||||
return engine_prompts
|
||||
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: TokensPrompt,
|
||||
engine_prompt: TokensPrompt | EmbedsPrompt,
|
||||
sampling_params: SamplingParams,
|
||||
tok_params: TokenizeParams,
|
||||
context: ConversationContext,
|
||||
lora_request: LoRARequest | None = None,
|
||||
priority: int = 0,
|
||||
**kwargs,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
prompt_text, _, _ = get_prompt_components(engine_prompt)
|
||||
|
||||
@@ -1297,18 +1127,21 @@ class OpenAIServing:
|
||||
while True:
|
||||
# Ensure that each sub-request has a unique request id.
|
||||
sub_request_id = f"{request_id}_{sub_request}"
|
||||
|
||||
self._log_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
trace_headers = kwargs.get("trace_headers")
|
||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
sub_request_id,
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
@@ -1318,10 +1151,10 @@ class OpenAIServing:
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async for res in generator:
|
||||
@@ -1350,7 +1183,6 @@ class OpenAIServing:
|
||||
elif isinstance(context, ParsableContext):
|
||||
engine_prompts = await self._render_next_turn(
|
||||
context.request,
|
||||
context.renderer,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
context.tool_parser_cls,
|
||||
|
||||
@@ -43,7 +43,6 @@ from vllm.entrypoints.openai.responses.protocol import (
|
||||
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
@@ -261,7 +260,7 @@ class ParsableContext(ConversationContext):
|
||||
self,
|
||||
*,
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
renderer: RendererLike,
|
||||
tokenizer: TokenizerLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
available_tools: list[str] | None,
|
||||
@@ -280,7 +279,6 @@ class ParsableContext(ConversationContext):
|
||||
if reasoning_parser_cls is None:
|
||||
raise ValueError("reasoning_parser_cls must be provided.")
|
||||
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
self.parser = get_responses_parser_for_simple_context(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
@@ -290,8 +288,6 @@ class ParsableContext(ConversationContext):
|
||||
)
|
||||
self.tool_parser_cls = tool_parser_cls
|
||||
self.request = request
|
||||
self.renderer = renderer
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.available_tools = available_tools or []
|
||||
self._tool_sessions: dict[str, ClientSession | Tool] = {}
|
||||
|
||||
@@ -59,12 +59,15 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
OpenAIBaseModel,
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
@@ -230,6 +233,42 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
from .utils import should_continue_final_message
|
||||
|
||||
# Check if we should continue the final message (partial completion)
|
||||
# This enables Anthropic-style partial message completion where the
|
||||
# user provides an incomplete assistant message to continue from.
|
||||
continue_final = should_continue_final_message(self.input)
|
||||
|
||||
reasoning = self.reasoning
|
||||
|
||||
return ChatParams(
|
||||
chat_template=default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs( # To remove unset values
|
||||
{},
|
||||
dict(
|
||||
add_generation_prompt=not continue_final,
|
||||
continue_final_message=continue_final,
|
||||
reasoning_effort=None if reasoning is None else reasoning.effort,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=self.max_output_tokens or 0,
|
||||
truncate_prompt_tokens=-1 if self.truncation != "disabled" else None,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_output_tokens",
|
||||
)
|
||||
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
|
||||
@@ -114,16 +114,15 @@ from vllm.entrypoints.openai.responses.utils import (
|
||||
construct_input_messages,
|
||||
construct_tool_dicts,
|
||||
extract_tool_types,
|
||||
should_continue_final_message,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_len
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
@@ -291,13 +290,14 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self.tool_server = tool_server
|
||||
|
||||
def _validate_generator_input(
|
||||
self, engine_prompt: TokensPrompt
|
||||
self,
|
||||
engine_prompt: TokensPrompt | EmbedsPrompt,
|
||||
) -> ErrorResponse | None:
|
||||
"""Add validations to the input to the generator here."""
|
||||
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
|
||||
prompt_len = get_prompt_len(engine_prompt)
|
||||
if self.max_model_len <= prompt_len:
|
||||
error_message = (
|
||||
"The engine prompt length"
|
||||
f" {len(engine_prompt['prompt_token_ids'])} "
|
||||
f"The engine prompt length {prompt_len} "
|
||||
f"exceeds the max_model_len {self.max_model_len}. "
|
||||
"Please reduce prompt."
|
||||
)
|
||||
@@ -307,6 +307,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
param="input",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _validate_create_responses_input(
|
||||
@@ -387,8 +388,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
model_name = self.models.model_name(lora_request)
|
||||
renderer = self.engine_client.renderer
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
|
||||
if self.use_harmony:
|
||||
messages, engine_prompts = self._make_request_with_harmony(
|
||||
@@ -396,7 +395,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
messages, engine_prompts = await self._make_request(
|
||||
request, prev_response, renderer
|
||||
request, prev_response
|
||||
)
|
||||
|
||||
except (
|
||||
@@ -431,6 +430,9 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
assert len(builtin_tool_list) == 0
|
||||
available_tools = []
|
||||
try:
|
||||
renderer = self.engine_client.renderer
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
|
||||
for engine_prompt in engine_prompts:
|
||||
maybe_error = self._validate_generator_input(engine_prompt)
|
||||
if maybe_error is not None:
|
||||
@@ -446,6 +448,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params
|
||||
)
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
@@ -465,7 +468,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# tokens during generation instead of at the end
|
||||
context = ParsableContext(
|
||||
response_messages=messages,
|
||||
renderer=renderer,
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=self.reasoning_parser,
|
||||
request=request,
|
||||
tool_parser_cls=self.tool_parser,
|
||||
@@ -495,6 +498,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
request_id=request.request_id,
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
tok_params=tok_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
@@ -596,7 +600,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
prev_response: ResponsesResponse | None,
|
||||
renderer: RendererLike,
|
||||
):
|
||||
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
# Construct the input messages.
|
||||
@@ -606,30 +609,15 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
|
||||
prev_response_output=prev_response.output if prev_response else None,
|
||||
)
|
||||
# Check if we should continue the final message (partial completion)
|
||||
# This enables Anthropic-style partial message completion where the
|
||||
# user provides an incomplete assistant message to continue from.
|
||||
continue_final = should_continue_final_message(request.input)
|
||||
chat_template_kwargs = dict(
|
||||
reasoning_effort=None
|
||||
if request.reasoning is None
|
||||
else request.reasoning.effort
|
||||
)
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
renderer,
|
||||
messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=self.tool_parser,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
# When continuing a partial message, we set continue_final_message=True
|
||||
# and add_generation_prompt=False so the model continues the message
|
||||
# rather than starting a new one.
|
||||
add_generation_prompt=not continue_final,
|
||||
continue_final_message=continue_final,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
return messages, engine_prompts
|
||||
|
||||
|
||||
@@ -8,8 +8,12 @@ from pydantic import Field, model_validator
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.renderers import ChatParams, merge_kwargs
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
@@ -119,6 +123,23 @@ class ChatRequestMixin(OpenAIBaseModel):
|
||||
)
|
||||
return data
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template or default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
continue_final_message=self.continue_final_message,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class EncodingRequestMixin(OpenAIBaseModel):
|
||||
# --8<-- [start:encoding-params]
|
||||
|
||||
@@ -4,10 +4,9 @@
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
@@ -15,13 +14,24 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
CompletionRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ClassificationCompletionRequest(
|
||||
PoolingBasicRequestMixin, CompletionRequestMixin, ClassifyRequestMixin
|
||||
):
|
||||
pass
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(
|
||||
@@ -33,6 +43,18 @@ class ClassificationChatRequest(
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
|
||||
ClassificationRequest: TypeAlias = (
|
||||
ClassificationCompletionRequest | ClassificationChatRequest
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Final, cast
|
||||
from typing import Final, TypeAlias
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
@@ -21,15 +20,14 @@ from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
from vllm.outputs import ClassificationOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
ClassificationServeContext = ServeContext[ClassificationRequest]
|
||||
ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
|
||||
|
||||
|
||||
class ServingClassification(OpenAIServing):
|
||||
@@ -77,34 +75,18 @@ class ServingClassification(OpenAIServing):
|
||||
if error_check_ret:
|
||||
return error_check_ret
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
self.renderer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=ctx.request.add_generation_prompt,
|
||||
continue_final_message=ctx.request.continue_final_message,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
ctx.engine_prompts = engine_prompts
|
||||
|
||||
elif isinstance(ctx.request, ClassificationCompletionRequest):
|
||||
input_data = ctx.request.input
|
||||
if input_data in (None, ""):
|
||||
return self.create_error_response(
|
||||
"Input or messages must be provided",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
if isinstance(input_data, list) and not input_data:
|
||||
ctx.engine_prompts = []
|
||||
return None
|
||||
|
||||
renderer = self._get_completion_renderer()
|
||||
prompt_input = cast(str | list[str], input_data)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=prompt_input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
@@ -128,7 +110,7 @@ class ServingClassification(OpenAIServing):
|
||||
items: list[ClassificationData] = []
|
||||
num_prompt_tokens = 0
|
||||
|
||||
final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)
|
||||
final_res_batch_checked = ctx.final_res_batch
|
||||
|
||||
for idx, final_res in enumerate(final_res_batch_checked):
|
||||
classify_res = ClassificationOutput.from_base(final_res.outputs)
|
||||
@@ -161,13 +143,6 @@ class ServingClassification(OpenAIServing):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(self, request: ClassificationRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
async def create_classify(
|
||||
self,
|
||||
request: ClassificationRequest,
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
@@ -14,15 +13,47 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
EmbedRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
def _get_max_total_output_tokens(
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[int | None, int]:
|
||||
max_total_tokens = model_config.max_model_len
|
||||
pooler_config = model_config.pooler_config
|
||||
|
||||
if pooler_config is None:
|
||||
return max_total_tokens, 0
|
||||
|
||||
if pooler_config.enable_chunked_processing:
|
||||
return None, 0
|
||||
|
||||
max_embed_len = pooler_config.max_embed_len or max_total_tokens
|
||||
max_output_tokens = max_total_tokens - max_embed_len
|
||||
return max_total_tokens, max_output_tokens
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(
|
||||
PoolingBasicRequestMixin, CompletionRequestMixin, EmbedRequestMixin
|
||||
):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
pass
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
(
|
||||
max_total_tokens,
|
||||
max_output_tokens,
|
||||
) = _get_max_total_output_tokens(model_config)
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_model_len - max_embed_len",
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(
|
||||
@@ -33,6 +64,24 @@ class EmbeddingChatRequest(
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
(
|
||||
max_total_tokens,
|
||||
max_output_tokens,
|
||||
) = _get_max_total_output_tokens(model_config)
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=max_output_tokens,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
max_output_tokens_param="max_model_len - max_embed_len",
|
||||
)
|
||||
|
||||
|
||||
EmbeddingRequest: TypeAlias = EmbeddingCompletionRequest | EmbeddingChatRequest
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from typing import Any, Final, cast
|
||||
from typing import Any, Final, TypeAlias
|
||||
|
||||
import torch
|
||||
from fastapi import Request
|
||||
@@ -22,8 +22,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
)
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -37,7 +36,7 @@ from vllm.utils.serial_utils import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
EmbeddingServeContext = ServeContext[EmbeddingRequest]
|
||||
EmbeddingServeContext: TypeAlias = ServeContext[EmbeddingRequest]
|
||||
|
||||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
@@ -95,19 +94,16 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
self.renderer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=ctx.request.add_generation_prompt,
|
||||
continue_final_message=ctx.request.continue_final_message,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(ctx.request, EmbeddingCompletionRequest):
|
||||
renderer = self._get_completion_renderer()
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
ctx.engine_prompts = await self._preprocess_completion(
|
||||
ctx.request,
|
||||
prompt_input=ctx.request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
return self.create_error_response("Invalid classification request type")
|
||||
@@ -117,19 +113,6 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
return RenderConfig(
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
def _build_response(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
@@ -246,14 +229,18 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
tok_params = ctx.request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
priority=ctx.request.priority,
|
||||
)
|
||||
|
||||
generators.append(original_generator)
|
||||
@@ -338,7 +325,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: TokensPrompt,
|
||||
engine_prompt: TokensPrompt | EmbedsPrompt,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
prompt_index: int,
|
||||
@@ -353,23 +340,25 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
tok_params = ctx.request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=getattr(ctx.request, "priority", 0),
|
||||
priority=ctx.request.priority,
|
||||
)
|
||||
|
||||
async def _prepare_generators(
|
||||
self,
|
||||
ctx: ServeContext,
|
||||
ctx: EmbeddingServeContext,
|
||||
) -> ErrorResponse | None:
|
||||
"""Override to support chunked processing."""
|
||||
ctx = cast(EmbeddingServeContext, ctx)
|
||||
|
||||
# Check if we should use chunked processing
|
||||
use_chunked = self._should_use_chunked_processing(ctx.request)
|
||||
|
||||
@@ -405,7 +394,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||
# Check if this specific prompt needs chunked processing
|
||||
if "prompt_token_ids" in engine_prompt:
|
||||
prompt_token_ids = engine_prompt["prompt_token_ids"]
|
||||
prompt_token_ids = engine_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
|
||||
|
||||
if len(prompt_token_ids) > max_pos_embeddings:
|
||||
# Use chunked processing for this prompt
|
||||
chunk_generators = await self._process_chunked_request(
|
||||
@@ -573,7 +563,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
"token IDs"
|
||||
)
|
||||
|
||||
original_token_ids = original_prompt["prompt_token_ids"]
|
||||
original_token_ids = original_prompt["prompt_token_ids"] # type: ignore[typeddict-item]
|
||||
|
||||
pooling_request_output = PoolingRequestOutput(
|
||||
request_id=aggregator["request_id"],
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
import time
|
||||
from typing import Any, Generic, TypeAlias, TypeVar
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
@@ -18,6 +17,7 @@ from vllm.entrypoints.pooling.base.protocol import (
|
||||
EncodingRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@@ -30,6 +30,18 @@ class PoolingCompletionRequest(
|
||||
):
|
||||
task: PoolingTask | None = None
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
@@ -48,6 +60,18 @@ class PoolingChatRequest(
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
|
||||
@@ -5,7 +5,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Final, cast
|
||||
from typing import Any, Final, cast
|
||||
|
||||
import jinja2
|
||||
from fastapi import Request
|
||||
@@ -14,10 +14,7 @@ from typing_extensions import assert_never
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
@@ -30,8 +27,6 @@ from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
PoolingResponse,
|
||||
PoolingResponseData,
|
||||
)
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.tasks import PoolingTask, SupportedTask
|
||||
@@ -99,11 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
"dimensions is currently not supported"
|
||||
)
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens
|
||||
)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
@@ -134,19 +124,16 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
self.renderer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
renderer = self._get_completion_renderer()
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.input,
|
||||
config=self._build_render_config(request),
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.input,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported request of type {type(request)}")
|
||||
@@ -207,11 +194,18 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
if is_io_processor_request:
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
else:
|
||||
tok_params = request.build_tok_params(self.model_config) # type: ignore
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
@@ -338,10 +332,3 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
return encode_bytes(bytes_only=encoding_format == "bytes_only")
|
||||
else:
|
||||
assert_never(encoding_format)
|
||||
|
||||
def _build_render_config(self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
@@ -3,12 +3,10 @@
|
||||
import time
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
@@ -19,6 +17,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
@@ -30,6 +29,17 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
)
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
@@ -85,6 +95,17 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
|
||||
)
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
encoder_config = model_config.encoder_config or {}
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=model_config.max_model_len,
|
||||
max_output_tokens=0,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
do_lower_case=encoder_config.get("do_lower_case", False),
|
||||
max_total_tokens_param="max_model_len",
|
||||
)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str | None = None
|
||||
|
||||
@@ -34,7 +34,6 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
compress_token_type_ids,
|
||||
get_score_prompt,
|
||||
)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -68,31 +67,31 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
data_1: list[str],
|
||||
data_2: list[str],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
encode_async = make_async(
|
||||
tokenizer.encode,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
input_texts = data_1 + data_2
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
tokenize_async = make_async(
|
||||
tokenizer.__call__, executor=self._tokenizer_executor
|
||||
)
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
tokenized_prompts = await asyncio.gather(
|
||||
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(
|
||||
request, tok_result["input_ids"], input_text
|
||||
)
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
|
||||
@@ -184,15 +183,16 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
async def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
data_1: list[str] | list[ScoreContentPartParam],
|
||||
data_2: list[str] | list[ScoreContentPartParam],
|
||||
request: RerankRequest | ScoreRequest,
|
||||
request_id: str,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
lora_request: LoRARequest | None | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
model_config = self.model_config
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
request_prompts: list[str] = []
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
|
||||
@@ -202,12 +202,13 @@ class ServingScores(OpenAIServing):
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError("MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
tok_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]
|
||||
|
||||
preprocess_async = make_async(
|
||||
self._preprocess_score, executor=self._tokenizer_executor
|
||||
self._preprocess_score,
|
||||
executor=self._tokenizer_executor,
|
||||
)
|
||||
|
||||
preprocessed_prompts = await asyncio.gather(
|
||||
@@ -215,7 +216,7 @@ class ServingScores(OpenAIServing):
|
||||
preprocess_async(
|
||||
request=request,
|
||||
tokenizer=tokenizer,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
data_1=t1,
|
||||
data_2=t2,
|
||||
)
|
||||
@@ -286,14 +287,6 @@ class ServingScores(OpenAIServing):
|
||||
raw_request: Request | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens, tokenization_kwargs
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
@@ -322,24 +315,20 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
if self.model_config.is_cross_encoder:
|
||||
return await self._cross_encoding_score(
|
||||
tokenizer=tokenizer,
|
||||
data_1=data_1, # type: ignore[arg-type]
|
||||
data_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
else:
|
||||
return await self._embedding_score(
|
||||
tokenizer=tokenizer,
|
||||
data_1=data_1, # type: ignore[arg-type]
|
||||
data_2=data_2, # type: ignore[arg-type]
|
||||
request=request,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
@@ -1,411 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RenderConfig:
|
||||
"""Configuration to control how prompts are prepared."""
|
||||
|
||||
max_length: int | None = None
|
||||
"""Maximum allowable total input token length. If provided,
|
||||
token inputs longer than this raise `ValueError`."""
|
||||
|
||||
truncate_prompt_tokens: int | None = None
|
||||
"""Number of tokens to keep. `None` means no truncation.
|
||||
`0` yields an empty list (and skips embeds).
|
||||
`-1` maps to `model_config.max_model_len`."""
|
||||
|
||||
add_special_tokens: bool = True
|
||||
"""Whether to add model-specific special tokens during tokenization."""
|
||||
|
||||
cache_salt: str | None = None
|
||||
"""String to disambiguate prefix cache entries."""
|
||||
|
||||
needs_detokenization: bool | None = False
|
||||
"""If True, detokenize IDs back to text for inclusion in outputs."""
|
||||
|
||||
def verify_truncate_prompt_tokens(self, model_config: ModelConfig) -> int | None:
|
||||
"""Validate and normalize `truncate_prompt_tokens` parameter."""
|
||||
truncate_prompt_tokens = self.truncate_prompt_tokens
|
||||
if truncate_prompt_tokens is None or truncate_prompt_tokens == 0:
|
||||
return truncate_prompt_tokens
|
||||
|
||||
if truncate_prompt_tokens < 0:
|
||||
truncate_prompt_tokens = model_config.max_model_len
|
||||
|
||||
max_length = self.max_length
|
||||
if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator]
|
||||
raise ValueError(
|
||||
f"{truncate_prompt_tokens=} cannot be greater than "
|
||||
f"{max_length=}. Please select a smaller truncation size."
|
||||
)
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
|
||||
class BaseRenderer(ABC):
|
||||
"""
|
||||
Base class for unified input processing and rendering.
|
||||
|
||||
The Renderer serves as a unified input processor that consolidates
|
||||
tokenization, chat template formatting, and multimodal input handling
|
||||
into a single component.
|
||||
It converts high-level API requests (OpenAI-style JSON) into token IDs and
|
||||
multimodal features ready for engine consumption.
|
||||
|
||||
Key responsibilities:
|
||||
- Convert text prompts to token sequences with proper special tokens
|
||||
- Apply chat templates and format conversations
|
||||
- Handle multimodal inputs (images, audio, etc.) when applicable
|
||||
- Manage prompt truncation and length validation
|
||||
- Provide clean separation between API layer and engine core
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@abstractmethod
|
||||
async def render_prompt(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
|
||||
config: RenderConfig,
|
||||
) -> list[TokensPrompt]:
|
||||
"""
|
||||
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||
|
||||
This method accepts text or token inputs and produces a
|
||||
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
|
||||
for the engine.
|
||||
|
||||
Args:
|
||||
prompt_or_prompts: One of:
|
||||
- `str`: Single text prompt.
|
||||
- `list[str]`: Batch of text prompts.
|
||||
- `list[int]`: Single pre-tokenized sequence.
|
||||
- `list[list[int]]`: Batch of pre-tokenized sequences.
|
||||
config: Render configuration controlling how prompts are prepared
|
||||
(e.g., tokenization and length handling).
|
||||
|
||||
Returns:
|
||||
list[TokensPrompt]: Engine-ready token prompts.
|
||||
|
||||
Raises:
|
||||
ValueError: If input formats are invalid or length limits exceeded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
|
||||
prompt_embeds: bytes | list[bytes] | None = None,
|
||||
config: RenderConfig,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
"""
|
||||
Convert text/token and/or base64-encoded embeddings inputs into
|
||||
engine-ready prompt objects using a unified RenderConfig.
|
||||
|
||||
At least one of `prompt_or_prompts` or `prompt_embeds` must be
|
||||
provided and non-empty. If both are omitted or empty (e.g., empty
|
||||
string and empty list), a `ValueError` is raised.
|
||||
|
||||
Args:
|
||||
prompt_or_prompts: Text or token inputs to include.
|
||||
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
|
||||
torch-saved tensor to be used as prompt embeddings.
|
||||
config: Render configuration controlling how prompts are prepared
|
||||
(e.g., tokenization and length handling).
|
||||
|
||||
Returns:
|
||||
list[Union[TokensPrompt, EmbedsPrompt]]:
|
||||
Engine-ready prompt objects.
|
||||
|
||||
Raises:
|
||||
ValueError: If both `prompt_or_prompts` and `prompt_embeds`
|
||||
are omitted or empty (decoder prompt cannot be empty), or if
|
||||
length limits are exceeded.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_prompt_embeds(
|
||||
self,
|
||||
prompt_embeds: bytes | list[bytes],
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
|
||||
cache_salt: str | None = None,
|
||||
) -> list[EmbedsPrompt]:
|
||||
"""Load and validate base64-encoded embeddings into prompt objects."""
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise VLLMValidationError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
|
||||
parameter="prompt_embeds",
|
||||
)
|
||||
|
||||
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(
|
||||
io.BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
if truncate_prompt_tokens is not None:
|
||||
tensor = tensor[-truncate_prompt_tokens:]
|
||||
embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
|
||||
if cache_salt is not None:
|
||||
embeds_prompt["cache_salt"] = cache_salt
|
||||
return embeds_prompt
|
||||
|
||||
if isinstance(prompt_embeds, list):
|
||||
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
|
||||
|
||||
return [_load_and_validate_embed(prompt_embeds)]
|
||||
|
||||
|
||||
class CompletionRenderer(BaseRenderer):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike | None = None,
|
||||
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
|
||||
| None = None,
|
||||
):
|
||||
super().__init__(model_config, tokenizer)
|
||||
self.async_tokenizer_pool = async_tokenizer_pool
|
||||
self.async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||
|
||||
async def render_prompt(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
|
||||
config: RenderConfig,
|
||||
) -> list[TokensPrompt]:
|
||||
"""Implementation of prompt rendering for completion-style requests.
|
||||
|
||||
Uses async tokenizer pooling for improved performance. See base class
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
tasks = (
|
||||
self._create_prompt(
|
||||
prompt_input,
|
||||
config=config,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
for prompt_input in parse_raw_prompts(prompt_or_prompts)
|
||||
)
|
||||
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
|
||||
prompt_embeds: bytes | list[bytes] | None = None,
|
||||
config: RenderConfig,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = config.verify_truncate_prompt_tokens(self.model_config)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
rendered: list[TokensPrompt | EmbedsPrompt] = []
|
||||
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(
|
||||
prompt_embeds, truncate_prompt_tokens, config.cache_salt
|
||||
)
|
||||
)
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
token_prompts = await self.render_prompt(
|
||||
prompt_or_prompts=prompt_or_prompts,
|
||||
config=config,
|
||||
)
|
||||
rendered.extend(token_prompts)
|
||||
|
||||
return rendered
|
||||
|
||||
def _maybe_apply_truncation(
|
||||
self, token_ids: list[int], truncate_prompt_tokens: int | None
|
||||
) -> list[int]:
|
||||
"""Apply truncation to token sequence."""
|
||||
if truncate_prompt_tokens is None:
|
||||
return token_ids
|
||||
if truncate_prompt_tokens >= len(token_ids):
|
||||
return token_ids
|
||||
|
||||
return token_ids[-truncate_prompt_tokens:]
|
||||
|
||||
async def _create_prompt(
|
||||
self,
|
||||
prompt_input: TextPrompt | TokensPrompt,
|
||||
config: RenderConfig,
|
||||
truncate_prompt_tokens: int | None,
|
||||
) -> TokensPrompt:
|
||||
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
# NOTE: detokenization is needed when echo is enabled,
|
||||
# where the input token IDs are decoded back to text.
|
||||
return await self._create_prompt_from_token_ids(
|
||||
prompt_token_ids,
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.cache_salt,
|
||||
config.needs_detokenization,
|
||||
)
|
||||
|
||||
if prompt is not None:
|
||||
return await self._create_prompt_from_text(
|
||||
prompt,
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.add_special_tokens,
|
||||
config.cache_salt,
|
||||
)
|
||||
|
||||
# TODO: Also handle embeds prompt using this method
|
||||
raise NotImplementedError
|
||||
|
||||
async def _create_prompt_from_text(
|
||||
self,
|
||||
text: str,
|
||||
max_length: int | None,
|
||||
truncate_prompt_tokens: int | None,
|
||||
add_special_tokens: bool,
|
||||
cache_salt: str | None,
|
||||
) -> TokensPrompt:
|
||||
"""Tokenize text input asynchronously."""
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
|
||||
# Handle encoder-specific preprocessing
|
||||
if (
|
||||
self.model_config.encoder_config is not None
|
||||
and self.model_config.encoder_config.get("do_lower_case", False)
|
||||
):
|
||||
text = text.lower()
|
||||
|
||||
# Tokenize texts
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = await async_tokenizer(text, add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
encoded = await async_tokenizer(
|
||||
text,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
return self._create_tokens_prompt(
|
||||
encoded.input_ids, max_length, cache_salt, text
|
||||
)
|
||||
|
||||
async def _create_prompt_from_token_ids(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
max_length: int | None,
|
||||
truncate_prompt_tokens: int | None,
|
||||
cache_salt: str | None,
|
||||
needs_detokenization: bool | None = False,
|
||||
) -> TokensPrompt:
|
||||
"""Optionally detokenize token IDs and build a tokens prompt."""
|
||||
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
|
||||
|
||||
prompt = None
|
||||
if needs_detokenization:
|
||||
async_tokenizer = self._get_async_tokenizer()
|
||||
prompt = await async_tokenizer.decode(token_ids)
|
||||
|
||||
return self._create_tokens_prompt(
|
||||
token_ids=token_ids,
|
||||
max_length=max_length,
|
||||
cache_salt=cache_salt,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
"""Get or create async tokenizer using shared pool."""
|
||||
async_tokenizer = self.async_tokenizer
|
||||
if async_tokenizer is not None:
|
||||
return async_tokenizer
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("No tokenizer available for text input processing")
|
||||
|
||||
if self.async_tokenizer_pool is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
else:
|
||||
async_tokenizer = self.async_tokenizer_pool.get(tokenizer)
|
||||
if async_tokenizer is None:
|
||||
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
|
||||
self.async_tokenizer_pool[tokenizer] = async_tokenizer
|
||||
self.async_tokenizer = async_tokenizer
|
||||
return async_tokenizer
|
||||
|
||||
def _create_tokens_prompt(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
max_length: int | None = None,
|
||||
cache_salt: str | None = None,
|
||||
prompt: str | None = None,
|
||||
) -> TokensPrompt:
|
||||
"""Create validated TokensPrompt."""
|
||||
if max_length is not None and len(token_ids) > max_length:
|
||||
raise VLLMValidationError(
|
||||
f"This model's maximum context length is {max_length} tokens. "
|
||||
f"However, your request has {len(token_ids)} input tokens. "
|
||||
"Please reduce the length of the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=len(token_ids),
|
||||
)
|
||||
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
|
||||
if cache_salt is not None:
|
||||
tokens_prompt["cache_salt"] = cache_salt
|
||||
if prompt is not None:
|
||||
tokens_prompt["prompt"] = prompt
|
||||
return tokens_prompt
|
||||
@@ -4,12 +4,14 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
SamplingParams,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
@@ -62,6 +64,12 @@ class GenerateRequest(BaseModel):
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
class GenerateResponseChoice(BaseModel):
|
||||
index: int
|
||||
|
||||
@@ -101,12 +101,13 @@ class ServingTokens(OpenAIServing):
|
||||
|
||||
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
|
||||
# completed
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
|
||||
if request.features is not None:
|
||||
engine_prompt["multi_modal_data"] = None
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.token_ids,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
assert len(engine_prompts) == 1
|
||||
engine_prompt = engine_prompts[0]
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
result_generator: AsyncGenerator[RequestOutput, None] | None = None
|
||||
@@ -128,11 +129,15 @@ class ServingTokens(OpenAIServing):
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
@@ -6,8 +6,10 @@ from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
OpenAIBaseModel,
|
||||
)
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
|
||||
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
@@ -35,6 +38,13 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
@@ -109,6 +119,30 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
||||
)
|
||||
return data
|
||||
|
||||
def build_chat_params(
|
||||
self,
|
||||
default_template: str | None,
|
||||
default_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> ChatParams:
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template or default_template,
|
||||
chat_template_content_format=default_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(
|
||||
add_generation_prompt=self.add_generation_prompt,
|
||||
continue_final_message=self.continue_final_message,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
|
||||
TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest
|
||||
|
||||
@@ -124,6 +158,13 @@ class DetokenizeRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
tokens: list[int]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
return TokenizeParams(
|
||||
max_total_tokens=None,
|
||||
max_output_tokens=0,
|
||||
needs_detokenization=True,
|
||||
)
|
||||
|
||||
|
||||
class DetokenizeResponse(OpenAIBaseModel):
|
||||
prompt: str
|
||||
|
||||
@@ -9,12 +9,9 @@ from fastapi import Request
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
@@ -83,21 +80,17 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
self.renderer,
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
tool_dicts=tool_dicts,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
renderer = self._get_completion_renderer()
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.prompt,
|
||||
config=self._build_render_config(request),
|
||||
engine_prompts = await self._preprocess_completion(
|
||||
request,
|
||||
prompt_input=request.prompt,
|
||||
prompt_embeds=None,
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@@ -106,11 +99,14 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
input_ids: list[int] = []
|
||||
for engine_prompt in engine_prompts:
|
||||
self._log_inputs(
|
||||
request_id, engine_prompt, params=None, lora_request=lora_request
|
||||
request_id,
|
||||
engine_prompt,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
@@ -136,7 +132,6 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_id = f"tokenize-{self._base_request_id(raw_request)}"
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
@@ -145,14 +140,13 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
prompt_input = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
engine_prompt = await self.renderer.tokenize_prompt_async(
|
||||
TokensPrompt(prompt_token_ids=request.tokens),
|
||||
request.build_tok_params(self.model_config),
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
return DetokenizeResponse(prompt=prompt_text)
|
||||
|
||||
async def get_tokenizer_info(
|
||||
self,
|
||||
@@ -165,9 +159,6 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
except Exception as e:
|
||||
return self.create_error_response(f"Failed to get tokenizer info: {str(e)}")
|
||||
|
||||
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
|
||||
return RenderConfig(add_special_tokens=request.add_special_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
from argparse import Namespace
|
||||
from logging import Logger
|
||||
from string import Template
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import regex as re
|
||||
from fastapi import Request
|
||||
@@ -18,9 +18,9 @@ from starlette.background import BackgroundTask, BackgroundTasks
|
||||
from vllm import envs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import EmbedsPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_len
|
||||
from vllm.logger import current_formatter_type, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -34,9 +34,7 @@ if TYPE_CHECKING:
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
else:
|
||||
ChatCompletionRequest = object
|
||||
CompletionRequest = object
|
||||
@@ -188,33 +186,6 @@ def cli_env_setup():
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def _validate_truncation_size(
|
||||
max_model_len: int,
|
||||
truncate_prompt_tokens: int | None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> int | None:
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= -1:
|
||||
truncate_prompt_tokens = max_model_len
|
||||
|
||||
if truncate_prompt_tokens > max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({max_model_len})."
|
||||
f" Please, select a smaller truncation size."
|
||||
)
|
||||
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
else:
|
||||
if tokenization_kwargs is not None:
|
||||
tokenization_kwargs["truncation"] = False
|
||||
|
||||
return truncate_prompt_tokens
|
||||
|
||||
|
||||
def get_max_tokens(
|
||||
max_model_len: int,
|
||||
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
|
||||
@@ -233,10 +204,7 @@ def get_max_tokens(
|
||||
# CompletionRequest (also a fallback for ChatCompletionRequest)
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
input_length = length_from_prompt_token_ids_or_embeds(
|
||||
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
prompt.get("prompt_embeds"), # type: ignore[arg-type]
|
||||
)
|
||||
input_length = get_prompt_len(prompt)
|
||||
default_max_tokens = max_model_len - input_length
|
||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||
|
||||
|
||||
@@ -21,12 +21,7 @@ else:
|
||||
MultiModalUUIDDict = object
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
"""Schema for a text prompt."""
|
||||
|
||||
prompt: str
|
||||
"""The input text to be tokenized before passing to the model."""
|
||||
|
||||
class _CommonKeys(TypedDict):
|
||||
multi_modal_data: NotRequired[MultiModalDataDict | None]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
@@ -56,7 +51,14 @@ class TextPrompt(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class TokensPrompt(TypedDict):
|
||||
class TextPrompt(_CommonKeys):
|
||||
"""Schema for a text prompt."""
|
||||
|
||||
prompt: str
|
||||
"""The input text to be tokenized before passing to the model."""
|
||||
|
||||
|
||||
class TokensPrompt(_CommonKeys):
|
||||
"""Schema for a tokenized prompt."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
@@ -68,47 +70,15 @@ class TokensPrompt(TypedDict):
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""A list of token type IDs to pass to the cross encoder model."""
|
||||
|
||||
multi_modal_data: NotRequired[MultiModalDataDict | None]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: NotRequired[dict[str, Any] | None]
|
||||
"""
|
||||
Optional multi-modal processor kwargs to be forwarded to the
|
||||
multimodal input mapper & processor. Note that if multiple modalities
|
||||
have registered mappers etc for the model being considered, we attempt
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
multi_modal_uuids: NotRequired[MultiModalUUIDDict]
|
||||
"""
|
||||
Optional user-specified UUIDs for multimodal items, mapped by modality.
|
||||
Lists must match the number of items per modality and may contain `None`.
|
||||
For `None` entries, the hasher will compute IDs automatically; non-None
|
||||
entries override the default hashes for caching.
|
||||
"""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
class EmbedsPrompt(_CommonKeys):
|
||||
"""Schema for a prompt provided via token embeddings."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
"""
|
||||
|
||||
|
||||
class DataPrompt(TypedDict):
|
||||
class DataPrompt(_CommonKeys):
|
||||
"""Represents generic inputs handled by IO processor plugins."""
|
||||
|
||||
data: Any
|
||||
@@ -197,7 +167,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
||||
mm_processor_kwargs: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt
|
||||
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
|
||||
"""
|
||||
Set of possible schemas for an LLM input, including
|
||||
both decoder-only and encoder/decoder input types:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
|
||||
from .data import (
|
||||
EmbedsPrompt,
|
||||
@@ -22,50 +21,6 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
def parse_raw_prompts(
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
) -> Sequence[TextPrompt] | Sequence[TokensPrompt]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [TextPrompt(prompt=prompt)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
# case 2: array of strings
|
||||
if is_list_of(prompt, str):
|
||||
prompt = cast(list[str], prompt)
|
||||
return [TextPrompt(prompt=elem) for elem in prompt]
|
||||
|
||||
# case 3: array of tokens
|
||||
if is_list_of(prompt, int):
|
||||
prompt = cast(list[int], prompt)
|
||||
return [TokensPrompt(prompt_token_ids=prompt)]
|
||||
|
||||
# case 4: array of token arrays
|
||||
if is_list_of(prompt, list):
|
||||
if len(prompt) == 1 and isinstance(prompt[0], list) and len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
for elem in prompt:
|
||||
if not isinstance(elem, list):
|
||||
raise TypeError(
|
||||
"prompt must be a list of lists, but found a non-list element."
|
||||
)
|
||||
if not is_list_of(elem, int):
|
||||
raise TypeError(
|
||||
"Nested lists of tokens must contain only integers."
|
||||
)
|
||||
|
||||
prompt = cast(list[list[int]], prompt)
|
||||
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||
|
||||
raise TypeError(
|
||||
"prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays"
|
||||
)
|
||||
|
||||
|
||||
class ParsedStrPrompt(TypedDict):
|
||||
type: Literal["str"]
|
||||
content: str
|
||||
@@ -145,3 +100,10 @@ def get_prompt_components(prompt: PromptType) -> PromptComponents:
|
||||
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
embeds=prompt.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
|
||||
return length_from_prompt_token_ids_or_embeds(
|
||||
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
prompt.get("prompt_embeds"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@@ -209,6 +209,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo
|
||||
item_processor_data = dict(**mm_data, audios=audios)
|
||||
|
||||
# some tokenizer kwargs are incompatible with UltravoxProcessor
|
||||
tok_kwargs.pop("add_special_tokens", None)
|
||||
tok_kwargs.pop("padding", None)
|
||||
tok_kwargs.pop("truncation", None)
|
||||
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .params import ChatParams, TokenizeParams, merge_kwargs
|
||||
from .protocol import RendererLike
|
||||
from .registry import RendererRegistry, renderer_from_config
|
||||
|
||||
__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"]
|
||||
__all__ = [
|
||||
"RendererLike",
|
||||
"RendererRegistry",
|
||||
"renderer_from_config",
|
||||
"ChatParams",
|
||||
"TokenizeParams",
|
||||
"merge_kwargs",
|
||||
]
|
||||
|
||||
@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -61,8 +62,8 @@ class DeepseekV32Renderer(RendererLike):
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
@@ -74,26 +75,22 @@ class DeepseekV32Renderer(RendererLike):
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
@@ -105,17 +102,13 @@ class DeepseekV32Renderer(RendererLike):
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
44
vllm/renderers/embed_utils.py
Normal file
44
vllm/renderers/embed_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
def safe_load_prompt_embeds(
|
||||
model_config: "ModelConfig",
|
||||
embed: bytes,
|
||||
) -> torch.Tensor:
|
||||
if not model_config.enable_prompt_embeds:
|
||||
raise VLLMValidationError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
|
||||
parameter="prompt_embeds",
|
||||
)
|
||||
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(
|
||||
BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
|
||||
return tensor
|
||||
@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -61,8 +62,8 @@ class Grok2Renderer(RendererLike):
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
@@ -74,26 +75,22 @@ class Grok2Renderer(RendererLike):
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
@@ -105,17 +102,13 @@ class Grok2Renderer(RendererLike):
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
|
||||
@@ -33,6 +33,7 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -632,9 +633,8 @@ class HfRenderer(RendererLike):
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
model_config = self.config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -642,9 +642,9 @@ class HfRenderer(RendererLike):
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolve_chat_template_content_format(
|
||||
chat_template=kwargs.get("chat_template"),
|
||||
tools=kwargs.get("tools"),
|
||||
given_format=chat_template_content_format,
|
||||
chat_template=params.chat_template,
|
||||
tools=params.chat_template_kwargs.get("tools"),
|
||||
given_format=params.chat_template_content_format,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
@@ -654,7 +654,7 @@ class HfRenderer(RendererLike):
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
**kwargs,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
|
||||
@@ -666,7 +666,7 @@ class HfRenderer(RendererLike):
|
||||
):
|
||||
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
|
||||
|
||||
# get video placehoder, replace it with runtime video-chunk prompts
|
||||
# get video placeholder, replace it with runtime video-chunk prompts
|
||||
video_placeholder = getattr(
|
||||
model_config.hf_config, "video_placeholder", None
|
||||
)
|
||||
@@ -676,24 +676,19 @@ class HfRenderer(RendererLike):
|
||||
video_placeholder,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
model_config = self.config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -701,9 +696,9 @@ class HfRenderer(RendererLike):
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolve_chat_template_content_format(
|
||||
chat_template=kwargs.get("chat_template"),
|
||||
tools=kwargs.get("tools"),
|
||||
given_format=chat_template_content_format,
|
||||
chat_template=params.chat_template,
|
||||
tools=params.chat_template_kwargs.get("tools"),
|
||||
given_format=params.chat_template_content_format,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
@@ -713,7 +708,7 @@ class HfRenderer(RendererLike):
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
**kwargs,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
|
||||
@@ -723,9 +718,7 @@ class HfRenderer(RendererLike):
|
||||
and mm_uuids is not None
|
||||
and mm_data is not None
|
||||
):
|
||||
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
|
||||
|
||||
# get video placehoder, replace it with runtime video-chunk prompts
|
||||
# get video placeholder, replace it with runtime video-chunk prompts
|
||||
video_placeholder = getattr(
|
||||
model_config.hf_config, "video_placeholder", None
|
||||
)
|
||||
@@ -735,14 +728,10 @@ class HfRenderer(RendererLike):
|
||||
video_placeholder,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -95,8 +96,8 @@ class MistralRenderer(RendererLike):
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
@@ -104,25 +105,25 @@ class MistralRenderer(RendererLike):
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
tokenizer,
|
||||
messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
@@ -131,17 +132,15 @@ class MistralRenderer(RendererLike):
|
||||
)
|
||||
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
tokenizer, messages, **kwargs
|
||||
tokenizer,
|
||||
messages,
|
||||
**params.get_apply_chat_template_kwargs(),
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
prompt = self.render_completion(prompt_raw)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
return conversation, prompt
|
||||
|
||||
351
vllm/renderers/params.py
Normal file
351
vllm/renderers/params.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_S = TypeVar("_S", list[int], "torch.Tensor")
|
||||
|
||||
|
||||
def merge_kwargs(
|
||||
defaults: dict[str, Any] | None,
|
||||
overrides: dict[str, Any] | None,
|
||||
/,
|
||||
*,
|
||||
unset_values: tuple[object, ...] = (None, "auto"),
|
||||
) -> dict[str, Any]:
|
||||
if defaults is None:
|
||||
defaults = {}
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
||||
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatParams:
|
||||
"""Configuration to control how to parse chat messages."""
|
||||
|
||||
chat_template: str | None = None
|
||||
"""The chat template to apply."""
|
||||
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format of the chat template."""
|
||||
|
||||
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""The kwargs to pass to the chat template."""
|
||||
|
||||
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
|
||||
if not default_chat_template_kwargs:
|
||||
return self
|
||||
|
||||
return ChatParams(
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
chat_template_kwargs=merge_kwargs(
|
||||
default_chat_template_kwargs,
|
||||
self.chat_template_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
|
||||
"""The arguments to pass to `tokenizer.apply_chat_template`."""
|
||||
return merge_kwargs(
|
||||
self.chat_template_kwargs,
|
||||
dict(chat_template=self.chat_template),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenizeParams:
|
||||
"""Configuration to control how prompts are tokenized."""
|
||||
|
||||
max_total_tokens: int | None
|
||||
"""
|
||||
Maximum allowed number of input + output tokens.
|
||||
|
||||
Usually, this refers to the model's context length.
|
||||
"""
|
||||
|
||||
max_output_tokens: int = 0
|
||||
"""Maximum requested number of output tokens."""
|
||||
|
||||
pad_prompt_tokens: int | None = None
|
||||
"""
|
||||
Number of tokens to pad to:
|
||||
- `None` means no padding.
|
||||
- `-1` maps to `max_input_tokens`.
|
||||
"""
|
||||
|
||||
truncate_prompt_tokens: int | None = None
|
||||
"""
|
||||
Number of tokens to keep:
|
||||
- `None` means no truncation.
|
||||
- `-1` maps to `max_input_tokens`.
|
||||
"""
|
||||
|
||||
do_lower_case: bool = False
|
||||
"""Whether to normalize text to lower case before tokenization."""
|
||||
|
||||
add_special_tokens: bool = True
|
||||
"""Whether to add special tokens."""
|
||||
|
||||
needs_detokenization: bool = False
|
||||
"""
|
||||
Whether the tokenized prompt needs to contain the original text.
|
||||
|
||||
Not to be confused with `SamplingParams.detokenize` which deals
|
||||
with the output generated by the model.
|
||||
"""
|
||||
|
||||
max_total_tokens_param: str = "max_total_tokens"
|
||||
"""Override this to edit the message for validation errors."""
|
||||
|
||||
max_output_tokens_param: str = "max_output_tokens"
|
||||
"""Override this to edit the message for validation errors."""
|
||||
|
||||
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
|
||||
"""Override this to edit the message for validation errors."""
|
||||
|
||||
@property
|
||||
def max_input_tokens(self) -> int | None:
|
||||
"""Maximum allowed number of input tokens."""
|
||||
if self.max_total_tokens is None:
|
||||
return None
|
||||
|
||||
return self.max_total_tokens - self.max_output_tokens
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
max_total_tokens = self.max_total_tokens
|
||||
max_output_tokens = self.max_output_tokens
|
||||
max_input_tokens = self.max_input_tokens
|
||||
truncate_prompt_tokens = self.truncate_prompt_tokens
|
||||
|
||||
if (
|
||||
max_output_tokens is not None
|
||||
and max_total_tokens is not None
|
||||
and max_output_tokens > max_total_tokens
|
||||
):
|
||||
raise VLLMValidationError(
|
||||
f"{self.max_output_tokens_param}={max_output_tokens}"
|
||||
f"cannot be greater than "
|
||||
f"{self.max_total_tokens_param}={max_total_tokens=}. "
|
||||
f"Please request fewer output tokens.",
|
||||
parameter=self.max_output_tokens_param,
|
||||
value=max_output_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
max_input_tokens is not None
|
||||
and truncate_prompt_tokens is not None
|
||||
and truncate_prompt_tokens > max_input_tokens
|
||||
):
|
||||
raise VLLMValidationError(
|
||||
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
|
||||
f"cannot be greater than {self.max_total_tokens_param} - "
|
||||
f"{self.max_output_tokens_param} = {max_input_tokens}. "
|
||||
f"Please request a smaller truncation size.",
|
||||
parameter=self.truncate_prompt_tokens_param,
|
||||
value=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def with_kwargs(self, tokenization_kwargs: dict[str, Any] | None):
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
|
||||
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
|
||||
pad_prompt_tokens = tokenization_kwargs.pop(
|
||||
"pad_prompt_tokens", self.pad_prompt_tokens
|
||||
)
|
||||
truncate_prompt_tokens = tokenization_kwargs.pop(
|
||||
"truncate_prompt_tokens", self.truncate_prompt_tokens
|
||||
)
|
||||
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
|
||||
add_special_tokens = tokenization_kwargs.pop(
|
||||
"add_special_tokens", self.add_special_tokens
|
||||
)
|
||||
needs_detokenization = tokenization_kwargs.pop(
|
||||
"needs_detokenization", self.needs_detokenization
|
||||
)
|
||||
|
||||
# https://huggingface.co/docs/transformers/en/pad_truncation
|
||||
if padding := tokenization_kwargs.pop("padding", None):
|
||||
if padding == "max_length":
|
||||
pad_prompt_tokens = max_length
|
||||
elif padding in (False, "do_not_pad"):
|
||||
pad_prompt_tokens = None
|
||||
else:
|
||||
# To emit the below warning
|
||||
tokenization_kwargs["padding"] = padding
|
||||
|
||||
if truncation := tokenization_kwargs.pop("truncation", None):
|
||||
if truncation in (True, "longest_first"):
|
||||
truncate_prompt_tokens = max_length
|
||||
elif truncation in (False, "do_not_truncate"):
|
||||
truncate_prompt_tokens = None
|
||||
else:
|
||||
# To emit the below warning
|
||||
tokenization_kwargs["truncation"] = truncation
|
||||
|
||||
if tokenization_kwargs:
|
||||
logger.warning(
|
||||
"The following tokenization arguments are not supported "
|
||||
"by vLLM Renderer and will be ignored: %s",
|
||||
tokenization_kwargs,
|
||||
)
|
||||
|
||||
max_total_tokens = self.max_total_tokens
|
||||
|
||||
return TokenizeParams(
|
||||
max_total_tokens=max_total_tokens,
|
||||
max_output_tokens=(
|
||||
0
|
||||
if max_total_tokens is None or max_length is None
|
||||
else max_total_tokens - max_length
|
||||
),
|
||||
pad_prompt_tokens=pad_prompt_tokens,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
do_lower_case=do_lower_case,
|
||||
add_special_tokens=add_special_tokens,
|
||||
needs_detokenization=needs_detokenization,
|
||||
)
|
||||
|
||||
def get_encode_kwargs(self) -> dict[str, Any]:
|
||||
"""The arguments to pass to `tokenizer.encode`."""
|
||||
max_length = self.truncate_prompt_tokens
|
||||
if max_length is not None and max_length < 0:
|
||||
max_length = self.max_input_tokens
|
||||
|
||||
return dict(
|
||||
truncation=self.truncate_prompt_tokens is not None,
|
||||
max_length=max_length,
|
||||
add_special_tokens=self.add_special_tokens,
|
||||
)
|
||||
|
||||
def _apply_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
if self.do_lower_case:
|
||||
text = text.lower()
|
||||
|
||||
return text
|
||||
|
||||
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
||||
"""Apply all validators to prompt text."""
|
||||
# TODO: Implement https://github.com/vllm-project/vllm/pull/31366
|
||||
for validator in (self._apply_lowercase,):
|
||||
text = validator(tokenizer, text)
|
||||
|
||||
return text
|
||||
|
||||
def apply_pre_tokenization(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
prompt: TextPrompt,
|
||||
) -> TextPrompt:
|
||||
"""
|
||||
Ensure that the prompt meets the requirements set out by this config.
|
||||
If that is not possible, raise a `VLLMValidationError`.
|
||||
|
||||
This method is run before tokenization occurs.
|
||||
"""
|
||||
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
|
||||
|
||||
return prompt
|
||||
|
||||
def _apply_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply padding to a token sequence."""
|
||||
pad_length = self.pad_prompt_tokens
|
||||
if pad_length is not None and pad_length < 0:
|
||||
pad_length = self.max_input_tokens
|
||||
|
||||
if pad_length is None or pad_length <= len(tokens):
|
||||
return tokens
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
|
||||
if not isinstance(tokens, list):
|
||||
raise ValueError("Cannot pad tokens for embedding inputs")
|
||||
|
||||
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
|
||||
|
||||
def _apply_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply truncation to a token sequence."""
|
||||
max_length = self.truncate_prompt_tokens
|
||||
if max_length is not None and max_length < 0:
|
||||
max_length = self.max_input_tokens
|
||||
|
||||
if max_length is None or max_length >= len(tokens):
|
||||
return tokens
|
||||
if max_length == 0:
|
||||
return tokens[:0]
|
||||
|
||||
if getattr(tokenizer, "truncation_side", "left") == "left":
|
||||
return tokens[-max_length:]
|
||||
|
||||
return tokens[:max_length]
|
||||
|
||||
def _apply_length_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply length checks to a token sequence."""
|
||||
max_input_tokens = self.max_input_tokens
|
||||
|
||||
if max_input_tokens is not None and len(tokens) > max_input_tokens:
|
||||
raise VLLMValidationError(
|
||||
f"You passed {len(tokens)} input tokens and "
|
||||
f"requested {self.max_output_tokens} output tokens. "
|
||||
f"However, the model's context length is only "
|
||||
f"{self.max_total_tokens}, resulting in a maximum "
|
||||
f"input length of {max_input_tokens}. "
|
||||
f"Please reduce the length of the input messages.",
|
||||
parameter="input_tokens",
|
||||
value=len(tokens),
|
||||
)
|
||||
|
||||
return tokens
|
||||
|
||||
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
||||
"""Apply all validators to a token sequence."""
|
||||
for validator in (
|
||||
self._apply_padding,
|
||||
self._apply_truncation,
|
||||
self._apply_length_check,
|
||||
):
|
||||
tokens = validator(tokenizer, tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def apply_post_tokenization(
|
||||
self,
|
||||
tokenizer: TokenizerLike | None,
|
||||
prompt: TokensPrompt | EmbedsPrompt,
|
||||
) -> TokensPrompt | EmbedsPrompt:
|
||||
"""
|
||||
Ensure that the prompt meets the requirements set out by this config.
|
||||
If that is not possible, raise a `VLLMValidationError`.
|
||||
|
||||
This method is run after tokenization occurs.
|
||||
"""
|
||||
if "prompt_token_ids" in prompt:
|
||||
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
|
||||
tokenizer,
|
||||
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
|
||||
)
|
||||
if "prompt_embeds" in prompt:
|
||||
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
|
||||
tokenizer,
|
||||
prompt["prompt_embeds"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return prompt
|
||||
@@ -1,9 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
from .embed_utils import safe_load_prompt_embeds
|
||||
from .params import ChatParams, TokenizeParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@@ -14,6 +20,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class RendererLike(Protocol):
|
||||
config: "ModelConfig"
|
||||
_async_tokenizer: AsyncMicrobatchTokenizer
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
@@ -33,16 +42,147 @@ class RendererLike(Protocol):
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
# Lazy initialization since offline LLM doesn't use async
|
||||
if not hasattr(self, "_async_tokenizer"):
|
||||
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
||||
|
||||
return self._async_tokenizer
|
||||
|
||||
# Step 1: Convert raw inputs to prompts
|
||||
def render_completion(
|
||||
self,
|
||||
prompt_raw: str | list[int] | bytes,
|
||||
) -> TextPrompt | TokensPrompt | EmbedsPrompt:
|
||||
error_msg = "Each prompt must be a string or an array of tokens"
|
||||
|
||||
if isinstance(prompt_raw, str):
|
||||
return TextPrompt(prompt=prompt_raw)
|
||||
|
||||
if isinstance(prompt_raw, list):
|
||||
if not is_list_of(prompt_raw, int):
|
||||
raise TypeError(error_msg)
|
||||
|
||||
return TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
|
||||
if isinstance(prompt_raw, bytes):
|
||||
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
|
||||
return EmbedsPrompt(prompt_embeds=embeds)
|
||||
|
||||
raise TypeError(error_msg)
|
||||
|
||||
def render_completions(
|
||||
self,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
|
||||
prompt_embeds: bytes | list[bytes] | None = None,
|
||||
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
prompts_raw = list[str | list[int] | bytes]()
|
||||
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
if isinstance(prompt_embeds, bytes):
|
||||
prompts_raw.append(prompt_embeds)
|
||||
else:
|
||||
prompts_raw.extend(prompt_embeds)
|
||||
|
||||
if prompt_input is not None:
|
||||
if isinstance(prompt_input, str) or (
|
||||
len(prompt_input) > 0 and is_list_of(prompt_input, int)
|
||||
):
|
||||
prompts_raw.append(prompt_input) # type: ignore[arg-type]
|
||||
else:
|
||||
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
|
||||
|
||||
if len(prompts_raw) == 0:
|
||||
raise ValueError("You must pass at least one prompt")
|
||||
|
||||
return [self.render_completion(prompt) for prompt in prompts_raw]
|
||||
|
||||
async def render_completions_async(
|
||||
self,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
|
||||
prompt_embeds: bytes | list[bytes] | None = None,
|
||||
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
return self.render_completions(prompt_input, prompt_embeds)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
**kwargs,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
**kwargs,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
|
||||
return self.render_messages(messages, **kwargs)
|
||||
params: ChatParams,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
return self.render_messages(messages, params)
|
||||
|
||||
# Step 2: Tokenize prompts if necessary
|
||||
def tokenize_prompt(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt | EmbedsPrompt:
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
def tokenize_prompts(
|
||||
self,
|
||||
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
|
||||
params: TokenizeParams,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
|
||||
|
||||
async def tokenize_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt | EmbedsPrompt:
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
|
||||
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt_token_ids = await tokenizer.encode(
|
||||
prompt["prompt"],
|
||||
**params.get_encode_kwargs(),
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
tokenizer = self.get_async_tokenizer()
|
||||
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
async def tokenize_prompts_async(
|
||||
self,
|
||||
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
|
||||
params: TokenizeParams,
|
||||
) -> list[TokensPrompt | EmbedsPrompt]:
|
||||
return await asyncio.gather(
|
||||
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
|
||||
)
|
||||
|
||||
@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -45,8 +46,8 @@ class TerratorchRenderer(RendererLike):
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
model_config = self.config
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
@@ -55,7 +56,7 @@ class TerratorchRenderer(RendererLike):
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=[1])
|
||||
prompt = self.render_completion([1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
@@ -66,8 +67,8 @@ class TerratorchRenderer(RendererLike):
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
params: ChatParams,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
model_config = self.config
|
||||
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
@@ -76,7 +77,7 @@ class TerratorchRenderer(RendererLike):
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs
|
||||
prompt = self.render_completion([1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
|
||||
@@ -834,7 +834,7 @@ def parse_pooling_type(pooling_name: str):
|
||||
@cache
|
||||
def get_sentence_transformer_tokenizer_config(
|
||||
model: str | Path, revision: str | None = "main"
|
||||
):
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Returns the tokenization configuration dictionary for a
|
||||
given Sentence Transformer BERT model.
|
||||
|
||||
@@ -50,14 +50,17 @@ class AsyncMicrobatchTokenizer:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
async def __call__(self, prompt, **kwargs) -> BatchEncoding:
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
async def encode(self, prompt, **kwargs) -> list[int]:
|
||||
return (await self(prompt, **kwargs)).input_ids
|
||||
|
||||
async def decode(self, token_ids, **kwargs) -> str:
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
|
||||
@@ -16,7 +16,6 @@ from vllm import TokensPrompt
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs.data import StreamingInput
|
||||
from vllm.logger import init_logger
|
||||
@@ -25,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.renderers import RendererLike, merge_kwargs
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -304,13 +303,20 @@ class AsyncLLM(EngineClient):
|
||||
"prompt logprobs"
|
||||
)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = {}
|
||||
_validate_truncation_size(
|
||||
self.model_config.max_model_len,
|
||||
params.truncate_prompt_tokens,
|
||||
tokenization_kwargs,
|
||||
)
|
||||
if params.truncate_prompt_tokens is not None:
|
||||
params_type = type(params).__name__
|
||||
warnings.warn(
|
||||
f"The `truncate_prompt_tokens` parameter in `{params_type}` "
|
||||
"is deprecated and will be removed in v0.16. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
if isinstance(prompt, AsyncGenerator):
|
||||
# Streaming input case.
|
||||
@@ -344,12 +350,12 @@ class AsyncLLM(EngineClient):
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
tokenization_kwargs,
|
||||
trace_headers,
|
||||
priority,
|
||||
data_parallel_rank,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
prompt_text = get_prompt_text(prompt)
|
||||
|
||||
@@ -757,7 +763,6 @@ class AsyncLLM(EngineClient):
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""
|
||||
@@ -772,22 +777,10 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
The caller of generate() iterates the returned AsyncGenerator,
|
||||
returning the RequestOutput back to the caller.
|
||||
|
||||
NOTE: truncate_prompt_tokens is deprecated in v0.14.
|
||||
TODO: Remove truncate_prompt_tokens in v0.15.
|
||||
"""
|
||||
|
||||
q: RequestOutputCollector | None = None
|
||||
try:
|
||||
if truncate_prompt_tokens is not None:
|
||||
warnings.warn(
|
||||
"The `truncate_prompt_tokens` parameter in `AsyncLLM.encode()` "
|
||||
"is deprecated and will be removed in v0.15. "
|
||||
"Please use `pooling_params.truncate_prompt_tokens` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
q = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
|
||||
Reference in New Issue
Block a user