[Refactor] Consolidate sequence normalization and enc-dec parsing (#33928)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-06 23:43:47 +08:00
committed by GitHub
parent 4707f7ebb4
commit cd8b405bd0
38 changed files with 1271 additions and 863 deletions

View File

@@ -54,6 +54,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}

View File

@@ -53,6 +53,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}

View File

@@ -52,6 +52,7 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}

View File

@@ -529,6 +529,7 @@ class MockModelConfig:
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}

View File

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.renderers.inputs.preprocess import prompt_to_seq
def test_empty_input():
assert prompt_to_seq([]) == []
assert prompt_to_seq([[]]) == [[]]
assert prompt_to_seq([[], []]) == [[], []]
def test_text_input():
assert prompt_to_seq("foo") == ["foo"]
assert prompt_to_seq(["foo"]) == ["foo"]
assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"]
def test_token_input():
assert prompt_to_seq([1, 2]) == [[1, 2]]
assert prompt_to_seq([[1, 2]]) == [[1, 2]]
assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]]
def test_text_token_input():
assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"]
assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]]
def test_bytes_input():
assert prompt_to_seq(b"foo") == [b"foo"]
assert prompt_to_seq([b"foo"]) == [b"foo"]
assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"]
def test_dict_input():
assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}]
assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [
{"prompt": "foo"},
{"prompt_token_ids": [1, 2]},
]

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
@@ -9,8 +10,11 @@ import pybase64
import pytest
import torch
from vllm.config import ModelConfig
from vllm.inputs import SingletonPrompt
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import HfRenderer
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tokenizers.registry import tokenizer_args_from_config
MODEL_NAME = "openai-community/gpt2"
@@ -33,6 +37,7 @@ class MockModelConfig:
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@dataclass
@@ -80,65 +85,34 @@ def _build_renderer(
return renderer
def _preprocess_prompt(
mdoel_config: ModelConfig,
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
return [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt)
)
for prompt in prompt_to_seq(prompt_or_prompts)
]
class TestValidatePrompt:
STRING_INPUTS = [
"",
"foo",
"foo bar",
"foo baz bar",
"foo bar qux baz",
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
# Test that a nested mixed-type list of lists raises a TypeError.
def test_empty_input(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(ValueError, match="at least one prompt"):
renderer.render_completions([])
renderer.render_prompts(_preprocess_prompt(renderer.config, []))
def test_invalid_type(self):
renderer = _build_renderer(MockModelConfig())
with pytest.raises(TypeError, match="string or an array of tokens"):
renderer.render_completions([[1, 2], ["foo", "bar"]])
@pytest.mark.parametrize("string_input", STRING_INPUTS)
def test_string_consistent(self, string_input: str):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(string_input) == renderer.render_completions(
[string_input]
)
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
def test_token_consistent(self, token_input: list[int]):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(token_input) == renderer.render_completions(
[token_input]
)
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
def test_string_slice(self, inputs_slice: slice):
renderer = _build_renderer(MockModelConfig())
assert renderer.render_completions(self.STRING_INPUTS)[
inputs_slice
] == renderer.render_completions(self.STRING_INPUTS[inputs_slice])
with pytest.raises(TypeError, match="should be a list of integers"):
renderer.render_prompts(
_preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type]
)
class TestRenderPrompt:
@@ -146,7 +120,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
tokens = [101, 7592, 2088]
prompts = renderer.render_completions(tokens)
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
@@ -159,7 +133,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
prompts = renderer.render_completions(token_lists)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, token_lists)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
@@ -174,7 +150,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
text_input = "x" * 10
prompts = renderer.render_completions(text_input)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_input)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
@@ -187,7 +165,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
text_list_input = ["x" * 10, "x" * 12, "x" * 14]
prompts = renderer.render_completions(text_list_input)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, text_list_input)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
@@ -200,7 +180,9 @@ class TestRenderPrompt:
def test_zero_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0),
@@ -212,7 +194,9 @@ class TestRenderPrompt:
def test_pos_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50),
@@ -224,7 +208,9 @@ class TestRenderPrompt:
def test_neg_truncation(self):
renderer = _build_renderer(MockModelConfig())
prompts = renderer.render_completions("x" * 200)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "x" * 200)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1),
@@ -237,7 +223,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="left")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
@@ -251,7 +239,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig(), truncation_side="right")
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5),
@@ -266,7 +256,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
@@ -285,7 +277,9 @@ class TestRenderPrompt:
# Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN
long_tokens = "x" * 150
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
@@ -304,7 +298,9 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
long_tokens = list(range(150)) # Exceeds max_total_tokens=100
prompts = renderer.render_completions(long_tokens)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, long_tokens)
)
with pytest.raises(
ValueError,
@@ -318,7 +314,9 @@ class TestRenderPrompt:
def test_no_tokenizer_for_text(self):
renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True))
prompts = renderer.render_completions("Hello world")
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, "Hello world")
)
with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"):
renderer.tokenize_prompts(
@@ -330,7 +328,7 @@ class TestRenderPrompt:
renderer = _build_renderer(MockModelConfig())
tokens = [1, 2, 3, 4]
prompts = renderer.render_completions(tokens)
prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens))
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(
@@ -359,7 +357,9 @@ class TestRenderEmbedPrompt:
tensor_input = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(tensor_input)
prompts = renderer.render_completions(prompt_embeds=embed_bytes)
prompts = renderer.render_prompts(
_preprocess_prompt(renderer.config, embed_bytes)
)
results = renderer.tokenize_prompts(
prompts,
TokenizeParams(max_total_tokens=100),
@@ -377,8 +377,11 @@ class TestRenderEmbedPrompt:
torch.randn(12, 512, dtype=torch.float32),
]
prompts = renderer.render_completions(
prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs],
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config,
[self._create_test_embed_bytes(t) for t in tensor_inputs],
)
)
results = renderer.tokenize_prompts(
prompts,
@@ -395,8 +398,10 @@ class TestRenderEmbedPrompt:
# Create tensor with more tokens than truncation limit
tensor_input = torch.randn(20, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
@@ -420,8 +425,10 @@ class TestRenderEmbedPrompt:
for dtype in dtypes:
tensor_input = torch.randn(5, 256, dtype=dtype)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
@@ -437,8 +444,10 @@ class TestRenderEmbedPrompt:
# Test tensor with batch dimension gets squeezed
tensor_input = torch.randn(1, 10, 768, dtype=torch.float32)
prompts = renderer.render_completions(
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config, self._create_test_embed_bytes(tensor_input)
)
)
results = renderer.tokenize_prompts(
prompts,
@@ -455,9 +464,11 @@ class TestRenderEmbedPrompt:
text_input = "Hello world"
tensor_input = torch.randn(5, 256, dtype=torch.float32)
prompts = renderer.render_completions(
text_input,
prompt_embeds=self._create_test_embed_bytes(tensor_input),
prompts = renderer.render_prompts(
_preprocess_prompt(
renderer.config,
[text_input, self._create_test_embed_bytes(tensor_input)],
)
)
results = renderer.tokenize_prompts(
prompts,
@@ -465,8 +476,8 @@ class TestRenderEmbedPrompt:
)
assert len(results) == 2
# First should be embed prompt
assert torch.equal(results[0]["prompt_embeds"], tensor_input)
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert len(results[1]["prompt_token_ids"]) == len(text_input)
# First should be tokens prompt
assert "prompt_token_ids" in results[0]
assert len(results[0]["prompt_token_ids"]) == len(text_input)
# Second should be embed prompt
assert torch.equal(results[1]["prompt_embeds"], tensor_input)

View File

@@ -3,16 +3,40 @@
import asyncio
import time
from dataclasses import dataclass
from typing import Any
from unittest.mock import Mock
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.config import ModelConfig
from vllm.renderers import ChatParams
from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template
from vllm.tokenizers.mistral import MistralTokenizer
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
runner_type = "generate"
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
max_model_len: int = 100
tokenizer_revision = None
tokenizer_mode = "mistral"
hf_config = MockHFConfig()
encoder_config: dict[str, Any] | None = None
enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
@pytest.mark.asyncio
async def test_async_mistral_tokenizer_does_not_block_event_loop():
@@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
time.sleep(2)
return expected_tokens
mock_model_config = MockModelConfig(skip_tokenizer_init=True)
mock_tokenizer = Mock(spec=MistralTokenizer)
mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={})
mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={})
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams())

