diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index de5d96d5b..8a2894154 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -54,6 +54,7 @@ class MockModelConfig: generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) skip_tokenizer_init = False + is_encoder_decoder: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/openai/test_completion_error.py b/tests/entrypoints/openai/test_completion_error.py index b60397cd7..bbf97534f 100644 --- a/tests/entrypoints/openai/test_completion_error.py +++ b/tests/entrypoints/openai/test_completion_error.py @@ -53,6 +53,7 @@ class MockModelConfig: generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) skip_tokenizer_init = False + is_encoder_decoder: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/openai/test_lora_resolvers.py b/tests/entrypoints/openai/test_lora_resolvers.py index b43577406..db7fbe2f8 100644 --- a/tests/entrypoints/openai/test_lora_resolvers.py +++ b/tests/entrypoints/openai/test_lora_resolvers.py @@ -52,6 +52,7 @@ class MockModelConfig: encoder_config = None generation_config: str = "auto" skip_tokenizer_init: bool = False + is_encoder_decoder: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 4365075f6..ef9d944ab 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -529,6 +529,7 @@ class MockModelConfig: generation_config: str = "auto" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) skip_tokenizer_init: bool = False + is_encoder_decoder: bool = False def get_diff_sampling_param(self): return self.diff_sampling_param or {} diff --git a/tests/renderers/inputs/__init__.py b/tests/renderers/inputs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/renderers/inputs/test_preprocess.py b/tests/renderers/inputs/test_preprocess.py new file mode 100644 index 000000000..707f9eedf --- /dev/null +++ b/tests/renderers/inputs/test_preprocess.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.renderers.inputs.preprocess import prompt_to_seq + + +def test_empty_input(): + assert prompt_to_seq([]) == [] + assert prompt_to_seq([[]]) == [[]] + assert prompt_to_seq([[], []]) == [[], []] + + +def test_text_input(): + assert prompt_to_seq("foo") == ["foo"] + assert prompt_to_seq(["foo"]) == ["foo"] + assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"] + + +def test_token_input(): + assert prompt_to_seq([1, 2]) == [[1, 2]] + assert prompt_to_seq([[1, 2]]) == [[1, 2]] + assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]] + + +def test_text_token_input(): + assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"] + assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]] + + +def test_bytes_input(): + assert prompt_to_seq(b"foo") == [b"foo"] + assert prompt_to_seq([b"foo"]) == [b"foo"] + assert prompt_to_seq([b"foo", b"bar"]) == [b"foo", b"bar"] + + +def test_dict_input(): + assert prompt_to_seq({"prompt": "foo"}) == [{"prompt": "foo"}] + assert prompt_to_seq([{"prompt": "foo"}]) == [{"prompt": "foo"}] + assert prompt_to_seq([{"prompt": "foo"}, {"prompt_token_ids": [1, 2]}]) == [ + {"prompt": "foo"}, + {"prompt_token_ids": [1, 2]}, + ] diff --git a/tests/renderers/test_completions.py b/tests/renderers/test_completions.py index 84b7230d9..1cef8551c 100644 --- a/tests/renderers/test_completions.py +++ b/tests/renderers/test_completions.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io +from collections.abc import Sequence from dataclasses import dataclass from typing import Any @@ -9,8 +10,11 @@ import pybase64 import pytest import torch +from vllm.config import ModelConfig +from vllm.inputs import SingletonPrompt from vllm.renderers import TokenizeParams from vllm.renderers.hf import HfRenderer +from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq from vllm.tokenizers.registry import tokenizer_args_from_config MODEL_NAME = "openai-community/gpt2" @@ -33,6 +37,7 @@ class MockModelConfig: encoder_config: dict[str, Any] | None = None enable_prompt_embeds: bool = True skip_tokenizer_init: bool = False + is_encoder_decoder: bool = False @dataclass @@ -80,65 +85,34 @@ def _build_renderer( return renderer +def _preprocess_prompt( + mdoel_config: ModelConfig, + prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes], +): + return [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(mdoel_config, prompt) + ) + for prompt in prompt_to_seq(prompt_or_prompts) + ] + + class TestValidatePrompt: - STRING_INPUTS = [ - "", - "foo", - "foo bar", - "foo baz bar", - "foo bar qux baz", - ] - - TOKEN_INPUTS = [ - [-1], - [1], - [1, 2], - [1, 3, 4], - [1, 2, 4, 3], - ] - - INPUTS_SLICES = [ - slice(None, None, -1), - slice(None, None, 2), - slice(None, None, -2), - ] - - # Test that a nested mixed-type list of lists raises a TypeError. def test_empty_input(self): renderer = _build_renderer(MockModelConfig()) with pytest.raises(ValueError, match="at least one prompt"): - renderer.render_completions([]) + renderer.render_prompts(_preprocess_prompt(renderer.config, [])) def test_invalid_type(self): renderer = _build_renderer(MockModelConfig()) - with pytest.raises(TypeError, match="string or an array of tokens"): - renderer.render_completions([[1, 2], ["foo", "bar"]]) - - @pytest.mark.parametrize("string_input", STRING_INPUTS) - def test_string_consistent(self, string_input: str): - renderer = _build_renderer(MockModelConfig()) - - assert renderer.render_completions(string_input) == renderer.render_completions( - [string_input] - ) - - @pytest.mark.parametrize("token_input", TOKEN_INPUTS) - def test_token_consistent(self, token_input: list[int]): - renderer = _build_renderer(MockModelConfig()) - - assert renderer.render_completions(token_input) == renderer.render_completions( - [token_input] - ) - - @pytest.mark.parametrize("inputs_slice", INPUTS_SLICES) - def test_string_slice(self, inputs_slice: slice): - renderer = _build_renderer(MockModelConfig()) - - assert renderer.render_completions(self.STRING_INPUTS)[ - inputs_slice - ] == renderer.render_completions(self.STRING_INPUTS[inputs_slice]) + with pytest.raises(TypeError, match="should be a list of integers"): + renderer.render_prompts( + _preprocess_prompt(renderer.config, [[1, 2], ["foo", "bar"]]) # type: ignore[arg-type] + ) class TestRenderPrompt: @@ -146,7 +120,7 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig()) tokens = [101, 7592, 2088] - prompts = renderer.render_completions(tokens) + prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens)) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100), @@ -159,7 +133,9 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig()) token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]] - prompts = renderer.render_completions(token_lists) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, token_lists) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100), @@ -174,7 +150,9 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig()) text_input = "x" * 10 - prompts = renderer.render_completions(text_input) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, text_input) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100), @@ -187,7 +165,9 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig()) text_list_input = ["x" * 10, "x" * 12, "x" * 14] - prompts = renderer.render_completions(text_list_input) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, text_list_input) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100), @@ -200,7 +180,9 @@ class TestRenderPrompt: def test_zero_truncation(self): renderer = _build_renderer(MockModelConfig()) - prompts = renderer.render_completions("x" * 200) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, "x" * 200) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=0), @@ -212,7 +194,9 @@ class TestRenderPrompt: def test_pos_truncation(self): renderer = _build_renderer(MockModelConfig()) - prompts = renderer.render_completions("x" * 200) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, "x" * 200) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=50), @@ -224,7 +208,9 @@ class TestRenderPrompt: def test_neg_truncation(self): renderer = _build_renderer(MockModelConfig()) - prompts = renderer.render_completions("x" * 200) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, "x" * 200) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=-1), @@ -237,7 +223,9 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig(), truncation_side="left") long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens - prompts = renderer.render_completions(long_tokens) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, long_tokens) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5), @@ -251,7 +239,9 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig(), truncation_side="right") long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] # 10 tokens - prompts = renderer.render_completions(long_tokens) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, long_tokens) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100, truncate_prompt_tokens=5), @@ -266,7 +256,9 @@ class TestRenderPrompt: # Exceeds max_total_tokens and max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN long_tokens = "x" * 150 - prompts = renderer.render_completions(long_tokens) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, long_tokens) + ) with pytest.raises( ValueError, @@ -285,7 +277,9 @@ class TestRenderPrompt: # Exceeds max_total_tokens but not max_total_tokens * VLLM_MAX_CHARS_PER_TOKEN long_tokens = "x" * 150 - prompts = renderer.render_completions(long_tokens) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, long_tokens) + ) with pytest.raises( ValueError, @@ -304,7 +298,9 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig()) long_tokens = list(range(150)) # Exceeds max_total_tokens=100 - prompts = renderer.render_completions(long_tokens) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, long_tokens) + ) with pytest.raises( ValueError, @@ -318,7 +314,9 @@ class TestRenderPrompt: def test_no_tokenizer_for_text(self): renderer = _build_renderer(MockModelConfig(skip_tokenizer_init=True)) - prompts = renderer.render_completions("Hello world") + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, "Hello world") + ) with pytest.raises(ValueError, match="`skip_tokenizer_init=True`"): renderer.tokenize_prompts( @@ -330,7 +328,7 @@ class TestRenderPrompt: renderer = _build_renderer(MockModelConfig()) tokens = [1, 2, 3, 4] - prompts = renderer.render_completions(tokens) + prompts = renderer.render_prompts(_preprocess_prompt(renderer.config, tokens)) results = renderer.tokenize_prompts( prompts, TokenizeParams( @@ -359,7 +357,9 @@ class TestRenderEmbedPrompt: tensor_input = torch.randn(10, 768, dtype=torch.float32) embed_bytes = self._create_test_embed_bytes(tensor_input) - prompts = renderer.render_completions(prompt_embeds=embed_bytes) + prompts = renderer.render_prompts( + _preprocess_prompt(renderer.config, embed_bytes) + ) results = renderer.tokenize_prompts( prompts, TokenizeParams(max_total_tokens=100), @@ -377,8 +377,11 @@ class TestRenderEmbedPrompt: torch.randn(12, 512, dtype=torch.float32), ] - prompts = renderer.render_completions( - prompt_embeds=[self._create_test_embed_bytes(t) for t in tensor_inputs], + prompts = renderer.render_prompts( + _preprocess_prompt( + renderer.config, + [self._create_test_embed_bytes(t) for t in tensor_inputs], + ) ) results = renderer.tokenize_prompts( prompts, @@ -395,8 +398,10 @@ class TestRenderEmbedPrompt: # Create tensor with more tokens than truncation limit tensor_input = torch.randn(20, 768, dtype=torch.float32) - prompts = renderer.render_completions( - prompt_embeds=self._create_test_embed_bytes(tensor_input), + prompts = renderer.render_prompts( + _preprocess_prompt( + renderer.config, self._create_test_embed_bytes(tensor_input) + ) ) results = renderer.tokenize_prompts( prompts, @@ -420,8 +425,10 @@ class TestRenderEmbedPrompt: for dtype in dtypes: tensor_input = torch.randn(5, 256, dtype=dtype) - prompts = renderer.render_completions( - prompt_embeds=self._create_test_embed_bytes(tensor_input), + prompts = renderer.render_prompts( + _preprocess_prompt( + renderer.config, self._create_test_embed_bytes(tensor_input) + ) ) results = renderer.tokenize_prompts( prompts, @@ -437,8 +444,10 @@ class TestRenderEmbedPrompt: # Test tensor with batch dimension gets squeezed tensor_input = torch.randn(1, 10, 768, dtype=torch.float32) - prompts = renderer.render_completions( - prompt_embeds=self._create_test_embed_bytes(tensor_input), + prompts = renderer.render_prompts( + _preprocess_prompt( + renderer.config, self._create_test_embed_bytes(tensor_input) + ) ) results = renderer.tokenize_prompts( prompts, @@ -455,9 +464,11 @@ class TestRenderEmbedPrompt: text_input = "Hello world" tensor_input = torch.randn(5, 256, dtype=torch.float32) - prompts = renderer.render_completions( - text_input, - prompt_embeds=self._create_test_embed_bytes(tensor_input), + prompts = renderer.render_prompts( + _preprocess_prompt( + renderer.config, + [text_input, self._create_test_embed_bytes(tensor_input)], + ) ) results = renderer.tokenize_prompts( prompts, @@ -465,8 +476,8 @@ class TestRenderEmbedPrompt: ) assert len(results) == 2 - # First should be embed prompt - assert torch.equal(results[0]["prompt_embeds"], tensor_input) - # Second should be tokens prompt - assert "prompt_token_ids" in results[1] - assert len(results[1]["prompt_token_ids"]) == len(text_input) + # First should be tokens prompt + assert "prompt_token_ids" in results[0] + assert len(results[0]["prompt_token_ids"]) == len(text_input) + # Second should be embed prompt + assert torch.equal(results[1]["prompt_embeds"], tensor_input) diff --git a/tests/renderers/test_mistral.py b/tests/renderers/test_mistral.py index 9346582bf..f1d73e738 100644 --- a/tests/renderers/test_mistral.py +++ b/tests/renderers/test_mistral.py @@ -3,16 +3,40 @@ import asyncio import time +from dataclasses import dataclass +from typing import Any from unittest.mock import Mock import pytest from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy -from vllm.config import ModelConfig from vllm.renderers import ChatParams from vllm.renderers.mistral import MistralRenderer, safe_apply_chat_template from vllm.tokenizers.mistral import MistralTokenizer +MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3" + + +@dataclass +class MockHFConfig: + model_type: str = "any" + + +@dataclass +class MockModelConfig: + runner_type = "generate" + model: str = MODEL_NAME + tokenizer: str = MODEL_NAME + trust_remote_code: bool = False + max_model_len: int = 100 + tokenizer_revision = None + tokenizer_mode = "mistral" + hf_config = MockHFConfig() + encoder_config: dict[str, Any] | None = None + enable_prompt_embeds: bool = True + skip_tokenizer_init: bool = False + is_encoder_decoder: bool = False + @pytest.mark.asyncio async def test_async_mistral_tokenizer_does_not_block_event_loop(): @@ -23,9 +47,10 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(): time.sleep(2) return expected_tokens + mock_model_config = MockModelConfig(skip_tokenizer_init=True) mock_tokenizer = Mock(spec=MistralTokenizer) mock_tokenizer.apply_chat_template = mocked_apply_chat_template - mock_renderer = MistralRenderer(Mock(spec=ModelConfig), tokenizer_kwargs={}) + mock_renderer = MistralRenderer(mock_model_config, tokenizer_kwargs={}) mock_renderer._tokenizer = mock_tokenizer task = mock_renderer.render_messages_async([], ChatParams()) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index a051bc54b..03e470427 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -4,52 +4,13 @@ import pytest from vllm.config import ModelConfig -from vllm.inputs import zip_enc_dec_prompts from vllm.inputs.preprocess import InputPreprocessor pytestmark = pytest.mark.cpu_test -@pytest.mark.parametrize( - "mm_processor_kwargs,expected_mm_kwargs", - [ - (None, [{}, {}]), - ({}, [{}, {}]), - ({"foo": 100}, [{"foo": 100}, {"foo": 100}]), - ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]), - ], -) -def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs): - """Test mm_processor_kwargs init for zipping enc/dec prompts.""" - encoder_prompts = ["An encoder prompt", "Another encoder prompt"] - decoder_prompts = ["A decoder prompt", "Another decoder prompt"] - zipped_prompts = zip_enc_dec_prompts( - encoder_prompts, decoder_prompts, mm_processor_kwargs - ) - assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts) - for enc, dec, exp_kwargs, zipped in zip( - encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts - ): - assert isinstance(zipped, dict) - assert len(zipped.keys()) == 3 - assert zipped["encoder_prompt"] == enc - assert zipped["decoder_prompt"] == dec - assert zipped["mm_processor_kwargs"] == exp_kwargs - - -@pytest.mark.parametrize( - "model_id", - [ - "facebook/chameleon-7b", - ], -) -@pytest.mark.parametrize( - "prompt", - [ - "", - {"prompt_token_ids": []}, - ], -) +@pytest.mark.parametrize("model_id", ["facebook/chameleon-7b"]) +@pytest.mark.parametrize("prompt", ["", {"prompt_token_ids": []}]) @pytest.mark.skip( reason=( "Applying huggingface processor on text inputs results in " diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 253cfc42d..372c4c81a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -16,6 +16,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import IOProcessor from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer +from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.v1.engine import EngineCoreRequest @@ -53,7 +54,11 @@ class EngineClient(ABC): @abstractmethod def generate( self, - prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], + prompt: EngineCoreRequest + | PromptType + | DictPrompt + | TokPrompt + | AsyncGenerator[StreamingInput, None], sampling_params: SamplingParams, request_id: str, *, @@ -70,7 +75,7 @@ class EngineClient(ABC): @abstractmethod def encode( self, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, pooling_params: PoolingParams, request_id: str, lora_request: LoRARequest | None = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2dd0c7b48..4a6162cd2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -4,7 +4,7 @@ import itertools import warnings from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, TypeAlias, cast +from typing import TYPE_CHECKING, Any, cast import cloudpickle import torch.nn as nn @@ -53,16 +53,13 @@ from vllm.entrypoints.pooling.score.utils import ( validate_score_input, ) from vllm.entrypoints.utils import log_non_default_args -from vllm.inputs import ( +from vllm.inputs.data import ( DataPrompt, - EmbedsPrompt, - ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt, TextPrompt, TokensPrompt, ) -from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods @@ -76,6 +73,13 @@ from vllm.outputs import ( from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs +from vllm.renderers.inputs import DictPrompt, SingletonDictPrompt, TokPrompt +from vllm.renderers.inputs.preprocess import ( + conversation_to_seq, + extract_prompt_components, + parse_model_prompt, + prompt_to_seq, +) from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask from vllm.tokenizers import TokenizerLike @@ -93,9 +97,6 @@ logger = init_logger(__name__) _R = TypeVar("_R", default=Any) -EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt -EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt] - class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -445,21 +446,20 @@ class LLM: if sampling_params is None: sampling_params = self.get_default_sampling_params() - self._validate_and_add_requests( + outputs = self._run_completion( prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, - lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request), + lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, priority=priority, ) - outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( self, - prompts: PromptType | Sequence[PromptType], + prompts: Sequence[DictPrompt | TokPrompt], lora_request: list[LoRARequest] | LoRARequest | None, ): # Grab the lora config off the vllm config on the engine, @@ -475,9 +475,6 @@ class LLM: ): return lora_request - if not isinstance(prompts, Sequence) or isinstance(prompts, str): - prompts = [prompts] - optional_loras = ( [lora_request] * len(prompts) if not isinstance(lora_request, Sequence) @@ -495,14 +492,12 @@ class LLM: def _resolve_single_prompt_mm_lora( self, - prompt: PromptType, + prompt: DictPrompt | TokPrompt, lora_request: LoRARequest | None, default_mm_loras: dict[str, str] | None, ): - if ( - not default_mm_loras - or not isinstance(prompt, dict) - or not (mm_data := prompt.get("multi_modal_data") or {}) + if not default_mm_loras or not ( + mm_data := prompt.get("multi_modal_data") or {} ): return lora_request @@ -806,61 +801,11 @@ class LLM: add_special_tokens=not model_config.is_encoder_decoder, ).with_kwargs(tokenization_kwargs) - def _normalize_prompts( - self, - prompts: PromptType | Sequence[PromptType], - ) -> list[EnginePrompt | EngineEncDecPrompt]: - if isinstance(prompts, str): - prompts = TextPrompt(prompt=prompts) - - return prompts if isinstance(prompts, Sequence) else [prompts] # type: ignore[return-value] - - def _preprocess_cmpl_singleton( - self, - prompt: SingletonPrompt, - tok_params: TokenizeParams, - *, - tokenize: bool, - ) -> EnginePrompt: - renderer = self.llm_engine.renderer - - if not isinstance(prompt, dict): - prompt = renderer.render_completion(prompt) - - return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt - - def _preprocess_cmpl_enc_dec( - self, - prompt: ExplicitEncoderDecoderPrompt, - tok_params: TokenizeParams, - ) -> EngineEncDecPrompt: - enc_prompt = prompt["encoder_prompt"] - dec_prompt = prompt["decoder_prompt"] - - return EngineEncDecPrompt( - encoder_prompt=self._preprocess_cmpl_singleton( - enc_prompt, - tok_params, - # TODO: Move multi-modal processor into tokenization - tokenize=not self.model_config.is_multimodal_model, - ), - decoder_prompt=( - None - if dec_prompt is None - else self._preprocess_cmpl_singleton( - dec_prompt, - tok_params, - # TODO: Move multi-modal processor into tokenization - tokenize=not self.model_config.is_multimodal_model, - ) - ), - ) - def _preprocess_completion( self, - prompts: PromptType | Sequence[PromptType], + prompts: Sequence[PromptType], tokenization_kwargs: dict[str, Any] | None = None, - ) -> list[EnginePrompt | EngineEncDecPrompt]: + ) -> list[DictPrompt | TokPrompt]: """ Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into a format that can be passed to `_add_request`. @@ -871,32 +816,26 @@ class LLM: A list of `TokensPrompts` objects containing the tokenized prompt after chat template interpolation, and the raw multi-modal inputs. """ + renderer = self.llm_engine.renderer + model_config = self.model_config + tok_params = self._get_cmpl_tok_params(tokenization_kwargs) - engine_prompts = list[EnginePrompt | EngineEncDecPrompt]() - for prompt in self._normalize_prompts(prompts): - if is_explicit_encoder_decoder_prompt(prompt): - engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params)) - else: - # Some MM models have non-default `add_special_tokens` - # TODO: Move multi-modal processor into tokenization - engine_prompts.append( - self._preprocess_cmpl_singleton( - prompt, - tok_params, - tokenize=not self.model_config.is_multimodal_model, - ) - ) + engine_prompts = list[DictPrompt | TokPrompt]() + for prompt in prompts: + parsed_prompt = parse_model_prompt(model_config, prompt) + in_prompt = renderer.render_prompt(parsed_prompt) + + # Some MM models have non-default `add_special_tokens` + # TODO: Move multi-modal processor into tokenization + engine_prompts.append( + in_prompt + if model_config.is_multimodal_model + else renderer.tokenize_prompt(in_prompt, tok_params) + ) return engine_prompts - def _normalize_conversations( - self, - conversations: list[ChatCompletionMessageParam] - | list[list[ChatCompletionMessageParam]], - ) -> list[list[ChatCompletionMessageParam]]: - return conversations if is_list_of(conversations, list) else [conversations] # type: ignore[list-item,return-value] - def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None): model_config = self.model_config encoder_config = model_config.encoder_config or {} @@ -909,8 +848,7 @@ class LLM: def _preprocess_chat( self, - conversations: list[ChatCompletionMessageParam] - | list[list[ChatCompletionMessageParam]], + conversations: Sequence[list[ChatCompletionMessageParam]], chat_template: str | None = None, chat_template_content_format: ChatTemplateContentFormatOption = "auto", chat_template_kwargs: dict[str, Any] | None = None, @@ -919,7 +857,7 @@ class LLM: tools: list[dict[str, Any]] | None = None, tokenization_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> list[EnginePrompt]: + ) -> list[DictPrompt | TokPrompt]: """ Convert a list of conversations into prompts so that they can then be used as input for other LLM APIs. @@ -947,11 +885,14 @@ class LLM: ) tok_params = self._get_chat_tok_params(tokenization_kwargs) - engine_prompts = list[EnginePrompt]() - for conversation in self._normalize_conversations(conversations): + engine_prompts = list[DictPrompt | TokPrompt]() + for conversation in conversations: _, in_prompt = renderer.render_messages(conversation, chat_params) if mm_processor_kwargs is not None: - in_prompt["mm_processor_kwargs"] = mm_processor_kwargs + target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore + "encoder_prompt", in_prompt + ) + target_prompt["mm_processor_kwargs"] = mm_processor_kwargs # type: ignore engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params)) @@ -960,8 +901,8 @@ class LLM: def chat( self, messages: list[ChatCompletionMessageParam] - | list[list[ChatCompletionMessageParam]], - sampling_params: SamplingParams | list[SamplingParams] | None = None, + | Sequence[list[ChatCompletionMessageParam]], + sampling_params: SamplingParams | Sequence[SamplingParams] | None = None, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: LoRARequest | None = None, chat_template: str | None = None, @@ -984,7 +925,7 @@ class LLM: to the OpenAI API. Args: - messages: A list of conversations or a single conversation. + messages: A sequence of conversations or a single conversation. - Each conversation is represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. @@ -1023,8 +964,23 @@ class LLM: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ - prompts = self._preprocess_chat( - messages, + model_config = self.model_config + runner_type = model_config.runner_type + if runner_type != "generate": + raise ValueError( + "LLM.chat() is only supported for generative models. " + "Try passing `--runner generate` to use the model as a " + "generative model." + ) + + if sampling_params is None: + sampling_params = self.get_default_sampling_params() + + outputs = self._run_chat( + messages=messages, + params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, chat_template=chat_template, chat_template_content_format=chat_template_content_format, chat_template_kwargs=chat_template_kwargs, @@ -1035,13 +991,7 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, ) - return self.generate( - prompts, - sampling_params=sampling_params, - use_tqdm=use_tqdm, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) + return self.engine_class.validate_outputs(outputs, RequestOutput) def encode( self, @@ -1163,7 +1113,7 @@ class LLM: msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!" raise ValueError(msg) - self._validate_and_add_requests( + outputs = self._run_completion( prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, @@ -1171,8 +1121,6 @@ class LLM: tokenization_kwargs=tokenization_kwargs, ) - outputs = self._run_engine(use_tqdm=use_tqdm) - model_outputs = self.engine_class.validate_outputs( outputs, PoolingRequestOutput ) @@ -1523,14 +1471,13 @@ class LLM: prompts.append(engine_prompt) - self._validate_and_add_requests( + outputs = self._run_completion( prompts=prompts, params=pooling_params_list, use_tqdm=use_tqdm, lora_request=lora_request, ) - outputs = self._run_engine(use_tqdm=use_tqdm) items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput) return [ScoringRequestOutput.from_base(item) for item in items] @@ -1727,33 +1674,29 @@ class LLM: """ return self.llm_engine.get_metrics() - def _validate_and_add_requests( + def _params_to_seq( self, - prompts: PromptType | Sequence[PromptType], params: SamplingParams - | Sequence[SamplingParams] | PoolingParams - | Sequence[PoolingParams], - *, - use_tqdm: bool | Callable[..., tqdm] = True, - lora_request: Sequence[LoRARequest | None] | LoRARequest | None, - tokenization_kwargs: dict[str, Any] | None = None, - priority: list[int] | None = None, - ) -> None: - in_prompts = self._normalize_prompts(prompts) - num_requests = len(in_prompts) - + | Sequence[SamplingParams | PoolingParams], + num_requests: int, + ) -> Sequence[SamplingParams | PoolingParams]: if isinstance(params, Sequence): if len(params) != num_requests: raise ValueError( f"The lengths of prompts ({params}) " - f"and lora_request ({len(params)}) must be the same." + f"and params ({len(params)}) must be the same." ) - engine_params = params - else: - engine_params = [params] * num_requests + return params + return [params] * num_requests + + def _lora_request_to_seq( + self, + lora_request: LoRARequest | None | Sequence[LoRARequest | None], + num_requests: int, + ) -> Sequence[LoRARequest | None]: if isinstance(lora_request, Sequence): if len(lora_request) != num_requests: raise ValueError( @@ -1761,28 +1704,50 @@ class LLM: f"and lora_request ({len(lora_request)}) must be the same." ) - engine_lora_requests: Sequence[LoRARequest | None] = lora_request - else: - engine_lora_requests = [lora_request] * num_requests + return lora_request + return [lora_request] * num_requests + + def _priority_to_seq( + self, + priority: list[int] | None, + num_requests: int, + ) -> Sequence[int]: if priority is not None: if len(priority) != num_requests: raise ValueError( f"The lengths of prompts ({num_requests}) " f"and priority ({len(priority)}) must be the same." ) - else: - priority = [0] * num_requests - if any(param.truncate_prompt_tokens is not None for param in engine_params): + return priority + + return [0] * num_requests + + def _run_completion( + self, + prompts: PromptType | Sequence[PromptType], + params: SamplingParams + | PoolingParams + | Sequence[SamplingParams | PoolingParams], + *, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: list[LoRARequest] | LoRARequest | None = None, + priority: list[int] | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + ): + seq_prompts = prompt_to_seq(prompts) + seq_params = self._params_to_seq(params, len(seq_prompts)) + + if any(param.truncate_prompt_tokens is not None for param in seq_params): # TODO: Remove this after deprecating `param.truncate_prompt_tokens` # Then, move the code from the `else` block to the top and let # `self._preprocess_completion` handle prompt normalization engine_prompts = [ engine_prompt - for in_prompt, param in zip(in_prompts, engine_params) + for prompt, param in zip(seq_prompts, seq_params) for engine_prompt in self._preprocess_completion( - [in_prompt], + [prompt], tokenization_kwargs=merge_kwargs( tokenization_kwargs, dict(truncate_prompt_tokens=param.truncate_prompt_tokens), @@ -1791,17 +1756,90 @@ class LLM: ] else: engine_prompts = self._preprocess_completion( - in_prompts, + seq_prompts, tokenization_kwargs=tokenization_kwargs, ) - for sp in engine_params: + self._validate_and_add_requests( + prompts=engine_prompts, + params=seq_params, + use_tqdm=use_tqdm, + lora_request=self._get_modality_specific_lora_reqs( + engine_prompts, lora_request + ), + tokenization_kwargs=tokenization_kwargs, + priority=priority, + ) + + return self._run_engine(use_tqdm=use_tqdm) + + def _run_chat( + self, + messages: list[ChatCompletionMessageParam] + | Sequence[list[ChatCompletionMessageParam]], + params: SamplingParams + | PoolingParams + | Sequence[SamplingParams | PoolingParams], + *, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: LoRARequest | None = None, + chat_template: str | None = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: list[dict[str, Any]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + tokenization_kwargs: dict[str, Any] | None = None, + mm_processor_kwargs: dict[str, Any] | None = None, + ): + engine_prompts = self._preprocess_chat( + conversation_to_seq(messages), + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + chat_template_kwargs=chat_template_kwargs, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + tokenization_kwargs=tokenization_kwargs, + mm_processor_kwargs=mm_processor_kwargs, + ) + + self._validate_and_add_requests( + prompts=engine_prompts, + params=params, + use_tqdm=use_tqdm, + lora_request=self._get_modality_specific_lora_reqs( + engine_prompts, lora_request + ), + tokenization_kwargs=tokenization_kwargs, + ) + + return self._run_engine(use_tqdm=use_tqdm) + + def _validate_and_add_requests( + self, + prompts: Sequence[DictPrompt | TokPrompt], + params: SamplingParams + | PoolingParams + | Sequence[SamplingParams | PoolingParams], + *, + use_tqdm: bool | Callable[..., tqdm] = True, + lora_request: Sequence[LoRARequest | None] | LoRARequest | None, + tokenization_kwargs: dict[str, Any] | None = None, + priority: list[int] | None = None, + ) -> None: + num_requests = len(prompts) + seq_params = self._params_to_seq(params, num_requests) + seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests) + seq_priority = self._priority_to_seq(priority, num_requests) + + for sp in seq_params: if isinstance(sp, SamplingParams): # We only care about the final output sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - it = engine_prompts + it = prompts if use_tqdm: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") @@ -1812,10 +1850,10 @@ class LLM: for i, prompt in enumerate(it): request_id = self._add_request( prompt, - engine_params[i], - lora_request=engine_lora_requests[i], + seq_params[i], + lora_request=seq_lora_requests[i], tokenization_kwargs=tokenization_kwargs, - priority=priority[i], + priority=seq_priority[i], ) added_request_ids.append(request_id) except Exception as e: @@ -1825,13 +1863,13 @@ class LLM: def _add_request( self, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, params: SamplingParams | PoolingParams, lora_request: LoRARequest | None = None, tokenization_kwargs: dict[str, Any] | None = None, priority: int = 0, ) -> str: - prompt_text, _, _ = get_prompt_components(prompt) + prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) request_id = str(next(self.request_counter)) if params.truncate_prompt_tokens is not None: diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 8ff686516..433fe961a 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -67,12 +67,13 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ) from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage -from vllm.inputs.data import EmbedsPrompt, TokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput from vllm.parser import ParserManager from vllm.reasoning import ReasoningParser +from vllm.renderers.inputs import TokPrompt from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import ( @@ -218,10 +219,7 @@ class OpenAIServingChat(OpenAIServing): async def render_chat_request( self, request: ChatCompletionRequest, - ) -> ( - tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]] - | ErrorResponse - ): + ) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse: """ render chat request by validating and preprocessing inputs. @@ -380,7 +378,7 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text = engine_prompt.get("prompt") + prompt_text = self._extract_prompt_text(engine_prompt) # If we are creating sub requests for multiple prompts, ensure that they # have unique request ids. @@ -389,10 +387,10 @@ class OpenAIServingChat(OpenAIServing): ) max_tokens = get_max_tokens( - max_model_len=self.max_model_len, - request=request, - prompt=engine_prompt, - default_sampling_params=self.default_sampling_params, + self.max_model_len, + request, + self._extract_prompt_len(engine_prompt), + self.default_sampling_params, ) sampling_params: SamplingParams | BeamSearchParams diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 8981b8662..d2fa2f931 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import ( from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import EmbedsPrompt, TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput +from vllm.renderers.inputs import TokPrompt from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.utils.async_utils import merge_async_iterators @@ -78,7 +78,7 @@ class OpenAIServingCompletion(OpenAIServing): async def render_completion_request( self, request: CompletionRequest, - ) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse: + ) -> list[TokPrompt] | ErrorResponse: """ render completion request by validating and preprocessing inputs. @@ -160,13 +160,13 @@ class OpenAIServingCompletion(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text = engine_prompt.get("prompt") + prompt_text = self._extract_prompt_text(engine_prompt) max_tokens = get_max_tokens( - max_model_len=self.max_model_len, - request=request, - prompt=engine_prompt, - default_sampling_params=self.default_sampling_params, + self.max_model_len, + request, + self._extract_prompt_len(engine_prompt), + self.default_sampling_params, ) sampling_params: SamplingParams | BeamSearchParams @@ -277,7 +277,7 @@ class OpenAIServingCompletion(OpenAIServing): # with the inputs token IDs if final_res.prompt is None: engine_prompt = engine_prompts[i] - final_res.prompt = engine_prompt.get("prompt") + final_res.prompt = self._extract_prompt_text(engine_prompt) final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -313,7 +313,7 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, - engine_prompts: list[TokensPrompt | EmbedsPrompt], + engine_prompts: list[TokPrompt], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -347,7 +347,7 @@ class OpenAIServingCompletion(OpenAIServing): prompt_text = res.prompt if prompt_text is None: engine_prompt = engine_prompts[prompt_idx] - prompt_text = engine_prompt.get("prompt") + prompt_text = self._extract_prompt_text(engine_prompt) # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f87ac5804..ebd629afa 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -96,11 +96,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ) from vllm.entrypoints.utils import get_max_tokens, sanitize_message from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import EmbedsPrompt, PromptType, TokensPrompt -from vllm.inputs.parse import ( - get_prompt_components, - is_explicit_encoder_decoder_prompt, -) +from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest @@ -108,6 +104,14 @@ from vllm.multimodal import MultiModalDataDict from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs +from vllm.renderers.inputs import TokPrompt +from vllm.renderers.inputs.preprocess import ( + SingletonDictPrompt, + extract_prompt_components, + extract_prompt_len, + parse_model_prompt, + prompt_to_seq, +) from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser @@ -203,7 +207,7 @@ class ServeContext(Generic[RequestT]): request_id: str created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - engine_prompts: list[TokensPrompt | EmbedsPrompt] | None = None + engine_prompts: list[TokPrompt] | None = None result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( None @@ -247,7 +251,7 @@ class OpenAIServing: async def beam_search( self, - prompt: PromptType, + prompt: TokPrompt, request_id: str, params: BeamSearchParams, lora_request: LoRARequest | None = None, @@ -271,20 +275,12 @@ class OpenAIServing: eos_token_id: int = tokenizer.eos_token_id # type: ignore - if is_explicit_encoder_decoder_prompt(prompt): - raise NotImplementedError + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + raise NotImplementedError("Encoder-decoder prompt not supported") - prompt_text: str | None - prompt_token_ids: list[int] - multi_modal_data: MultiModalDataDict | None - if isinstance(prompt, str): - prompt_text = prompt - prompt_token_ids = [] - multi_modal_data = None - else: - prompt_text = prompt.get("prompt") # type: ignore - prompt_token_ids = prompt.get("prompt_token_ids", []) # type: ignore - multi_modal_data = prompt.get("multi_modal_data") # type: ignore + prompt_text: str | None = prompt.get("prompt") # type: ignore + prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore + multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore mm_processor_kwargs: dict[str, Any] | None = None @@ -963,22 +959,40 @@ class OpenAIServing: request: RendererRequest, prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_embeds: bytes | list[bytes] | None, - ) -> list[TokensPrompt | EmbedsPrompt]: + ) -> list[TokPrompt]: renderer = self.renderer - tok_params = request.build_tok_params(self.model_config) + model_config = self.model_config - in_prompts = await renderer.render_completions_async( - prompt_input, prompt_embeds - ) - engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params) + tok_params = request.build_tok_params(model_config) + + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + in_prompts = await renderer.render_prompts_async(parsed_prompts) extra_items = { k: v for k in ("mm_processor_kwargs", "cache_salt") if (v := getattr(request, k, None)) is not None } - for prompt in engine_prompts: - prompt.update(extra_items) # type: ignore + for in_prompt in in_prompts: + target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore + "encoder_prompt", in_prompt + ) + target_prompt.update(extra_items) # type: ignore + + engine_prompts = await renderer.tokenize_prompts_async(in_prompts, tok_params) return engine_prompts @@ -991,7 +1005,7 @@ class OpenAIServing: default_template_kwargs: dict[str, Any] | None, tool_dicts: list[dict[str, Any]] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, - ) -> tuple[list[ConversationMessage], list[TokensPrompt | EmbedsPrompt]]: + ) -> tuple[list[ConversationMessage], list[TokPrompt]]: from vllm.tokenizers.mistral import MistralTokenizer renderer = self.renderer @@ -1009,17 +1023,21 @@ class OpenAIServing: default_template, default_template_content_format ).with_defaults(default_template_kwargs) - conversation, prompt = await renderer.render_messages_async( + conversation, in_prompt = await renderer.render_messages_async( messages, chat_params ) - engine_prompt = await renderer.tokenize_prompt_async(prompt, tok_params) + target_prompt: SingletonDictPrompt = in_prompt.get( # type: ignore + "encoder_prompt", in_prompt + ) extra_items = { k: v for k in ("mm_processor_kwargs", "cache_salt") if (v := getattr(request, k, None)) is not None } - engine_prompt.update(extra_items) # type: ignore + target_prompt.update(extra_items) # type: ignore + + engine_prompt = await renderer.tokenize_prompt_async(target_prompt, tok_params) # tool parsing is done only if a tool_parser has been set and if # tool_choice is not "none" (if tool_choice is "none" but a tool_parser @@ -1040,6 +1058,15 @@ class OpenAIServing: return conversation, [engine_prompt] + def _extract_prompt_components(self, prompt: object): + return extract_prompt_components(self.model_config, prompt) + + def _extract_prompt_text(self, prompt: object): + return self._extract_prompt_components(prompt).text + + def _extract_prompt_len(self, prompt: object): + return extract_prompt_len(self.model_config, prompt) + async def _render_next_turn( self, request: ResponsesRequest, @@ -1067,7 +1094,7 @@ class OpenAIServing: async def _generate_with_builtin_tools( self, request_id: str, - engine_prompt: TokensPrompt | EmbedsPrompt, + engine_prompt: TokPrompt, sampling_params: SamplingParams, tok_params: TokenizeParams, context: ConversationContext, @@ -1075,7 +1102,7 @@ class OpenAIServing: priority: int = 0, trace_headers: Mapping[str, str] | None = None, ): - prompt_text = engine_prompt.get("prompt") + prompt_text = self._extract_prompt_text(engine_prompt) orig_priority = priority sub_request = 0 @@ -1145,12 +1172,12 @@ class OpenAIServing: context.chat_template_content_format, ) engine_prompt = engine_prompts[0] - prompt_text = engine_prompt.get("prompt") + prompt_text = self._extract_prompt_text(engine_prompt) sampling_params.max_tokens = get_max_tokens( self.max_model_len, context.request, - engine_prompt, + self._extract_prompt_len(engine_prompt), self.default_sampling_params, # type: ignore ) @@ -1161,20 +1188,20 @@ class OpenAIServing: def _log_inputs( self, request_id: str, - inputs: PromptType, + inputs: PromptType | TokPrompt, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, ) -> None: if self.request_logger is None: return - prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs) + components = self._extract_prompt_components(inputs) self.request_logger.log_inputs( request_id, - prompt, - prompt_token_ids, - prompt_embeds, + components.text, + components.token_ids, + components.embeds, params=params, lora_request=lora_request, ) diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 1ed7a79cc..500401468 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -116,13 +116,13 @@ from vllm.entrypoints.openai.responses.utils import ( ) from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import EmbedsPrompt, TokensPrompt -from vllm.inputs.parse import get_prompt_len +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs from vllm.outputs import CompletionOutput from vllm.parser import ParserManager +from vllm.renderers.inputs import TokPrompt from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.tokenizers import TokenizerLike from vllm.utils import random_uuid @@ -292,10 +292,10 @@ class OpenAIServingResponses(OpenAIServing): def _validate_generator_input( self, - engine_prompt: TokensPrompt | EmbedsPrompt, + engine_prompt: TokPrompt, ) -> ErrorResponse | None: """Add validations to the input to the generator here.""" - prompt_len = get_prompt_len(engine_prompt) + prompt_len = self._extract_prompt_len(engine_prompt) if self.max_model_len <= prompt_len: error_message = ( f"The engine prompt length {prompt_len} " @@ -442,7 +442,7 @@ class OpenAIServingResponses(OpenAIServing): default_max_tokens = get_max_tokens( self.max_model_len, request, - engine_prompt, + self._extract_prompt_len(engine_prompt), self.default_sampling_params, ) diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 19dccbb17..454359ffd 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -7,7 +7,7 @@ import time import zlib from collections.abc import AsyncGenerator, Callable from functools import cached_property -from typing import Literal, TypeAlias, TypeVar, cast +from typing import Final, Literal, TypeAlias, TypeVar, cast import numpy as np from fastapi import Request @@ -37,12 +37,13 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( TranslationStreamResponse, ) from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.logprobs import FlatLogprobs, Logprob from vllm.model_executor.models import SupportsTranscription, supports_transcription from vllm.outputs import RequestOutput +from vllm.renderers.inputs import EncoderDecoderDictPrompt +from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt from vllm.tokenizers import get_tokenizer from vllm.utils.import_utils import PlaceholderModule @@ -94,7 +95,7 @@ class OpenAISpeechToText(OpenAIServing): ) self.default_sampling_params = self.model_config.get_diff_sampling_param() - self.task_type = task_type + self.task_type: Final = task_type self.asr_config = self.model_cls.get_speech_to_text_config( self.model_config, task_type @@ -298,35 +299,26 @@ class OpenAISpeechToText(OpenAIServing): to_language=to_language, ) if request.response_format == "verbose_json": - if not is_explicit_encoder_decoder_prompt(prompt): - raise VLLMValidationError( - "Expected prompt to be an encoder-decoder prompt", - parameter="prompt", - value=type(prompt).__name__, - ) - - prompt = self._preprocess_verbose_prompt(prompt) + prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt)) prompts.append(prompt) + return prompts, duration - def _repl_verbose_text(self, text: str): - return text.replace("<|notimestamps|>", "<|0.00|>") - - def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt): + def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt): dec_prompt = prompt["decoder_prompt"] - if isinstance(dec_prompt, str): - prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt) - elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt: - dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"]) - else: + if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt): raise VLLMValidationError( "Expected decoder_prompt to contain text", parameter="decoder_prompt", value=type(dec_prompt).__name__, ) + dec_prompt["prompt"] = dec_prompt["prompt"].replace( + "<|notimestamps|>", "<|0.00|>" + ) + return prompt def _get_verbose_segments( diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index e1b776377..f06ed9ad7 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -28,10 +28,11 @@ from vllm.entrypoints.pooling.utils import ( encode_pooling_output_base64, encode_pooling_output_float, ) -from vllm.inputs.data import EmbedsPrompt, TokensPrompt +from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.pooling_params import PoolingParams +from vllm.renderers.inputs import TokPrompt from vllm.utils.async_utils import merge_async_iterators from vllm.utils.collection_utils import chunk_list from vllm.utils.serial_utils import EmbedDType, Endianness @@ -369,7 +370,7 @@ class OpenAIServingEmbedding(OpenAIServing): async def _create_single_prompt_generator( self, ctx: EmbeddingServeContext, - engine_prompt: TokensPrompt | EmbedsPrompt, + engine_prompt: TokPrompt, pooling_params: PoolingParams, trace_headers: Mapping[str, str] | None, prompt_index: int, diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index faf5a09d4..3ad5786db 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -33,8 +33,11 @@ from vllm.entrypoints.pooling.utils import ( encode_pooling_output_base64, encode_pooling_output_float, ) +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput +from vllm.renderers.inputs import TokPrompt +from vllm.renderers.inputs.preprocess import prompt_to_seq from vllm.utils.async_utils import merge_async_iterators from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness @@ -91,6 +94,7 @@ class OpenAIServingPooling(OpenAIServing): "dimensions is currently not supported" ) + engine_prompts: Sequence[PromptType | TokPrompt] if is_io_processor_request: if self.io_processor is None: raise ValueError( @@ -102,14 +106,10 @@ class OpenAIServingPooling(OpenAIServing): validated_prompt = self.io_processor.parse_request(request) - engine_prompts = await self.io_processor.pre_process_async( + raw_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id ) - if not isinstance(engine_prompts, Sequence) or isinstance( - engine_prompts, (str, bytes, bytearray) - ): - engine_prompts = [engine_prompts] - + engine_prompts = prompt_to_seq(raw_prompts) elif isinstance(request, PoolingChatRequest): error_check_ret = self._validate_chat_template( request_chat_template=request.chat_template, diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 290f8fd89..4900bfa7d 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -17,8 +17,6 @@ from starlette.background import BackgroundTask, BackgroundTasks from vllm import envs from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import EmbedsPrompt, TokensPrompt -from vllm.inputs.parse import get_prompt_len from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -189,7 +187,7 @@ def cli_env_setup(): def get_max_tokens( max_model_len: int, request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", - prompt: TokensPrompt | EmbedsPrompt, + input_length: int, default_sampling_params: dict, ) -> int: # NOTE: Avoid isinstance() for better efficiency @@ -204,7 +202,6 @@ def get_max_tokens( # CompletionRequest (also a fallback for ChatCompletionRequest) max_tokens = getattr(request, "max_tokens", None) - input_length = get_prompt_len(prompt) default_max_tokens = max_model_len - input_length max_output_tokens = current_platform.get_max_output_tokens(input_length) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 0fdb3ab5e..de8ddc615 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -16,11 +16,8 @@ from .data import ( TextPrompt, TokenInputs, TokensPrompt, - build_explicit_enc_dec_prompt, embeds_inputs, - to_enc_dec_tuple_list, token_inputs, - zip_enc_dec_prompts, ) __all__ = [ @@ -39,8 +36,5 @@ __all__ = [ "EncoderDecoderInputs", "ProcessorInputs", "SingletonInputs", - "build_explicit_enc_dec_prompt", - "to_enc_dec_tuple_list", - "zip_enc_dec_prompts", "StreamingInput", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index d9f9814ee..7848c2c03 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, cast +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import torch -from typing_extensions import NotRequired, TypedDict, TypeVar +from typing_extensions import NotRequired, TypedDict from vllm.sampling_params import SamplingParams @@ -23,7 +22,13 @@ else: MultiModalUUIDDict = object -class _CommonKeys(TypedDict): +# Inputs to LLM API +class _PromptOptions(TypedDict): + """ + Additional options available to all + [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]. + """ + multi_modal_data: NotRequired[MultiModalDataDict | None] """ Optional multi-modal data to pass to the model, @@ -53,14 +58,14 @@ class _CommonKeys(TypedDict): """ -class TextPrompt(_CommonKeys): +class TextPrompt(_PromptOptions): """Schema for a text prompt.""" prompt: str """The input text to be tokenized before passing to the model.""" -class TokensPrompt(_CommonKeys): +class TokensPrompt(_PromptOptions): """Schema for a tokenized prompt.""" prompt_token_ids: list[int] @@ -73,7 +78,7 @@ class TokensPrompt(_CommonKeys): """A list of token type IDs to pass to the cross encoder model.""" -class EmbedsPrompt(_CommonKeys): +class EmbedsPrompt(_PromptOptions): """Schema for a prompt provided via token embeddings.""" prompt_embeds: torch.Tensor @@ -83,93 +88,113 @@ class EmbedsPrompt(_CommonKeys): """The prompt text corresponding to the token embeddings, if available.""" -class DataPrompt(_CommonKeys): - """Represents generic inputs handled by IO processor plugins.""" +DecoderOnlyPrompt: TypeAlias = ( + str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt +) +""" +Schema of a prompt for a decoder-only model: + +- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt (list of token IDs, or + [`TokensPrompt`][vllm.inputs.data.TokensPrompt]) +- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) + +For encoder-decoder models, passing a singleton prompt is shorthand for passing +`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`. +""" + + +EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt +""" +Schema of a prompt for the encoder part of a encoder-decoder model: + +- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt (list of token IDs, or + [`TokensPrompt`][vllm.inputs.data.TokensPrompt]) +""" + + +DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt +""" +Schema of a prompt for the decoder part of an encoder-decoder model: + +- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt]) +- A tokenized prompt (list of token IDs, or + [`TokensPrompt`][vllm.inputs.data.TokensPrompt]) + +Note: + Multi-modal inputs are not supported for decoder prompts. +""" + + +class ExplicitEncoderDecoderPrompt(TypedDict): + """ + Schema for a pair of encoder and decoder singleton prompts. + + Note: + This schema is not valid for decoder-only models. + """ + + encoder_prompt: EncoderPrompt + """The prompt for the encoder part of the model.""" + + decoder_prompt: DecoderPrompt | None + """ + The prompt for the decoder part of the model. + + Passing `None` will cause the prompt to be inferred automatically. + """ + + +EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt +""" +Schema for a prompt for an encoder-decoder model. + +You can pass a singleton encoder prompt, in which case the decoder prompt is +considered to be `None` (i.e., infer automatically). +""" + + +SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt +""" +Schema for a single prompt. This is as opposed to a data structure +which encapsulates multiple prompts, such as +[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]. +""" + + +PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt +""" +Schema for any prompt, regardless of model type. + +This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs. +""" + + +class DataPrompt(_PromptOptions): + """ + Represents generic inputs that are converted to + [`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins. + """ data: Any - """The input data""" + """The input data.""" data_format: str - """The input data format""" + """The input data format.""" -SingletonPrompt: TypeAlias = str | TextPrompt | TokensPrompt | EmbedsPrompt -""" -Set of possible schemas for a single prompt: - -- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) -- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) -- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) - -Note that "singleton" is as opposed to a data structure -which encapsulates multiple prompts, i.e. of the sort -which may be utilized for encoder/decoder models when -the user desires to express both the encoder & decoder -prompts explicitly, i.e. -[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] - -A prompt of type [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] may be -employed as (1) input to a decoder-only model, (2) input to -the encoder of an encoder/decoder model, in the scenario -where the decoder-prompt is not specified explicitly, or -(3) as a member of a larger data structure encapsulating -more than one prompt, i.e. -[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] -""" - - -_T1_co = TypeVar( - "_T1_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True -) -_T2_co = TypeVar( - "_T2_co", bound=SingletonPrompt, default=SingletonPrompt, covariant=True -) - - -# TODO: Make fields ReadOnly once mypy supports it -class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): +# Outputs of processor +class _InputOptions(TypedDict): """ - Represents an encoder/decoder model input prompt, - comprising an explicit encoder prompt and a decoder prompt. - - The encoder and decoder prompts, respectively, may be formatted - according to any of the - [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas, - and are not required to have the same schema. - - Only the encoder prompt may have multi-modal data. mm_processor_kwargs - should be at the top-level, and should not be set in the encoder/decoder - prompts, since they are agnostic to the encoder/decoder. - - Note that an - [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] - may not be used as an input to a decoder-only model, - and that the `encoder_prompt` and `decoder_prompt` - fields of this data structure themselves must be - [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] instances. + Additional options available to all input types. """ - encoder_prompt: _T1_co - - decoder_prompt: _T2_co | None - - mm_processor_kwargs: NotRequired[dict[str, Any]] + cache_salt: NotRequired[str] + """Optional cache salt to be used for prefix caching.""" -PromptType: TypeAlias = SingletonPrompt | ExplicitEncoderDecoderPrompt[Any, Any] -""" -Set of possible schemas for an LLM input, including -both decoder-only and encoder/decoder input types: - -- A text prompt ([`str`][] or [`TextPrompt`][vllm.inputs.data.TextPrompt]) -- A tokenized prompt ([`TokensPrompt`][vllm.inputs.data.TokensPrompt]) -- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) -- A single data structure containing both an encoder and a decoder prompt - ([`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]) -""" - - -class TokenInputs(TypedDict): +class TokenInputs(_InputOptions): """Represents token-based inputs.""" type: Literal["token"] @@ -178,11 +203,6 @@ class TokenInputs(TypedDict): prompt_token_ids: list[int] """The token IDs of the prompt.""" - cache_salt: NotRequired[str] - """ - Optional cache salt to be used for prefix caching. - """ - def token_inputs( prompt_token_ids: list[int], @@ -198,7 +218,7 @@ def token_inputs( return inputs -class EmbedsInputs(TypedDict): +class EmbedsInputs(_InputOptions): """Represents embeddings-based inputs.""" type: Literal["embeds"] @@ -207,11 +227,6 @@ class EmbedsInputs(TypedDict): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" - cache_salt: NotRequired[str] - """ - Optional cache salt to be used for prefix caching. - """ - def embeds_inputs( prompt_embeds: torch.Tensor, @@ -229,96 +244,60 @@ def embeds_inputs( DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs """ -The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they are -passed to the model executor. -This specifies the data required for decoder-only models. +A processed prompt from +[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] +which can be passed to +[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] +for decoder-only models. +""" + + +EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs +""" +A processed encoder prompt from +[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] +which can be passed to +[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] +for encoder-decoder models. +""" + + +DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs +""" +A processed decoder prompt from +[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] +which can be passed to +[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] +for encoder-decoder models. """ class EncoderDecoderInputs(TypedDict): """ - The inputs in [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] before they - are passed to the model executor. - - This specifies the required data for encoder-decoder models. + A processed pair of encoder and decoder singleton prompts. + [`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] + which can be passed to + [`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] + for encoder-decoder models. """ - encoder: TokenInputs | MultiModalEncDecInputs + encoder: EncoderInputs """The inputs for the encoder portion.""" - decoder: TokenInputs | MultiModalInputs + decoder: DecoderInputs """The inputs for the decoder portion.""" -SingletonInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs -""" -A processed [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] which can be -passed to [`Sequence`][collections.abc.Sequence]. -""" - ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs """ -The outputs from [`vllm.inputs.preprocess.InputPreprocessor`][]. +A processed prompt from +[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] +which can be passed to +[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]. """ -_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) -_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) - -def build_explicit_enc_dec_prompt( - encoder_prompt: _T1, - decoder_prompt: _T2 | None, - mm_processor_kwargs: dict[str, Any] | None = None, -) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: - if mm_processor_kwargs is None: - mm_processor_kwargs = {} - return ExplicitEncoderDecoderPrompt( - encoder_prompt=encoder_prompt, - decoder_prompt=decoder_prompt, - mm_processor_kwargs=mm_processor_kwargs, - ) - - -def zip_enc_dec_prompts( - enc_prompts: Iterable[_T1], - dec_prompts: Iterable[_T2 | None], - mm_processor_kwargs: Iterable[dict[str, Any]] | dict[str, Any] | None = None, -) -> list[ExplicitEncoderDecoderPrompt[_T1, _T2]]: - """ - Zip encoder and decoder prompts together into a list of - [`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt] - instances. - - `mm_processor_kwargs` may also be provided; if a dict is passed, the same - dictionary will be used for every encoder/decoder prompt. If an iterable is - provided, it will be zipped with the encoder/decoder prompts. - """ - if mm_processor_kwargs is None: - mm_processor_kwargs = cast(dict[str, Any], {}) - if isinstance(mm_processor_kwargs, dict): - return [ - build_explicit_enc_dec_prompt( - encoder_prompt, - decoder_prompt, - cast(dict[str, Any], mm_processor_kwargs), - ) - for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts) - ] - return [ - build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, mm_proc_kwargs) - for (encoder_prompt, decoder_prompt, mm_proc_kwargs) in zip( - enc_prompts, dec_prompts, mm_processor_kwargs - ) - ] - - -def to_enc_dec_tuple_list( - enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], -) -> list[tuple[_T1, _T2 | None]]: - return [ - (enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) - for enc_dec_prompt in enc_dec_prompts - ] +SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs @dataclass diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 7cb1eb4b4..611a470ba 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -1,88 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict -from typing_extensions import TypeIs - -from vllm.utils import length_from_prompt_token_ids_or_embeds - -from .data import ( - EmbedsPrompt, - ExplicitEncoderDecoderPrompt, - ProcessorInputs, - PromptType, - SingletonInputs, - SingletonPrompt, - TextPrompt, - TokensPrompt, -) - -if TYPE_CHECKING: - import torch - - -class ParsedStrPrompt(TypedDict): - type: Literal["str"] - content: str - - -class ParsedTextPrompt(TypedDict): - type: Literal["text"] - content: TextPrompt - - -class ParsedTokensPrompt(TypedDict): - type: Literal["tokens"] - content: TokensPrompt - - -class ParsedEmbedsPrompt(TypedDict): - type: Literal["embeds"] - content: EmbedsPrompt - - -ParsedSingletonPrompt: TypeAlias = ( - ParsedStrPrompt | ParsedTextPrompt | ParsedTokensPrompt | ParsedEmbedsPrompt -) - - -def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt: - if isinstance(prompt, str): - return ParsedStrPrompt(type="str", content=prompt) - elif isinstance(prompt, dict): - # Type ignores are because mypy does not correctly infer the TypedDicts - # Pyright does succeed. - if "prompt_embeds" in prompt: - return ParsedEmbedsPrompt(type="embeds", content=prompt) # type: ignore[typeddict-item] - elif "prompt_token_ids" in prompt: - return ParsedTokensPrompt(type="tokens", content=prompt) # type: ignore[typeddict-item] - elif "prompt" in prompt: - return ParsedTextPrompt(type="text", content=prompt) - raise TypeError( - "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt" - ) - - -def is_explicit_encoder_decoder_prompt( - prompt: PromptType, -) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(prompt, dict) and "encoder_prompt" in prompt - - -def split_enc_dec_prompt( - prompt: PromptType, -) -> tuple[SingletonPrompt, SingletonPrompt | None]: - if isinstance(prompt, str): - return prompt, None - - if "encoder_prompt" in prompt and "decoder_prompt" in prompt: - # NOTE: This passes pyright but not mypy - return ( - prompt["encoder_prompt"], # type: ignore[typeddict-item] - prompt["decoder_prompt"], # type: ignore[typeddict-item] - ) - - return prompt, None +from .data import ProcessorInputs, SingletonInputs def split_enc_dec_inputs( @@ -96,30 +15,3 @@ def split_enc_dec_inputs( ) return None, inputs - - -class PromptComponents(NamedTuple): - text: str | None = None - token_ids: list[int] | None = None - embeds: "torch.Tensor | None" = None - - -def get_prompt_components(prompt: PromptType) -> PromptComponents: - if isinstance(prompt, str): - return PromptComponents(text=prompt) - - if encoder_prompt := prompt.get("encoder_prompt"): - return get_prompt_components(encoder_prompt) # type: ignore[arg-type] - - return PromptComponents( - text=prompt.get("prompt"), # type: ignore[arg-type] - token_ids=prompt.get("prompt_token_ids"), # type: ignore[arg-type] - embeds=prompt.get("prompt_embeds"), - ) - - -def get_prompt_len(prompt: TokensPrompt | EmbedsPrompt): - return length_from_prompt_token_ids_or_embeds( - prompt.get("prompt_token_ids"), # type: ignore[arg-type] - prompt.get("prompt_embeds"), # type: ignore[arg-type] - ) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 0a3b0c946..1d085cabb 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -2,43 +2,51 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping -from typing import Any +from typing import Any, overload from typing_extensions import assert_never from vllm.config import ModelConfig, ObservabilityConfig -from vllm.inputs.parse import split_enc_dec_prompt from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import ( MultiModalDataDict, - MultiModalEncDecInputs, MultiModalInputs, MultiModalUUIDDict, ) from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.renderers import renderer_from_config +from vllm.renderers.inputs import ( + DecoderDictPrompt, + DecoderOnlyDictPrompt, + DictPrompt, + EncoderDecoderDictPrompt, + EncoderDictPrompt, + SingletonDictPrompt, + TokPrompt, +) +from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt from vllm.tokenizers import TokenizerLike from vllm.utils.jsontree import json_iter_leaves from vllm.v1.metrics.stats import MultiModalCacheStats from .data import ( + DecoderInputs, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, EncoderDecoderInputs, + EncoderInputs, ProcessorInputs, PromptType, SingletonInputs, - SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, embeds_inputs, token_inputs, ) -from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt logger = init_logger(__name__) @@ -328,9 +336,36 @@ class InputPreprocessor: return inputs + @overload def _prompt_to_llm_inputs( self, - prompt: SingletonPrompt, + prompt: EncoderDictPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> EncoderInputs: ... + + @overload + def _prompt_to_llm_inputs( # type: ignore[misc] + self, + prompt: DecoderDictPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> DecoderInputs: ... + + @overload + def _prompt_to_llm_inputs( # type: ignore[misc] + self, + prompt: DecoderOnlyDictPrompt, + tokenization_kwargs: dict[str, Any] | None = None, + *, + mm_uuids: MultiModalUUIDDict | None = None, + ) -> DecoderOnlyInputs: ... + + def _prompt_to_llm_inputs( + self, + prompt: SingletonDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, @@ -346,34 +381,25 @@ class InputPreprocessor: * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance """ - parsed = parse_singleton_prompt(prompt) + if "prompt_embeds" in prompt: + return self._process_embeds(prompt) # type: ignore[arg-type] - if parsed["type"] == "embeds": - return self._process_embeds(parsed["content"]) - if parsed["type"] == "tokens": + if "prompt_token_ids" in prompt: return self._process_tokens( - parsed["content"], + prompt, # type: ignore[arg-type] mm_uuids=mm_uuids, ) - if parsed["type"] == "text": + + if "prompt" in prompt: return self._process_text( - parsed["content"], - tokenization_kwargs=tokenization_kwargs, - mm_uuids=mm_uuids, - ) - if parsed["type"] == "str": - return self._process_text( - TextPrompt(prompt=parsed["content"]), + prompt, # type: ignore[arg-type] tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) - assert_never(parsed) + assert_never(prompt) # type: ignore[arg-type] - def _validate_enc_inputs( - self, - inputs: SingletonInputs, - ) -> TokenInputs | MultiModalEncDecInputs: + def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs: if inputs["type"] == "embeds": raise ValueError( "Embedding inputs are not supported for encoder-decoder models" @@ -387,10 +413,7 @@ class InputPreprocessor: return inputs # type: ignore[return-value] - def _validate_dec_inputs( - self, - inputs: SingletonInputs, - ) -> TokenInputs | MultiModalInputs: + def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs: if inputs["type"] == "embeds": raise ValueError( "Embedding inputs are not supported for encoder-decoder models" @@ -403,14 +426,15 @@ class InputPreprocessor: encoder_inputs: SingletonInputs, decoder_inputs: SingletonInputs | None = None, ) -> EncoderDecoderInputs: - if decoder_inputs is None: - decoder_inputs = encoder_inputs - enc_inputs = self._validate_enc_inputs(encoder_inputs) - dec_inputs = self._validate_dec_inputs(decoder_inputs) - enc_inputs_new: TokenInputs | MultiModalEncDecInputs - dec_inputs_new: TokenInputs | MultiModalInputs + if decoder_inputs is None: + dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment] + else: + dec_inputs = self._validate_dec_inputs(decoder_inputs) + + enc_inputs_new: EncoderInputs + dec_inputs_new: DecoderInputs if enc_inputs["type"] == "multimodal": enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"]) @@ -437,7 +461,7 @@ class InputPreprocessor: def _process_encoder_decoder_prompt( self, - prompt: PromptType, + prompt: EncoderDecoderDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, @@ -448,24 +472,6 @@ class InputPreprocessor: [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] instance. - There are two types of input prompts: - singleton prompts which carry only the - encoder prompt, and explicit encoder/decoder - prompts which carry both the encoder and the - decoder prompts as member variables. - - This function handles the following scenarios: - * Singleton encoder prompt: extract encoder prompt - token ids & infer default decoder prompt token ids - * Explicit encoder/decoder prompt: extract encoder - and decoder prompt token ids - - Note that for Explicit encoder/decoder prompts, - each sub-prompt (encoder or decoder prompt) can - have any possible singleton type; thus this - method relies on helper functions to obtain - token ids for the sub-prompts. - Arguments: * prompt: an input prompt @@ -475,7 +481,8 @@ class InputPreprocessor: * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] instance """ - encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt) + encoder_prompt = prompt["encoder_prompt"] + decoder_prompt = prompt["decoder_prompt"] return self._build_enc_dec_inputs( encoder_inputs=self._prompt_to_llm_inputs( @@ -495,7 +502,7 @@ class InputPreprocessor: def _process_decoder_only_prompt( self, - prompt: SingletonPrompt, + prompt: DecoderOnlyDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, @@ -521,7 +528,7 @@ class InputPreprocessor: def _preprocess( self, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, tokenization_kwargs: dict[str, Any] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, @@ -530,25 +537,20 @@ class InputPreprocessor: # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder. return self._process_encoder_decoder_prompt( - prompt, + parse_enc_dec_prompt(prompt), tokenization_kwargs, mm_uuids=mm_uuids, ) - if is_explicit_encoder_decoder_prompt(prompt): - raise ValueError( - "Cannot pass encoder-decoder prompt to decoder-only models" - ) - return self._process_decoder_only_prompt( - prompt, + parse_dec_only_prompt(prompt), tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) def preprocess( self, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, tokenization_kwargs: dict[str, Any] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 9f01af9a8..a3f8b21c2 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -20,7 +20,7 @@ from typing import ( import numpy as np from PIL.Image import Image -from typing_extensions import NotRequired, TypeVar +from typing_extensions import TypeVar from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader @@ -32,9 +32,13 @@ if TYPE_CHECKING: import torch import torch.types from transformers.feature_extraction_utils import BatchFeature + + from vllm.inputs.data import _InputOptions else: torch = LazyLoader("torch", globals(), "torch") + _InputOptions = dict + _T = TypeVar("_T") HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"] @@ -1059,7 +1063,7 @@ A dictionary containing per-item placeholder ranges for each modality. """ -class MultiModalInputs(TypedDict): +class MultiModalInputs(_InputOptions): """ Represents the outputs of [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor], @@ -1084,11 +1088,6 @@ class MultiModalInputs(TypedDict): `prompt_token_ids`. """ - cache_salt: NotRequired[str] - """ - Optional cache salt to be used for prefix caching. - """ - class MultiModalEncDecInputs(MultiModalInputs): """ diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d50d2f69c..45dde6e47 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams + from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.attention.selector import AttentionSelectorConfig @@ -565,7 +566,7 @@ class Platform: @classmethod def validate_request( cls, - prompt: "PromptType", + prompt: "PromptType | DictPrompt | TokPrompt", params: "SamplingParams | PoolingParams", processed_inputs: "ProcessorInputs", ) -> None: diff --git a/vllm/renderers/deepseek_v32.py b/vllm/renderers/deepseek_v32.py index 9e6008c55..f83edd16f 100644 --- a/vllm/renderers/deepseek_v32.py +++ b/vllm/renderers/deepseek_v32.py @@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams from .protocol import BaseRenderer @@ -61,7 +62,7 @@ class DeepseekV32Renderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = parse_chat_messages( messages, @@ -75,7 +76,7 @@ class DeepseekV32Renderer(BaseRenderer): **params.get_apply_chat_template_kwargs(), ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: @@ -87,7 +88,7 @@ class DeepseekV32Renderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = await parse_chat_messages_async( messages, @@ -101,7 +102,7 @@ class DeepseekV32Renderer(BaseRenderer): **params.get_apply_chat_template_kwargs(), ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: diff --git a/vllm/renderers/grok2.py b/vllm/renderers/grok2.py index c1b56b397..c5c3afe86 100644 --- a/vllm/renderers/grok2.py +++ b/vllm/renderers/grok2.py @@ -9,11 +9,12 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.grok2 import Grok2Tokenizer +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams from .protocol import BaseRenderer @@ -61,7 +62,7 @@ class Grok2Renderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = parse_chat_messages( messages, @@ -75,7 +76,7 @@ class Grok2Renderer(BaseRenderer): **params.get_apply_chat_template_kwargs(), ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: @@ -87,7 +88,7 @@ class Grok2Renderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = await parse_chat_messages_async( messages, @@ -101,7 +102,7 @@ class Grok2Renderer(BaseRenderer): **params.get_apply_chat_template_kwargs(), ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py index 079dd0953..5425bd888 100644 --- a/vllm/renderers/hf.py +++ b/vllm/renderers/hf.py @@ -25,7 +25,6 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer @@ -33,6 +32,8 @@ from vllm.transformers_utils.chat_templates import get_chat_template_fallback_pa from vllm.transformers_utils.processor import cached_get_processor from vllm.utils.func_utils import supports_kw +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams from .protocol import BaseRenderer @@ -632,7 +633,7 @@ class HfRenderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: model_config = self.config tokenizer = self.get_tokenizer() @@ -674,7 +675,7 @@ class HfRenderer(BaseRenderer): video_placeholder, ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: @@ -686,7 +687,7 @@ class HfRenderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: model_config = self.config tokenizer = self.get_tokenizer() @@ -726,7 +727,7 @@ class HfRenderer(BaseRenderer): video_placeholder, ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: diff --git a/vllm/renderers/inputs/__init__.py b/vllm/renderers/inputs/__init__.py new file mode 100644 index 000000000..31abda66a --- /dev/null +++ b/vllm/renderers/inputs/__init__.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .preprocess import ( + DecoderDictPrompt, + DecoderOnlyDictPrompt, + DictPrompt, + EncoderDecoderDictPrompt, + EncoderDictPrompt, + SingletonDictPrompt, +) +from .tokenize import ( + DecoderOnlyTokPrompt, + DecoderTokPrompt, + EncoderDecoderTokPrompt, + EncoderTokPrompt, + SingletonTokPrompt, + TokPrompt, +) + +__all__ = [ + "DecoderOnlyDictPrompt", + "EncoderDictPrompt", + "DecoderDictPrompt", + "EncoderDecoderDictPrompt", + "SingletonDictPrompt", + "DictPrompt", + "DecoderOnlyTokPrompt", + "EncoderTokPrompt", + "DecoderTokPrompt", + "EncoderDecoderTokPrompt", + "SingletonTokPrompt", + "TokPrompt", +] diff --git a/vllm/renderers/inputs/preprocess.py b/vllm/renderers/inputs/preprocess.py new file mode 100644 index 000000000..eaac6aeb5 --- /dev/null +++ b/vllm/renderers/inputs/preprocess.py @@ -0,0 +1,255 @@ +""" +Schemas and utilites for preprocessing inputs. +""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload + +from vllm.inputs import ( + EmbedsPrompt, + ExplicitEncoderDecoderPrompt, + PromptType, + SingletonPrompt, + TextPrompt, + TokensPrompt, +) +from vllm.utils import length_from_prompt_token_ids_or_embeds +from vllm.utils.collection_utils import is_list_of + +if TYPE_CHECKING: + import torch + + from vllm.config import ModelConfig + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +@overload +def prompt_to_seq( + prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes], +) -> Sequence[SingletonPrompt]: ... + + +@overload +def prompt_to_seq( # type: ignore[misc] + prompt_or_prompts: ExplicitEncoderDecoderPrompt + | Sequence[ExplicitEncoderDecoderPrompt], +) -> Sequence[ExplicitEncoderDecoderPrompt]: ... + + +@overload +def prompt_to_seq( # type: ignore[misc] + prompt_or_prompts: PromptType | Sequence[PromptType], +) -> Sequence[PromptType]: ... + + +def prompt_to_seq( + prompt_or_prompts: PromptType | bytes | Sequence[PromptType | bytes], +) -> Sequence[PromptType]: + if isinstance(prompt_or_prompts, (dict, str, bytes)) or ( + len(prompt_or_prompts) > 0 and is_list_of(prompt_or_prompts, int) + ): + return [prompt_or_prompts] # type: ignore[list-item] + + return prompt_or_prompts # type: ignore[return-value] + + +def conversation_to_seq( + conversation_or_conversations: list["ChatCompletionMessageParam"] + | Sequence[list["ChatCompletionMessageParam"]], +) -> Sequence[list["ChatCompletionMessageParam"]]: + if len(conversation_or_conversations) > 0 and is_list_of( + conversation_or_conversations, dict + ): + return [conversation_or_conversations] # type: ignore[list-item] + + return conversation_or_conversations # type: ignore[return-value] + + +DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt +""" +A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt] +that has been standardized into a dictionary. +""" + + +EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt +""" +A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt] +that has been standardized into a dictionary. +""" + + +DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt +""" +A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt] +that has been standardized into a dictionary. +""" + + +class EncoderDecoderDictPrompt(TypedDict): + """ + A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt] + that has been standardized into a dictionary. + """ + + encoder_prompt: EncoderDictPrompt + + decoder_prompt: DecoderDictPrompt | None + + +SingletonDictPrompt: TypeAlias = ( + DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt +) +""" +A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] +that has been standardized into a dictionary. +""" + + +DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt +""" +A [`PromptType`][vllm.inputs.data.PromptType] +that has been standardized into a dictionary. +""" + + +def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt: + """ + Parse a prompt for a decoder-only model and normalize it to a dictionary. + """ + if isinstance(prompt, str): + return TextPrompt(prompt=prompt) + + if isinstance(prompt, list): + if not is_list_of(prompt, int): + raise TypeError("Token prompt should be a list of integers") + + return TokensPrompt(prompt_token_ids=prompt) + + if isinstance(prompt, dict): + if "encoder_prompt" in prompt: + raise TypeError("Cannot pass encoder-decoder prompt to decoder-only models") + + if ( + "prompt" in prompt + or "prompt_token_ids" in prompt + or "prompt_embeds" in prompt + ): + return prompt # type: ignore[return-value] + + raise TypeError("Prompt dictionary must contain text, tokens, or embeddings") + + raise TypeError("Prompt should be a string, list of tokens, or dictionary") + + +def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt: + if isinstance(prompt, str): + return TextPrompt(prompt=prompt) + + if isinstance(prompt, list): + if not is_list_of(prompt, int): + raise TypeError("Token prompt should be a list of integers") + + return TokensPrompt(prompt_token_ids=prompt) + + if isinstance(prompt, dict): + if "prompt_embeds" in prompt: + raise TypeError("Cannot pass embeddings prompt to encoder-decoder models") + + if "prompt" in prompt or "prompt_token_ids" in prompt: + return prompt # type: ignore[return-value] + + raise TypeError("Prompt dictionary must contain text or tokens") + + raise TypeError("Prompt should be a string, list of tokens, or dictionary") + + +def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt: + if isinstance(prompt, str): + return TextPrompt(prompt=prompt) + + if isinstance(prompt, list): + if not is_list_of(prompt, int): + raise TypeError("Token prompt should be a list of integers") + + return TokensPrompt(prompt_token_ids=prompt) + + if isinstance(prompt, dict): + if "prompt_embeds" in prompt: + raise TypeError("Cannot pass embeddings prompt to encoder-decoder models") + + if ( + "multi_modal_data" in prompt + or "mm_processor_kwargs" in prompt + or "multi_modal_uuids" in prompt + ): + raise TypeError("Cannot pass multi-modal inputs to decoder prompt") + + if "prompt" in prompt or "prompt_token_ids" in prompt: + return prompt # type: ignore[return-value] + + raise TypeError("Prompt dictionary must contain text or tokens") + + raise TypeError("Prompt should be a string, list of tokens, or dictionary") + + +def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt: + """ + Parse a prompt for an encoder-decoder model and normalize it to a dictionary. + """ + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item] + dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item] + else: + enc_prompt = prompt + dec_prompt = None + + return EncoderDecoderDictPrompt( + encoder_prompt=_parse_enc_prompt(enc_prompt), + decoder_prompt=None if dec_prompt is None else _parse_dec_prompt(dec_prompt), + ) + + +def parse_model_prompt(model_config: "ModelConfig", prompt: object): + if model_config.is_encoder_decoder: + return parse_enc_dec_prompt(prompt) + + return parse_dec_only_prompt(prompt) + + +class PromptComponents(NamedTuple): + text: str | None = None + token_ids: list[int] | None = None + embeds: "torch.Tensor | None" = None + + +def extract_prompt_components( + model_config: "ModelConfig", + prompt: object, +) -> PromptComponents: + target_prompt = ( + parse_enc_dec_prompt(prompt)["encoder_prompt"] + if model_config.is_encoder_decoder + else parse_dec_only_prompt(prompt) + ) + + return PromptComponents( + text=target_prompt.get("prompt"), + token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type] + embeds=target_prompt.get("prompt_embeds"), + ) + + +def extract_prompt_len(model_config: "ModelConfig", prompt: object): + target_prompt = ( + parse_enc_dec_prompt(prompt)["encoder_prompt"] + if model_config.is_encoder_decoder + else parse_dec_only_prompt(prompt) + ) + + return length_from_prompt_token_ids_or_embeds( + target_prompt.get("prompt_token_ids"), # type: ignore[arg-type] + target_prompt.get("prompt_embeds"), + ) diff --git a/vllm/renderers/inputs/tokenize.py b/vllm/renderers/inputs/tokenize.py new file mode 100644 index 000000000..3734fac99 --- /dev/null +++ b/vllm/renderers/inputs/tokenize.py @@ -0,0 +1,57 @@ +""" +Schemas and utilites for tokenization inputs. +""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TypeAlias, TypedDict + +from vllm.inputs import EmbedsPrompt, TokensPrompt + +DecoderOnlyTokPrompt: TypeAlias = TokensPrompt | EmbedsPrompt +""" +A [`DecoderOnlyDictPrompt`][vllm.renderers.inputs.preprocess.DecoderOnlyDictPrompt] +that has been tokenized. +""" + + +EncoderTokPrompt: TypeAlias = TokensPrompt +""" +A [`EncoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDictPrompt] +that has been tokenized. +""" + + +DecoderTokPrompt: TypeAlias = TokensPrompt +""" +A [`DecoderDictPrompt`][vllm.renderers.inputs.preprocess.DecoderDictPrompt] +that has been tokenized. +""" + + +class EncoderDecoderTokPrompt(TypedDict): + """ + A + [`EncoderDecoderDictPrompt`][vllm.renderers.inputs.preprocess.EncoderDecoderDictPrompt] + that has been tokenized. + """ + + encoder_prompt: EncoderTokPrompt + + decoder_prompt: DecoderTokPrompt | None + + +SingletonTokPrompt: TypeAlias = ( + DecoderOnlyTokPrompt | EncoderTokPrompt | DecoderTokPrompt +) +""" +A [`SingletonDictPrompt`][vllm.renderers.inputs.preprocess.SingletonDictPrompt] +that has been tokenized. +""" + + +TokPrompt: TypeAlias = DecoderOnlyTokPrompt | EncoderDecoderTokPrompt +""" +A [`DictPrompt`][vllm.renderers.inputs.preprocess.DictPrompt] +that has been tokenized. +""" diff --git a/vllm/renderers/mistral.py b/vllm/renderers/mistral.py index dcccb09b9..0d15b37e0 100644 --- a/vllm/renderers/mistral.py +++ b/vllm/renderers/mistral.py @@ -10,12 +10,13 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils.async_utils import make_async +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams from .protocol import BaseRenderer @@ -95,7 +96,7 @@ class MistralRenderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = parse_chat_messages( messages, @@ -109,7 +110,7 @@ class MistralRenderer(BaseRenderer): **params.get_apply_chat_template_kwargs(), ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: @@ -121,7 +122,7 @@ class MistralRenderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: tokenizer = self.get_tokenizer() conversation, mm_data, mm_uuids = await parse_chat_messages_async( messages, @@ -135,7 +136,7 @@ class MistralRenderer(BaseRenderer): **params.get_apply_chat_template_kwargs(), ) - prompt = self.render_completion(prompt_raw) + prompt = parse_dec_only_prompt(prompt_raw) if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: diff --git a/vllm/renderers/protocol.py b/vllm/renderers/protocol.py index 676026965..5d84ac546 100644 --- a/vllm/renderers/protocol.py +++ b/vllm/renderers/protocol.py @@ -2,14 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, overload from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.tokenizers import TokenizerLike from vllm.utils.async_utils import AsyncMicrobatchTokenizer -from vllm.utils.collection_utils import is_list_of from .embed_utils import safe_load_prompt_embeds +from .inputs import ( + DictPrompt, + EncoderDecoderDictPrompt, + EncoderDecoderTokPrompt, + TokPrompt, +) from .params import ChatParams, TokenizeParams if TYPE_CHECKING: @@ -57,140 +63,217 @@ class BaseRenderer(ABC): return self._async_tokenizer # Step 1: Convert raw inputs to prompts - def render_completion( + def render_prompt( self, - prompt_raw: str | list[int] | bytes, - ) -> TextPrompt | TokensPrompt | EmbedsPrompt: - error_msg = "Each prompt must be a string or an array of tokens" + prompt: DictPrompt | bytes, + ) -> DictPrompt: + if isinstance(prompt, bytes): + embeds = safe_load_prompt_embeds(self.config, prompt) + prompt = EmbedsPrompt(prompt_embeds=embeds) - if isinstance(prompt_raw, str): - return TextPrompt(prompt=prompt_raw) + return prompt - if isinstance(prompt_raw, list): - if not is_list_of(prompt_raw, int): - raise TypeError(error_msg) - - return TokensPrompt(prompt_token_ids=prompt_raw) - - if isinstance(prompt_raw, bytes): - embeds = safe_load_prompt_embeds(self.config, prompt_raw) - return EmbedsPrompt(prompt_embeds=embeds) - - raise TypeError(error_msg) - - def render_completions( + def render_prompts( self, - prompt_input: str | list[str] | list[int] | list[list[int]] | None = None, - prompt_embeds: bytes | list[bytes] | None = None, - ) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: - prompts_raw = list[str | list[int] | bytes]() - - if prompt_embeds is not None: # embeds take higher priority - if isinstance(prompt_embeds, bytes): - prompts_raw.append(prompt_embeds) - else: - prompts_raw.extend(prompt_embeds) - - if prompt_input is not None: - if isinstance(prompt_input, str) or ( - len(prompt_input) > 0 and is_list_of(prompt_input, int) - ): - prompts_raw.append(prompt_input) # type: ignore[arg-type] - else: - prompts_raw.extend(prompt_input) # type: ignore[arg-type] - - if len(prompts_raw) == 0: + prompts: Sequence[DictPrompt | bytes], + ) -> list[DictPrompt]: + if len(prompts) == 0: raise ValueError("You must pass at least one prompt") - return [self.render_completion(prompt) for prompt in prompts_raw] + return [self.render_prompt(prompt) for prompt in prompts] - async def render_completions_async( + async def render_prompts_async( self, - prompt_input: str | list[str] | list[int] | list[list[int]] | None = None, - prompt_embeds: bytes | list[bytes] | None = None, - ) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: - return self.render_completions(prompt_input, prompt_embeds) + prompts: Sequence[DictPrompt | bytes], + ) -> list[DictPrompt]: + return self.render_prompts(prompts) @abstractmethod def render_messages( self, messages: list["ChatCompletionMessageParam"], params: ChatParams, - ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list["ConversationMessage"], DictPrompt]: raise NotImplementedError async def render_messages_async( self, messages: list["ChatCompletionMessageParam"], params: ChatParams, - ) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list["ConversationMessage"], DictPrompt]: return self.render_messages(messages, params) # Step 2: Tokenize prompts if necessary + def _tokenize_prompt( + self, + prompt: TextPrompt, + params: TokenizeParams, + ) -> TokensPrompt: + tokenizer = self.get_tokenizer() + prompt_token_ids = tokenizer.encode( + prompt["prompt"], + **params.get_encode_kwargs(), + ) + + return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) + + async def _tokenize_prompt_async( + self, + prompt: TextPrompt, + params: TokenizeParams, + ) -> TokensPrompt: + tokenizer = self.get_async_tokenizer() + prompt_token_ids = await tokenizer.encode( + prompt["prompt"], + **params.get_encode_kwargs(), + ) + + return TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) + + def _detokenize_prompt(self, prompt: TokensPrompt) -> TokensPrompt: + tokenizer = self.get_tokenizer() + prompt["prompt"] = tokenizer.decode(prompt["prompt_token_ids"]) + + return prompt + + async def _detokenize_prompt_async(self, prompt: TokensPrompt) -> TokensPrompt: + tokenizer = self.get_async_tokenizer() + prompt["prompt"] = await tokenizer.decode(prompt["prompt_token_ids"]) + + return prompt + + def _tokenize_enc_dec_prompt( + self, + prompt: EncoderDecoderDictPrompt, + params: TokenizeParams, + ) -> EncoderDecoderTokPrompt: + enc_prompt, dec_prompt = ( + self.tokenize_prompt(prompt["encoder_prompt"], params), + ( + None + if prompt["decoder_prompt"] is None + else self.tokenize_prompt(prompt["decoder_prompt"], params) + ), + ) + + return EncoderDecoderTokPrompt( + encoder_prompt=enc_prompt, + decoder_prompt=dec_prompt, + ) + + async def _tokenize_enc_dec_prompt_async( + self, + prompt: EncoderDecoderDictPrompt, + params: TokenizeParams, + ) -> EncoderDecoderTokPrompt: + enc_prompt, dec_prompt = await asyncio.gather( + self.tokenize_prompt_async(prompt["encoder_prompt"], params), + ( + asyncio.sleep(0) + if prompt["decoder_prompt"] is None + else self.tokenize_prompt_async(prompt["decoder_prompt"], params) + ), + ) + + return EncoderDecoderTokPrompt( + encoder_prompt=enc_prompt, + decoder_prompt=dec_prompt, + ) + + @overload def tokenize_prompt( self, - prompt: TextPrompt | TokensPrompt | EmbedsPrompt, + prompt: TextPrompt | TokensPrompt, params: TokenizeParams, - ) -> TokensPrompt | EmbedsPrompt: + ) -> TokensPrompt: ... + + @overload + def tokenize_prompt( # type: ignore[misc] + self, + prompt: EmbedsPrompt, + params: TokenizeParams, + ) -> EmbedsPrompt: ... + + @overload + def tokenize_prompt( # type: ignore[misc] + self, + prompt: EncoderDecoderDictPrompt, + params: TokenizeParams, + ) -> EncoderDecoderTokPrompt: ... + + def tokenize_prompt( + self, + prompt: DictPrompt, + params: TokenizeParams, + ) -> TokPrompt: + if "encoder_prompt" in prompt: + return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type] + if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: prompt = params.apply_pre_tokenization(self.tokenizer, prompt) - - tokenizer = self.get_tokenizer() - prompt_token_ids = tokenizer.encode( - prompt["prompt"], - **params.get_encode_kwargs(), - ) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) + prompt = self._tokenize_prompt(prompt, params) if params.needs_detokenization and "prompt" not in prompt: if "prompt_token_ids" not in prompt: raise RuntimeError("Cannot run detokenization on embeddings") - tokenizer = self.get_tokenizer() - prompt_text = tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item] - prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key] + prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type] return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] def tokenize_prompts( self, - prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt], + prompts: Sequence[DictPrompt], params: TokenizeParams, - ) -> list[TokensPrompt | EmbedsPrompt]: + ) -> list[TokPrompt]: return [self.tokenize_prompt(prompt, params) for prompt in prompts] + @overload + async def tokenize_prompt_async( + self, + prompt: TextPrompt | TokensPrompt, + params: TokenizeParams, + ) -> TokensPrompt: ... + + @overload + async def tokenize_prompt_async( # type: ignore[misc] + self, + prompt: EmbedsPrompt, + params: TokenizeParams, + ) -> EmbedsPrompt: ... + + @overload + async def tokenize_prompt_async( # type: ignore[misc] + self, + prompt: EncoderDecoderDictPrompt, + params: TokenizeParams, + ) -> EncoderDecoderTokPrompt: ... + async def tokenize_prompt_async( self, - prompt: TextPrompt | TokensPrompt | EmbedsPrompt, + prompt: DictPrompt, params: TokenizeParams, - ) -> TokensPrompt | EmbedsPrompt: + ) -> TokPrompt: + if "encoder_prompt" in prompt: + return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type] + if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt: prompt = params.apply_pre_tokenization(self.tokenizer, prompt) - - tokenizer = self.get_async_tokenizer() - prompt_token_ids = await tokenizer.encode( - prompt["prompt"], - **params.get_encode_kwargs(), - ) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, **prompt) + prompt = await self._tokenize_prompt_async(prompt, params) if params.needs_detokenization and "prompt" not in prompt: if "prompt_token_ids" not in prompt: raise RuntimeError("Cannot run detokenization on embeddings") - tokenizer = self.get_async_tokenizer() - prompt_text = await tokenizer.decode(prompt["prompt_token_ids"]) # type: ignore[typeddict-item] - prompt["prompt"] = prompt_text # type: ignore[typeddict-unknown-key] + prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type] return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type] async def tokenize_prompts_async( self, - prompts: list[TextPrompt | TokensPrompt | EmbedsPrompt], + prompts: Sequence[DictPrompt], params: TokenizeParams, - ) -> list[TokensPrompt | EmbedsPrompt]: + ) -> list[TokPrompt]: return await asyncio.gather( *(self.tokenize_prompt_async(prompt, params) for prompt in prompts) ) diff --git a/vllm/renderers/terratorch.py b/vllm/renderers/terratorch.py index 8753d1f0f..58c1459d2 100644 --- a/vllm/renderers/terratorch.py +++ b/vllm/renderers/terratorch.py @@ -9,10 +9,11 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) -from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams from .protocol import BaseRenderer @@ -45,7 +46,7 @@ class TerratorchRenderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: model_config = self.config conversation, mm_data, mm_uuids = parse_chat_messages( @@ -54,7 +55,7 @@ class TerratorchRenderer(BaseRenderer): content_format="string", ) - prompt = self.render_completion([1]) # Dummy token IDs + prompt = parse_dec_only_prompt([1]) # Dummy token IDs if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: @@ -66,7 +67,7 @@ class TerratorchRenderer(BaseRenderer): self, messages: list[ChatCompletionMessageParam], params: ChatParams, - ) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt | EmbedsPrompt]: + ) -> tuple[list[ConversationMessage], DictPrompt]: model_config = self.config conversation, mm_data, mm_uuids = await parse_chat_messages_async( @@ -75,7 +76,7 @@ class TerratorchRenderer(BaseRenderer): content_format="string", ) - prompt = self.render_completion([1]) # Dummy token IDs + prompt = parse_dec_only_prompt([1]) # Dummy token IDs if mm_data is not None: prompt["multi_modal_data"] = mm_data if mm_uuids is not None: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 28c95777c..e6da4f335 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -28,6 +28,8 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer, merge_kwargs +from vllm.renderers.inputs import DictPrompt, TokPrompt +from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tasks import SupportedTask from vllm.tokenizers import TokenizerLike @@ -42,7 +44,6 @@ from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.engine.utils import get_prompt_text from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import ( StatLoggerFactory, @@ -284,7 +285,11 @@ class AsyncLLM(EngineClient): async def add_request( self, request_id: str, - prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], + prompt: EngineCoreRequest + | PromptType + | DictPrompt + | TokPrompt + | AsyncGenerator[StreamingInput, None], params: SamplingParams | PoolingParams, arrival_time: float | None = None, lora_request: LoRARequest | None = None, @@ -367,7 +372,7 @@ class AsyncLLM(EngineClient): data_parallel_rank=data_parallel_rank, supported_tasks=await self.get_supported_tasks(), ) - prompt_text = get_prompt_text(prompt) + prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) self.input_processor.assign_request_id(request) @@ -484,7 +489,9 @@ class AsyncLLM(EngineClient): raise ValueError( "prompt_embeds not supported for streaming inputs" ) - prompt_text = get_prompt_text(input_chunk.prompt) + prompt_text, _, _ = extract_prompt_components( + self.model_config, input_chunk.prompt + ) await self._add_request(req, prompt_text, None, 0, queue) except (asyncio.CancelledError, GeneratorExit): cancelled = True @@ -528,7 +535,11 @@ class AsyncLLM(EngineClient): # re-multiplexed in the API server anyhow. async def generate( self, - prompt: EngineCoreRequest | PromptType | AsyncGenerator[StreamingInput, None], + prompt: EngineCoreRequest + | PromptType + | DictPrompt + | TokPrompt + | AsyncGenerator[StreamingInput, None], sampling_params: SamplingParams, request_id: str, *, @@ -769,7 +780,7 @@ class AsyncLLM(EngineClient): async def encode( self, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, pooling_params: PoolingParams, request_id: str, lora_request: LoRARequest | None = None, diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 01cf999c7..64dc1831b 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -7,14 +7,13 @@ from typing import Any, Literal, cast from vllm.config import VllmConfig from vllm.exceptions import VLLMValidationError -from vllm.inputs import ( +from vllm.inputs.data import ( ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, - TextPrompt, ) -from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs +from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -30,6 +29,7 @@ from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer +from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tokenizers import TokenizerLike @@ -243,8 +243,8 @@ class InputProcessor: return mm_processor.info.parse_mm_data(mm_data) def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None: - if isinstance(prompt, str): - prompt = TextPrompt(prompt=prompt) + if not isinstance(prompt, dict): + return mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {}) mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {}) @@ -297,7 +297,7 @@ class InputProcessor: f"multi_modal_uuids[{modality!r}] is missing." ) - def _validate_mm_uuids(self, prompt: PromptType) -> None: + def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None: """ Validate that user-provided multi_modal_uuids align with multi_modal_data in the incoming request prompt(s). @@ -305,10 +305,10 @@ class InputProcessor: auto-hashed downstream. """ - if is_explicit_encoder_decoder_prompt(prompt): - self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item] - if (dec_prompt := prompt["decoder_prompt"]) is not None: + if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item] self._validate_singleton_mm_uuids(dec_prompt) else: self._validate_singleton_mm_uuids(prompt) @@ -449,21 +449,23 @@ class InputProcessor: def _extract_singleton_mm_data( self, prompt: SingletonPrompt ) -> MultiModalDataDict | None: - if isinstance(prompt, str): + if not isinstance(prompt, dict): return None - return prompt.get("multi_modal_data") # type: ignore[return-value] + return prompt.get("multi_modal_data") - def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None: - if is_explicit_encoder_decoder_prompt(prompt): - return self._extract_singleton_mm_data(prompt["encoder_prompt"]) + def _extract_mm_data( + self, prompt: PromptType | DictPrompt | TokPrompt + ) -> MultiModalDataDict | None: + if isinstance(prompt, dict) and "encoder_prompt" in prompt: + return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item] else: return self._extract_singleton_mm_data(prompt) def _maybe_build_mm_uuids( self, request_id: str, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, ) -> MultiModalUUIDDict | None: """Build per-item multimodal hash overrides when enabled. In this case, multimodal data items are identified by their request id, modality and @@ -519,7 +521,7 @@ class InputProcessor: def process_inputs( self, request_id: str, - prompt: PromptType, + prompt: PromptType | DictPrompt | TokPrompt, params: SamplingParams | PoolingParams, arrival_time: float | None = None, lora_request: LoRARequest | None = None, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4f44e7101..294c9ff62 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -22,6 +22,8 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer +from vllm.renderers.inputs import DictPrompt, TokPrompt +from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.sampling_params import SamplingParams from vllm.tasks import SupportedTask from vllm.tokenizers import TokenizerLike @@ -32,7 +34,6 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.parallel_sampling import ParentRequest -from vllm.v1.engine.utils import get_prompt_text from vllm.v1.executor import Executor from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.reader import Metric, get_metrics_snapshot @@ -216,7 +217,7 @@ class LLMEngine: def add_request( self, request_id: str, - prompt: EngineCoreRequest | PromptType, + prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt, params: SamplingParams | PoolingParams, arrival_time: float | None = None, lora_request: LoRARequest | None = None, @@ -251,7 +252,7 @@ class LLMEngine: priority, supported_tasks=self.get_supported_tasks(), ) - prompt_text = get_prompt_text(prompt) + prompt_text, _, _ = extract_prompt_components(self.model_config, prompt) self.input_processor.assign_request_id(request) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 44465f4d0..6c11087a3 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -17,8 +17,6 @@ import zmq from vllm import envs from vllm.config import CacheConfig, ParallelConfig, VllmConfig -from vllm.inputs import PromptType -from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy @@ -226,10 +224,6 @@ def get_device_indices( return value -def get_prompt_text(prompt: PromptType) -> str | None: - return get_prompt_components(prompt)[0] - - class CoreEngineActorManager: """ Utility class to handle creation, readiness, and shutdown