View File

@@ -4,52 +4,13 @@
import pytest
from vllm.config import ModelConfig
from vllm.inputs import zip_enc_dec_prompts
from vllm.inputs.preprocess import InputPreprocessor
pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize(
"mm_processor_kwargs,expected_mm_kwargs",
[
(None, [{}, {}]),
({}, [{}, {}]),
({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
],
)
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
"""Test mm_processor_kwargs init for zipping enc/dec prompts."""
encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
zipped_prompts = zip_enc_dec_prompts(
encoder_prompts, decoder_prompts, mm_processor_kwargs
)
assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
for enc, dec, exp_kwargs, zipped in zip(
encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
):
assert isinstance(zipped, dict)
assert len(zipped.keys()) == 3
assert zipped["encoder_prompt"] == enc
assert zipped["decoder_prompt"] == dec
assert zipped["mm_processor_kwargs"] == exp_kwargs
@pytest.mark.parametrize(
"model_id",
[
"facebook/chameleon-7b",
],
)
@pytest.mark.parametrize(
"prompt",
[
"",
{"prompt_token_ids": []},
],
)
@pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"])
@pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}])
@pytest.mark.skip(
reason=(
"Applying huggingface processor on text inputs results in "

View File

@@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest
@@ -53,7 +54,11 @@ class EngineClient(ABC):
@abstractmethod
def generate(
self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
@@ -70,7 +75,7 @@ class EngineClient(ABC):
@abstractmethod
def encode(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,

View File

@@ -4,7 +4,7 @@
import itertools
import warnings
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from typing import TYPE_CHECKING, Any, cast
import cloudpickle
import torch.nn as nn
@@ -53,16 +53,13 @@ from vllm.entrypoints.pooling.score.utils import (
validate_score_input,
)
from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs import (
from vllm.inputs.data import (
DataPrompt,
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -76,6 +73,13 @@ from vllm.outputs import (
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, SingletonDictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import (
conversation_to_seq,
extract_prompt_components,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import TokenizerLike
@@ -93,9 +97,6 @@ logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
@@ -445,21 +446,20 @@ class LLM:
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests(
outputs = self._run_completion(
prompts=prompts,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs(
self,
prompts: PromptType | Sequence[PromptType],
prompts: Sequence[DictPrompt | TokPrompt],
lora_request: list[LoRARequest] | LoRARequest | None,
):
# Grab the lora config off the vllm config on the engine,
@@ -475,9 +475,6 @@ class LLM:
):
return lora_request
if not isinstance(prompts, Sequence) or isinstance(prompts, str):
prompts = [prompts]
optional_loras = (
[lora_request] * len(prompts)
if not isinstance(lora_request, Sequence)
@@ -495,14 +492,12 @@ class LLM:
def _resolve_single_prompt_mm_lora(
self,
prompt: PromptType,
prompt: DictPrompt | TokPrompt,
lora_request: LoRARequest | None,
default_mm_loras: dict[str, str] | None,
):
if (
not default_mm_loras
or not isinstance(prompt, dict)
or not (mm_data := prompt.get("multi_modal_data") or {})
if not default_mm_loras or not (
mm_data := prompt.get("multi_modal_data") or {}
):
return lora_request
@@ -806,61 +801,11 @@ class LLM:
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _normalize_prompts(
self,
prompts: PromptType | Sequence[PromptType],
) -> list[EnginePrompt | EngineEncDecPrompt]:
if isinstance(prompts, str):
prompts = TextPrompt(prompt=prompts)
return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value]
def _preprocess_cmpl_singleton(
self,
prompt: SingletonPrompt,
tok_params: TokenizeParams,
*,
tokenize: bool,
) -> EnginePrompt:
renderer = self.llm_engine.renderer
if not isinstance(prompt, dict):
prompt = renderer.render_completion(prompt)
return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt
def _preprocess_cmpl_enc_dec(
self,
prompt: ExplicitEncoderDecoderPrompt,
tok_params: TokenizeParams,
) -> EngineEncDecPrompt:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
return EngineEncDecPrompt(
encoder_prompt=self._preprocess_cmpl_singleton(
enc_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
),
decoder_prompt=(
None
if dec_prompt is None
else self._preprocess_cmpl_singleton(
dec_prompt,
tok_params,
# TODO: Move multi-modal processor into tokenization
tokenize=not self.model_config.is_multimodal_model,
)
),
)
def _preprocess_completion(
self,
prompts: PromptType | Sequence[PromptType],
prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt | EngineEncDecPrompt]:
) -> list[DictPrompt | TokPrompt]:
"""
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`.
@@ -871,32 +816,26 @@ class LLM:
A list of `TokensPrompts` objects containing the tokenized prompt
after chat template interpolation, and the raw multi-modal inputs.
"""
renderer = self.llm_engine.renderer
model_config = self.model_config
tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
for prompt in self._normalize_prompts(prompts):
if is_explicit_encoder_decoder_prompt(prompt):
engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
else:
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
self._preprocess_cmpl_singleton(
prompt,
tok_params,
tokenize=not self.model_config.is_multimodal_model,
)
)
engine_prompts = list[DictPrompt | TokPrompt]()
for prompt in prompts:
parsed_prompt = parse_model_prompt(model_config, prompt)
in_prompt = renderer.render_prompt(parsed_prompt)
# Some MM models have non-default `add_special_tokens`
# TODO: Move multi-modal processor into tokenization
engine_prompts.append(
in_prompt
if model_config.is_multimodal_model
else renderer.tokenize_prompt(in_prompt, tok_params)
)
return engine_prompts
def _normalize_conversations(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
) -> list[list[ChatCompletionMessageParam]]:
return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value]
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
@@ -909,8 +848,7 @@ class LLM:
def _preprocess_chat(
self,
conversations: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
conversations: Sequence[list[ChatCompletionMessageParam]],
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
chat_template_kwargs: dict[str, Any] | None = None,
@@ -919,7 +857,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[EnginePrompt]:
) -> list[DictPrompt | TokPrompt]:
"""
Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs.
@@ -947,11 +885,14 @@ class LLM:
)
tok_params = self._get_chat_tok_params(tokenization_kwargs)
engine_prompts = list[EnginePrompt]()
for conversation in self._normalize_conversations(conversations):
engine_prompts = list[DictPrompt | TokPrompt]()
for conversation in conversations:
_, in_prompt = renderer.render_messages(conversation, chat_params)
if mm_processor_kwargs is not None:
in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt["mm_processor_kwargs"] = mm_processor_kwargs # type: ignore
engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
@@ -960,8 +901,8 @@ class LLM:
def chat(
self,
messages: list[ChatCompletionMessageParam]
| list[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | list[SamplingParams] | None = None,
| Sequence[list[ChatCompletionMessageParam]],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
chat_template: str | None = None,
@@ -984,7 +925,7 @@ class LLM:
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
messages: A sequence of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
@@ -1023,8 +964,23 @@ class LLM:
A list of `RequestOutput` objects containing the generated
responses in the same order as the input messages.
"""
prompts = self._preprocess_chat(
messages,
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "generate":
raise ValueError(
"LLM.chat() is only supported for generative models. "
"Try passing `--runner generate` to use the model as a "
"generative model."
)
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
outputs = self._run_chat(
messages=messages,
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
@@ -1035,13 +991,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs,
)
return self.generate(
prompts,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def encode(
self,
@@ -1163,7 +1113,7 @@ class LLM:
msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
raise ValueError(msg)
self._validate_and_add_requests(
outputs = self._run_completion(
prompts=prompts,
params=pooling_params,
use_tqdm=use_tqdm,
@@ -1171,8 +1121,6 @@ class LLM:
tokenization_kwargs=tokenization_kwargs,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput
)
@@ -1523,14 +1471,13 @@ class LLM:
prompts.append(engine_prompt)
self._validate_and_add_requests(
outputs = self._run_completion(
prompts=prompts,
params=pooling_params_list,
use_tqdm=use_tqdm,
lora_request=lora_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
@@ -1727,33 +1674,29 @@ class LLM:
"""
return self.llm_engine.get_metrics()
def _validate_and_add_requests(
def _params_to_seq(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| Sequence[SamplingParams]
| PoolingParams
| Sequence[PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
in_prompts = self._normalize_prompts(prompts)
num_requests = len(in_prompts)
| Sequence[SamplingParams | PoolingParams],
num_requests: int,
) -> Sequence[SamplingParams | PoolingParams]:
if isinstance(params, Sequence):
if len(params) != num_requests:
raise ValueError(
f"The lengths of prompts ({params}) "
f"and lora_request ({len(params)}) must be the same."
f"and params ({len(params)}) must be the same."
)
engine_params = params
else:
engine_params = [params] * num_requests
return params
return [params] * num_requests
def _lora_request_to_seq(
self,
lora_request: LoRARequest | None | Sequence[LoRARequest | None],
num_requests: int,
) -> Sequence[LoRARequest | None]:
if isinstance(lora_request, Sequence):
if len(lora_request) != num_requests:
raise ValueError(
@@ -1761,28 +1704,50 @@ class LLM:
f"and lora_request ({len(lora_request)}) must be the same."
)
engine_lora_requests: Sequence[LoRARequest | None] = lora_request
else:
engine_lora_requests = [lora_request] * num_requests
return lora_request
return [lora_request] * num_requests
def _priority_to_seq(
self,
priority: list[int] | None,
num_requests: int,
) -> Sequence[int]:
if priority is not None:
if len(priority) != num_requests:
raise ValueError(
f"The lengths of prompts ({num_requests}) "
f"and priority ({len(priority)}) must be the same."
)
else:
priority = [0] * num_requests
if any(param.truncate_prompt_tokens is not None for param in engine_params):
return priority
return [0] * num_requests
def _run_completion(
self,
prompts: PromptType | Sequence[PromptType],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
):
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
if any(param.truncate_prompt_tokens is not None for param in seq_params):
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
# Then, move the code from the `else` block to the top and let
# `self._preprocess_completion` handle prompt normalization
engine_prompts = [
engine_prompt
for in_prompt, param in zip(in_prompts, engine_params)
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
[in_prompt],
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
@@ -1791,17 +1756,90 @@ class LLM:
]
else:
engine_prompts = self._preprocess_completion(
in_prompts,
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)
for sp in engine_params:
self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
return self._run_engine(use_tqdm=use_tqdm)
def _run_chat(
self,
messages: list[ChatCompletionMessageParam]
| Sequence[list[ChatCompletionMessageParam]],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: LoRARequest | None = None,
chat_template: str | None = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
):
engine_prompts = self._preprocess_chat(
conversation_to_seq(messages),
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
chat_template_kwargs=chat_template_kwargs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
self._validate_and_add_requests(
prompts=engine_prompts,
params=params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
)
return self._run_engine(use_tqdm=use_tqdm)
def _validate_and_add_requests(
self,
prompts: Sequence[DictPrompt | TokPrompt],
params: SamplingParams
| PoolingParams
| Sequence[SamplingParams | PoolingParams],
*,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
seq_priority = self._priority_to_seq(priority, num_requests)
for sp in seq_params:
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = engine_prompts
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
@@ -1812,10 +1850,10 @@ class LLM:
for i, prompt in enumerate(it):
request_id = self._add_request(
prompt,
engine_params[i],
lora_request=engine_lora_requests[i],
seq_params[i],
lora_request=seq_lora_requests[i],
tokenization_kwargs=tokenization_kwargs,
priority=priority[i],
priority=seq_priority[i],
)
added_request_ids.append(request_id)
except Exception as e:
@@ -1825,13 +1863,13 @@ class LLM:
def _add_request(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: int = 0,
) -> str:
prompt_text, _, _ = get_prompt_components(prompt)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
request_id = str(next(self.request_counter))
if params.truncate_prompt_tokens is not None:

View File

@@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
)
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.parser import ParserManager
from vllm.reasoning import ReasoningParser
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import (
@@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> (
tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]
| ErrorResponse
):
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
"""
render chat request by validating and preprocessing inputs.
@@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
@@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing):
)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
self.max_model_len,
request,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams

View File

@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import merge_async_iterators
@@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse:
) -> list[TokPrompt] | ErrorResponse:
"""
render completion request by validating and preprocessing inputs.
@@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
self.max_model_len,
request,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)
sampling_params: SamplingParams | BeamSearchParams
@@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing):
# with the inputs token IDs
if final_res.prompt is None:
engine_prompt = engine_prompts[i]
final_res.prompt = engine_prompt.get("prompt")
final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
@@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
engine_prompts: list[TokensPrompt | EmbedsPrompt],
engine_prompts: list[TokPrompt],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
@@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:

View File

@@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
)
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt
from vllm.inputs.parse import (
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
@@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import (
SingletonDictPrompt,
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
@@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]):
request_id: str
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None
engine_prompts: list[TokPrompt] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None
@@ -247,7 +251,7 @@ class OpenAIServing:
async def beam_search(
self,
prompt: PromptType,
prompt: TokPrompt,
request_id: str,
params: BeamSearchParams,
lora_request: LoRARequest | None = None,
@@ -271,20 +275,12 @@ class OpenAIServing:
eos_token_id: int = tokenizer.eos_token_id # type: ignore
if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
raise NotImplementedError("Encoder-decoder prompt not supported")
prompt_text: str | None
prompt_token_ids: list[int]
multi_modal_data: MultiModalDataDict | None
if isinstance(prompt, str):
prompt_text = prompt
prompt_token_ids = []
multi_modal_data = None
else:
prompt_text = prompt.get("prompt") # type: ignore
prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data = prompt.get("multi_modal_data") # type: ignore
prompt_text: str | None = prompt.get("prompt") # type: ignore
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
mm_processor_kwargs: dict[str, Any] | None = None
@@ -963,22 +959,40 @@ class OpenAIServing:
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
renderer = self.renderer
tok_params = request.build_tok_params(self.model_config)
model_config = self.model_config
in_prompts = await renderer.render_completions_async(
prompt_input, prompt_embeds
)
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
tok_params = request.build_tok_params(model_config)
prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
parsed_prompts = [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompts
]
in_prompts = await renderer.render_prompts_async(parsed_prompts)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
for prompt in engine_prompts:
prompt.update(extra_items) # type: ignore
for in_prompt in in_prompts:
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
target_prompt.update(extra_items) # type: ignore
engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params)
return engine_prompts
@@ -991,7 +1005,7 @@ class OpenAIServing:
default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]:
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
from vllm.tokenizers.mistral import MistralTokenizer
renderer = self.renderer
@@ -1009,17 +1023,21 @@ class OpenAIServing:
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
conversation, prompt = await renderer.render_messages_async(
conversation, in_prompt = await renderer.render_messages_async(
messages, chat_params
)
engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params)
target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore
"encoder_prompt", in_prompt
)
extra_items = {
k: v
for k in ("mm_processor_kwargs", "cache_salt")
if (v := getattr(request, k, None)) is not None
}
engine_prompt.update(extra_items) # type: ignore
target_prompt.update(extra_items) # type: ignore
engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params)
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
@@ -1040,6 +1058,15 @@ class OpenAIServing:
return conversation, [engine_prompt]
def _extract_prompt_components(self, prompt: object):
return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: object):
return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: object):
return extract_prompt_len(self.model_config, prompt)
async def _render_next_turn(
self,
request: ResponsesRequest,
@@ -1067,7 +1094,7 @@ class OpenAIServing:
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
sampling_params: SamplingParams,
tok_params: TokenizeParams,
context: ConversationContext,
@@ -1075,7 +1102,7 @@ class OpenAIServing:
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
orig_priority = priority
sub_request = 0
@@ -1145,12 +1172,12 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text = engine_prompt.get("prompt")
prompt_text = self._extract_prompt_text(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
)
@@ -1161,20 +1188,20 @@ class OpenAIServing:
def _log_inputs(
self,
request_id: str,
inputs: PromptType,
inputs: PromptType | TokPrompt,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
components = self._extract_prompt_components(inputs)
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
components.text,
components.token_ids,
components.embeds,
params=params,
lora_request=lora_request,
)

View File

@@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import (
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.renderers.inputs import TokPrompt
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
@@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input(
self,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
prompt_len = get_prompt_len(engine_prompt)
prompt_len = self._extract_prompt_len(engine_prompt)
if self.max_model_len <= prompt_len:
error_message = (
f"The engine prompt length {prompt_len} "
@@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)

View File

@@ -7,7 +7,7 @@ import time
import zlib
from collections.abc import AsyncGenerator, Callable
from functools import cached_property
from typing import Literal, TypeAlias, TypeVar, cast
from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
@@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationStreamResponse,
)
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
@@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing):
)
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.task_type = task_type
self.task_type: Final = task_type
self.asr_config = self.model_cls.get_speech_to_text_config(
self.model_config, task_type
@@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing):
to_language=to_language,
)
if request.response_format == "verbose_json":
if not is_explicit_encoder_decoder_prompt(prompt):
raise VLLMValidationError(
"Expected prompt to be an encoder-decoder prompt",
parameter="prompt",
value=type(prompt).__name__,
)
prompt = self._preprocess_verbose_prompt(prompt)
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
prompts.append(prompt)
return prompts, duration
def _repl_verbose_text(self, text: str):
return text.replace("<|notimestamps|>", "<|0.00|>")
def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt):
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"]
if isinstance(dec_prompt, str):
prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt)
elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt:
dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"])
else:
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
raise VLLMValidationError(
"Expected decoder_prompt to contain text",
parameter="decoder_prompt",
value=type(dec_prompt).__name__,
)
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
"<|notimestamps|>", "<|0.00|>"
)
return prompt
def _get_verbose_segments(

View File

@@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs.data import EmbedsPrompt, TokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import TokPrompt
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.collection_utils import chunk_list
from vllm.utils.serial_utils import EmbedDType, Endianness
@@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: TokensPrompt | EmbedsPrompt,
engine_prompt: TokPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,

View File

@@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs import TokPrompt
from vllm.renderers.inputs.preprocess import prompt_to_seq
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
@@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing):
"dimensions is currently not supported"
)
engine_prompts: Sequence[PromptType | TokPrompt]
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
@@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing):
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
if not isinstance(engine_prompts, Sequence) or isinstance(
engine_prompts, (str, bytes, bytearray)
):
engine_prompts = [engine_prompts]
engine_prompts = prompt_to_seq(raw_prompts)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
request_chat_template=request.chat_template,

View File

@@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_len
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -189,7 +187,7 @@ def cli_env_setup():
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
input_length: int,
default_sampling_params: dict,
) -> int:
# NOTE: Avoid isinstance() for better efficiency
@@ -204,7 +202,6 @@ def get_max_tokens(
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = get_prompt_len(prompt)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)

View File

@@ -16,11 +16,8 @@ from .data import (
TextPrompt,
TokenInputs,
TokensPrompt,
build_explicit_enc_dec_prompt,
embeds_inputs,
to_enc_dec_tuple_list,
token_inputs,
zip_enc_dec_prompts,
)
__all__ = [
@@ -39,8 +36,5 @@ __all__ = [
"EncoderDecoderInputs",
"ProcessorInputs",
"SingletonInputs",
"build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list",
"zip_enc_dec_prompts",
"StreamingInput",
]

View File

@@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar
from typing_extensions import NotRequired, TypedDict
from vllm.sampling_params import SamplingParams
@@ -23,7 +22,13 @@ else:
MultiModalUUIDDict = object
class _CommonKeys(TypedDict):
# Inputs to LLM API
class _PromptOptions(TypedDict):
"""
Additional options available to all
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt].
"""
multi_modal_data: NotRequired[MultiModalDataDict | None]
"""
Optional multi-modal data to pass to the model,
@@ -53,14 +58,14 @@ class _CommonKeys(TypedDict):
"""
class TextPrompt(_CommonKeys):
class TextPrompt(_PromptOptions):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
class TokensPrompt(_CommonKeys):
class TokensPrompt(_PromptOptions):
"""Schema for a tokenized prompt."""
prompt_token_ids: list[int]
@@ -73,7 +78,7 @@ class TokensPrompt(_CommonKeys):
"""A list of token type IDs to pass to the cross encoder model."""
class EmbedsPrompt(_CommonKeys):
class EmbedsPrompt(_PromptOptions):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
@@ -83,93 +88,113 @@ class EmbedsPrompt(_CommonKeys):
"""The prompt text corresponding to the token embeddings, if available."""
class DataPrompt(_CommonKeys):
"""Represents generic inputs handled by IO processor plugins."""
DecoderOnlyPrompt: TypeAlias = (
str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt
)
"""
Schema of a prompt for a decoder-only model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
For encoder-decoder models, passing a singleton prompt is shorthand for passing
`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`.
"""
EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the encoder part of a encoder-decoder model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
"""
DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt
"""
Schema of a prompt for the decoder part of an encoder-decoder model:
- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt (list of token IDs, or
[`TokensPrompt`][vllm.inputs.data.TokensPrompt])
Note:
Multi-modal inputs are not supported for decoder prompts.
"""
class ExplicitEncoderDecoderPrompt(TypedDict):
"""
Schema for a pair of encoder and decoder singleton prompts.
Note:
This schema is not valid for decoder-only models.
"""
encoder_prompt: EncoderPrompt
"""The prompt for the encoder part of the model."""
decoder_prompt: DecoderPrompt | None
"""
The prompt for the decoder part of the model.
Passing `None` will cause the prompt to be inferred automatically.
"""
EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt
"""
Schema for a prompt for an encoder-decoder model.
You can pass a singleton encoder prompt, in which case the decoder prompt is
considered to be `None` (i.e., infer automatically).
"""
SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt
"""
Schema for a single prompt. This is as opposed to a data structure
which encapsulates multiple prompts, such as
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt].
"""
PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt
"""
Schema for any prompt, regardless of model type.
This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs.
"""
class DataPrompt(_PromptOptions):
"""
Represents generic inputs that are converted to
[`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins.
"""
data: Any
"""The input data"""
"""The input data."""
data_format: str
"""The input data format"""
"""The input data format."""
SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt
"""
Set of possible schemas for a single prompt:
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be
employed as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
"""
_T1_co = TypeVar(
"_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
_T2_co = TypeVar(
"_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True
)
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
# Outputs of processor
class _InputOptions(TypedDict):
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances.
Additional options available to all input types.
"""
encoder_prompt: _T1_co
decoder_prompt: _T2_co | None
mm_processor_kwargs: NotRequired[dict[str, Any]]
cache_salt: NotRequired[str]
"""Optional cache salt to be used for prefix caching."""
PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt])
- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt])
- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt])
- A single data structure containing both an encoder and a decoder prompt
([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt])
"""
class TokenInputs(TypedDict):
class TokenInputs(_InputOptions):
"""Represents token-based inputs."""
type: Literal["token"]
@@ -178,11 +203,6 @@ class TokenInputs(TypedDict):
prompt_token_ids: list[int]
"""The token IDs of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def token_inputs(
prompt_token_ids: list[int],
@@ -198,7 +218,7 @@ def token_inputs(
return inputs
class EmbedsInputs(TypedDict):
class EmbedsInputs(_InputOptions):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
@@ -207,11 +227,6 @@ class EmbedsInputs(TypedDict):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(
prompt_embeds: torch.Tensor,
@@ -229,96 +244,60 @@ def embeds_inputs(
DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
"""
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are
passed to the model executor.
This specifies the data required for decoder-only models.
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for decoder-only models.
"""
EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs
"""
A processed encoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs
"""
A processed decoder prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
class EncoderDecoderInputs(TypedDict):
"""
The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they
are passed to the model executor.
This specifies the required data for encoder-decoder models.
A processed pair of encoder and decoder singleton prompts.
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]
for encoder-decoder models.
"""
encoder: TokenInputs | MultiModalEncDecInputs
encoder: EncoderInputs
"""The inputs for the encoder portion."""
decoder: TokenInputs | MultiModalInputs
decoder: DecoderInputs
"""The inputs for the decoder portion."""
SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs
"""
A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be
passed to [`Sequence`][collections.abc.Sequence].
"""
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
"""
The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][].
A processed prompt from
[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor]
which can be passed to
[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor].
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def build_explicit_enc_dec_prompt(
encoder_prompt: _T1,
decoder_prompt: _T2 | None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt,
decoder_prompt=decoder_prompt,
mm_processor_kwargs=mm_processor_kwargs,
)
def zip_enc_dec_prompts(
enc_prompts: Iterable[_T1],
dec_prompts: Iterable[_T2 | None],
mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None,
) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
"""
Zip encoder and decoder prompts together into a list of
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
`mm_processor_kwargs` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is
provided, it will be zipped with the encoder/decoder prompts.
"""
if mm_processor_kwargs is None:
mm_processor_kwargs = cast(dict[str, Any], {})
if isinstance(mm_processor_kwargs, dict):
return [
build_explicit_enc_dec_prompt(
encoder_prompt,
decoder_prompt,
cast(dict[str, Any], mm_processor_kwargs),
)
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
]
return [
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs)
for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip(
enc_prompts, dec_prompts, mm_processor_kwargs
)
]
def to_enc_dec_tuple_list(
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
) -> list[tuple[_T1, _T2 | None]]:
return [
(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts
]
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
@dataclass

View File

@@ -1,88 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict
from typing_extensions import TypeIs
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .data import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
if TYPE_CHECKING:
import torch
class ParsedStrPrompt(TypedDict):
type: Literal["str"]
content: str
class ParsedTextPrompt(TypedDict):
type: Literal["text"]
content: TextPrompt
class ParsedTokensPrompt(TypedDict):
type: Literal["tokens"]
content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal["embeds"]
content: EmbedsPrompt
ParsedSingletonPrompt: TypeAlias = (
ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt
)
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt"
)
def is_explicit_encoder_decoder_prompt(
prompt: PromptType,
) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def split_enc_dec_prompt(
prompt: PromptType,
) -> tuple[SingletonPrompt, SingletonPrompt | None]:
if isinstance(prompt, str):
return prompt, None
if "encoder_prompt" in prompt and "decoder_prompt" in prompt:
# NOTE: This passes pyright but not mypy
return (
prompt["encoder_prompt"], # type: ignore[typeddict-item]
prompt["decoder_prompt"], # type: ignore[typeddict-item]
)
return prompt, None
from .data import ProcessorInputs, SingletonInputs
def split_enc_dec_inputs(
@@ -96,30 +15,3 @@ def split_enc_dec_inputs(
)
return None, inputs
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def get_prompt_components(prompt: PromptType) -> PromptComponents:
if isinstance(prompt, str):
return PromptComponents(text=prompt)
if encoder_prompt := prompt.get("encoder_prompt"):
return get_prompt_components(encoder_prompt) # type: ignore[arg-type]
return PromptComponents(
text=prompt.get("prompt"), # type: ignore[arg-type]
token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=prompt.get("prompt_embeds"),
)
def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt):
return length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)

View File

@@ -2,43 +2,51 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import Any
from typing import Any, overload
from typing_extensions import assert_never
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.inputs.parse import split_enc_dec_prompt
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs,
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import renderer_from_config
from vllm.renderers.inputs import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
TokPrompt,
)
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import (
DecoderInputs,
DecoderOnlyInputs,
EmbedsInputs,
EmbedsPrompt,
EncoderDecoderInputs,
EncoderInputs,
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
TokenInputs,
TokensPrompt,
embeds_inputs,
token_inputs,
)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
logger = init_logger(__name__)
@@ -328,9 +336,36 @@ class InputPreprocessor:
return inputs
@overload
def _prompt_to_llm_inputs(
self,
prompt: SingletonPrompt,
prompt: EncoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs: ...
@overload
def _prompt_to_llm_inputs( # type: ignore[misc]
self,
prompt: DecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs: ...
@overload
def _prompt_to_llm_inputs( # type: ignore[misc]
self,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs: ...
def _prompt_to_llm_inputs(
self,
prompt: SingletonDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
@@ -346,34 +381,25 @@ class InputPreprocessor:
* [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
"""
parsed = parse_singleton_prompt(prompt)
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]
if parsed["type"] == "embeds":
return self._process_embeds(parsed["content"])
if parsed["type"] == "tokens":
if "prompt_token_ids" in prompt:
return self._process_tokens(
parsed["content"],
prompt, # type: ignore[arg-type]
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
if "prompt" in prompt:
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
prompt, # type: ignore[arg-type]
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
assert_never(parsed)
assert_never(prompt) # type: ignore[arg-type]
def _validate_enc_inputs(
self,
inputs: SingletonInputs,
) -> TokenInputs | MultiModalEncDecInputs:
def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
@@ -387,10 +413,7 @@ class InputPreprocessor:
return inputs # type: ignore[return-value]
def _validate_dec_inputs(
self,
inputs: SingletonInputs,
) -> TokenInputs | MultiModalInputs:
def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
@@ -403,14 +426,15 @@ class InputPreprocessor:
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None = None,
) -> EncoderDecoderInputs:
if decoder_inputs is None:
decoder_inputs = encoder_inputs
enc_inputs = self._validate_enc_inputs(encoder_inputs)
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: TokenInputs | MultiModalEncDecInputs
dec_inputs_new: TokenInputs | MultiModalInputs
if decoder_inputs is None:
dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
else:
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal":
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
@@ -437,7 +461,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
prompt: EncoderDecoderDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
@@ -448,24 +472,6 @@ class InputPreprocessor:
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* prompt: an input prompt
@@ -475,7 +481,8 @@ class InputPreprocessor:
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
"""
encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt)
encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"]
return self._build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs(
@@ -495,7 +502,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
prompt: SingletonPrompt,
prompt: DecoderOnlyDictPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
@@ -521,7 +528,7 @@ class InputPreprocessor:
def _preprocess(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
@@ -530,25 +537,20 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
return self._process_encoder_decoder_prompt(
prompt,
parse_enc_dec_prompt(prompt),
tokenization_kwargs,
mm_uuids=mm_uuids,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError(
"Cannot pass encoder-decoder prompt to decoder-only models"
)
return self._process_decoder_only_prompt(
prompt,
parse_dec_only_prompt(prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def preprocess(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
tokenization_kwargs: dict[str, Any] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,

View File

@@ -20,7 +20,7 @@ from typing import (
import numpy as np
from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar
from typing_extensions import TypeVar
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
@@ -32,9 +32,13 @@ if TYPE_CHECKING:
import torch
import torch.types
from transformers.feature_extraction_utils import BatchFeature
from vllm.inputs.data import _InputOptions
else:
torch = LazyLoader("torch", globals(), "torch")
_InputOptions = dict
_T = TypeVar("_T")
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
@@ -1059,7 +1063,7 @@ A dictionary containing per-item placeholder ranges for each modality.
"""
class MultiModalInputs(TypedDict):
class MultiModalInputs(_InputOptions):
"""
Represents the outputs of
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor],
@@ -1084,11 +1088,6 @@ class MultiModalInputs(TypedDict):
`prompt_token_ids`.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
class MultiModalEncDecInputs(MultiModalInputs):
"""

View File

@@ -19,6 +19,7 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import SamplingParams
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.selector import AttentionSelectorConfig
@@ -565,7 +566,7 @@ class Platform:
@classmethod
def validate_request(
cls,
prompt: "PromptType",
prompt: "PromptType | DictPrompt | TokPrompt",
params: "SamplingParams | PoolingParams",
processed_inputs: "ProcessorInputs",
) -> None:

View File

@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
@@ -61,7 +62,7 @@ class DeepseekV32Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
@@ -75,7 +76,7 @@ class DeepseekV32Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
@@ -87,7 +88,7 @@ class DeepseekV32Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
@@ -101,7 +102,7 @@ class DeepseekV32Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:

View File

@@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
@@ -61,7 +62,7 @@ class Grok2Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
@@ -75,7 +76,7 @@ class Grok2Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
@@ -87,7 +88,7 @@ class Grok2Renderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
@@ -101,7 +102,7 @@ class Grok2Renderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:

View File

@@ -25,7 +25,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
@@ -33,6 +32,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
@@ -632,7 +633,7 @@ class HfRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
@@ -674,7 +675,7 @@ class HfRenderer(BaseRenderer):
video_placeholder,
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
@@ -686,7 +687,7 @@ class HfRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
@@ -726,7 +727,7 @@ class HfRenderer(BaseRenderer):
video_placeholder,
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .preprocess import (
DecoderDictPrompt,
DecoderOnlyDictPrompt,
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDictPrompt,
SingletonDictPrompt,
)
from .tokenize import (
DecoderOnlyTokPrompt,
DecoderTokPrompt,
EncoderDecoderTokPrompt,
EncoderTokPrompt,
SingletonTokPrompt,
TokPrompt,
)
__all__ = [
"DecoderOnlyDictPrompt",
"EncoderDictPrompt",
"DecoderDictPrompt",
"EncoderDecoderDictPrompt",
"SingletonDictPrompt",
"DictPrompt",
"DecoderOnlyTokPrompt",
"EncoderTokPrompt",
"DecoderTokPrompt",
"EncoderDecoderTokPrompt",
"SingletonTokPrompt",
"TokPrompt",
]

View File

@@ -0,0 +1,255 @@
"""
Schemas and utilites for preprocessing inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
from vllm.inputs import (
EmbedsPrompt,
ExplicitEncoderDecoderPrompt,
PromptType,
SingletonPrompt,
TextPrompt,
TokensPrompt,
)
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collection_utils import is_list_of
if TYPE_CHECKING:
import torch
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
@overload
def prompt_to_seq(
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
) -> Sequence[SingletonPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: ExplicitEncoderDecoderPrompt
| Sequence[ExplicitEncoderDecoderPrompt],
) -> Sequence[ExplicitEncoderDecoderPrompt]: ...
@overload
def prompt_to_seq( # type: ignore[misc]
prompt_or_prompts: PromptType | Sequence[PromptType],
) -> Sequence[PromptType]: ...
def prompt_to_seq(
prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes],
) -> Sequence[PromptType]:
if isinstance(prompt_or_prompts, (dict, str, bytes)) or (
len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int)
):
return [prompt_or_prompts] # type: ignore[list-item]
return prompt_or_prompts # type: ignore[return-value]
def conversation_to_seq(
conversation_or_conversations: list["ChatCompletionMessageParam"]
| Sequence[list["ChatCompletionMessageParam"]],
) -> Sequence[list["ChatCompletionMessageParam"]]:
if len(conversation_or_conversations) > 0 and is_list_of(
conversation_or_conversations, dict
):
return [conversation_or_conversations] # type: ignore[list-item]
return conversation_or_conversations # type: ignore[return-value]
DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt]
that has been standardized into a dictionary.
"""
EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt]
that has been standardized into a dictionary.
"""
DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt
"""
A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt]
that has been standardized into a dictionary.
"""
class EncoderDecoderDictPrompt(TypedDict):
"""
A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt]
that has been standardized into a dictionary.
"""
encoder_prompt: EncoderDictPrompt
decoder_prompt: DecoderDictPrompt | None
SingletonDictPrompt: TypeAlias = (
DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt
)
"""
A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]
that has been standardized into a dictionary.
"""
DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt
"""
A [`PromptType`][vllm.inputs.data.PromptType]
that has been standardized into a dictionary.
"""
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
"""
Parse a prompt for a decoder-only model and normalize it to a dictionary.
"""
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "encoder_prompt" in prompt:
raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models")
if (
"prompt" in prompt
or "prompt_token_ids" in prompt
or "prompt_embeds" in prompt
):
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text, tokens, or embeddings")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
if isinstance(prompt, str):
return TextPrompt(prompt=prompt)
if isinstance(prompt, list):
if not is_list_of(prompt, int):
raise TypeError("Token prompt should be a list of integers")
return TokensPrompt(prompt_token_ids=prompt)
if isinstance(prompt, dict):
if "prompt_embeds" in prompt:
raise TypeError("Cannot pass embeddings prompt to encoder-decoder models")
if (
"multi_modal_data" in prompt
or "mm_processor_kwargs" in prompt
or "multi_modal_uuids" in prompt
):
raise TypeError("Cannot pass multi-modal inputs to decoder prompt")
if "prompt" in prompt or "prompt_token_ids" in prompt:
return prompt # type: ignore[return-value]
raise TypeError("Prompt dictionary must contain text or tokens")
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
"""
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
"""
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
else:
enc_prompt = prompt
dec_prompt = None
return EncoderDecoderDictPrompt(
encoder_prompt=_parse_enc_prompt(enc_prompt),
decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt),
)
def parse_model_prompt(model_config: "ModelConfig", prompt: object):
if model_config.is_encoder_decoder:
return parse_enc_dec_prompt(prompt)
return parse_dec_only_prompt(prompt)
class PromptComponents(NamedTuple):
text: str | None = None
token_ids: list[int] | None = None
embeds: "torch.Tensor | None" = None
def extract_prompt_components(
model_config: "ModelConfig",
prompt: object,
) -> PromptComponents:
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return PromptComponents(
text=target_prompt.get("prompt"),
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
embeds=target_prompt.get("prompt_embeds"),
)
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
target_prompt = (
parse_enc_dec_prompt(prompt)["encoder_prompt"]
if model_config.is_encoder_decoder
else parse_dec_only_prompt(prompt)
)
return length_from_prompt_token_ids_or_embeds(
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
target_prompt.get("prompt_embeds"),
)

View File

@@ -0,0 +1,57 @@
"""
Schemas and utilites for tokenization inputs.
"""
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TypeAlias, TypedDict
from vllm.inputs import EmbedsPrompt, TokensPrompt
DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt
"""
A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt]
that has been tokenized.
"""
EncoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt]
that has been tokenized.
"""
DecoderTokPrompt: TypeAlias = TokensPrompt
"""
A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt]
that has been tokenized.
"""
class EncoderDecoderTokPrompt(TypedDict):
"""
A
[`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt]
that has been tokenized.
"""
encoder_prompt: EncoderTokPrompt
decoder_prompt: DecoderTokPrompt | None
SingletonTokPrompt: TypeAlias = (
DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt
)
"""
A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt]
that has been tokenized.
"""
TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt
"""
A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt]
that has been tokenized.
"""

View File

@@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
@@ -95,7 +96,7 @@ class MistralRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
@@ -109,7 +110,7 @@ class MistralRenderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
@@ -121,7 +122,7 @@ class MistralRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
@@ -135,7 +136,7 @@ class MistralRenderer(BaseRenderer):
**params.get_apply_chat_template_kwargs(),
)
prompt = self.render_completion(prompt_raw)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:

View File

@@ -2,14 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.collection_utils import is_list_of
from .embed_utils import safe_load_prompt_embeds
from .inputs import (
DictPrompt,
EncoderDecoderDictPrompt,
EncoderDecoderTokPrompt,
TokPrompt,
)
from .params import ChatParams, TokenizeParams
if TYPE_CHECKING:
@@ -57,140 +63,217 @@ class BaseRenderer(ABC):
return self._async_tokenizer
# Step 1: Convert raw inputs to prompts
def render_completion(
def render_prompt(
self,
prompt_raw: str | list[int] | bytes,
) -> TextPrompt | TokensPrompt | EmbedsPrompt:
error_msg = "Each prompt must be a string or an array of tokens"
prompt: DictPrompt | bytes,
) -> DictPrompt:
if isinstance(prompt, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt)
prompt = EmbedsPrompt(prompt_embeds=embeds)
if isinstance(prompt_raw, str):
return TextPrompt(prompt=prompt_raw)
return prompt
if isinstance(prompt_raw, list):
if not is_list_of(prompt_raw, int):
raise TypeError(error_msg)
return TokensPrompt(prompt_token_ids=prompt_raw)
if isinstance(prompt_raw, bytes):
embeds = safe_load_prompt_embeds(self.config, prompt_raw)
return EmbedsPrompt(prompt_embeds=embeds)
raise TypeError(error_msg)
def render_completions(
def render_prompts(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
prompts_raw = list[str | list[int] | bytes]()
if prompt_embeds is not None: # embeds take higher priority
if isinstance(prompt_embeds, bytes):
prompts_raw.append(prompt_embeds)
else:
prompts_raw.extend(prompt_embeds)
if prompt_input is not None:
if isinstance(prompt_input, str) or (
len(prompt_input) > 0 and is_list_of(prompt_input, int)
):
prompts_raw.append(prompt_input) # type: ignore[arg-type]
else:
prompts_raw.extend(prompt_input) # type: ignore[arg-type]
if len(prompts_raw) == 0:
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
if len(prompts) == 0:
raise ValueError("You must pass at least one prompt")
return [self.render_completion(prompt) for prompt in prompts_raw]
return [self.render_prompt(prompt) for prompt in prompts]
async def render_completions_async(
async def render_prompts_async(
self,
prompt_input: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_completions(prompt_input, prompt_embeds)
prompts: Sequence[DictPrompt | bytes],
) -> list[DictPrompt]:
return self.render_prompts(prompts)
@abstractmethod
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list["ConversationMessage"], DictPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
params: ChatParams,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list["ConversationMessage"], DictPrompt]:
return self.render_messages(messages, params)
# Step 2: Tokenize prompts if necessary
def _tokenize_prompt(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
async def _tokenize_prompt_async(
self,
prompt: TextPrompt,
params: TokenizeParams,
) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_tokenizer()
prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"])
return prompt
async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt:
tokenizer = self.get_async_tokenizer()
prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"])
return prompt
def _tokenize_enc_dec_prompt(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = (
self.tokenize_prompt(prompt["encoder_prompt"], params),
(
None
if prompt["decoder_prompt"] is None
else self.tokenize_prompt(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
async def _tokenize_enc_dec_prompt_async(
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt:
enc_prompt, dec_prompt = await asyncio.gather(
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
(
asyncio.sleep(0)
if prompt["decoder_prompt"] is None
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
),
)
return EncoderDecoderTokPrompt(
encoder_prompt=enc_prompt,
decoder_prompt=dec_prompt,
)
@overload
def tokenize_prompt(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
) -> TokensPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
def tokenize_prompt( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
def tokenize_prompt(
self,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokPrompt:
if "encoder_prompt" in prompt:
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_tokenizer()
prompt_token_ids = tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
prompt = self._tokenize_prompt(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_tokenizer()
prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
def tokenize_prompts(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
@overload
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt,
params: TokenizeParams,
) -> TokensPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EmbedsPrompt,
params: TokenizeParams,
) -> EmbedsPrompt: ...
@overload
async def tokenize_prompt_async( # type: ignore[misc]
self,
prompt: EncoderDecoderDictPrompt,
params: TokenizeParams,
) -> EncoderDecoderTokPrompt: ...
async def tokenize_prompt_async(
self,
prompt: TextPrompt | TokensPrompt | EmbedsPrompt,
prompt: DictPrompt,
params: TokenizeParams,
) -> TokensPrompt | EmbedsPrompt:
) -> TokPrompt:
if "encoder_prompt" in prompt:
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
tokenizer = self.get_async_tokenizer()
prompt_token_ids = await tokenizer.encode(
prompt["prompt"],
**params.get_encode_kwargs(),
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt)
prompt = await self._tokenize_prompt_async(prompt, params)
if params.needs_detokenization and "prompt" not in prompt:
if "prompt_token_ids" not in prompt:
raise RuntimeError("Cannot run detokenization on embeddings")
tokenizer = self.get_async_tokenizer()
prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key]
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
async def tokenize_prompts_async(
self,
prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt],
prompts: Sequence[DictPrompt],
params: TokenizeParams,
) -> list[TokensPrompt | EmbedsPrompt]:
) -> list[TokPrompt]:
return await asyncio.gather(
*(self.tokenize_prompt_async(prompt, params) for prompt in prompts)
)

View File

@@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
from .protocol import BaseRenderer
@@ -45,7 +46,7 @@ class TerratorchRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages(
@@ -54,7 +55,7 @@ class TerratorchRenderer(BaseRenderer):
content_format="string",
)
prompt = self.render_completion([1]) # Dummy token IDs
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
@@ -66,7 +67,7 @@ class TerratorchRenderer(BaseRenderer):
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]:
) -> tuple[list[ConversationMessage], DictPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
@@ -75,7 +76,7 @@ class TerratorchRenderer(BaseRenderer):
content_format="string",
)
prompt = self.render_completion([1]) # Dummy token IDs
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:

View File

@@ -28,6 +28,8 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
@@ -42,7 +44,6 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
@@ -284,7 +285,11 @@ class AsyncLLM(EngineClient):
async def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
@@ -367,7 +372,7 @@ class AsyncLLM(EngineClient):
data_parallel_rank=data_parallel_rank,
supported_tasks=await self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
self.input_processor.assign_request_id(request)
@@ -484,7 +489,9 @@ class AsyncLLM(EngineClient):
raise ValueError(
"prompt_embeds not supported for streaming inputs"
)
prompt_text = get_prompt_text(input_chunk.prompt)
prompt_text, _, _ = extract_prompt_components(
self.model_config, input_chunk.prompt
)
await self._add_request(req, prompt_text, None, 0, queue)
except (asyncio.CancelledError, GeneratorExit):
cancelled = True
@@ -528,7 +535,11 @@ class AsyncLLM(EngineClient):
# re-multiplexed in the API server anyhow.
async def generate(
self,
prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None],
prompt: EngineCoreRequest
| PromptType
| DictPrompt
| TokPrompt
| AsyncGenerator[StreamingInput, None],
sampling_params: SamplingParams,
request_id: str,
*,
@@ -769,7 +780,7 @@ class AsyncLLM(EngineClient):
async def encode(
self,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
pooling_params: PoolingParams,
request_id: str,
lora_request: LoRARequest | None = None,

View File

@@ -7,14 +7,13 @@ from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs import (
from vllm.inputs.data import (
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -30,6 +29,7 @@ from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike
@@ -243,8 +243,8 @@ class InputProcessor:
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str):
prompt = TextPrompt(prompt=prompt)
if not isinstance(prompt, dict):
return
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
@@ -297,7 +297,7 @@ class InputProcessor:
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType) -> None:
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
@@ -305,10 +305,10 @@ class InputProcessor:
auto-hashed downstream.
"""
if is_explicit_encoder_decoder_prompt(prompt):
self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
if (dec_prompt := prompt["decoder_prompt"]) is not None:
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
self._validate_singleton_mm_uuids(dec_prompt)
else:
self._validate_singleton_mm_uuids(prompt)
@@ -449,21 +449,23 @@ class InputProcessor:
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, str):
if not isinstance(prompt, dict):
return None
return prompt.get("multi_modal_data") # type: ignore[return-value]
return prompt.get("multi_modal_data")
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
if is_explicit_encoder_decoder_prompt(prompt):
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
def _extract_mm_data(
self, prompt: PromptType | DictPrompt | TokPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
) -> MultiModalUUIDDict | None:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
@@ -519,7 +521,7 @@ class InputProcessor:
def process_inputs(
self,
request_id: str,
prompt: PromptType,
prompt: PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,

View File

@@ -22,6 +22,8 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import BaseRenderer
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
@@ -32,7 +34,6 @@ from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.input_processor import InputProcessor
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.utils import get_prompt_text
from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
@@ -216,7 +217,7 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: EngineCoreRequest | PromptType,
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
@@ -251,7 +252,7 @@ class LLMEngine:
priority,
supported_tasks=self.get_supported_tasks(),
)
prompt_text = get_prompt_text(prompt)
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
self.input_processor.assign_request_id(request)

View File

@@ -17,8 +17,6 @@ import zmq
from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.inputs import PromptType
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy
@@ -226,10 +224,6 @@ def get_device_indices(
return value
def get_prompt_text(prompt: PromptType) -> str | None:
return get_prompt_components(prompt)[0]
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown