diff --git a/docs/api/README.md b/docs/api/README.md index da734ce1a..2e97b745b 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -27,11 +27,9 @@ LLM Class. - [vllm.LLM][] -LLM Inputs. +Prompt schema for LLM APIs. -- [vllm.inputs.PromptType][] -- [vllm.inputs.TextPrompt][] -- [vllm.inputs.TokensPrompt][] +- [vllm.inputs.llm][] ## vLLM Engines @@ -58,13 +56,7 @@ Looking to add your own multi-modal model? Please follow the instructions listed - [vllm.multimodal.MULTIMODAL_REGISTRY][] -### Inputs - -User-facing inputs. - -- [vllm.multimodal.inputs.MultiModalDataDict][] - -Internal data structures. +### Internal data structures - [vllm.multimodal.inputs.PlaceholderRange][] - [vllm.multimodal.inputs.NestedTensors][] @@ -72,7 +64,6 @@ Internal data structures. - [vllm.multimodal.inputs.MultiModalFieldConfig][] - [vllm.multimodal.inputs.MultiModalKwargsItem][] - [vllm.multimodal.inputs.MultiModalKwargsItems][] -- [vllm.multimodal.inputs.MultiModalInputs][] ### Data Parsing diff --git a/docs/contributing/model/transcription.md b/docs/contributing/model/transcription.md index 7fe010e5f..a23de100d 100644 --- a/docs/contributing/model/transcription.md +++ b/docs/contributing/model/transcription.md @@ -23,7 +23,7 @@ Declare supported languages and capabilities: from torch import nn from vllm.config import ModelConfig, SpeechToTextConfig - from vllm.inputs.data import PromptType + from vllm.inputs import PromptType from vllm.model_executor.models.interfaces import SupportsTranscription class YourASRModel(nn.Module, SupportsTranscription): @@ -66,7 +66,7 @@ This is for controlling general behavior of the API when serving your model: See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls. -Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns: +Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.llm.PromptType]. There are two common patterns: #### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n) diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 6b92181fd..ee82c34fa 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -18,7 +18,7 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models](../ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: - `prompt`: The prompt should follow the format that is documented on HuggingFace. -- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][]. +- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.inputs.MultiModalDataDict][]. ### Image Inputs diff --git a/examples/pooling/token_embed/jina_embeddings_v4_offline.py b/examples/pooling/token_embed/jina_embeddings_v4_offline.py index 83d4c446d..3822d2b42 100644 --- a/examples/pooling/token_embed/jina_embeddings_v4_offline.py +++ b/examples/pooling/token_embed/jina_embeddings_v4_offline.py @@ -4,7 +4,7 @@ import torch from vllm import LLM -from vllm.inputs.data import TextPrompt +from vllm.inputs import TextPrompt from vllm.multimodal.utils import fetch_image # Initialize model diff --git a/tests/entrypoints/openai/chat_completion/test_chat_error.py b/tests/entrypoints/openai/chat_completion/test_chat_error.py index 5fd7bc09c..4b5003c41 100644 --- a/tests/entrypoints/openai/chat_completion/test_chat_error.py +++ b/tests/entrypoints/openai/chat_completion/test_chat_error.py @@ -105,7 +105,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ) async def _fake_preprocess_chat(*args, **kwargs): - # return conversation, engine_prompts + # return conversation, engine_inputs return ( [{"role": "user", "content": "Test"}], [{"prompt_token_ids": [1, 2, 3]}], diff --git a/tests/entrypoints/openai/chat_completion/test_serving_chat.py b/tests/entrypoints/openai/chat_completion/test_serving_chat.py index ebfcb675c..ff78c9c09 100644 --- a/tests/entrypoints/openai/chat_completion/test_serving_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_serving_chat.py @@ -958,14 +958,14 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): serving_chat = _build_serving_chat(mock_engine) orig_render_chat_request = serving_chat.render_chat_request - captured_prompts = [] + captured_inputs = [] async def render_chat_request(request): result = await orig_render_chat_request(request) assert isinstance(result, tuple) - conversation, engine_prompts = result - captured_prompts.extend(engine_prompts) + conversation, engine_inputs = result + captured_inputs.extend(engine_inputs) return result @@ -981,18 +981,18 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type): with suppress(Exception): await serving_chat.create_chat_completion(req) - assert len(captured_prompts) == 1 - assert "cache_salt" not in captured_prompts[0] + assert len(captured_inputs) == 1 + assert "cache_salt" not in captured_inputs[0] - captured_prompts.clear() + captured_inputs.clear() # Test with certain cache_salt req.cache_salt = "test_salt" with suppress(Exception): await serving_chat.create_chat_completion(req) - assert len(captured_prompts) == 1 - assert captured_prompts[0]["cache_salt"] == "test_salt" + assert len(captured_inputs) == 1 + assert captured_inputs[0]["cache_salt"] == "test_salt" @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/responses/test_serving_responses.py b/tests/entrypoints/openai/responses/test_serving_responses.py index b5d2b24a6..157f7f12f 100644 --- a/tests/entrypoints/openai/responses/test_serving_responses.py +++ b/tests/entrypoints/openai/responses/test_serving_responses.py @@ -37,7 +37,7 @@ from vllm.entrypoints.openai.responses.serving import ( from vllm.entrypoints.openai.responses.streaming_events import ( StreamingState, ) -from vllm.inputs.data import TokensPrompt +from vllm.inputs import tokens_input from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams @@ -258,20 +258,20 @@ class TestValidateGeneratorInput: """Test _validate_generator_input with valid prompt length""" # Create an engine prompt with valid length (less than max_model_len) valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len - engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids) + engine_input = tokens_input(valid_prompt_token_ids) # Call the method - result = serving_responses_instance._validate_generator_input(engine_prompt) + result = serving_responses_instance._validate_generator_input(engine_input) # Should return None for valid input assert result is None # create an invalid engine prompt invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len - engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids) + engine_input = tokens_input(invalid_prompt_token_ids) # Call the method - result = serving_responses_instance._validate_generator_input(engine_prompt) + result = serving_responses_instance._validate_generator_input(engine_input) # Should return an ErrorResponse assert result is not None diff --git a/tests/entrypoints/serve/render/test_launch_render.py b/tests/entrypoints/serve/render/test_launch_render.py index 37859e01f..fa452a105 100644 --- a/tests/entrypoints/serve/render/test_launch_render.py +++ b/tests/entrypoints/serve/render/test_launch_render.py @@ -73,20 +73,6 @@ async def test_chat_render_multi_turn(client): assert len(data["token_ids"]) > 0 -@pytest.mark.asyncio -async def test_chat_render_invalid_model(client): - response = await client.post( - "/v1/chat/completions/render", - json={ - "model": "nonexistent-model", - "messages": [{"role": "user", "content": "Hello"}], - }, - ) - - assert response.status_code == 404 - assert "error" in response.json() - - # -- Completion Render -- diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 015770991..afda75d4f 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -16,7 +16,7 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) -from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict +from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.utils import ( encode_audio_url, encode_image_url, diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 375099f43..48329d9ae 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -13,8 +13,8 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from transformers import AutoProcessor from vllm import SamplingParams, TextPrompt, TokensPrompt +from vllm.inputs import MultiModalDataBuiltins from vllm.logprobs import Logprob, SampleLogprobs -from vllm.multimodal import MultiModalDataBuiltins from vllm.platforms import current_platform from ....utils import VLLM_PATH, large_gpu_test diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a623e1b06..cce69e15b 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -15,13 +15,11 @@ from vllm.config.multimodal import ( ImageDummyOptions, VideoDummyOptions, ) -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.inputs import MultiModalDataDict, MultiModalInput +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import MultiModalProcessorOnlyCache -from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - InputProcessingContext, -) +from vllm.multimodal.inputs import batched_tensors_equal +from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.utils.mistral import is_mistral_tokenizer @@ -420,8 +418,8 @@ def test_processing_correctness( def _assert_inputs_equal( - a: MultiModalInputs, - b: MultiModalInputs, + a: MultiModalInput, + b: MultiModalInput, *, ignore_mm_keys: set[str] | None = None, msg: str = "", diff --git a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py index 8ce9a9b52..77f8884fe 100644 --- a/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py +++ b/tests/plugins/bge_m3_sparse_plugin/bge_m3_sparse_processor/sparse_embeddings_processor.py @@ -6,11 +6,9 @@ from collections.abc import Sequence from vllm.config import ModelConfig, PoolerConfig, VllmConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.protocol import EmbedRequestMixin -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType from vllm.outputs import PoolingRequestOutput -from vllm.plugins.io_processors.interface import ( - IOProcessor, -) +from vllm.plugins.io_processors.interface import IOProcessor from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index a1262c28b..89bdab8c3 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -18,7 +18,7 @@ from einops import rearrange from terratorch.datamodules import Sen1Floods11NonGeoDataModule from vllm.config import VllmConfig -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.plugins.io_processors.interface import IOProcessor diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 19a013bd1..14e1a2d57 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from vllm.config import VllmConfig -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType from vllm.outputs import PoolingRequestOutput from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors.interface import IOProcessor diff --git a/tests/renderers/inputs/test_preprocess.py b/tests/renderers/inputs/test_preprocess.py index 707f9eedf..98219bb14 100644 --- a/tests/renderers/inputs/test_preprocess.py +++ b/tests/renderers/inputs/test_preprocess.py @@ -15,13 +15,13 @@ def test_text_input(): assert prompt_to_seq(["foo", "bar"]) == ["foo", "bar"] -def test_token_input(): +def test_tokens_input(): assert prompt_to_seq([1, 2]) == [[1, 2]] assert prompt_to_seq([[1, 2]]) == [[1, 2]] assert prompt_to_seq([[1, 2], [3, 4]]) == [[1, 2], [3, 4]] -def test_text_token_input(): +def test_text_tokens_input(): assert prompt_to_seq([[1, 2], "foo"]) == [[1, 2], "foo"] assert prompt_to_seq(["foo", [1, 2]]) == ["foo", [1, 2]] diff --git a/tests/renderers/test_completions.py b/tests/renderers/test_completions.py index 5a48cd15d..2daeb4d3d 100644 --- a/tests/renderers/test_completions.py +++ b/tests/renderers/test_completions.py @@ -129,7 +129,7 @@ class TestValidatePrompt: class TestRenderPrompt: - def test_token_input(self): + def test_tokens_input(self): renderer = _build_renderer(MockModelConfig()) tokens = [101, 7592, 2088] @@ -339,7 +339,7 @@ class TestRenderPrompt: TokenizeParams(max_total_tokens=100), ) - def test_token_input_with_needs_detokenization(self): + def test_tokens_input_with_needs_detokenization(self): renderer = _build_renderer(MockModelConfig()) tokens = [1, 2, 3, 4] diff --git a/tests/v1/shutdown/test_processor_error.py b/tests/v1/shutdown/test_processor_error.py index 013b929e3..0aba775e4 100644 --- a/tests/v1/shutdown/test_processor_error.py +++ b/tests/v1/shutdown/test_processor_error.py @@ -9,7 +9,7 @@ import pytest from tests.v1.shutdown.utils import SHUTDOWN_TEST_TIMEOUT_SEC from vllm import SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.inputs.data import TokensPrompt +from vllm.inputs import TokensPrompt from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.exceptions import EngineGenerateError diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 230f5a123..47f69cbf7 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -3,11 +3,16 @@ from dataclasses import dataclass -from vllm.inputs import EncoderDecoderInputs, TokenInputs, token_inputs -from vllm.inputs.data import DecoderInputs +from vllm.inputs import ( + DecoderOnlyEngineInput, + EncoderDecoderInput, + MultiModalInput, + TokensInput, + mm_input, + tokens_input, +) from vllm.logprobs import Logprob from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import MultiModalInputs, mm_inputs @dataclass @@ -18,7 +23,7 @@ class BeamSearchSequence: about to be returned to the user. """ - orig_prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs + orig_prompt: TokensInput | MultiModalInput | EncoderDecoderInput # NOTE: Tokens represents decoder tokens in the encoder / decoder case tokens: list[int] @@ -40,13 +45,13 @@ class BeamSearchSequence: cache_salt = prompt.get("cache_salt") if prompt["type"] == "token": - return token_inputs( + return tokens_input( self.tokens, prompt=prompt_text, cache_salt=cache_salt, ) - return mm_inputs( + return mm_input( prompt_token_ids=self.tokens, mm_kwargs=prompt["mm_kwargs"], mm_hashes=prompt["mm_hashes"], @@ -56,8 +61,8 @@ class BeamSearchSequence: ) def _build_encoder_decoder_inputs( - self, prompt: EncoderDecoderInputs - ) -> EncoderDecoderInputs: + self, prompt: EncoderDecoderInput + ) -> EncoderDecoderInput: """Rebuild the encoder-decoder inputs with the current beam search sequence's tokens. @@ -70,9 +75,9 @@ class BeamSearchSequence: # Rebuild decoder prompt with updated tokens, # but keep everything else the same. - new_dec_prompt: DecoderInputs + new_dec_prompt: DecoderOnlyEngineInput if dec_prompt["type"] == "multimodal": - new_dec_prompt = mm_inputs( + new_dec_prompt = mm_input( self.tokens, mm_kwargs=dec_prompt["mm_kwargs"], mm_hashes=dec_prompt["mm_hashes"], @@ -81,13 +86,13 @@ class BeamSearchSequence: cache_salt=dec_prompt.get("cache_salt"), ) else: - new_dec_prompt = token_inputs( + new_dec_prompt = tokens_input( self.tokens, prompt=dec_prompt.get("prompt"), cache_salt=dec_prompt.get("cache_salt"), ) - return EncoderDecoderInputs( + return EncoderDecoderInput( type="enc_dec", encoder_prompt=prompt["encoder_prompt"], decoder_prompt=new_dec_prompt, @@ -107,7 +112,7 @@ class BeamSearchOutput: class BeamSearchInstance: def __init__( self, - prompt: TokenInputs | MultiModalInputs | EncoderDecoderInputs, + prompt: TokensInput | MultiModalInput | EncoderDecoderInput, lora_request: LoRARequest | None = None, logprobs: list[dict[int, Logprob]] | None = None, **kwargs, diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 8304e8703..dd71762b5 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -35,9 +35,9 @@ from huggingface_hub import snapshot_download from PIL import Image from typing_extensions import deprecated +from vllm.inputs import MultiModalDataDict from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path -from vllm.multimodal import MultiModalDataDict from vllm.multimodal.audio import get_audio_duration from vllm.multimodal.image import convert_image_mode from vllm.tokenizers import TokenizerLike diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 0b3b29cd6..3d466e3fc 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -11,7 +11,7 @@ from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, WeightTransferUpdateRequest, ) -from vllm.inputs.data import ProcessorInputs, PromptType +from vllm.inputs import EngineInput, PromptType from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.plugins.io_processors import IOProcessor @@ -34,7 +34,7 @@ class StreamingInput: where inputs are provided via an async generator. """ - prompt: ProcessorInputs + prompt: EngineInput sampling_params: SamplingParams | None = None @@ -68,7 +68,7 @@ class EngineClient(ABC): self, prompt: EngineCoreRequest | PromptType - | ProcessorInputs + | EngineInput | AsyncGenerator[StreamingInput, None], sampling_params: SamplingParams, request_id: str, @@ -87,7 +87,7 @@ class EngineClient(ABC): @abstractmethod def encode( self, - prompt: PromptType | ProcessorInputs, + prompt: PromptType | EngineInput, pooling_params: PoolingParams, request_id: str, lora_request: LoRARequest | None = None, diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py index 7216ae2a4..9270a49d1 100644 --- a/vllm/entrypoints/anthropic/serving.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -797,12 +797,12 @@ class AnthropicServingMessages(OpenAIServingChat): if isinstance(result, ErrorResponse): return result - _, engine_prompts = result + _, engine_inputs = result input_tokens = sum( # type: ignore - len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc] - for prompt in engine_prompts - if "prompt_token_ids" in prompt + len(engine_input["prompt_token_ids"]) # type: ignore[typeddict-item, misc] + for engine_input in engine_inputs + if "prompt_token_ids" in engine_input ) response = AnthropicCountTokensResponse( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6af762991..51e62042f 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -40,9 +40,10 @@ from typing_extensions import Required, TypedDict from vllm import envs from vllm.config import ModelConfig +from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.logger import init_logger from vllm.model_executor.models import SupportsMultiModal -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalBatchedField, MultiModalFlatField, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index e9e7cb91c..3e9e2f6d4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -57,9 +57,9 @@ from vllm.entrypoints.pooling.score.utils import ( validate_score_input, ) from vllm.entrypoints.utils import log_non_default_args -from vllm.inputs.data import ( +from vllm.inputs import ( DataPrompt, - ProcessorInputs, + EngineInput, PromptType, SingletonPrompt, TextPrompt, @@ -589,7 +589,7 @@ class LLM: def _resolve_mm_lora( self, - prompt: ProcessorInputs, + prompt: EngineInput, lora_request: LoRARequest | None, ) -> LoRARequest | None: if prompt["type"] != "multimodal": @@ -716,8 +716,8 @@ class LLM: eos_token_id = tokenizer.eos_token_id sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) - engine_prompts = self._preprocess_cmpl(prompts) - lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts)) + engine_inputs = self._preprocess_cmpl(prompts) + lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs)) if use_tqdm and concurrency_limit is not None: logger.warning( @@ -727,7 +727,7 @@ class LLM: use_tqdm = False if concurrency_limit is None: - concurrency_limit = len(engine_prompts) + concurrency_limit = len(engine_inputs) # generate 2 * beam_width candidates at each step # following the huggingface transformers implementation @@ -740,7 +740,7 @@ class LLM: ) instances: list[BeamSearchInstance] = [] - for lora_req, prompt in zip(lora_requests, engine_prompts): + for lora_req, prompt in zip(lora_requests, engine_inputs): if prompt["type"] == "embeds": raise NotImplementedError( "Embedding prompt not supported for beam search" @@ -845,7 +845,7 @@ class LLM: self, prompts: Sequence[PromptType], tokenization_kwargs: dict[str, Any] | None = None, - ) -> Sequence[ProcessorInputs]: + ) -> Sequence[EngineInput]: """ Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into a format that can be passed to `_add_request`. @@ -853,7 +853,7 @@ class LLM: Refer to [LLM.generate][] for a complete description of the arguments. Returns: - A list of `ProcessorInputs` objects ready to be passed into LLMEngine. + A list of `EngineInput` objects ready to be passed into LLMEngine. """ renderer = self.renderer model_config = self.model_config @@ -871,9 +871,9 @@ class LLM: self, prompt: PromptType, tokenization_kwargs: dict[str, Any] | None = None, - ) -> ProcessorInputs: - (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs) - return engine_prompt + ) -> EngineInput: + (engine_input,) = self._preprocess_cmpl([prompt], tokenization_kwargs) + return engine_input def _preprocess_chat( self, @@ -886,7 +886,7 @@ class LLM: tools: list[dict[str, Any]] | None = None, tokenization_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> Sequence[ProcessorInputs]: + ) -> Sequence[EngineInput]: """ Convert a list of conversations into prompts so that they can then be used as input for other LLM APIs. @@ -894,7 +894,7 @@ class LLM: Refer to [LLM.chat][] for a complete description of the arguments. Returns: - A list of `ProcessorInputs` objects ready to be passed into LLMEngine. + A list of `EngineInput` objects ready to be passed into LLMEngine. """ renderer = self.renderer @@ -915,14 +915,14 @@ class LLM: **(tokenization_kwargs or {}) ) - _, engine_prompts = renderer.render_chat( + _, engine_inputs = renderer.render_chat( conversations, chat_params, tok_params, prompt_extras={"mm_processor_kwargs": mm_processor_kwargs}, ) - return engine_prompts + return engine_inputs def _preprocess_chat_one( self, @@ -935,8 +935,8 @@ class LLM: tools: list[dict[str, Any]] | None = None, tokenization_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None, - ) -> ProcessorInputs: - (engine_prompt,) = self._preprocess_chat( + ) -> EngineInput: + (engine_input,) = self._preprocess_chat( [conversation], chat_template=chat_template, chat_template_content_format=chat_template_content_format, @@ -948,7 +948,7 @@ class LLM: mm_processor_kwargs=mm_processor_kwargs, ) - return engine_prompt + return engine_input def chat( self, @@ -1909,7 +1909,7 @@ class LLM: def _render_and_run_requests( self, - prompts: Iterable[ProcessorInputs], + prompts: Iterable[EngineInput], params: Sequence[SamplingParams | PoolingParams], output_type: type[_O], *, @@ -1938,7 +1938,7 @@ class LLM: def _render_and_add_requests( self, - prompts: Iterable[ProcessorInputs], + prompts: Iterable[EngineInput], params: Sequence[SamplingParams | PoolingParams], *, lora_requests: Sequence[LoRARequest | None] | None = None, @@ -1967,7 +1967,7 @@ class LLM: def _add_request( self, - prompt: ProcessorInputs, + prompt: EngineInput, params: SamplingParams | PoolingParams, lora_request: LoRARequest | None = None, priority: int = 0, diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 6b39bc5ec..135aaf13c 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -63,7 +63,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ) from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage -from vllm.inputs.data import ProcessorInputs +from vllm.inputs import EngineInput from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput @@ -177,7 +177,7 @@ class OpenAIServingChat(OpenAIServing): async def render_chat_request( self, request: ChatCompletionRequest, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse: + ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse: """ Validate the model and preprocess a chat completion request. @@ -185,7 +185,7 @@ class OpenAIServingChat(OpenAIServing): engine-aware checks (LoRA model validation, engine health). Returns: - A tuple of (conversation, engine_prompts) on success, + A tuple of (conversation, engine_inputs) on success, or an ErrorResponse on failure. """ error_check_ret = await self._check_model(request) @@ -231,7 +231,7 @@ class OpenAIServingChat(OpenAIServing): if isinstance(result, ErrorResponse): return result - conversation, engine_prompts = result + conversation, engine_inputs = result request_id = ( f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" @@ -251,13 +251,13 @@ class OpenAIServingChat(OpenAIServing): # Schedule the request and get the result generator. max_model_len = self.model_config.max_model_len generators: list[AsyncGenerator[RequestOutput, None]] = [] - for i, engine_prompt in enumerate(engine_prompts): - prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids + for i, engine_input in enumerate(engine_inputs): + prompt_token_ids = self._extract_prompt_components(engine_input).token_ids # If we are creating sub requests for multiple prompts, ensure that they # have unique request ids. sub_request_id = ( - request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" + request_id if len(engine_inputs) == 1 else f"{request_id}_{i}" ) max_tokens = get_max_tokens( @@ -265,7 +265,7 @@ class OpenAIServingChat(OpenAIServing): request.max_completion_tokens if request.max_completion_tokens is not None else request.max_tokens, - self._extract_prompt_len(engine_prompt), + self._extract_prompt_len(engine_input), self.default_sampling_params, self.override_max_tokens, ) @@ -283,7 +283,7 @@ class OpenAIServingChat(OpenAIServing): self._log_inputs( sub_request_id, - engine_prompt, + engine_input, params=sampling_params, lora_request=lora_request, ) @@ -296,7 +296,7 @@ class OpenAIServingChat(OpenAIServing): if isinstance(sampling_params, BeamSearchParams): generator = self.beam_search( - prompt=engine_prompt, + prompt=engine_input, request_id=sub_request_id, params=sampling_params, lora_request=lora_request, @@ -313,7 +313,7 @@ class OpenAIServingChat(OpenAIServing): reasoning_ended = None generator = self.engine_client.generate( - engine_prompt, + engine_input, sampling_params, sub_request_id, lora_request=lora_request, diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 96cd7797c..fb7f253c7 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -33,7 +33,7 @@ from vllm.entrypoints.openai.engine.serving import ( from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import ProcessorInputs +from vllm.inputs import EngineInput from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput @@ -82,7 +82,7 @@ class OpenAIServingCompletion(OpenAIServing): async def render_completion_request( self, request: CompletionRequest, - ) -> list[ProcessorInputs] | ErrorResponse: + ) -> list[EngineInput] | ErrorResponse: """ Validate the model and preprocess a completion request. @@ -90,8 +90,7 @@ class OpenAIServingCompletion(OpenAIServing): engine-aware checks (LoRA model validation, engine health). Returns: - A list of engine_prompts on success, - or an ErrorResponse on failure. + A list of engine_inputs on success, or an ErrorResponse on failure. """ error_check_ret = await self._check_model(request) if error_check_ret is not None: @@ -128,7 +127,7 @@ class OpenAIServingCompletion(OpenAIServing): if isinstance(result, ErrorResponse): return result - engine_prompts = result + engine_inputs = result request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}" created_time = int(time.time()) @@ -145,11 +144,11 @@ class OpenAIServingCompletion(OpenAIServing): # Schedule the request and get the result generator. max_model_len = self.model_config.max_model_len generators: list[AsyncGenerator[RequestOutput, None]] = [] - for i, engine_prompt in enumerate(engine_prompts): + for i, engine_input in enumerate(engine_inputs): max_tokens = get_max_tokens( max_model_len, request.max_tokens, - self._extract_prompt_len(engine_prompt), + self._extract_prompt_len(engine_input), self.default_sampling_params, self.override_max_tokens, ) @@ -169,7 +168,7 @@ class OpenAIServingCompletion(OpenAIServing): self._log_inputs( request_id_item, - engine_prompt, + engine_input, params=sampling_params, lora_request=lora_request, ) @@ -182,7 +181,7 @@ class OpenAIServingCompletion(OpenAIServing): if isinstance(sampling_params, BeamSearchParams): generator = self.beam_search( - prompt=engine_prompt, + prompt=engine_input, request_id=request_id, params=sampling_params, lora_request=lora_request, @@ -190,7 +189,7 @@ class OpenAIServingCompletion(OpenAIServing): ) else: generator = self.engine_client.generate( - engine_prompt, + engine_input, sampling_params, request_id_item, lora_request=lora_request, @@ -204,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing): result_generator = merge_async_iterators(*generators) model_name = self.models.model_name(lora_request) - num_prompts = len(engine_prompts) + num_prompts = len(engine_inputs) # Streaming response tokenizer = self.renderer.tokenizer @@ -212,7 +211,7 @@ class OpenAIServingCompletion(OpenAIServing): if request.stream: return self.completion_stream_generator( request, - engine_prompts, + engine_inputs, result_generator, request_id, created_time, @@ -235,8 +234,7 @@ class OpenAIServingCompletion(OpenAIServing): # We did not pass it into vLLM engine to avoid being redundant # with the inputs token IDs if final_res.prompt is None: - engine_prompt = engine_prompts[i] - final_res.prompt = self._extract_prompt_text(engine_prompt) + final_res.prompt = self._extract_prompt_text(engine_inputs[i]) final_res_batch_checked = cast(list[RequestOutput], final_res_batch) @@ -268,7 +266,7 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, - engine_prompts: list[ProcessorInputs], + engine_inputs: list[EngineInput], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -301,8 +299,8 @@ class OpenAIServingCompletion(OpenAIServing): prompt_text = res.prompt if prompt_text is None: - engine_prompt = engine_prompts[prompt_idx] - prompt_text = self._extract_prompt_text(engine_prompt) + engine_input = engine_inputs[prompt_idx] + prompt_text = self._extract_prompt_text(engine_input) # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index c19910c51..d8df1d3c4 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -72,11 +72,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ) from vllm.entrypoints.utils import create_error_response from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import ( - ProcessorInputs, - PromptType, - TokensPrompt, -) +from vllm.inputs import EngineInput, PromptType, TokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob, PromptLogprobs from vllm.lora.request import LoRARequest @@ -163,7 +159,7 @@ class ServeContext(Generic[RequestT]): request_id: str created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - engine_prompts: list[ProcessorInputs] | None = None + engine_inputs: list[EngineInput] | None = None result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( None @@ -202,7 +198,7 @@ class OpenAIServing: async def beam_search( self, - prompt: ProcessorInputs, + prompt: EngineInput, request_id: str, params: BeamSearchParams, lora_request: LoRARequest | None = None, @@ -493,21 +489,21 @@ class OpenAIServing: if isinstance(pooling_params, ErrorResponse): return pooling_params - if ctx.engine_prompts is None: + if ctx.engine_inputs is None: return self.create_error_response("Engine prompts not available") - for i, engine_prompt in enumerate(ctx.engine_prompts): + for i, engine_input in enumerate(ctx.engine_inputs): request_id_item = f"{ctx.request_id}-{i}" self._log_inputs( request_id_item, - engine_prompt, + engine_input, params=pooling_params, lora_request=ctx.lora_request, ) generator = self.engine_client.encode( - engine_prompt, + engine_input, pooling_params, request_id_item, lora_request=ctx.lora_request, @@ -526,10 +522,10 @@ class OpenAIServing: ctx: ServeContext, ) -> ErrorResponse | None: """Collect batch results from the result generator.""" - if ctx.engine_prompts is None: + if ctx.engine_inputs is None: return self.create_error_response("Engine prompts not available") - num_prompts = len(ctx.engine_prompts) + num_prompts = len(ctx.engine_inputs) final_res_batch: list[PoolingRequestOutput | None] final_res_batch = [None] * num_prompts @@ -806,19 +802,19 @@ class OpenAIServing: # Apply server defaults first, then request kwargs override. return default_chat_template_kwargs | request_chat_template_kwargs - def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs): + def _extract_prompt_components(self, prompt: PromptType | EngineInput): return extract_prompt_components(self.model_config, prompt) - def _extract_prompt_text(self, prompt: ProcessorInputs): + def _extract_prompt_text(self, prompt: PromptType | EngineInput): return self._extract_prompt_components(prompt).text - def _extract_prompt_len(self, prompt: ProcessorInputs): + def _extract_prompt_len(self, prompt: EngineInput): return extract_prompt_len(self.model_config, prompt) def _log_inputs( self, request_id: str, - inputs: PromptType | ProcessorInputs, + inputs: PromptType | EngineInput, params: SamplingParams | PoolingParams | BeamSearchParams | None, lora_request: LoRARequest | None, ) -> None: diff --git a/vllm/entrypoints/openai/realtime/serving.py b/vllm/entrypoints/openai/realtime/serving.py index 5aead4d00..710d1907a 100644 --- a/vllm/entrypoints/openai/realtime/serving.py +++ b/vllm/entrypoints/openai/realtime/serving.py @@ -12,7 +12,7 @@ from vllm.engine.protocol import EngineClient, StreamingInput from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.model_executor.models.interfaces import SupportsRealtime from vllm.renderers.inputs.preprocess import parse_model_prompt @@ -83,6 +83,6 @@ class OpenAIServingRealtime(OpenAIServing): async for prompt in stream_input_iter: parsed_prompt = parse_model_prompt(model_config, prompt) - (engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt]) + (engine_input,) = await renderer.render_cmpl_async([parsed_prompt]) - yield StreamingInput(prompt=engine_prompt) + yield StreamingInput(prompt=engine_input) diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index a130d3686..e71a62461 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -110,7 +110,7 @@ from vllm.entrypoints.openai.responses.utils import ( from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import ProcessorInputs, token_inputs +from vllm.inputs import EngineInput, tokens_input from vllm.logger import init_logger from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import SampleLogprobs @@ -269,10 +269,10 @@ class OpenAIServingResponses(OpenAIServing): def _validate_generator_input( self, - engine_prompt: ProcessorInputs, + engine_input: EngineInput, ) -> ErrorResponse | None: """Add validations to the input to the generator here.""" - prompt_len = self._extract_prompt_len(engine_prompt) + prompt_len = self._extract_prompt_len(engine_input) max_model_len = self.model_config.max_model_len if prompt_len >= max_model_len: @@ -369,11 +369,11 @@ class OpenAIServingResponses(OpenAIServing): model_name = self.models.model_name(lora_request) if self.use_harmony: - messages, engine_prompts = self._make_request_with_harmony( + messages, engine_inputs = self._make_request_with_harmony( request, prev_response ) else: - messages, engine_prompts = await self._make_request(request, prev_response) + messages, engine_inputs = await self._make_request(request, prev_response) request_metadata = RequestResponseMetadata(request_id=request.request_id) if raw_request: @@ -413,15 +413,15 @@ class OpenAIServingResponses(OpenAIServing): available_tools = [] tokenizer = self.renderer.get_tokenizer() - for engine_prompt in engine_prompts: - maybe_error = self._validate_generator_input(engine_prompt) + for engine_input in engine_inputs: + maybe_error = self._validate_generator_input(engine_input) if maybe_error is not None: return maybe_error default_max_tokens = get_max_tokens( max_model_len, request.max_output_tokens, - self._extract_prompt_len(engine_prompt), + self._extract_prompt_len(engine_input), self.default_sampling_params, self.override_max_tokens, ) @@ -480,7 +480,7 @@ class OpenAIServingResponses(OpenAIServing): ) generator = self._generate_with_builtin_tools( request_id=request.request_id, - engine_prompt=engine_prompt, + engine_input=engine_input, sampling_params=sampling_params, context=context, lora_request=lora_request, @@ -586,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing): prev_response_output=prev_response.output if prev_response else None, ) - _, engine_prompts = await self.openai_serving_render.preprocess_chat( + _, engine_inputs = await self.openai_serving_render.preprocess_chat( request, messages, default_template=self.chat_template, @@ -595,7 +595,7 @@ class OpenAIServingResponses(OpenAIServing): tool_dicts=tool_dicts, tool_parser=self.parser.tool_parser_cls if self.parser else None, ) - return messages, engine_prompts + return messages, engine_inputs async def _render_next_turn( self, @@ -610,7 +610,7 @@ class OpenAIServingResponses(OpenAIServing): request_input=messages, ) - _, engine_prompts = await self.openai_serving_render.preprocess_chat( + _, engine_inputs = await self.openai_serving_render.preprocess_chat( request, new_messages, default_template=chat_template, @@ -619,12 +619,12 @@ class OpenAIServingResponses(OpenAIServing): tool_dicts=tool_dicts, tool_parser=tool_parser, ) - return engine_prompts + return engine_inputs async def _generate_with_builtin_tools( self, request_id: str, - engine_prompt: ProcessorInputs, + engine_input: EngineInput, sampling_params: SamplingParams, context: ConversationContext, lora_request: LoRARequest | None = None, @@ -641,13 +641,13 @@ class OpenAIServingResponses(OpenAIServing): self._log_inputs( sub_request_id, - engine_prompt, + engine_input, params=sampling_params, lora_request=lora_request, ) generator = self.engine_client.generate( - engine_prompt, + engine_input, sampling_params, sub_request_id, lora_request=lora_request, @@ -675,11 +675,11 @@ class OpenAIServingResponses(OpenAIServing): # Render the next prompt token ids and update sampling_params. if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): token_ids = context.render_for_completion() - engine_prompt = token_inputs(token_ids) + engine_input = tokens_input(token_ids) sampling_params.max_tokens = max_model_len - len(token_ids) elif isinstance(context, ParsableContext): - (engine_prompt,) = await self._render_next_turn( + (engine_input,) = await self._render_next_turn( context.request, context.parser.response_messages, context.tool_dicts, @@ -691,7 +691,7 @@ class OpenAIServingResponses(OpenAIServing): sampling_params.max_tokens = get_max_tokens( max_model_len, context.request.max_output_tokens, - self._extract_prompt_len(engine_prompt), + self._extract_prompt_len(engine_input), self.default_sampling_params, # type: ignore self.override_max_tokens, # type: ignore ) @@ -713,14 +713,10 @@ class OpenAIServingResponses(OpenAIServing): arrival_time = time.time() messages = self._construct_input_messages_with_harmony(request, prev_response) prompt_token_ids = render_for_completion(messages) - engine_prompt = token_inputs(prompt_token_ids) - engine_prompt["arrival_time"] = arrival_time + engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt) + engine_input["arrival_time"] = arrival_time - # Add cache_salt if provided in the request - if request.cache_salt is not None: - engine_prompt["cache_salt"] = request.cache_salt - - return messages, [engine_prompt] + return messages, [engine_input] async def _initialize_tool_sessions( self, diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index bf58273f7..e0a3cf0dc 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -38,7 +38,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ) from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError -from vllm.inputs import EncoderDecoderInputs, ProcessorInputs +from vllm.inputs import EncoderDecoderInput, EngineInput from vllm.logger import init_logger from vllm.logprobs import FlatLogprobs, Logprob from vllm.model_executor.models import SupportsTranscription @@ -171,7 +171,7 @@ class OpenAISpeechToText(OpenAIServing): request: SpeechToTextRequest, audio_data: bytes, request_id: str, - ) -> tuple[list[ProcessorInputs], float]: + ) -> tuple[list[EngineInput], float]: # Validate request language = self.model_cls.validate_language(request.language) # Skip to_language validation to avoid extra logging for Whisper. @@ -250,9 +250,9 @@ class OpenAISpeechToText(OpenAIServing): parsed_prompts.append(parsed_prompt) - engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts) + engine_inputs = await self.renderer.render_cmpl_async(parsed_prompts) - return engine_prompts, duration + return engine_inputs, duration def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt): dec_prompt = prompt["decoder_prompt"] @@ -271,7 +271,7 @@ class OpenAISpeechToText(OpenAIServing): return prompt @staticmethod - def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int: + def _get_decoder_prompt_len(engine_inputs: list[EngineInput]) -> int: """Get the length of the decoder prompt. Currently we need to offset by the decoder prompt length when running beam search because the mm encoder is not currently cached and runs on decode calls; because of @@ -282,12 +282,13 @@ class OpenAISpeechToText(OpenAIServing): encoder/decoder caching is implemented. """ input_len = 0 - assert len(engine_prompts) > 0 - first_eng_prompt = engine_prompts[0] + assert len(engine_inputs) > 0 + first_input = engine_inputs[0] + + if first_input.get("type") == "enc_dec": + first_input = cast(EncoderDecoderInput, first_input) + input_len = len(first_input["decoder_prompt"]["prompt_token_ids"]) - if first_eng_prompt.get("type") == "enc_dec": - first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt) - input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"]) return input_len def _get_verbose_segments( @@ -409,7 +410,7 @@ class OpenAISpeechToText(OpenAIServing): lora_request = self._maybe_get_adapters(request) - engine_prompts, duration_s = await self._preprocess_speech_to_text( + engine_inputs, duration_s = await self._preprocess_speech_to_text( request=request, audio_data=audio_data, request_id=request_id, @@ -420,7 +421,7 @@ class OpenAISpeechToText(OpenAIServing): list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None input_len = ( - OpenAISpeechToText._get_decoder_prompt_len(engine_prompts) + OpenAISpeechToText._get_decoder_prompt_len(engine_inputs) if request.use_beam_search else 0 ) @@ -450,12 +451,12 @@ class OpenAISpeechToText(OpenAIServing): sampling_params.logprobs = 1 list_result_generator = [] - for i, engine_prompt in enumerate(engine_prompts): + for i, engine_input in enumerate(engine_inputs): request_id_item = f"{request_id}_{i}" self._log_inputs( request_id_item, - engine_prompt, + engine_input, params=sampling_params, lora_request=lora_request, ) @@ -468,7 +469,7 @@ class OpenAISpeechToText(OpenAIServing): if isinstance(sampling_params, BeamSearchParams): generator = self.beam_search( - prompt=engine_prompt, + prompt=engine_input, params=sampling_params, request_id=request_id_item, lora_request=lora_request, @@ -476,7 +477,7 @@ class OpenAISpeechToText(OpenAIServing): ) else: generator = self.engine_client.generate( - engine_prompt, + engine_input, sampling_params, request_id_item, lora_request=lora_request, diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py index 5b09ffb49..09e22156e 100644 --- a/vllm/entrypoints/pooling/base/io_processor.py +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -18,7 +18,7 @@ from vllm.entrypoints.pooling.typing import ( PoolingCompletionLikeRequest, PoolingServeContext, ) -from vllm.inputs.data import ProcessorInputs, SingletonPrompt +from vllm.inputs import EngineInput, SingletonPrompt from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq from vllm.tool_parsers import ToolParser @@ -60,7 +60,7 @@ class PoolingIOProcessor: chat_template_kwargs=request.chat_template_kwargs, trust_request_chat_template=self.trust_request_chat_template, ) - _, engine_prompts = self._preprocess_chat_online( + _, engine_inputs = self._preprocess_chat_online( request, request.messages, default_template=self.chat_template, @@ -68,7 +68,7 @@ class PoolingIOProcessor: default_template_kwargs=None, ) elif isinstance(request, PoolingCompletionLikeRequest): - engine_prompts = self._preprocess_completion_online( + engine_inputs = self._preprocess_completion_online( request, prompt_input=request.input, prompt_embeds=None, @@ -76,7 +76,7 @@ class PoolingIOProcessor: else: raise ValueError(f"Invalid {self.name} request type") - ctx.engine_prompts = engine_prompts + ctx.engine_inputs = engine_inputs async def pre_process_online_async(self, ctx: PoolingServeContext): self.pre_process_online(ctx) @@ -100,7 +100,7 @@ class PoolingIOProcessor: self, prompts: PromptType | Sequence[PromptType], tokenization_kwargs: dict[str, Any] | None = None, - ) -> Sequence[ProcessorInputs]: + ) -> Sequence[EngineInput]: return self._preprocess_completion_offline( prompts=prompts, tokenization_kwargs=tokenization_kwargs ) @@ -128,7 +128,7 @@ class PoolingIOProcessor: request: RendererRequest, prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_embeds: bytes | list[bytes] | None, - ) -> list[ProcessorInputs]: + ) -> list[EngineInput]: renderer = self.renderer model_config = self.model_config @@ -167,7 +167,7 @@ class PoolingIOProcessor: default_template_kwargs: dict[str, Any] | None, tool_dicts: list[dict[str, Any]] | None = None, tool_parser: type[ToolParser] | None = None, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: + ) -> tuple[list[ConversationMessage], list[EngineInput]]: renderer = self.renderer default_template_kwargs = merge_kwargs( @@ -188,7 +188,7 @@ class PoolingIOProcessor: default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), ) - (conversation,), (engine_prompt,) = renderer.render_chat( + (conversation,), (engine_input,) = renderer.render_chat( [messages], chat_params, tok_params, @@ -199,13 +199,13 @@ class PoolingIOProcessor: }, ) - return conversation, [engine_prompt] + return conversation, [engine_input] def _preprocess_completion_offline( self, prompts: PromptType | Sequence[PromptType], tokenization_kwargs: dict[str, Any] | None = None, - ) -> Sequence[ProcessorInputs]: + ) -> Sequence[EngineInput]: renderer = self.renderer model_config = self.model_config diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py index 312eed6bf..1f7238e27 100644 --- a/vllm/entrypoints/pooling/base/serving.py +++ b/vllm/entrypoints/pooling/base/serving.py @@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext from vllm.exceptions import VLLMNotFoundError -from vllm.inputs.data import ProcessorInputs +from vllm.inputs import EngineInput from vllm.lora.request import LoRARequest from vllm.renderers.base import BaseRenderer from vllm.renderers.inputs.preprocess import extract_prompt_components @@ -106,7 +106,7 @@ class PoolingServing: self, ctx: PoolingServeContext, ): - if ctx.engine_prompts is None: + if ctx.engine_inputs is None: raise ValueError("Engine prompts not available") generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] @@ -120,7 +120,7 @@ class PoolingServing: pooling_params = self.io_processor.create_pooling_params(ctx.request) pooling_params.verify(self.model_config) - for i, engine_prompt in enumerate(ctx.engine_prompts): + for i, engine_input in enumerate(ctx.engine_inputs): prompt_request_id = ( f"{ctx.request_id}-{i}" if ctx.prompt_request_ids is None @@ -129,13 +129,13 @@ class PoolingServing: self._log_inputs( prompt_request_id, - engine_prompt, + engine_input, params=pooling_params, lora_request=ctx.lora_request, ) generator = self.engine_client.encode( - engine_prompt, + engine_input, pooling_params, prompt_request_id, lora_request=ctx.lora_request, @@ -151,13 +151,13 @@ class PoolingServing: self, ctx: PoolingServeContext, ): - if ctx.engine_prompts is None: + if ctx.engine_inputs is None: raise ValueError("Engine prompts not available") if ctx.result_generator is None: raise ValueError("Result generator not available") - num_inputs = len(ctx.engine_prompts) + num_inputs = len(ctx.engine_inputs) final_res_batch: list[PoolingRequestOutput | None] final_res_batch = [None] * num_inputs @@ -317,7 +317,7 @@ class PoolingServing: def _log_inputs( self, request_id: str, - inputs: ProcessorInputs, + inputs: EngineInput, params: PoolingParams, lora_request: LoRARequest | None, ) -> None: diff --git a/vllm/entrypoints/pooling/embed/io_processor.py b/vllm/entrypoints/pooling/embed/io_processor.py index 9342013bf..f9383d6a6 100644 --- a/vllm/entrypoints/pooling/embed/io_processor.py +++ b/vllm/entrypoints/pooling/embed/io_processor.py @@ -24,7 +24,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingCompletionRequest, ) from vllm.entrypoints.pooling.typing import PoolingServeContext -from vllm.inputs.data import ProcessorInputs, token_inputs +from vllm.inputs import EngineInput, tokens_input from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.renderers import merge_kwargs @@ -83,20 +83,20 @@ class EmbedIOProcessor(PoolingIOProcessor): ################################################################# def _pre_process_chunked(self, ctx: PoolingServeContext) -> None: - if ctx.engine_prompts is None: + if ctx.engine_inputs is None: raise ValueError("Engine prompts not available") - ctx.intermediates = ctx.engine_prompts + ctx.intermediates = ctx.engine_inputs request_id = ctx.request_id max_model_len = self.model_config.max_model_len - chunked_engine_prompts: list[ProcessorInputs] = [] + chunked_engine_inputs: list[EngineInput] = [] prompt_request_ids: list[str] = [] - for prompt_idx, engine_prompt in enumerate(ctx.engine_prompts): - token_ids = engine_prompt.get("prompt_token_ids", None) + for prompt_idx, engine_input in enumerate(ctx.engine_inputs): + token_ids = engine_input.get("prompt_token_ids", None) if token_ids is None: raise NotImplementedError( "Long Text Embedding with Chunked Processing does " - "not support EmbedsPrompt and EncoderDecoderInputs." + "not support EmbedsPrompt and EncoderDecoderInput." ) prompt_token_ids = cast(list[int], token_ids) @@ -104,14 +104,14 @@ class EmbedIOProcessor(PoolingIOProcessor): for chunk_idx, chunk_tokens in enumerate( chunk_list(prompt_token_ids, max_model_len) ): - chunked_engine_prompts.append( - token_inputs(prompt_token_ids=chunk_tokens) + chunked_engine_inputs.append( + tokens_input(prompt_token_ids=chunk_tokens) ) prompt_request_ids.append( f"{request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" ) - ctx.engine_prompts = chunked_engine_prompts + ctx.engine_inputs = chunked_engine_inputs ctx.prompt_request_ids = prompt_request_ids return None @@ -184,8 +184,8 @@ class EmbedIOProcessor(PoolingIOProcessor): if ctx.intermediates is None: raise ValueError("Original prompts inputs not available") - original_engine_prompts = cast(list[ProcessorInputs], ctx.intermediates) - num_prompts = len(original_engine_prompts) + original_engine_inputs = cast(list[EngineInput], ctx.intermediates) + num_prompts = len(original_engine_inputs) # Finalize aggregated results final_res_batch: list[PoolingRequestOutput] = [] @@ -211,12 +211,12 @@ class EmbedIOProcessor(PoolingIOProcessor): pooling_output_data = PoolingOutput(data=final_embedding) # Get original prompt token IDs for this prompt - original_prompt = original_engine_prompts[prompt_idx] + original_prompt = original_engine_inputs[prompt_idx] token_ids = original_prompt.get("prompt_token_ids", None) if token_ids is None: raise NotImplementedError( "Long Text Embedding with Chunked Processing does " - "not support EmbedsPrompt and EncoderDecoderInputs." + "not support EmbedsPrompt and EncoderDecoderInput." ) original_token_ids = cast(list[int], token_ids) @@ -372,7 +372,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ] for uri in request.images ] - ctx.engine_prompts = self._batch_render_chat( + ctx.engine_inputs = self._batch_render_chat( request, all_messages, truncate_prompt_tokens, truncation_side ) @@ -382,7 +382,7 @@ class EmbedIOProcessor(PoolingIOProcessor): self._mixed_input_to_messages(inp, task_prefix=task_prefix) for inp in request.inputs ] - ctx.engine_prompts = self._batch_render_chat( + ctx.engine_inputs = self._batch_render_chat( request, all_messages, truncate_prompt_tokens, truncation_side ) @@ -396,7 +396,7 @@ class EmbedIOProcessor(PoolingIOProcessor): truncate_prompt_tokens=truncate_prompt_tokens, truncation_side=truncation_side, ) - ctx.engine_prompts = self._preprocess_completion_online( + ctx.engine_inputs = self._preprocess_completion_online( proxy, prompt_input=proxy.input, prompt_embeds=None ) @@ -406,7 +406,7 @@ class EmbedIOProcessor(PoolingIOProcessor): all_messages: Sequence[list[ChatCompletionMessageParam]], truncate_prompt_tokens: int | None, truncation_side: Literal["left", "right"] | None, - ) -> list[ProcessorInputs]: + ) -> list[EngineInput]: """Batch-render multiple conversations through the chat template.""" if not all_messages: return [] @@ -438,8 +438,8 @@ class EmbedIOProcessor(PoolingIOProcessor): default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), ) - _, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params) - return engine_prompts + _, engine_inputs = renderer.render_chat(all_messages, chat_params, tok_params) + return engine_inputs def _validate_input_type(self, input_type: str | None) -> None: """Raise if *input_type* is not supported by this model.""" diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index d9f8ea166..4706684f3 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -33,7 +33,7 @@ from vllm.entrypoints.pooling.utils import ( encode_pooling_output_float, ) from vllm.entrypoints.serve.render.serving import OpenAIServingRender -from vllm.inputs import ProcessorInputs +from vllm.inputs import EngineInput from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput from vllm.renderers.inputs.preprocess import prompt_to_seq @@ -110,7 +110,7 @@ class OpenAIServingPooling(OpenAIServing): request.task, ) - engine_prompts: Sequence[ProcessorInputs] + engine_inputs: Sequence[EngineInput] if use_io_processor := isinstance(request, IOProcessorRequest): if self.io_processor is None: raise ValueError( @@ -125,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing): raw_prompts = await self.io_processor.pre_process_async( prompt=validated_prompt, request_id=request_id ) - engine_prompts = await self.openai_serving_render.preprocess_cmpl( + engine_inputs = await self.openai_serving_render.preprocess_cmpl( request, prompt_to_seq(raw_prompts), ) @@ -138,7 +138,7 @@ class OpenAIServingPooling(OpenAIServing): if error_check_ret is not None: return error_check_ret - _, engine_prompts = await self.openai_serving_render.preprocess_chat( + _, engine_inputs = await self.openai_serving_render.preprocess_chat( request, request.messages, default_template=self.chat_template, @@ -146,7 +146,7 @@ class OpenAIServingPooling(OpenAIServing): default_template_kwargs=None, ) elif isinstance(request, PoolingCompletionRequest): - engine_prompts = await self.openai_serving_render.preprocess_completion( + engine_inputs = await self.openai_serving_render.preprocess_completion( request, prompt_input=request.input, prompt_embeds=None, @@ -165,12 +165,12 @@ class OpenAIServingPooling(OpenAIServing): else: pooling_params = request.to_pooling_params() # type: ignore - for i, engine_prompt in enumerate(engine_prompts): + for i, engine_input in enumerate(engine_inputs): request_id_item = f"{request_id}-{i}" self._log_inputs( request_id_item, - engine_prompt, + engine_input, params=pooling_params, lora_request=lora_request, ) @@ -182,7 +182,7 @@ class OpenAIServingPooling(OpenAIServing): ) generator = self.engine_client.encode( - engine_prompt, + engine_input, pooling_params, request_id_item, lora_request=lora_request, @@ -221,7 +221,7 @@ class OpenAIServingPooling(OpenAIServing): return IOProcessorResponse(request_id=request_id, data=output) assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) - num_prompts = len(engine_prompts) + num_prompts = len(engine_inputs) # Non-streaming response final_res_batch: list[PoolingRequestOutput | None] diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index d8cbff99d..d6b70c7ac 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import ( parse_score_data_single, validate_score_input, ) -from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs +from vllm.inputs import EngineInput, TokensPrompt, tokens_input from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput @@ -110,12 +110,12 @@ class ServingScores(OpenAIServing): *(encode_async(t, **tokenization_kwargs) for t in input_texts) ) - engine_prompts: list[ProcessorInputs] = [] + engine_inputs: list[EngineInput] = [] for tok_result, input_text in zip(tokenized_prompts, input_texts): text_token_prompt = self._validate_input(request, tok_result, input_text) - engine_prompts.append( - token_inputs( + engine_inputs.append( + tokens_input( text_token_prompt["prompt_token_ids"], prompt=input_text, ) @@ -125,19 +125,19 @@ class ServingScores(OpenAIServing): generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] pooling_params = request.to_pooling_params("embed") - for i, engine_prompt in enumerate(engine_prompts): + for i, engine_input in enumerate(engine_inputs): request_id_item = f"{request_id}-{i}" self._log_inputs( request_id_item, - engine_prompt, + engine_input, params=pooling_params, lora_request=lora_request, ) generators.append( self.engine_client.encode( - engine_prompt, + engine_input, pooling_params, request_id_item, lora_request=lora_request, @@ -151,7 +151,7 @@ class ServingScores(OpenAIServing): # Non-streaming response final_res_batch: list[PoolingRequestOutput] = [] - embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts) + embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_inputs) async for i, res in result_generator: embeddings[i] = res @@ -183,7 +183,7 @@ class ServingScores(OpenAIServing): request: RerankRequest | ScoreRequest, tokenizer: TokenizerLike, tokenization_kwargs: dict[str, Any], - ) -> tuple[str, TokensPrompt]: + ) -> TokensPrompt: """Parse a single ScoreData into a text + optional multimodal TokensPrompt for late-interaction encoding. @@ -197,21 +197,22 @@ class ServingScores(OpenAIServing): else: text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config) - prompt_inputs = tokenizer(text, **tokenization_kwargs) - self._validate_input(request, prompt_inputs["input_ids"], text) + prompt_ids = tokenizer.encode(text, **tokenization_kwargs) + self._validate_input(request, prompt_ids, text) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], + tok_prompt = TokensPrompt( + prompt_token_ids=prompt_ids, + prompt=text, ) if mm_data is not None: - engine_prompt["multi_modal_data"] = mm_data + tok_prompt["multi_modal_data"] = mm_data if mm_uuids is not None: - engine_prompt["multi_modal_uuids"] = mm_uuids + tok_prompt["multi_modal_uuids"] = mm_uuids if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + tok_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs - return text, engine_prompt + return tok_prompt async def _late_interaction_score( self, @@ -240,7 +241,7 @@ class ServingScores(OpenAIServing): executor=self._tokenizer_executor, ) - preprocessed = await asyncio.gather( + tok_prompts = await asyncio.gather( *( preprocess_async( data=d, @@ -253,12 +254,8 @@ class ServingScores(OpenAIServing): ) ) - query_prompts: list[TokensPrompt] = [ - prompt for _, prompt in preprocessed[: len(data_1)] - ] - doc_prompts: list[TokensPrompt] = [ - prompt for _, prompt in preprocessed[len(data_1) :] - ] + query_prompts = tok_prompts[: len(data_1)] + doc_prompts = tok_prompts[len(data_1) :] default_pooling_params = request.to_pooling_params("token_embed") @@ -268,7 +265,7 @@ class ServingScores(OpenAIServing): query_prompts ) query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - for i, engine_prompt in enumerate(query_prompts): + for i, tok_prompt in enumerate(query_prompts): request_id_item = f"{request_id}-query-{i}" pooling_params = default_pooling_params.clone() pooling_params.late_interaction_params = ( @@ -280,14 +277,14 @@ class ServingScores(OpenAIServing): self._log_inputs( request_id_item, - engine_prompt, + tok_prompt, params=pooling_params, lora_request=lora_request, ) query_generators.append( self.engine_client.encode( - engine_prompt, + tok_prompt, pooling_params, request_id_item, lora_request=lora_request, @@ -306,7 +303,7 @@ class ServingScores(OpenAIServing): # stage 2: encode docs and return scalar scores from workers. doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - for i, engine_prompt in enumerate(doc_prompts): + for i, tok_prompt in enumerate(doc_prompts): request_id_item = f"{request_id}-doc-{i}" query_idx = 0 if len(query_prompts) == 1 else i pooling_params = default_pooling_params.clone() @@ -316,14 +313,14 @@ class ServingScores(OpenAIServing): self._log_inputs( request_id_item, - engine_prompt, + tok_prompt, params=pooling_params, lora_request=lora_request, ) doc_generators.append( self.engine_client.encode( - engine_prompt, + tok_prompt, pooling_params, request_id_item, lora_request=lora_request, @@ -404,28 +401,22 @@ class ServingScores(OpenAIServing): ) ) - request_prompts: list[str] = [] - engine_prompts: list[TokensPrompt] = [] - for full_prompt, engine_prompt in preprocessed_prompts: - request_prompts.append(full_prompt) - engine_prompts.append(engine_prompt) - # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] default_pooling_params = request.to_pooling_params("classify") - for i, engine_prompt in enumerate(engine_prompts): + for i, (full_prompt, tok_prompt) in enumerate(preprocessed_prompts): request_id_item = f"{request_id}-{i}" self._log_inputs( request_id_item, - request_prompts[i], + full_prompt, params=default_pooling_params, lora_request=lora_request, ) - if token_type_ids := engine_prompt.pop("token_type_ids", None): + if token_type_ids := tok_prompt.pop("token_type_ids", None): pooling_params = default_pooling_params.clone() compressed = compress_token_type_ids(token_type_ids) pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed} @@ -433,7 +424,7 @@ class ServingScores(OpenAIServing): pooling_params = default_pooling_params generator = self.engine_client.encode( - engine_prompt, + tok_prompt, pooling_params, request_id_item, lora_request=lora_request, @@ -447,7 +438,7 @@ class ServingScores(OpenAIServing): # Non-streaming response final_res_batch: list[PoolingRequestOutput | None] = [None] * len( - engine_prompts + preprocessed_prompts ) async for i, res in result_generator: @@ -464,7 +455,7 @@ class ServingScores(OpenAIServing): data_2: ScoreData, ) -> tuple[str, TokensPrompt]: model_config = self.model_config - full_prompt, engine_prompt = get_score_prompt( + full_prompt, engine_input = get_score_prompt( model_config=model_config, data_1=data_1, data_2=data_2, @@ -472,11 +463,11 @@ class ServingScores(OpenAIServing): tokenization_kwargs=tokenization_kwargs, score_template=self.score_template, ) - self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt) + self._validate_input(request, engine_input["prompt_token_ids"], full_prompt) if request.mm_processor_kwargs is not None: - engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + engine_input["mm_processor_kwargs"] = request.mm_processor_kwargs - return full_prompt, engine_prompt + return full_prompt, engine_input async def _run_scoring( self, diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 60e71ff73..f620e3790 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -20,10 +20,14 @@ from vllm.entrypoints.chat_utils import ( MultiModalItemTracker, _parse_chat_message_content_parts, ) -from vllm.inputs import TokensPrompt -from vllm.inputs.data import PromptType, TextPrompt +from vllm.inputs import ( + MultiModalDataDict, + MultiModalUUIDDict, + PromptType, + TextPrompt, + TokensPrompt, +) from vllm.model_executor.models.interfaces import supports_score_template -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.outputs import PoolingRequestOutput from vllm.renderers.hf import safe_apply_chat_template from vllm.tokenizers import TokenizerLike diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py index f9f361824..1df72ca5c 100644 --- a/vllm/entrypoints/pooling/typing.py +++ b/vllm/entrypoints/pooling/typing.py @@ -32,7 +32,7 @@ from vllm.entrypoints.pooling.score.protocol import ( ScoreRequest, ScoreResponse, ) -from vllm.inputs import ProcessorInputs +from vllm.inputs import EngineInput from vllm.lora.request import LoRARequest PoolingCompletionLikeRequest: TypeAlias = ( @@ -74,7 +74,7 @@ class PoolingServeContext(Generic[PoolingRequestT]): created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - engine_prompts: list[ProcessorInputs] | None = None + engine_inputs: list[EngineInput] | None = None prompt_request_ids: list[str] | None = None intermediates: Any | None = None diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index 028e8dee7..af4e8c20c 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -33,19 +33,20 @@ class MultiModalFeatures(BaseModel): """Lightweight multimodal metadata produced by the render step. Carries hashes (for cache lookup / identification) and placeholder - positions so the downstream ``/generate`` service knows *where* in + positions so the downstream `/generate` service knows *where* in the token sequence each multimodal item lives. - .. note:: Phase 1 — metadata only. - Phase 2 should add ``mm_kwargs`` (processed tensor data) using a - binary transport so the ``/generate`` side can skip re-processing. - The ``/generate`` endpoint must also be updated to inject these - features into ``ProcessorInputs`` before passing to - ``InputProcessor.process_inputs``. + Note: + Phase 1 — metadata only. + Phase 2 should add `mm_kwargs` (processed tensor data) using a + binary transport so the ``/generate` side can skip re-processing. + The `/generate` endpoint must also be updated to inject these + features into `EngineInput` before passing to + `InputProcessor.process_inputs`. """ mm_hashes: dict[str, list[str]] - """Per-modality item hashes, e.g. ``{"image": ["abc", "def"]}``.""" + """Per-modality item hashes, e.g. `{"image": ["abc", "def"]}`.""" mm_placeholders: dict[str, list[PlaceholderRangeInfo]] """Per-modality placeholder ranges in the token sequence.""" diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index 46f68d535..79367622c 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -99,13 +99,11 @@ class ServingTokens(OpenAIServing): if raw_request: raw_request.state.request_metadata = request_metadata - engine_prompts = await self.openai_serving_render.preprocess_completion( + (engine_input,) = await self.openai_serving_render.preprocess_completion( request, prompt_input=request.token_ids, prompt_embeds=None, ) - assert len(engine_prompts) == 1 - engine_prompt = engine_prompts[0] # Schedule the request and get the result generator. result_generator: AsyncGenerator[RequestOutput, None] | None = None @@ -115,7 +113,7 @@ class ServingTokens(OpenAIServing): self._log_inputs( request_id, - engine_prompt, + engine_input, params=sampling_params, lora_request=lora_request, ) @@ -127,7 +125,7 @@ class ServingTokens(OpenAIServing): ) result_generator = self.engine_client.generate( - engine_prompt, + engine_input, sampling_params, request_id, lora_request=lora_request, diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index a6d2f5040..6009b666d 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -34,9 +34,15 @@ from vllm.entrypoints.utils import ( create_error_response, get_max_tokens, ) -from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt +from vllm.inputs import ( + EngineInput, + MultiModalHashes, + MultiModalPlaceholders, + PromptType, + SingletonPrompt, + tokens_input, +) from vllm.logger import init_logger -from vllm.multimodal.inputs import MultiModalHashes, MultiModalPlaceholderDict from vllm.parser import ParserManager from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers.inputs.preprocess import ( @@ -127,22 +133,22 @@ class OpenAIServingRender: if isinstance(result, ErrorResponse): return result - _, engine_prompts = result + _, engine_inputs = result - if len(engine_prompts) != 1: + if len(engine_inputs) != 1: return self.create_error_response( - f"Expected exactly 1 engine prompt, got {len(engine_prompts)}" + f"Expected exactly 1 engine prompt, got {len(engine_inputs)}" ) - engine_prompt = engine_prompts[0] + engine_input = engine_inputs[0] - prompt_components = extract_prompt_components(self.model_config, engine_prompt) + prompt_components = extract_prompt_components(self.model_config, engine_input) token_ids = prompt_components.token_ids if not token_ids: return self.create_error_response("No token_ids rendered") token_ids = list(token_ids) - input_length = extract_prompt_len(self.model_config, engine_prompt) + input_length = extract_prompt_len(self.model_config, engine_input) max_tokens = get_max_tokens( self.model_config.max_model_len, request.max_completion_tokens @@ -159,7 +165,7 @@ class OpenAIServingRender: return GenerateRequest( request_id=request_id, token_ids=token_ids, - features=self._extract_mm_features(engine_prompt), + features=self._extract_mm_features(engine_input), sampling_params=params, model=request.model, stream=bool(request.stream), @@ -171,7 +177,7 @@ class OpenAIServingRender: async def render_chat( self, request: ChatCompletionRequest, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse: + ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse: """Core preprocessing logic for chat requests (no model/engine check). Called directly by render_chat_request and delegated to by @@ -184,7 +190,6 @@ class OpenAIServingRender: if is_mistral_tokenizer(tokenizer): # because of issues with pydantic we need to potentially # re-serialize the tool_calls field of the request - # for more info: see comment in `maybe_serialize_tool_calls` _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type] _mt.truncate_tool_call_ids(request) # type: ignore[arg-type] _mt.validate_request_params(request) @@ -232,7 +237,7 @@ class OpenAIServingRender: if error_check_ret is not None: return error_check_ret - conversation, engine_prompts = await self.preprocess_chat( + conversation, engine_inputs = await self.preprocess_chat( request, request.messages, default_template=self.chat_template, @@ -244,11 +249,11 @@ class OpenAIServingRender: else: # For GPT-OSS. should_include_tools = tool_dicts is not None - conversation, engine_prompts = self._make_request_with_harmony( + conversation, engine_inputs = self._make_request_with_harmony( request, should_include_tools ) - return conversation, engine_prompts + return conversation, engine_inputs async def render_completion_request( self, @@ -266,16 +271,16 @@ class OpenAIServingRender: if isinstance(result, ErrorResponse): return result generate_requests: list[GenerateRequest] = [] - for engine_prompt in result: + for engine_input in result: prompt_components = extract_prompt_components( - self.model_config, engine_prompt + self.model_config, engine_input ) token_ids = prompt_components.token_ids if not token_ids: return self.create_error_response("No token_ids rendered") token_ids = list(token_ids) - input_length = extract_prompt_len(self.model_config, engine_prompt) + input_length = extract_prompt_len(self.model_config, engine_input) max_tokens = get_max_tokens( self.model_config.max_model_len, request.max_tokens, @@ -293,7 +298,7 @@ class OpenAIServingRender: GenerateRequest( request_id=request_id, token_ids=token_ids, - features=self._extract_mm_features(engine_prompt), + features=self._extract_mm_features(engine_input), sampling_params=params, model=request.model, stream=bool(request.stream), @@ -308,7 +313,7 @@ class OpenAIServingRender: async def render_completion( self, request: CompletionRequest, - ) -> list[ProcessorInputs] | ErrorResponse: + ) -> list[EngineInput] | ErrorResponse: """Core preprocessing logic for completion requests (no model/engine check). Called directly by render_completion_request and delegated to by @@ -326,28 +331,28 @@ class OpenAIServingRender: "prompt_logprobs is not compatible with prompt embeds." ) - engine_prompts = await self.preprocess_completion( + engine_inputs = await self.preprocess_completion( request, prompt_input=request.prompt, prompt_embeds=request.prompt_embeds, ) - return engine_prompts + return engine_inputs @staticmethod def _extract_mm_features( - engine_prompt: ProcessorInputs, + engine_input: EngineInput, ) -> MultiModalFeatures | None: """Extract multimodal metadata from a rendered engine prompt. Returns ``None`` for text-only prompts. """ - if engine_prompt.get("type") != "multimodal": + if engine_input.get("type") != "multimodal": return None - # At this point engine_prompt is a MultiModalInputs TypedDict. - mm_hashes: MultiModalHashes = engine_prompt["mm_hashes"] # type: ignore[typeddict-item] - raw_placeholders: MultiModalPlaceholderDict = engine_prompt["mm_placeholders"] # type: ignore[typeddict-item] + # At this point engine_input is a MultiModalInputs TypedDict. + mm_hashes: MultiModalHashes = engine_input["mm_hashes"] # type: ignore[typeddict-item] + raw_placeholders: MultiModalPlaceholders = engine_input["mm_placeholders"] # type: ignore[typeddict-item] mm_placeholders = { modality: [ @@ -401,13 +406,9 @@ class OpenAIServingRender: # Render prompt token ids. prompt_token_ids = render_for_completion(messages) - engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt) - # Add cache_salt if provided in the request - if request.cache_salt is not None: - engine_prompt["cache_salt"] = request.cache_salt - - return messages, [engine_prompt] + return messages, [engine_input] def create_error_response( self, @@ -450,7 +451,7 @@ class OpenAIServingRender: request: Any, prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_embeds: bytes | list[bytes] | None, - ) -> list[ProcessorInputs]: + ) -> list[EngineInput]: """Copied from OpenAIServing._preprocess_completion.""" prompts = list[SingletonPrompt | bytes]() if prompt_embeds is not None: # embeds take higher priority @@ -463,7 +464,7 @@ class OpenAIServingRender: self, request: Any, prompts: Sequence[PromptType | bytes], - ) -> list[ProcessorInputs]: + ) -> list[EngineInput]: """Copied from OpenAIServing._preprocess_cmpl.""" renderer = self.renderer model_config = self.model_config @@ -497,7 +498,7 @@ class OpenAIServingRender: default_template_kwargs: dict[str, Any] | None, tool_dicts: list[dict[str, Any]] | None = None, tool_parser: type[ToolParser] | None = None, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: + ) -> tuple[list[ConversationMessage], list[EngineInput]]: """Copied from OpenAIServing._preprocess_chat.""" renderer = self.renderer mm_config = self.model_config.multimodal_config @@ -519,7 +520,7 @@ class OpenAIServingRender: default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), ) - (conversation,), (engine_prompt,) = await renderer.render_chat_async( + (conversation,), (engine_input,) = await renderer.render_chat_async( [messages], chat_params, tok_params, @@ -546,4 +547,4 @@ class OpenAIServingRender: tokenizer = renderer.get_tokenizer() request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] - return conversation, [engine_prompt] + return conversation, [engine_input] diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index d68651da8..22b852d27 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( TokenizeResponse, TokenizerInfoResponse, ) -from vllm.inputs import TokensPrompt, token_inputs +from vllm.inputs import TokensPrompt, tokens_input from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike @@ -79,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing): if error_check_ret is not None: return error_check_ret - _, engine_prompts = await self.openai_serving_render.preprocess_chat( + _, engine_inputs = await self.openai_serving_render.preprocess_chat( request, request.messages, default_template=self.chat_template, @@ -88,22 +88,22 @@ class OpenAIServingTokenization(OpenAIServing): tool_dicts=tool_dicts, ) else: - engine_prompts = await self.openai_serving_render.preprocess_completion( + engine_inputs = await self.openai_serving_render.preprocess_completion( request, prompt_input=request.prompt, prompt_embeds=None, ) input_ids: list[int] = [] - for engine_prompt in engine_prompts: + for engine_input in engine_inputs: self._log_inputs( request_id, - engine_prompt, + engine_input, params=None, lora_request=lora_request, ) - prompt_components = self._extract_prompt_components(engine_prompt) + prompt_components = self._extract_prompt_components(engine_input) if prompt_components.token_ids is not None: input_ids.extend(prompt_components.token_ids) @@ -134,16 +134,16 @@ class OpenAIServingTokenization(OpenAIServing): self._log_inputs( request_id, - token_inputs(request.tokens), + tokens_input(request.tokens), params=None, lora_request=lora_request, ) - engine_prompt = await self.renderer.tokenize_prompt_async( + tok_prompt = await self.renderer.tokenize_prompt_async( TokensPrompt(prompt_token_ids=request.tokens), request.build_tok_params(self.model_config), ) - prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item] + prompt_text = tok_prompt["prompt"] # type: ignore[typeddict-item] return DetokenizeResponse(prompt=prompt_text) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 2f9db8bdd..7f89a598b 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,38 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .data import ( +from .engine import ( + DecoderOnlyEngineInput, + EmbedsInput, + EncoderDecoderInput, + EngineInput, + MultiModalEncDecInput, + MultiModalHashes, + MultiModalInput, + MultiModalPlaceholders, + SingletonInput, + TokensInput, + build_enc_dec_input, + embeds_input, + mm_enc_dec_input, + mm_input, + split_enc_dec_input, + tokens_input, +) +from .llm import ( DataPrompt, - DecoderOnlyInputs, - EmbedsInputs, EmbedsPrompt, - EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, - ProcessorInputs, + ModalityData, + MultiModalDataBuiltins, + MultiModalDataDict, + MultiModalUUIDDict, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, - TokenInputs, TokensPrompt, - embeds_inputs, - token_inputs, ) __all__ = [ + "ModalityData", + "MultiModalDataBuiltins", + "MultiModalDataDict", + "MultiModalUUIDDict", "DataPrompt", "TextPrompt", "TokensPrompt", "PromptType", "SingletonPrompt", "ExplicitEncoderDecoderPrompt", - "TokenInputs", - "EmbedsInputs", "EmbedsPrompt", - "token_inputs", - "embeds_inputs", - "DecoderOnlyInputs", - "EncoderDecoderInputs", - "ProcessorInputs", - "SingletonInputs", + "MultiModalHashes", + "MultiModalPlaceholders", + "TokensInput", + "EmbedsInput", + "MultiModalInput", + "MultiModalEncDecInput", + "tokens_input", + "embeds_input", + "mm_input", + "mm_enc_dec_input", + "build_enc_dec_input", + "split_enc_dec_input", + "DecoderOnlyEngineInput", + "EncoderDecoderInput", + "SingletonInput", + "EngineInput", ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py deleted file mode 100644 index a3d3e2198..000000000 --- a/vllm/inputs/data.py +++ /dev/null @@ -1,413 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Literal, TypeAlias - -import torch -from typing_extensions import NotRequired, TypedDict, assert_never - -if TYPE_CHECKING: - from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalEncDecInputs, - MultiModalInputs, - MultiModalUUIDDict, - ) -else: - MultiModalDataDict = object - MultiModalEncDecInputs = object - MultiModalInputs = object - MultiModalUUIDDict = object - - -# Inputs to LLM API -class _PromptOptions(TypedDict): - """ - Additional options available to all - [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt]. - """ - - multi_modal_data: NotRequired[MultiModalDataDict | None] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - mm_processor_kwargs: NotRequired[dict[str, Any] | None] - """ - Optional multi-modal processor kwargs to be forwarded to the - multimodal input mapper & processor. Note that if multiple modalities - have registered mappers etc for the model being considered, we attempt - to pass the mm_processor_kwargs to each of them. - """ - - multi_modal_uuids: NotRequired[MultiModalUUIDDict] - """ - Optional user-specified UUIDs for multimodal items, mapped by modality. - Lists must match the number of items per modality and may contain `None`. - For `None` entries, the hasher will compute IDs automatically; non-None - entries override the default hashes for caching, and MUST be unique per - multimodal item. - """ - - cache_salt: NotRequired[str] - """ - Optional cache salt to be used for prefix caching. - """ - - -class TextPrompt(_PromptOptions): - """Schema for a text prompt.""" - - prompt: str - """The input text to be tokenized before passing to the model.""" - - -class TokensPrompt(_PromptOptions): - """Schema for a tokenized prompt.""" - - prompt_token_ids: list[int] - """A list of token IDs to pass to the model.""" - - prompt: NotRequired[str] - """The prompt text corresponding to the token IDs, if available.""" - - token_type_ids: NotRequired[list[int]] - """A list of token type IDs to pass to the cross encoder model.""" - - -class EmbedsPrompt(_PromptOptions): - """Schema for a prompt provided via token embeddings.""" - - prompt_embeds: torch.Tensor - """The embeddings of the prompt.""" - - prompt: NotRequired[str] - """The prompt text corresponding to the token embeddings, if available.""" - - -DecoderOnlyPrompt: TypeAlias = ( - str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt -) -""" -Schema of a prompt for a decoder-only model: - -- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt]) -- A tokenized prompt (list of token IDs, or - [`TokensPrompt`][vllm.inputs.data.TokensPrompt]) -- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.data.EmbedsPrompt]) - -For encoder-decoder models, passing a singleton prompt is shorthand for passing -`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`. -""" - - -EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt -""" -Schema of a prompt for the encoder part of a encoder-decoder model: - -- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt]) -- A tokenized prompt (list of token IDs, or - [`TokensPrompt`][vllm.inputs.data.TokensPrompt]) -""" - - -DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt -""" -Schema of a prompt for the decoder part of an encoder-decoder model: - -- A text prompt (string or [`TextPrompt`][vllm.inputs.data.TextPrompt]) -- A tokenized prompt (list of token IDs, or - [`TokensPrompt`][vllm.inputs.data.TokensPrompt]) - -Note: - Multi-modal inputs are not supported for decoder prompts. -""" - - -class ExplicitEncoderDecoderPrompt(TypedDict): - """ - Schema for a pair of encoder and decoder singleton prompts. - - Note: - This schema is not valid for decoder-only models. - """ - - encoder_prompt: EncoderPrompt - """The prompt for the encoder part of the model.""" - - decoder_prompt: DecoderPrompt | None - """ - The prompt for the decoder part of the model. - - Passing `None` will cause the prompt to be inferred automatically. - """ - - -EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt -""" -Schema for a prompt for an encoder-decoder model. - -You can pass a singleton encoder prompt, in which case the decoder prompt is -considered to be `None` (i.e., infer automatically). -""" - - -SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt -""" -Schema for a single prompt. This is as opposed to a data structure -which encapsulates multiple prompts, such as -[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]. -""" - - -PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt -""" -Schema for any prompt, regardless of model type. - -This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs. -""" - - -class DataPrompt(_PromptOptions): - """ - Represents generic inputs that are converted to - [`PromptType`][vllm.inputs.data.PromptType] by IO processor plugins. - """ - - data: Any - """The input data.""" - - data_format: str - """The input data format.""" - - -# Outputs of processor -class _InputOptions(TypedDict): - """ - Additional options available to all input types. - """ - - arrival_time: NotRequired[float] - """The time when the input was received (before rendering).""" - - cache_salt: NotRequired[str] - """Optional cache salt to be used for prefix caching.""" - - -class TokenInputs(_InputOptions): - """Represents token-based inputs.""" - - type: Literal["token"] - """The type of inputs.""" - - prompt_token_ids: list[int] - """The token IDs of the prompt.""" - - prompt: NotRequired[str] - """The prompt text corresponding to the token IDs, if available.""" - - -def token_inputs( - prompt_token_ids: list[int], - *, - prompt: str | None = None, - cache_salt: str | None = None, -) -> TokenInputs: - """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional - values.""" - inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) - - if prompt is not None: - inputs["prompt"] = prompt - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - - return inputs - - -class EmbedsInputs(_InputOptions): - """Represents embeddings-based inputs.""" - - type: Literal["embeds"] - """The type of inputs.""" - - prompt_embeds: torch.Tensor - """The embeddings of the prompt.""" - - prompt: NotRequired[str] - """The prompt text corresponding to the token IDs, if available.""" - - -def embeds_inputs( - prompt_embeds: torch.Tensor, - *, - prompt: str | None = None, - cache_salt: str | None = None, -) -> EmbedsInputs: - """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional - values.""" - inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) - - if prompt is not None: - inputs["prompt"] = prompt - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - - return inputs - - -DecoderOnlyInputs: TypeAlias = TokenInputs | EmbedsInputs | MultiModalInputs -""" -A processed prompt from -[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] -which can be passed to -[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] -for decoder-only models. -""" - - -EncoderInputs: TypeAlias = TokenInputs | MultiModalEncDecInputs -""" -A processed encoder prompt from -[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] -which can be passed to -[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] -for encoder-decoder models. -""" - - -DecoderInputs: TypeAlias = TokenInputs | MultiModalInputs -""" -A processed decoder prompt from -[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] -which can be passed to -[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] -for encoder-decoder models. -""" - - -class EncoderDecoderInputs(TypedDict): - """ - A processed pair of encoder and decoder singleton prompts. - [`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] - which can be passed to - [`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor] - for encoder-decoder models. - """ - - type: Literal["enc_dec"] - - encoder_prompt: EncoderInputs - """The inputs for the encoder portion.""" - - decoder_prompt: DecoderInputs - """The inputs for the decoder portion.""" - - arrival_time: NotRequired[float] - """The time when the input was received (before rendering).""" - - -ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs -""" -A processed prompt from -[`InputPreprocessor`][vllm.inputs.preprocess.InputPreprocessor] -which can be passed to -[`InputProcessor`][vllm.v1.engine.input_processor.InputProcessor]. -""" - - -SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs -"""The inputs for a single encoder/decoder prompt.""" - - -def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs: - if inputs["type"] == "embeds": - raise ValueError( - "Embedding inputs are not supported for encoder-decoder models" - ) - - if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs: - raise RuntimeError( - "You should register an encoder-decoder multi-modal processor " - "for encoder-decoder models." - ) - - return inputs # type: ignore[return-value] - - -def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs: - if inputs["type"] == "embeds": - raise ValueError( - "Embedding inputs are not supported for encoder-decoder models" - ) - - return inputs - - -def _prepare_decoder_input_ids_for_generation( - decoder_input_ids: list[int], - decoder_start_token_id: int, -) -> list[int]: - """ - Prepare `decoder_input_ids` for generation with encoder-decoder models, - according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`. - - Source: - https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py - """ - if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id: - decoder_input_ids = [decoder_start_token_id] + decoder_input_ids - - return decoder_input_ids - - -def build_enc_dec_inputs( - encoder_inputs: SingletonInputs, - decoder_inputs: SingletonInputs | None, - decoder_start_token_id: int, - skip_decoder_start_token: bool = False, -) -> EncoderDecoderInputs: - enc_inputs = _validate_enc_inputs(encoder_inputs) - - if decoder_inputs is None: - dec_inputs: DecoderInputs = enc_inputs - else: - dec_inputs = _validate_dec_inputs(decoder_inputs) - - enc_inputs_new: EncoderInputs - dec_inputs_new: DecoderInputs - - if enc_inputs["type"] == "multimodal": - from vllm.multimodal.inputs import mm_inputs - - enc_inputs_new = token_inputs( - enc_inputs["encoder_prompt_token_ids"], - prompt=enc_inputs.get("encoder_prompt"), - ) - dec_inputs_new = mm_inputs( - prompt_token_ids=dec_inputs["prompt_token_ids"], - prompt=dec_inputs.get("prompt"), - mm_kwargs=enc_inputs["mm_kwargs"], - mm_hashes=enc_inputs["mm_hashes"], - mm_placeholders=enc_inputs["mm_placeholders"], - ) - elif enc_inputs["type"] == "token": - enc_inputs_new = token_inputs(prompt_token_ids=[]) - dec_inputs_new = dec_inputs - else: - assert_never(enc_inputs) - - if not skip_decoder_start_token: - dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation( - dec_inputs_new["prompt_token_ids"], - decoder_start_token_id, - ) - - if cache_salt := enc_inputs.get("cache_salt"): - dec_inputs_new["cache_salt"] = cache_salt - - return EncoderDecoderInputs( - type="enc_dec", - encoder_prompt=enc_inputs_new, - decoder_prompt=dec_inputs_new, - ) diff --git a/vllm/inputs/engine.py b/vllm/inputs/engine.py new file mode 100644 index 000000000..2b426eba8 --- /dev/null +++ b/vllm/inputs/engine.py @@ -0,0 +1,352 @@ +"""Schema and utilities for inputs to the engine client (`LLMEngine`/`AsyncLLM`).""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Literal, TypeAlias + +from typing_extensions import NotRequired, TypedDict, assert_never + +if TYPE_CHECKING: + import torch + + from vllm.multimodal.inputs import MultiModalKwargsOptionalItems, PlaceholderRange + + +class _InputOptions(TypedDict): + """ + Additional options available to all + [`SingletonInput`][vllm.inputs.engine.SingletonInput] types. + """ + + arrival_time: NotRequired[float] + """The time when the input was received (before rendering).""" + + cache_salt: NotRequired[str] + """Optional cache salt to be used for prefix caching.""" + + +class TokensInput(_InputOptions): + """Represents token-based input to the engine.""" + + type: Literal["token"] + """The type of input.""" + + prompt_token_ids: list[int] + """The token IDs of the prompt.""" + + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + + +def tokens_input( + prompt_token_ids: list[int], + *, + prompt: str | None = None, + cache_salt: str | None = None, +) -> TokensInput: + """ + Construct [`TokensInput`][vllm.inputs.engine.TokensInput] + from optional values. + """ + inputs = TokensInput(type="token", prompt_token_ids=prompt_token_ids) + + if prompt is not None: + inputs["prompt"] = prompt + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + +class EmbedsInput(_InputOptions): + """Represents embeddings-based input to the engine.""" + + type: Literal["embeds"] + """The type of input.""" + + prompt_embeds: "torch.Tensor" + """The embeddings of the prompt.""" + + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + + +def embeds_input( + prompt_embeds: "torch.Tensor", + *, + prompt: str | None = None, + cache_salt: str | None = None, +) -> EmbedsInput: + """ + Construct [`EmbedsInput`][vllm.inputs.engine.EmbedsInput] + from optional values. + """ + inputs = EmbedsInput(type="embeds", prompt_embeds=prompt_embeds) + + if prompt is not None: + inputs["prompt"] = prompt + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + +MultiModalHashes: TypeAlias = Mapping[str, list[str]] +""" +A dictionary containing per-item hashes for each modality. +""" + + +MultiModalPlaceholders: TypeAlias = Mapping[str, Sequence["PlaceholderRange"]] +""" +A dictionary containing per-item placeholder ranges for each modality. +""" + + +class MultiModalInput(_InputOptions): + """Represents multi-modal input to the engine.""" + + type: Literal["multimodal"] + """The type of input.""" + + prompt_token_ids: list[int] + """The processed token IDs which includes placeholder tokens.""" + + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + + mm_kwargs: "MultiModalKwargsOptionalItems" + """Keyword arguments to be directly passed to the model after batching.""" + + mm_hashes: MultiModalHashes + """The hashes of the multi-modal data.""" + + mm_placeholders: MultiModalPlaceholders + """ + For each modality, information about the placeholder tokens in + `prompt_token_ids`. + """ + + +def mm_input( + prompt_token_ids: list[int], + mm_kwargs: "MultiModalKwargsOptionalItems", + mm_hashes: MultiModalHashes, + mm_placeholders: MultiModalPlaceholders, + *, + prompt: str | None = None, + cache_salt: str | None = None, +) -> MultiModalInput: + inputs = MultiModalInput( + type="multimodal", + prompt_token_ids=prompt_token_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + if prompt is not None: + inputs["prompt"] = prompt + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + +class MultiModalEncDecInput(MultiModalInput): + """ + Represents multi-modal input to the engine for encoder-decoder models. + + Note: + Even text-only encoder-decoder models are currently implemented + as multi-modal models for convenience. + (Example: https://github.com/vllm-project/bart-plugin) + """ + + encoder_prompt_token_ids: list[int] + """The processed token IDs of the encoder prompt.""" + + encoder_prompt: NotRequired[str] + """The prompt text corresponding to the encoder token IDs, if available.""" + + +def mm_enc_dec_input( + encoder_inputs: MultiModalInput, + decoder_prompt_token_ids: list[int], + *, + decoder_prompt: str | None = None, +) -> MultiModalEncDecInput: + inputs = MultiModalEncDecInput( + type="multimodal", + prompt_token_ids=decoder_prompt_token_ids, + encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], + mm_kwargs=encoder_inputs["mm_kwargs"], + mm_hashes=encoder_inputs["mm_hashes"], + mm_placeholders=encoder_inputs["mm_placeholders"], + ) + + if decoder_prompt is not None: + inputs["prompt"] = decoder_prompt + if "prompt" in encoder_inputs: + inputs["encoder_prompt"] = encoder_inputs["prompt"] + if "cache_salt" in encoder_inputs: + inputs["cache_salt"] = encoder_inputs["cache_salt"] + + return inputs + + +DecoderOnlyEngineInput: TypeAlias = TokensInput | EmbedsInput | MultiModalInput +""" +A rendered [`DecoderOnlyPrompt`][vllm.inputs.llm.DecoderOnlyPrompt] +which can be passed to `LLMEngine.add_request` or `AsyncLLM.add_request`. +""" + + +EncoderInput: TypeAlias = TokensInput | MultiModalEncDecInput +""" +A rendered [`EncoderPrompt`][vllm.inputs.llm.EncoderPrompt] +which can be passed to `LLMEngine.add_request` or `AsyncLLM.add_request`. +""" + + +DecoderEngineInput: TypeAlias = TokensInput | MultiModalInput +""" +A rendered [`DecoderPrompt`][vllm.inputs.llm.DecoderPrompt] +which can be passed to `LLMEngine.add_request` or `AsyncLLM.add_request`. +""" + + +class EncoderDecoderInput(TypedDict): + """ + A rendered [`EncoderDecoderPrompt`][vllm.inputs.llm.EncoderDecoderPrompt] + which can be passed to `LLMEngine.add_request` or `AsyncLLM.add_request`. + """ + + type: Literal["enc_dec"] + + encoder_prompt: EncoderInput + """The inputs for the encoder portion.""" + + decoder_prompt: DecoderEngineInput + """The inputs for the decoder portion.""" + + arrival_time: NotRequired[float] + """The time when the input was received (before rendering).""" + + +SingletonInput: TypeAlias = DecoderOnlyEngineInput | MultiModalEncDecInput +""" +A rendered [`SingletonPrompt`][vllm.inputs.llm.SingletonPrompt] +which can be passed to `LLMEngine.add_request` or `AsyncLLM.add_request`. +""" + + +EngineInput: TypeAlias = DecoderOnlyEngineInput | EncoderDecoderInput +""" +A rendered [`PromptType`][vllm.inputs.llm.PromptType] +which can be passed to `LLMEngine.add_request` or `AsyncLLM.add_request`. +""" + + +def _validate_enc_input(enc_input: SingletonInput) -> EncoderInput: + if enc_input["type"] == "embeds": + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) + + if ( + enc_input["type"] == "multimodal" + and "encoder_prompt_token_ids" not in enc_input + ): + raise RuntimeError( + "You should register an encoder-decoder multi-modal processor " + "for encoder-decoder models." + ) + + return enc_input # type: ignore[return-value] + + +def _validate_dec_input(dec_input: SingletonInput) -> DecoderEngineInput: + if dec_input["type"] == "embeds": + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) + + return dec_input + + +def _prepare_decoder_input_ids_for_generation( + decoder_input_ids: list[int], + decoder_start_token_id: int, +) -> list[int]: + """ + Prepare `decoder_input_ids` for generation with encoder-decoder models, + according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`. + + Source: + https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py + """ + if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id: + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + +def build_enc_dec_input( + encoder_input: SingletonInput, + decoder_input: SingletonInput | None, + decoder_start_token_id: int, + skip_decoder_start_token: bool = False, +) -> EncoderDecoderInput: + enc_input = _validate_enc_input(encoder_input) + + if decoder_input is None: + dec_input: DecoderEngineInput = enc_input + else: + dec_input = _validate_dec_input(decoder_input) + + enc_input_new: EncoderInput + dec_input_new: DecoderEngineInput + + if enc_input["type"] == "multimodal": + enc_input_new = tokens_input( + enc_input["encoder_prompt_token_ids"], + prompt=enc_input.get("encoder_prompt"), + ) + dec_input_new = mm_input( + prompt_token_ids=dec_input["prompt_token_ids"], + prompt=dec_input.get("prompt"), + mm_kwargs=enc_input["mm_kwargs"], + mm_hashes=enc_input["mm_hashes"], + mm_placeholders=enc_input["mm_placeholders"], + ) + elif enc_input["type"] == "token": + enc_input_new = tokens_input(prompt_token_ids=[]) + dec_input_new = dec_input + else: + assert_never(enc_input) + + if not skip_decoder_start_token: + dec_input_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation( + dec_input_new["prompt_token_ids"], + decoder_start_token_id, + ) + + if cache_salt := enc_input.get("cache_salt"): + dec_input_new["cache_salt"] = cache_salt + + return EncoderDecoderInput( + type="enc_dec", + encoder_prompt=enc_input_new, + decoder_prompt=dec_input_new, + ) + + +def split_enc_dec_input( + inputs: EngineInput, +) -> tuple[SingletonInput | None, SingletonInput]: + if inputs["type"] == "enc_dec": + return inputs["encoder_prompt"], inputs["decoder_prompt"] + + return None, inputs diff --git a/vllm/inputs/llm.py b/vllm/inputs/llm.py new file mode 100644 index 000000000..ff22af819 --- /dev/null +++ b/vllm/inputs/llm.py @@ -0,0 +1,222 @@ +"""Schema and utilities for input prompts to the LLM API.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, final + +from typing_extensions import NotRequired, TypedDict + +if TYPE_CHECKING: + import torch + + from vllm.multimodal.inputs import AudioItem, ImageItem, VideoItem, VisionChunk + + +_T = TypeVar("_T") + +ModalityData: TypeAlias = _T | list[_T | None] | None +""" +Either a single data item, or a list of data items. Can only be None if UUID +is provided. + +The number of data items allowed per modality is restricted by +`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalDataBuiltins(TypedDict, total=False): + """Type annotations for modality types predefined by vLLM.""" + + image: ModalityData["ImageItem"] + """The input image(s).""" + + video: ModalityData["VideoItem"] + """The input video(s).""" + + audio: ModalityData["AudioItem"] + """The input audio(s).""" + + vision_chunk: ModalityData["VisionChunk"] + """The input visual atom(s) - unified modality for images and video chunks.""" + + +MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]] +""" +A dictionary containing an entry for each modality type to input. + +The built-in modalities are defined by +[`MultiModalDataBuiltins`][vllm.inputs.llm.MultiModalDataBuiltins]. +""" + +MultiModalUUIDDict: TypeAlias = Mapping[str, Sequence[str | None] | str] +""" +A dictionary containing user-provided UUIDs for items in each modality. +If a UUID for an item is not provided, its entry will be `None` and +MultiModalHasher will compute a hash for the item. + +The UUID will be used to identify the item for all caching purposes +(input processing caching, embedding caching, prefix caching, etc). +""" + + +class _PromptOptions(TypedDict): + """ + Additional options available to all + [`SingletonPrompt`][vllm.inputs.llm.SingletonPrompt] types. + """ + + multi_modal_data: NotRequired[MultiModalDataDict | None] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[dict[str, Any] | None] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + multi_modal_uuids: NotRequired[MultiModalUUIDDict] + """ + Optional user-specified UUIDs for multimodal items, mapped by modality. + Lists must match the number of items per modality and may contain `None`. + For `None` entries, the hasher will compute IDs automatically; non-None + entries override the default hashes for caching, and MUST be unique per + multimodal item. + """ + + cache_salt: NotRequired[str] + """ + Optional cache salt to be used for prefix caching. + """ + + +class TextPrompt(_PromptOptions): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + +class TokensPrompt(_PromptOptions): + """Schema for a tokenized prompt.""" + + prompt_token_ids: list[int] + """A list of token IDs to pass to the model.""" + + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + + token_type_ids: NotRequired[list[int]] + """A list of token type IDs to pass to the cross encoder model.""" + + +class EmbedsPrompt(_PromptOptions): + """Schema for a prompt provided via token embeddings.""" + + prompt_embeds: "torch.Tensor" + """The embeddings of the prompt.""" + + prompt: NotRequired[str] + """The prompt text corresponding to the token embeddings, if available.""" + + +DecoderOnlyPrompt: TypeAlias = ( + str | TextPrompt | list[int] | TokensPrompt | EmbedsPrompt +) +""" +Schema of a prompt for a decoder-only model: + +- A text prompt (string or [`TextPrompt`][vllm.inputs.llm.TextPrompt]) +- A tokenized prompt (list of token IDs, or + [`TokensPrompt`][vllm.inputs.llm.TokensPrompt]) +- An embeddings prompt ([`EmbedsPrompt`][vllm.inputs.llm.EmbedsPrompt]) + +For encoder-decoder models, passing a singleton prompt is shorthand for passing +`ExplicitEncoderDecoderPrompt(encoder_prompt=prompt, decoder_prompt=None)`. +""" + + +EncoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt +""" +Schema of a prompt for the encoder part of a encoder-decoder model: + +- A text prompt (string or [`TextPrompt`][vllm.inputs.llm.TextPrompt]) +- A tokenized prompt (list of token IDs, or + [`TokensPrompt`][vllm.inputs.llm.TokensPrompt]) +""" + + +DecoderPrompt: TypeAlias = str | TextPrompt | list[int] | TokensPrompt +""" +Schema of a prompt for the decoder part of an encoder-decoder model: + +- A text prompt (string or [`TextPrompt`][vllm.inputs.llm.TextPrompt]) +- A tokenized prompt (list of token IDs, or + [`TokensPrompt`][vllm.inputs.llm.TokensPrompt]) + +Note: + Multi-modal inputs are not supported for decoder prompts. +""" + + +class ExplicitEncoderDecoderPrompt(TypedDict): + """ + Schema for a pair of encoder and decoder singleton prompts. + + Note: + This schema is not valid for decoder-only models. + """ + + encoder_prompt: EncoderPrompt + """The prompt for the encoder part of the model.""" + + decoder_prompt: DecoderPrompt | None + """ + The prompt for the decoder part of the model. + + Passing `None` will cause the prompt to be inferred automatically. + """ + + +EncoderDecoderPrompt: TypeAlias = EncoderPrompt | ExplicitEncoderDecoderPrompt +""" +Schema for a prompt for an encoder-decoder model. + +You can pass a singleton encoder prompt, in which case the decoder prompt is +considered to be `None` (i.e., infer automatically). +""" + + +SingletonPrompt: TypeAlias = DecoderOnlyPrompt | EncoderPrompt | DecoderPrompt +""" +Schema for a single prompt. This is as opposed to a data structure +which encapsulates multiple prompts, such as +[`ExplicitEncoderDecoderPrompt`][vllm.inputs.llm.ExplicitEncoderDecoderPrompt]. +""" + + +PromptType: TypeAlias = DecoderOnlyPrompt | EncoderDecoderPrompt +""" +Schema for any prompt, regardless of model type. + +This is the input format accepted by most [`LLM`][vllm.entrypoints.llm.LLM] APIs. +""" + + +class DataPrompt(_PromptOptions): + """ + Represents generic inputs that are converted to + [`PromptType`][vllm.inputs.llm.PromptType] by IO processor plugins. + """ + + data: Any + """The input data.""" + + data_format: str + """The input data format.""" diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py deleted file mode 100644 index ab29935ac..000000000 --- a/vllm/inputs/parse.py +++ /dev/null @@ -1,13 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from .data import ProcessorInputs, SingletonInputs - - -def split_enc_dec_inputs( - inputs: ProcessorInputs, -) -> tuple[SingletonInputs | None, SingletonInputs]: - if inputs["type"] == "enc_dec": - return inputs["encoder_prompt"], inputs["decoder_prompt"] - - return None, inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index a722bb3bf..7722014f9 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -7,14 +7,9 @@ from typing import Any, overload from typing_extensions import assert_never from vllm.config import VllmConfig -from vllm.inputs.data import build_enc_dec_inputs +from vllm.inputs import build_enc_dec_input from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalInputs, - MultiModalUUIDDict, -) from vllm.renderers import BaseRenderer, renderer_from_config from vllm.renderers.inputs import ( DecoderDictPrompt, @@ -26,20 +21,25 @@ from vllm.renderers.inputs import ( from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt from vllm.tokenizers import TokenizerLike -from .data import ( - DecoderInputs, - DecoderOnlyInputs, - EmbedsInputs, +from .engine import ( + DecoderEngineInput, + DecoderOnlyEngineInput, + EmbedsInput, + EncoderDecoderInput, + EncoderInput, + EngineInput, + MultiModalInput, + SingletonInput, + TokensInput, + tokens_input, +) +from .llm import ( EmbedsPrompt, - EncoderDecoderInputs, - EncoderInputs, - ProcessorInputs, + MultiModalDataDict, + MultiModalUUIDDict, PromptType, - SingletonInputs, TextPrompt, - TokenInputs, TokensPrompt, - token_inputs, ) logger = init_logger(__name__) @@ -95,7 +95,7 @@ class InputPreprocessor: tokenization_kwargs: dict[str, Any] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, - ) -> MultiModalInputs: + ) -> MultiModalInput: """ Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata. @@ -111,7 +111,7 @@ class InputPreprocessor: def _process_embeds( self, parsed_content: EmbedsPrompt, - ) -> EmbedsInputs: + ) -> EmbedsInput: return self.renderer._process_embeds(parsed_content) def _truncate_inputs( @@ -134,12 +134,12 @@ class InputPreprocessor: self, parsed_content: TokensPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> TokenInputs | MultiModalInputs: + ) -> TokensInput | MultiModalInput: prompt_token_ids = self._truncate_inputs( parsed_content["prompt_token_ids"], tokenization_kwargs ) - inputs: TokenInputs | MultiModalInputs + inputs: TokensInput | MultiModalInput if multi_modal_data := parsed_content.get("multi_modal_data"): inputs = self._process_multimodal( prompt_token_ids, @@ -149,7 +149,7 @@ class InputPreprocessor: mm_uuids=parsed_content.get("multi_modal_uuids"), ) else: - inputs = token_inputs(prompt_token_ids) + inputs = tokens_input(prompt_token_ids) if prompt_text := parsed_content.get("prompt"): inputs["prompt"] = prompt_text @@ -162,10 +162,10 @@ class InputPreprocessor: self, parsed_content: TextPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> TokenInputs | MultiModalInputs: + ) -> TokensInput | MultiModalInput: prompt_text = parsed_content["prompt"] - inputs: TokenInputs | MultiModalInputs + inputs: TokensInput | MultiModalInput if multi_modal_data := parsed_content.get("multi_modal_data"): inputs = self._process_multimodal( prompt_text, @@ -178,7 +178,7 @@ class InputPreprocessor: prompt_text, tokenization_kwargs=tokenization_kwargs, ) - inputs = token_inputs(prompt_token_ids) + inputs = tokens_input(prompt_token_ids) inputs["prompt"] = prompt_text @@ -192,38 +192,27 @@ class InputPreprocessor: self, prompt: EncoderDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> EncoderInputs: ... + ) -> EncoderInput: ... @overload def _prompt_to_llm_inputs( # type: ignore[misc] self, prompt: DecoderDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> DecoderInputs: ... + ) -> DecoderEngineInput: ... @overload def _prompt_to_llm_inputs( # type: ignore[misc] self, prompt: DecoderOnlyDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> DecoderOnlyInputs: ... + ) -> DecoderOnlyEngineInput: ... def _prompt_to_llm_inputs( self, prompt: SingletonDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> SingletonInputs: - """ - Extract the singleton inputs from a prompt. - - Arguments: - - * prompt: single encoder or decoder input prompt - - Returns: - - * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance - """ + ) -> SingletonInput: if "prompt_embeds" in prompt: return self._process_embeds(prompt) # type: ignore[arg-type] @@ -242,22 +231,7 @@ class InputPreprocessor: self, prompt: EncoderDecoderDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> EncoderDecoderInputs: - """ - For encoder/decoder models only: - Process an input prompt into an - [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] - instance. - - Arguments: - - * prompt: an input prompt - - Returns: - - * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] - instance - """ + ) -> EncoderDecoderInput: encoder_prompt = prompt["encoder_prompt"] decoder_prompt = prompt["decoder_prompt"] @@ -270,12 +244,12 @@ class InputPreprocessor: self.renderer.mm_processor.skip_decoder_start_token ) - return build_enc_dec_inputs( - encoder_inputs=self._prompt_to_llm_inputs( + return build_enc_dec_input( + encoder_input=self._prompt_to_llm_inputs( encoder_prompt, tokenization_kwargs=tokenization_kwargs, ), - decoder_inputs=( + decoder_input=( None if decoder_prompt is None else self._prompt_to_llm_inputs( @@ -291,20 +265,7 @@ class InputPreprocessor: self, prompt: DecoderOnlyDictPrompt, tokenization_kwargs: dict[str, Any] | None = None, - ) -> DecoderOnlyInputs: - """ - For decoder-only models: - Process an input prompt into a - [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance. - - Arguments: - - * prompt: input prompt - - Returns: - - * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance - """ + ) -> DecoderOnlyEngineInput: return self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, @@ -314,7 +275,7 @@ class InputPreprocessor: self, prompt: PromptType, tokenization_kwargs: dict[str, Any] | None = None, - ) -> ProcessorInputs: + ) -> EngineInput: """Preprocess the input prompt.""" if self.model_config.is_encoder_decoder: # Encoder-decoder model requires special mapping of diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 908581786..7b891f8ee 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -12,6 +12,7 @@ from transformers.models.aria.processing_aria import AriaProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_rank +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -24,7 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py index 82906a6fa..9fbbd7be7 100644 --- a/vllm/model_executor/models/audioflamingo3.py +++ b/vllm/model_executor/models/audioflamingo3.py @@ -31,17 +31,16 @@ from transformers.models.qwen2_audio import Qwen2AudioEncoder from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( DictEmbeddingItems, - ModalityData, ModalityDataItems, MultiModalDataItems, MultiModalDataParser, diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index c1806beec..65f306393 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -17,9 +17,9 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/bagel.py b/vllm/model_executor/models/bagel.py index 425342e8b..97cfe75e6 100644 --- a/vllm/model_executor/models/bagel.py +++ b/vllm/model_executor/models/bagel.py @@ -15,6 +15,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( @@ -24,7 +25,6 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/bee.py b/vllm/model_executor/models/bee.py index ecb645edf..af44c34f4 100644 --- a/vllm/model_executor/models/bee.py +++ b/vllm/model_executor/models/bee.py @@ -9,8 +9,8 @@ from transformers.activations import GELUActivation from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict from .llava_next import ( LlavaDummyInputsBuilder, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8f79c1aae..8b5fd452e 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -15,11 +15,11 @@ from transformers import ( from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index e09a4eac7..a150428ba 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -19,6 +19,7 @@ from transformers import ( from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention @@ -43,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 597f6a8c1..05a494683 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -17,6 +17,7 @@ from transformers import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict, MultiModalInput from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import Attention, MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer @@ -32,9 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsQuant from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( @@ -207,7 +206,7 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: if inputs.mm_data_items: if isinstance(inputs.prompt, str): if len(inputs.prompt) > 0: diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index 69b2abb5f..c3118ee77 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -19,6 +19,7 @@ from transformers.models.cohere2_vision.processing_cohere2_vision import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import MulAndSilu from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -28,7 +29,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/cohere_asr.py b/vllm/model_executor/models/cohere_asr.py index 716215a34..2f8513823 100644 --- a/vllm/model_executor/models/cohere_asr.py +++ b/vllm/model_executor/models/cohere_asr.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, cast +from typing import Literal import numpy as np import torch @@ -14,7 +14,7 @@ from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs.data import PromptType +from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import ( @@ -32,7 +32,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) @@ -2047,14 +2046,11 @@ class CohereASRForConditionalGeneration( f"<|noitn|><|notimestamp|><|nodiarize|>" ) prompt_text = request_prompt if request_prompt else default_prompt - prompt = { - "prompt": prompt_text, - "multi_modal_data": { - "audio": (audio, stt_config.sample_rate), - }, - } - return cast(PromptType, prompt) + return TextPrompt( + prompt=prompt_text, + multi_modal_data={"audio": (audio, stt_config.sample_rate)}, + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: diff --git a/vllm/model_executor/models/colmodernvbert.py b/vllm/model_executor/models/colmodernvbert.py index 39dca6edd..1e8477e12 100644 --- a/vllm/model_executor/models/colmodernvbert.py +++ b/vllm/model_executor/models/colmodernvbert.py @@ -16,11 +16,11 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 756d7acde..2575d3dcd 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -12,6 +12,7 @@ from transformers import BatchFeature, CLIPVisionConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsLoRA, @@ -27,7 +28,6 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/deepseek_ocr2.py b/vllm/model_executor/models/deepseek_ocr2.py index d76e2aa40..70f50ea7d 100644 --- a/vllm/model_executor/models/deepseek_ocr2.py +++ b/vllm/model_executor/models/deepseek_ocr2.py @@ -12,6 +12,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsLoRA, @@ -27,7 +28,6 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 469d7fb71..b5ad00914 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -17,11 +17,11 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.transformers.utils import replace_linear_class from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 25b4087d3..b5c1616ac 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -15,6 +15,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import ( MMEncoderAttention, @@ -54,7 +55,6 @@ from vllm.model_executor.models.utils import ( ) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig, DotsVisionConfig from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 87d33d1b7..08a4c4862 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -40,6 +40,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.attention import ( @@ -58,7 +59,6 @@ from vllm.model_executor.layers.rotary_embedding.common import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/fireredasr2.py b/vllm/model_executor/models/fireredasr2.py index 26ede3e80..217bb5b2d 100644 --- a/vllm/model_executor/models/fireredasr2.py +++ b/vllm/model_executor/models/fireredasr2.py @@ -15,7 +15,7 @@ from transformers import ( from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType +from vllm.inputs import MultiModalDataDict, PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.linear import ( @@ -27,7 +27,6 @@ from vllm.model_executor.models.whisper_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/funasr.py b/vllm/model_executor/models/funasr.py index 78acca3c2..98313db79 100644 --- a/vllm/model_executor/models/funasr.py +++ b/vllm/model_executor/models/funasr.py @@ -17,7 +17,7 @@ from transformers import ( from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs.data import PromptType +from vllm.inputs import MultiModalDataDict, PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.attention.mm_encoder_attention import ( @@ -37,7 +37,6 @@ from vllm.model_executor.models.whisper_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/funaudiochat.py b/vllm/model_executor/models/funaudiochat.py index 2265d0424..9557ca680 100644 --- a/vllm/model_executor/models/funaudiochat.py +++ b/vllm/model_executor/models/funaudiochat.py @@ -27,12 +27,12 @@ from transformers.modeling_outputs import BaseModelOutput from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index cc15cee59..3c0b911c3 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -28,11 +28,11 @@ from transformers import BatchFeature, FuyuConfig, FuyuImageProcessor, FuyuProce from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index cbc5ebc7d..0f059b6d1 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -12,12 +12,12 @@ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 4b6f53788..342d6c476 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -19,7 +19,7 @@ from transformers.models.siglip import SiglipImageProcessorFast from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType, TextPrompt +from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import RowParallelLinear @@ -32,7 +32,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 786b1175c..3275329f0 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -50,6 +50,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import utils as dist_utils +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.attention import ( MMEncoderAttention, @@ -74,7 +75,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 83af8ea86..9d08df4df 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -18,6 +18,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer @@ -32,7 +33,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index fd47a014a..0d54588cc 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -14,7 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.linear import ( @@ -27,13 +27,11 @@ from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( DictEmbeddingItems, - ModalityData, ModalityDataItems, MultiModalDataItems, MultiModalDataParser, diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index b97fc67f1..dca54425c 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -36,13 +36,12 @@ from transformers import BatchFeature, PretrainedConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index ec0f10ea6..5125406ab 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -37,6 +37,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention @@ -52,8 +53,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, - ModalityData, - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index f0eeed7f1..53923d884 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -18,10 +18,10 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/hyperclovax_vision_v2.py b/vllm/model_executor/models/hyperclovax_vision_v2.py index 40b459a64..6cfeac67a 100644 --- a/vllm/model_executor/models/hyperclovax_vision_v2.py +++ b/vllm/model_executor/models/hyperclovax_vision_v2.py @@ -21,9 +21,9 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.forward_context import set_forward_context +from vllm.inputs import MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index a59c45654..d3ffdd4cf 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -30,6 +30,7 @@ from transformers import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -37,7 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index eb7c9693a..fa59931d0 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -30,8 +30,7 @@ from transformers.models.whisper.tokenization_whisper import LANGUAGES from typing_extensions import Self, TypeIs from vllm.config import ModelConfig, SpeechToTextConfig -from vllm.inputs import TokensPrompt -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc from vllm.model_executor.layers.quantization import QuantizationConfig diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index e1e67b047..8b3828a91 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -23,12 +23,12 @@ from transformers.models.internvl.video_processing_internvl import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interns1_vit import InternS1VisionModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 5cb7f462d..c8611a499 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -18,6 +18,7 @@ from transformers import BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.models.intern_vit import ( @@ -28,7 +29,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( BatchedTensorInputs, - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index e29646182..412948c48 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -16,6 +16,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -44,7 +45,6 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/kanana_v.py b/vllm/model_executor/models/kanana_v.py index 991fa28d9..125d7e71c 100644 --- a/vllm/model_executor/models/kanana_v.py +++ b/vllm/model_executor/models/kanana_v.py @@ -19,10 +19,10 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionCon from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 5e062fa74..ebeed00e4 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -19,6 +19,7 @@ from transformers.utils import torch_int from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.attention import ( MMEncoderAttention, @@ -41,8 +42,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, - ModalityData, - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index d304b245e..5c9bdcbf5 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -14,13 +14,13 @@ from transformers.activations import GELUActivation from transformers.feature_extraction_utils import BatchFeature from vllm.config import VllmConfig +from vllm.inputs import ModalityData from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, - ModalityData, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/kimi_audio.py b/vllm/model_executor/models/kimi_audio.py index 05a20950c..fc5065065 100644 --- a/vllm/model_executor/models/kimi_audio.py +++ b/vllm/model_executor/models/kimi_audio.py @@ -14,7 +14,7 @@ from transformers import WhisperConfig as HFWhisperConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import PromptType, TokensPrompt from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index 10d21aab0..1c0320c0d 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -16,6 +16,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors import ( @@ -35,7 +36,6 @@ from vllm.model_executor.models.kimi_k25_vit import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/kimi_vl.py b/vllm/model_executor/models/kimi_vl.py index 4ff8f11ab..e3bc08c65 100644 --- a/vllm/model_executor/models/kimi_vl.py +++ b/vllm/model_executor/models/kimi_vl.py @@ -54,12 +54,12 @@ from transformers.activations import GELUActivation from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP from vllm.model_executor.models.moonvit import MoonVitPretrainedModel from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/lfm2_vl.py b/vllm/model_executor/models/lfm2_vl.py index 63f546c5a..9be8c5c1e 100644 --- a/vllm/model_executor/models/lfm2_vl.py +++ b/vllm/model_executor/models/lfm2_vl.py @@ -21,6 +21,7 @@ from transformers.models.lfm2_vl.image_processing_lfm2_vl_fast import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.forward_context import set_forward_context +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateCopyFunc, MambaStateCopyFuncCalculator, @@ -30,7 +31,6 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 450af2587..2a5067867 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -20,17 +20,15 @@ from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, - mm_inputs, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, @@ -777,7 +775,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: hf_config = self.info.get_hf_config() image_token_id = hf_config.image_token_index @@ -833,7 +831,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): for modality, placeholders in mm_placeholders.items() } - return mm_inputs( + return mm_input( prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 54558e123..e256768ef 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -11,11 +11,11 @@ from transformers import BatchFeature, LlavaNextVideoConfig, LlavaNextVideoProce from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.clip import CLIPVisionModel from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index f747df09c..638d9ba9d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -15,10 +15,10 @@ from transformers.models.llava_onevision.modeling_llava_onevision import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 08b955c81..62aaed46f 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -38,6 +38,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( @@ -48,7 +49,6 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index f176e50f8..9251b1472 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -41,9 +41,9 @@ from transformers.models.whisper.modeling_whisper import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, NestedTensors, ) @@ -51,7 +51,6 @@ from vllm.multimodal.parse import ( AudioItem, AudioProcessorItems, DictEmbeddingItems, - ModalityData, ModalityDataItems, MultiModalDataItems, ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index bb7f8490d..79162eef3 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -41,6 +41,7 @@ from typing_extensions import TypeVar from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import ( BaseResampler, @@ -54,7 +55,6 @@ from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, @@ -64,7 +64,6 @@ from vllm.multimodal.parse import ( ImageItem, ImageProcessorItems, ImageSize, - ModalityData, ModalityDataItems, MultiModalDataItems, MultiModalDataParser, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 2c12d5a75..0ece3dda2 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -11,6 +11,7 @@ from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -18,7 +19,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index c8cbb5890..227ef2fa6 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -38,6 +38,7 @@ from vllm.compilation.decorators import ( from vllm.config import VllmConfig, set_current_vllm_config from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ( @@ -53,7 +54,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index faac00a4e..1d756a2ad 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -29,6 +29,7 @@ from vllm.distributed import ( split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul from vllm.model_executor.layers.attention import Attention, MMEncoderAttention from vllm.model_executor.layers.layernorm import RMSNorm @@ -49,7 +50,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/molmo2.py b/vllm/model_executor/models/molmo2.py index 7d7fa38b5..aa58fa6d1 100644 --- a/vllm/model_executor/models/molmo2.py +++ b/vllm/model_executor/models/molmo2.py @@ -33,6 +33,7 @@ from vllm.distributed import ( split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import MulAndSilu, SiluAndMul, get_act_fn from vllm.model_executor.layers.attention import Attention, MMEncoderAttention @@ -54,7 +55,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 1741e18fd..182e0f159 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -21,6 +21,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions +from vllm.inputs import MultiModalDataDict, MultiModalInput from vllm.logger import init_logger from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.layernorm import RMSNorm @@ -48,9 +49,7 @@ from vllm.multimodal.evs import ( from vllm.multimodal.inputs import ( AudioItem, BatchedTensorInputs, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, VideoItem, ) @@ -576,7 +575,7 @@ class NanoNemotronVLMultiModalProcessor( self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: use_audio_in_video = bool( inputs.hf_processor_mm_kwargs.get("use_audio_in_video", False) ) @@ -632,7 +631,7 @@ class NanoNemotronVLMultiModalProcessor( for modality, placeholders in mm_placeholders.items() } - return MultiModalInputs( + return MultiModalInput( type="multimodal", prompt_token_ids=prompt_ids, mm_kwargs=mm_info.kwargs, diff --git a/vllm/model_executor/models/nemotron_parse.py b/vllm/model_executor/models/nemotron_parse.py index f4837185f..ae417f095 100644 --- a/vllm/model_executor/models/nemotron_parse.py +++ b/vllm/model_executor/models/nemotron_parse.py @@ -23,6 +23,7 @@ from transformers import ( from vllm.config import CacheConfig, VllmConfig from vllm.config.lora import LoRAConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -41,7 +42,6 @@ from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.whisper import WhisperAttention, WhisperCrossAttention from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index ea8b083ff..9fd4cf079 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -14,12 +14,10 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - BatchedTensorInputs, - MultiModalDataDict, -) +from vllm.multimodal.inputs import BatchedTensorInputs from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index 2807c634b..3d9cf1c34 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -30,6 +30,7 @@ from transformers import BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.aimv2 import AIMv2Model @@ -42,7 +43,6 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 57559ba99..4acad73c5 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -12,6 +12,7 @@ from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.ovis import VisualEmbedding @@ -24,7 +25,6 @@ from vllm.model_executor.models.utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 33b54185c..2a5a0e6a1 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -35,6 +35,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.attention import ( MMEncoderAttention, ) @@ -53,7 +54,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 90db5d695..d7b8e77c6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,12 +9,11 @@ from transformers import BatchFeature, PaliGemmaConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict, MultiModalInput from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( @@ -231,7 +230,7 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: mm_inputs = super().apply(inputs, timing_ctx) prompt_token_ids = mm_inputs["prompt_token_ids"] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index cb1e0ab83..95689ef32 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -30,12 +30,12 @@ from transformers import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 5ccac92e3..2db95b857 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -18,6 +18,7 @@ from transformers import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_pp_group +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -27,7 +28,6 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index eaf5843a3..75545e44e 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -25,6 +25,7 @@ from transformers.models.pixtral.modeling_pixtral import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm @@ -37,7 +38,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, NestedTensors, ) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index ff7dbb703..8e106baec 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -46,6 +46,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.forward_context import set_forward_context +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2_5_vl import ( @@ -66,8 +67,6 @@ from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, - ModalityData, - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index d125570a1..e7e8d7471 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -38,11 +38,10 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( AudioItem, - ModalityData, - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a8840022a..176f45781 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -47,6 +47,7 @@ from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils +from vllm.inputs import ModalityData, MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.attention import MMEncoderAttention @@ -65,8 +66,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, - ModalityData, - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/qwen3_asr.py b/vllm/model_executor/models/qwen3_asr.py index 5c7b4a567..3015ae031 100644 --- a/vllm/model_executor/models/qwen3_asr.py +++ b/vllm/model_executor/models/qwen3_asr.py @@ -33,7 +33,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, @@ -59,8 +59,6 @@ from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( AudioItem, - ModalityData, - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, diff --git a/vllm/model_executor/models/qwen3_asr_realtime.py b/vllm/model_executor/models/qwen3_asr_realtime.py index 4fb6ef5d9..6405c101b 100644 --- a/vllm/model_executor/models/qwen3_asr_realtime.py +++ b/vllm/model_executor/models/qwen3_asr_realtime.py @@ -23,7 +23,7 @@ import numpy as np import torch from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.models.interfaces import ( SupportsRealtime, diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index fc097ffdd..d9bc02c65 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -31,6 +31,8 @@ import torch import torch.nn as nn import torch.nn.functional as F from packaging.version import Version +from transformers import PretrainedConfig +from transformers import __version__ as TRANSFORMERS_VERSION from transformers.feature_extraction_utils import BatchFeature from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import ( Qwen3OmniMoeAudioEncoderConfig, @@ -42,15 +44,10 @@ from transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe import ( ) from transformers.models.whisper import WhisperFeatureExtractor -# isort: off -from transformers import PretrainedConfig -from transformers import __version__ as TRANSFORMERS_VERSION -# isort: on - from vllm.compilation.decorators import support_torch_compile from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.attention.mm_encoder_attention import ( diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 55841e30e..418b75e38 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -52,6 +52,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import get_pp_group, parallel_state +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.attention.mm_encoder_attention import ( @@ -76,7 +77,6 @@ from vllm.multimodal.evs import ( recompute_mrope_positions, ) from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalFieldElem, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 335b62e2b..e2232956e 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -18,6 +18,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( @@ -30,7 +31,6 @@ from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py index 72f68659c..bf9661bca 100644 --- a/vllm/model_executor/models/rvl.py +++ b/vllm/model_executor/models/rvl.py @@ -9,8 +9,8 @@ from transformers.activations import GELUActivation from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict from .llava_next import ( LlavaDummyInputsBuilder, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 8b7dfd51c..ce3a260d0 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -18,6 +18,7 @@ from transformers import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict, MultiModalInput from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import ( EncoderOnlyAttention, @@ -38,9 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( @@ -193,7 +192,7 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: if inputs.mm_data_items: if isinstance(inputs.prompt, str): if len(inputs.prompt) > 0: diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index a1666c647..a3415a20a 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -16,6 +16,7 @@ from transformers import PretrainedConfig from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig @@ -24,7 +25,6 @@ from vllm.model_executor.models.intern_vit import ( InternVisionPatchModel, ) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processors.internvl import ( diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 9a0d6d215..dc4b42961 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -13,6 +13,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer @@ -24,7 +25,6 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 1b63c55f9..e863b0bb5 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -34,6 +34,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import ModalityData, MultiModalDataDict, MultiModalInput, mm_input from vllm.logger import init_logger from vllm.model_executor.layers.pooler import IdentityPooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -41,13 +42,9 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( ImageItem, - ModalityData, - MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargsItems, PlaceholderRange, - mm_inputs, ) from vllm.multimodal.parse import ( DictEmbeddingItems, @@ -196,7 +193,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: mm_items = inputs.mm_data_items hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs @@ -224,7 +221,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} - return mm_inputs( + return mm_input( prompt_token_ids=[1], mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 9ad271427..ddcd91f61 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -22,16 +22,14 @@ from typing import TYPE_CHECKING import torch from vllm.config.utils import getattr_iter +from vllm.inputs import MultiModalDataDict, MultiModalInput, mm_input from vllm.logger import init_logger from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal from vllm.multimodal import MultiModalKwargsItems from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, - MultiModalInputs, PlaceholderRange, - mm_inputs, ) from vllm.multimodal.parse import ( ImageProcessorItems, @@ -179,7 +177,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: """ Process multi-modal inputs to be used in vLLM. @@ -261,7 +259,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): with timing_ctx.record("get_mm_hashes"): mm_hashes = inputs.get_mm_hashes(self.info.model_id) - return mm_inputs( + return mm_input( prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index a66bda3c1..83241b329 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -23,13 +23,13 @@ from transformers.models.whisper.modeling_whisper import ( from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs import MultiModalDataDict from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index dba52d106..aa9116906 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -19,7 +19,7 @@ from transformers import BatchFeature, WhisperConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import MultiModalDataDict, PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -32,7 +32,6 @@ from vllm.model_executor.models.whisper import ( from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, NestedTensors, diff --git a/vllm/model_executor/models/voxtral_realtime.py b/vllm/model_executor/models/voxtral_realtime.py index bb2c701e9..b70714a0d 100644 --- a/vllm/model_executor/models/voxtral_realtime.py +++ b/vllm/model_executor/models/voxtral_realtime.py @@ -20,7 +20,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.engine.protocol import StreamingInput from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S -from vllm.inputs.data import PromptType, TokensPrompt +from vllm.inputs import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.voxtral import ( @@ -31,9 +31,7 @@ from vllm.model_executor.models.voxtral import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import _I, BaseMultiModalProcessorCache -from vllm.multimodal.inputs import ( - MultiModalKwargsOptionalItems, -) +from vllm.multimodal.inputs import MultiModalKwargsOptionalItems from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing.processor import ( diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 631a829cf..f0f6f619b 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -21,7 +21,12 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt +from vllm.inputs import ( + ExplicitEncoderDecoderPrompt, + MultiModalDataDict, + PromptType, + TextPrompt, +) from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import ( @@ -44,7 +49,6 @@ from vllm.model_executor.models.whisper_utils import ( ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, ) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index be28c728c..34438a4fb 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,16 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from .hasher import MultiModalHasher -from .inputs import ( - BatchedTensorInputs, - ModalityData, - MultiModalDataBuiltins, - MultiModalDataDict, - MultiModalKwargsItems, - MultiModalPlaceholderDict, - MultiModalUUIDDict, - NestedTensors, -) +from .inputs import BatchedTensorInputs, MultiModalKwargsItems, NestedTensors from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -25,13 +16,8 @@ Info: __all__ = [ "BatchedTensorInputs", - "ModalityData", - "MultiModalDataBuiltins", - "MultiModalDataDict", "MultiModalHasher", "MultiModalKwargsItems", - "MultiModalPlaceholderDict", - "MultiModalUUIDDict", "NestedTensors", "MULTIMODAL_REGISTRY", "MultiModalRegistry", diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 1e25142f3..750893272 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -15,12 +15,11 @@ from typing import ( TypedDict, Union, cast, - final, ) import numpy as np from PIL.Image import Image -from typing_extensions import NotRequired, TypeVar +from typing_extensions import TypeVar from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader @@ -32,14 +31,9 @@ if TYPE_CHECKING: import torch import torch.types from transformers.feature_extraction_utils import BatchFeature - - from vllm.inputs.data import _InputOptions else: torch = LazyLoader("torch", globals(), "torch") - _InputOptions = dict - -_T = TypeVar("_T") HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"] """ @@ -98,15 +92,6 @@ which are treated as audio embeddings; these are directly passed to the model without HF processing. """ -ModalityData: TypeAlias = _T | list[_T | None] | None -""" -Either a single data item, or a list of data items. Can only be None if UUID -is provided. - -The number of data items allowed per modality is restricted by -`--limit-mm-per-prompt`. -""" - class VisionChunkImage(TypedDict): """Represents an image wrapped as a vision chunk.""" @@ -126,46 +111,10 @@ class VisionChunkVideo(TypedDict): video_idx: int -VisionChunk = VisionChunkImage | VisionChunkVideo +VisionChunk: TypeAlias = VisionChunkImage | VisionChunkVideo """A vision chunk is either an image or a video chunk.""" -@final -class MultiModalDataBuiltins(TypedDict, total=False): - """Type annotations for modality types predefined by vLLM.""" - - image: ModalityData[ImageItem] - """The input image(s).""" - - video: ModalityData[VideoItem] - """The input video(s).""" - - audio: ModalityData[AudioItem] - """The input audio(s).""" - - vision_chunk: ModalityData[VisionChunk] - """The input visual atom(s) - unified modality for images and video chunks.""" - - -MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]] -""" -A dictionary containing an entry for each modality type to input. - -The built-in modalities are defined by -[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. -""" - -MultiModalUUIDDict: TypeAlias = Mapping[str, Sequence[str | None] | str] -""" -A dictionary containing user-provided UUIDs for items in each modality. -If a UUID for an item is not provided, its entry will be `None` and -MultiModalHasher will compute a hash for the item. - -The UUID will be used to identify the item for all caching purposes -(input processing caching, embedding caching, prefix caching, etc). -""" - - @dataclass(frozen=True) class PlaceholderRange: """ @@ -1048,112 +997,3 @@ MultiModalKwargsOptionalItems: TypeAlias = ( MultiModalKwargsItems[MultiModalKwargsItem] | MultiModalKwargsItems[MultiModalKwargsItem | None] ) - - -MultiModalHashes = dict[str, list[str]] -""" -A dictionary containing per-item hashes for each modality. -""" - - -MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] -""" -A dictionary containing per-item placeholder ranges for each modality. -""" - - -class MultiModalInputs(_InputOptions): - """ - Represents the outputs of - [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor], - ready to be passed to vLLM internals. - """ - - type: Literal["multimodal"] - """The type of inputs.""" - - prompt_token_ids: list[int] - """The processed token IDs which includes placeholder tokens.""" - - prompt: NotRequired[str] - """The prompt text corresponding to the token IDs, if available.""" - - mm_kwargs: MultiModalKwargsOptionalItems - """Keyword arguments to be directly passed to the model after batching.""" - - mm_hashes: MultiModalHashes - """The hashes of the multi-modal data.""" - - mm_placeholders: MultiModalPlaceholderDict - """ - For each modality, information about the placeholder tokens in - `prompt_token_ids`. - """ - - -def mm_inputs( - prompt_token_ids: list[int], - mm_kwargs: MultiModalKwargsOptionalItems, - mm_hashes: MultiModalHashes, - mm_placeholders: MultiModalPlaceholderDict, - *, - prompt: str | None = None, - cache_salt: str | None = None, -) -> MultiModalInputs: - inputs = MultiModalInputs( - type="multimodal", - prompt_token_ids=prompt_token_ids, - mm_kwargs=mm_kwargs, - mm_hashes=mm_hashes, - mm_placeholders=mm_placeholders, - ) - - if prompt is not None: - inputs["prompt"] = prompt - if cache_salt is not None: - inputs["cache_salt"] = cache_salt - - return inputs - - -class MultiModalEncDecInputs(MultiModalInputs): - """ - Represents the outputs of - [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor] - ready to be passed to vLLM internals. - - Note: Even text-only encoder-decoder models are currently implemented - as multi-modal models for convenience. - (Example: https://github.com/vllm-project/bart-plugin) - """ - - encoder_prompt_token_ids: list[int] - """The processed token IDs of the encoder prompt.""" - - encoder_prompt: NotRequired[str] - """The prompt text corresponding to the encoder token IDs, if available.""" - - -def mm_enc_dec_inputs( - encoder_inputs: MultiModalInputs, - decoder_prompt_token_ids: list[int], - *, - decoder_prompt: str | None = None, -) -> MultiModalEncDecInputs: - inputs = MultiModalEncDecInputs( - type="multimodal", - prompt_token_ids=decoder_prompt_token_ids, - encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], - mm_kwargs=encoder_inputs["mm_kwargs"], - mm_hashes=encoder_inputs["mm_hashes"], - mm_placeholders=encoder_inputs["mm_placeholders"], - ) - - if decoder_prompt is not None: - inputs["prompt"] = decoder_prompt - if "prompt" in encoder_inputs: - inputs["encoder_prompt"] = encoder_inputs["prompt"] - if "cache_salt" in encoder_inputs: - inputs["cache_salt"] = encoder_inputs["cache_salt"] - - return inputs diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 9e1774e39..f2187effa 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -19,6 +19,7 @@ import numpy as np import torch from typing_extensions import assert_never +from vllm.inputs import ModalityData, MultiModalDataDict, MultiModalUUIDDict from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader @@ -29,11 +30,8 @@ from .inputs import ( HfImageItem, HfVideoItem, ImageItem, - ModalityData, - MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, - MultiModalUUIDDict, VideoItem, ) from .media import MediaWithBytes @@ -407,8 +405,8 @@ _D = TypeVar("_D", bound=ModalityDataItems[Any, Any]) class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): """ - As [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict], but - normalized such that each entry corresponds to a list. + A normalized [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict] + such that each entry corresponds to a list. """ def select(self, modalities: Set[str]): @@ -477,7 +475,7 @@ ModalityDataParser: TypeAlias = Callable[ class MultiModalDataParser: """ - Parses [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] + Parses [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict] into [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]. Args: @@ -695,8 +693,8 @@ class MultiModalDataParser: MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]] """ -As [`MultiModalUUIDDict`][vllm.multimodal.inputs.MultiModalUUIDDict], but -normalized such that each entry corresponds to a list. +A normalized [`MultiModalUUIDDict`][vllm.inputs.MultiModalUUIDDict] +such that each entry corresponds to a list. """ diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index 98a41f69b..6a9b997f9 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -11,15 +11,14 @@ from typing import TYPE_CHECKING, Any, overload import torch from typing_extensions import TypeVar +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger -from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.parse import ( DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser, ) -from vllm.renderers import TokenizeParams from vllm.tokenizers import TokenizerLike from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils.func_utils import get_allowed_kwarg_only_overrides @@ -32,12 +31,14 @@ if TYPE_CHECKING: from transformers.processing_utils import ProcessorMixin from vllm.config import ModelConfig + from vllm.renderers import TokenizeParams else: PretrainedConfig = object BatchFeature = object ProcessorMixin = object ModelConfig = object + TokenizeParams = object logger = init_logger(__name__) @@ -339,6 +340,8 @@ class BaseProcessingInfo: def get_default_tok_params(self) -> TokenizeParams: """Construct the default parameters for tokenization.""" + from vllm.renderers import TokenizeParams + model_config = self.ctx.model_config encoder_config = model_config.encoder_config or {} @@ -451,8 +454,7 @@ class BaseProcessingInfo: validate: bool = True, ) -> MultiModalDataItems: """ - Normalize - [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] + Normalize [`MultiModalDataDict`][vllm.inputs.MultiModalDataDict] to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems] before passing them to [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. diff --git a/vllm/multimodal/processing/dummy_inputs.py b/vllm/multimodal/processing/dummy_inputs.py index 0f1029b76..970cc431e 100644 --- a/vllm/multimodal/processing/dummy_inputs.py +++ b/vllm/multimodal/processing/dummy_inputs.py @@ -14,9 +14,9 @@ from vllm.config.multimodal import ( ImageDummyOptions, VideoDummyOptions, ) +from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger -from ..inputs import MultiModalDataDict from .context import BaseProcessingInfo from .inputs import ProcessorInputs diff --git a/vllm/multimodal/processing/inputs.py b/vllm/multimodal/processing/inputs.py index 7c5d2fde8..ae12f4e51 100644 --- a/vllm/multimodal/processing/inputs.py +++ b/vllm/multimodal/processing/inputs.py @@ -3,8 +3,9 @@ from collections.abc import Mapping from dataclasses import dataclass, field +from vllm.inputs import MultiModalHashes + from ..hasher import MultiModalHasher -from ..inputs import MultiModalHashes from ..parse import MultiModalDataItems, MultiModalUUIDItems @@ -26,7 +27,7 @@ class ProcessorInputs: mm_uuid_items = self.mm_uuid_items or {} hf_processor_mm_kwargs = self.hf_processor_mm_kwargs - mm_hashes: MultiModalHashes = {} + mm_hashes = dict[str, list[str]]() hasher = MultiModalHasher for modality, data_items in mm_data_items.items(): diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index f26c17964..f0d2efe29 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -6,34 +6,29 @@ from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, S from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Generic, - NamedTuple, - Protocol, - TypeAlias, - cast, -) +from typing import TYPE_CHECKING, Generic, NamedTuple, Protocol, TypeAlias, cast import regex as re import torch from typing_extensions import TypeVar, assert_never +from vllm.inputs import ( + MultiModalEncDecInput, + MultiModalHashes, + MultiModalInput, + mm_enc_dec_input, + mm_input, +) from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from ..inputs import ( - MultiModalEncDecInputs, MultiModalFieldConfig, - MultiModalHashes, - MultiModalInputs, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalKwargsOptionalItems, PlaceholderRange, - mm_enc_dec_inputs, - mm_inputs, ) from ..parse import ( DictEmbeddingItems, @@ -994,7 +989,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_items: MultiModalDataItems, mm_uuid_items: MultiModalUUIDItems | None = None, hf_processor_mm_kwargs: Mapping[str, object] | None = None, - ) -> MultiModalInputs: + ) -> MultiModalInput: processor_inputs = ProcessorInputs( prompt, mm_items, @@ -1638,7 +1633,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalInputs: + ) -> MultiModalInput: """ Process multi-modal inputs to be used in vLLM. @@ -1673,7 +1668,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): for modality, placeholders in mm_placeholders.items() } - return mm_inputs( + return mm_input( prompt_token_ids=prompt_ids, mm_kwargs=mm_info.kwargs, mm_hashes=mm_info.hashes, @@ -1708,7 +1703,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): self, prompt: str | list[int], mm_items: MultiModalDataItems, - encoder_inputs: MultiModalInputs, + encoder_inputs: MultiModalInput, ): tokenizer = self.info.get_tokenizer() decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items) @@ -1721,7 +1716,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): decoder_prompt_text = None decoder_prompt_ids = decoder_prompt_raw - return mm_enc_dec_inputs( + return mm_enc_dec_input( encoder_inputs, decoder_prompt_ids, decoder_prompt=decoder_prompt_text, @@ -1731,7 +1726,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): self, inputs: ProcessorInputs, timing_ctx: TimingContext, - ) -> MultiModalEncDecInputs: + ) -> MultiModalEncDecInput: """ Process multi-modal inputs to be used in vLLM. The main processing steps are modified to fit encoder-decoder model: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 60c92d263..fa414a592 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from multiprocessing.synchronize import Lock as LockType from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast +from vllm.inputs import MultiModalInput from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config @@ -19,7 +20,6 @@ from .cache import ( ShmObjectStoreReceiverCache, ShmObjectStoreSenderCache, ) -from .inputs import MultiModalInputs from .processing import ( BaseDummyInputsBuilder, BaseMultiModalProcessor, @@ -220,7 +220,7 @@ class MultiModalRegistry: *, cache: BaseMultiModalProcessorCache | None = None, processor: BaseMultiModalProcessor | None = None, - ) -> MultiModalInputs: + ) -> MultiModalInput: """ Create dummy data for profiling the memory usage of a model. diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index c9f6b98bd..2d321cb67 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -12,6 +12,7 @@ import numpy.typing as npt from PIL import Image from typing_extensions import deprecated +from vllm.inputs import MultiModalPlaceholders from vllm.utils.import_utils import LazyLoader from .hasher import MultiModalHasher @@ -19,7 +20,6 @@ from .inputs import ( BatchedTensorInputs, MultiModalFieldElem, MultiModalKwargsItem, - MultiModalPlaceholderDict, MultiModalSharedField, ) from .media import AudioMediaIO, ImageMediaIO, MediaConnector, VideoMediaIO @@ -110,10 +110,10 @@ def encode_video_url( def argsort_mm_positions( - mm_positions: MultiModalPlaceholderDict, + mm_positions: MultiModalPlaceholders, ) -> list[tuple[str, int]]: """ - Given a `MultiModalPlaceholderDict`, output a sequence of keys to + Given a `MultiModalPlaceholders`, output a sequence of keys to sort the dictionary by `offset` (starting index in the input sequence) in ascending order. diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 39688bb8b..281e91999 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup from vllm.config import VllmConfig - from vllm.inputs import ProcessorInputs + from vllm.inputs import EngineInput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -635,7 +635,7 @@ class Platform: @classmethod def validate_request( cls, - processed_inputs: "ProcessorInputs", + processed_inputs: "EngineInput", params: "SamplingParams | PoolingParams", ) -> None: """Raises if this request is unsupported on this platform""" diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index f73eb99ab..d75ac8e94 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, Sequence from typing import Generic, TypeVar from vllm.config import VllmConfig -from vllm.inputs.data import PromptType +from vllm.inputs import PromptType from vllm.outputs import PoolingRequestOutput from vllm.pooling_params import PoolingParams from vllm.renderers import BaseRenderer diff --git a/vllm/renderers/base.py b/vllm/renderers/base.py index 63946e8fd..35c3986e6 100644 --- a/vllm/renderers/base.py +++ b/vllm/renderers/base.py @@ -11,17 +11,32 @@ from typing import TYPE_CHECKING, Any, Generic, overload from typing_extensions import TypeVar from vllm.inputs import ( - EmbedsInputs, + EmbedsInput, EmbedsPrompt, - EncoderDecoderInputs, - ProcessorInputs, - SingletonInputs, + EncoderDecoderInput, + EngineInput, + MultiModalDataDict, + MultiModalInput, + MultiModalUUIDDict, + SingletonInput, TextPrompt, - TokenInputs, + TokensInput, TokensPrompt, + build_enc_dec_input, + embeds_input, + tokens_input, ) -from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry +from vllm.multimodal.cache import BaseMultiModalProcessorCache +from vllm.multimodal.parse import ( + MultiModalDataItems, + MultiModalUUIDItems, + parse_mm_uuids, +) +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs +from vllm.multimodal.registry import MultiModalTimingRegistry from vllm.tokenizers import TokenizerLike from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.counter import AtomicCounter @@ -46,14 +61,6 @@ if TYPE_CHECKING: ChatCompletionMessageParam, ConversationMessage, ) - from vllm.multimodal.cache import BaseMultiModalProcessorCache - from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalInputs, - MultiModalUUIDDict, - ) - from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems - from vllm.multimodal.processing import BaseMultiModalProcessor logger = init_logger(__name__) @@ -86,9 +93,6 @@ class BaseRenderer(ABC, Generic[_T]): self.mm_processor: BaseMultiModalProcessor | None = None self._mm_cache_stats: MultiModalCacheStats | None = None if config.model_config.is_multimodal_model: - from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry - from vllm.multimodal.registry import MultiModalTimingRegistry - mm_processor_cache = mm_registry.processor_cache_from_config(config) # Deep-copy the tokenizer so the multimodal processor gets its @@ -524,9 +528,9 @@ class BaseRenderer(ABC, Generic[_T]): # Step 4: Convert to engine inputs def _validate_mm_uuids( self, - mm_data: "MultiModalDataDict", - mm_data_items: "MultiModalDataItems", - mm_uuid_items: "MultiModalUUIDItems", + mm_data: MultiModalDataDict, + mm_data_items: MultiModalDataItems, + mm_uuid_items: MultiModalUUIDItems, ) -> None: # NOTE: Keys corresponding to `None` in `mm_data` don't appear in # `mm_data_items` @@ -560,11 +564,11 @@ class BaseRenderer(ABC, Generic[_T]): def _process_mm_uuids( self, - mm_data: "MultiModalDataDict", - mm_data_items: "MultiModalDataItems", - mm_uuid_items: "MultiModalUUIDItems", + mm_data: MultiModalDataDict, + mm_data_items: MultiModalDataItems, + mm_uuid_items: MultiModalUUIDItems, mm_req_id: str, - ): + ) -> MultiModalUUIDItems: model_config = self.model_config # NOTE: When users explicitly turn off BOTH prefix caching and input @@ -590,14 +594,11 @@ class BaseRenderer(ABC, Generic[_T]): def _process_multimodal( self, prompt: list[int] | str, - mm_data: "MultiModalDataDict", - mm_uuids: "MultiModalUUIDDict | None", + mm_data: MultiModalDataDict, + mm_uuids: MultiModalUUIDDict | None, mm_processor_kwargs: Mapping[str, object] | None, tokenization_kwargs: dict[str, Any] | None, - ) -> "MultiModalInputs": - from vllm.multimodal.parse import parse_mm_uuids - from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs - + ) -> "MultiModalInput": mm_req_id = f"renderer{self.api_process_rank}-mm-{self._mm_req_counter.inc(1)}" mm_processor = self.get_mm_processor() @@ -628,12 +629,12 @@ class BaseRenderer(ABC, Generic[_T]): def _process_tokens( self, prompt: TokensPrompt, - ) -> "TokenInputs | MultiModalInputs": + ) -> TokensInput | MultiModalInput: prompt_token_ids = prompt["prompt_token_ids"] - inputs: TokenInputs | MultiModalInputs + engine_input: TokensInput | MultiModalInput if multi_modal_data := prompt.get("multi_modal_data"): - inputs = self._process_multimodal( + engine_input = self._process_multimodal( prompt_token_ids, multi_modal_data, mm_processor_kwargs=prompt.get("mm_processor_kwargs"), @@ -641,19 +642,16 @@ class BaseRenderer(ABC, Generic[_T]): mm_uuids=prompt.get("multi_modal_uuids"), ) else: - inputs = token_inputs(prompt_token_ids) + engine_input = tokens_input(prompt_token_ids) if prompt_text := prompt.get("prompt"): - inputs["prompt"] = prompt_text + engine_input["prompt"] = prompt_text if cache_salt := prompt.get("cache_salt"): - inputs["cache_salt"] = cache_salt + engine_input["cache_salt"] = cache_salt - return inputs + return engine_input - def _process_embeds( - self, - prompt: EmbedsPrompt, - ) -> EmbedsInputs: + def _process_embeds(self, prompt: EmbedsPrompt) -> EmbedsInput: if not self.model_config.enable_prompt_embeds: raise ValueError( "You must set `--enable-prompt-embeds` to input `prompt_embeds`." @@ -676,15 +674,12 @@ class BaseRenderer(ABC, Generic[_T]): # hidden device transfer in the critical path of generation. prompt_embeds = prompt_embeds.cpu() - return embeds_inputs( + return embeds_input( prompt_embeds=prompt_embeds, cache_salt=prompt.get("cache_salt"), ) - def _process_singleton( - self, - prompt: SingletonTokPrompt, - ) -> SingletonInputs: + def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput: if "prompt_embeds" in prompt: return self._process_embeds(prompt) # type: ignore[arg-type] @@ -693,7 +688,7 @@ class BaseRenderer(ABC, Generic[_T]): def _process_enc_dec( self, prompt: EncoderDecoderTokPrompt, - ) -> EncoderDecoderInputs: + ) -> EncoderDecoderInput: enc_prompt = prompt["encoder_prompt"] dec_prompt = prompt["decoder_prompt"] @@ -704,27 +699,25 @@ class BaseRenderer(ABC, Generic[_T]): if isinstance(self.mm_processor, EncDecMultiModalProcessor): skip_decoder_start_token = self.mm_processor.skip_decoder_start_token - return build_enc_dec_inputs( - encoder_inputs=self._process_singleton(enc_prompt), - decoder_inputs=( + return build_enc_dec_input( + encoder_input=self._process_singleton(enc_prompt), + decoder_input=( None if dec_prompt is None else self._process_singleton(dec_prompt) ), decoder_start_token_id=self.get_dec_start_token_id(), skip_decoder_start_token=skip_decoder_start_token, ) - def process_for_engine( - self, prompt: TokPrompt, arrival_time: float - ) -> ProcessorInputs: - engine_prompt: ProcessorInputs + def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput: + engine_input: EngineInput if "encoder_prompt" in prompt: - engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type] + engine_input = self._process_enc_dec(prompt) # type: ignore[arg-type] else: - engine_prompt = self._process_singleton(prompt) + engine_input = self._process_singleton(prompt) - engine_prompt["arrival_time"] = arrival_time + engine_input["arrival_time"] = arrival_time - return engine_prompt + return engine_input # Top-level methods def render_cmpl( diff --git a/vllm/renderers/hf.py b/vllm/renderers/hf.py index 02395b775..b6ee93910 100644 --- a/vllm/renderers/hf.py +++ b/vllm/renderers/hf.py @@ -5,7 +5,7 @@ import itertools from collections import defaultdict, deque from collections.abc import Set from functools import lru_cache -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import Any, Literal, cast, overload import jinja2 import jinja2.ext @@ -25,6 +25,7 @@ from vllm.entrypoints.chat_utils import ( parse_chat_messages, parse_chat_messages_async, ) +from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.logger import init_logger from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer @@ -37,13 +38,6 @@ from .inputs import DictPrompt from .inputs.preprocess import parse_dec_only_prompt from .params import ChatParams -if TYPE_CHECKING: - from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict -else: - MultiModalDataDict = dict[str, Any] - MultiModalUUIDDict = dict[str, Any] - - logger = init_logger(__name__) @@ -512,9 +506,9 @@ def safe_apply_chat_template( def rebuild_mm_uuids_from_mm_data( - mm_uuids: "MultiModalUUIDDict", - mm_data: "MultiModalDataDict", -) -> "MultiModalUUIDDict": + mm_uuids: MultiModalUUIDDict, + mm_data: MultiModalDataDict, +) -> MultiModalUUIDDict: """Rebuild mm_uuids after vision_chunk processing. When videos are split into chunks, the original UUIDs need to be updated @@ -547,7 +541,7 @@ def rebuild_mm_uuids_from_mm_data( def build_video_prompts_from_mm_data( - mm_data: "MultiModalDataDict", + mm_data: MultiModalDataDict, ) -> list[str]: """Build video prompts from vision_chunk data. @@ -585,7 +579,7 @@ def build_video_prompts_from_mm_data( def replace_vision_chunk_video_placeholder( prompt_raw: str | list[int], - mm_data: "MultiModalDataDict", + mm_data: MultiModalDataDict, video_placeholder: str | None, ) -> str | list[int]: # get video placeholder, replace it with runtime video-chunk prompts diff --git a/vllm/renderers/inputs/preprocess.py b/vllm/renderers/inputs/preprocess.py index e972d0755..1828c4ff5 100644 --- a/vllm/renderers/inputs/preprocess.py +++ b/vllm/renderers/inputs/preprocess.py @@ -9,8 +9,8 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload from vllm.inputs import ( EmbedsPrompt, + EngineInput, ExplicitEncoderDecoderPrompt, - ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, @@ -70,28 +70,28 @@ def conversation_to_seq( DecoderOnlyDictPrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt """ -A [`DecoderOnlyPrompt`][vllm.inputs.data.DecoderOnlyPrompt] +A [`DecoderOnlyPrompt`][vllm.inputs.llm.DecoderOnlyPrompt] that has been standardized into a dictionary. """ EncoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt """ -A [`EncoderPrompt`][vllm.inputs.data.EncoderPrompt] +A [`EncoderPrompt`][vllm.inputs.llm.EncoderPrompt] that has been standardized into a dictionary. """ DecoderDictPrompt: TypeAlias = TextPrompt | TokensPrompt """ -A [`DecoderPrompt`][vllm.inputs.data.DecoderPrompt] +A [`DecoderPrompt`][vllm.inputs.llm.DecoderPrompt] that has been standardized into a dictionary. """ class EncoderDecoderDictPrompt(TypedDict): """ - A [`EncoderDecoderPrompt`][vllm.inputs.data.EncoderDecoderPrompt] + A [`EncoderDecoderPrompt`][vllm.inputs.llm.EncoderDecoderPrompt] that has been standardized into a dictionary. """ @@ -104,14 +104,14 @@ SingletonDictPrompt: TypeAlias = ( DecoderOnlyDictPrompt | EncoderDictPrompt | DecoderDictPrompt ) """ -A [`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] +A [`SingletonPrompt`][vllm.inputs.llm.SingletonPrompt] that has been standardized into a dictionary. """ DictPrompt: TypeAlias = DecoderOnlyDictPrompt | EncoderDecoderDictPrompt """ -A [`PromptType`][vllm.inputs.data.PromptType] +A [`PromptType`][vllm.inputs.llm.PromptType] that has been standardized into a dictionary. """ @@ -236,7 +236,7 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object): def extract_prompt_components( model_config: "ModelConfig", - prompt: PromptType | ProcessorInputs, + prompt: PromptType | EngineInput, ) -> PromptComponents: target_prompt = extract_target_prompt(model_config, prompt) @@ -248,7 +248,8 @@ def extract_prompt_components( def extract_prompt_len( - model_config: "ModelConfig", prompt: PromptType | ProcessorInputs + model_config: "ModelConfig", + prompt: PromptType | EngineInput, ): target_prompt = extract_target_prompt(model_config, prompt) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a9c42e78e..c6d34eda9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -21,7 +21,7 @@ from vllm.distributed.weight_transfer.base import ( from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient, StreamingInput from vllm.entrypoints.serve.elastic_ep.middleware import set_scaling_elastic_ep -from vllm.inputs import ProcessorInputs, PromptType +from vllm.inputs import EngineInput, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry @@ -139,7 +139,7 @@ class AsyncLLM(EngineClient): self.model_config.io_processor_plugin, ) - # Convert TokPrompt --> EngineCoreRequest. + # Convert EngineInput --> EngineCoreRequest. self.input_processor = InputProcessor(self.vllm_config, renderer) # Converts EngineCoreOutputs --> RequestOutput. @@ -290,7 +290,7 @@ class AsyncLLM(EngineClient): request_id: str, prompt: EngineCoreRequest | PromptType - | ProcessorInputs + | EngineInput | AsyncGenerator[StreamingInput, None], params: SamplingParams | PoolingParams, arrival_time: float | None = None, @@ -530,7 +530,7 @@ class AsyncLLM(EngineClient): self, prompt: EngineCoreRequest | PromptType - | ProcessorInputs + | EngineInput | AsyncGenerator[StreamingInput, None], sampling_params: SamplingParams, request_id: str, @@ -776,7 +776,7 @@ class AsyncLLM(EngineClient): async def encode( self, - prompt: PromptType | ProcessorInputs, + prompt: PromptType | EngineInput, pooling_params: PoolingParams, request_id: str, lora_request: LoRARequest | None = None, diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index b77b9277a..b59d02a46 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -7,20 +7,18 @@ from typing import Any, Literal import vllm.envs as envs from vllm.config import VllmConfig -from vllm.inputs.data import ( - ProcessorInputs, +from vllm.inputs import ( + EngineInput, PromptType, - SingletonInputs, + SingletonInput, + split_enc_dec_input, ) -from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.encoder_budget import MultiModalBudget -from vllm.multimodal.inputs import ( - MultiModalFeatureSpec, -) +from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.utils import argsort_mm_positions from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams @@ -197,7 +195,7 @@ class InputProcessor: def process_inputs( self, request_id: str, - prompt: PromptType | ProcessorInputs, + prompt: PromptType | EngineInput, params: SamplingParams | PoolingParams, supported_tasks: tuple[SupportedTask, ...], arrival_time: float | None = None, @@ -232,7 +230,7 @@ class InputProcessor: if arrival_time is None: arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment] - processed_inputs: ProcessorInputs = prompt # type: ignore[assignment] + processed_inputs: EngineInput = prompt # type: ignore[assignment] else: logger.warning_once( "Passing raw prompts to InputProcessor is deprecated " @@ -250,7 +248,7 @@ class InputProcessor: current_platform.validate_request(processed_inputs, params) - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + encoder_inputs, decoder_inputs = split_enc_dec_input(processed_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs) # Mypy can be conservative for TypedDict unions; normalize access. @@ -385,7 +383,7 @@ class InputProcessor: def _validate_model_input( self, - prompt_inputs: SingletonInputs, + prompt_input: SingletonInput, prompt_type: Literal["encoder", "decoder"], ) -> None: model_config = self.model_config @@ -393,20 +391,18 @@ class InputProcessor: prompt_ids = ( None - if prompt_inputs["type"] == "embeds" - else prompt_inputs["prompt_token_ids"] + if prompt_input["type"] == "embeds" + else prompt_input["prompt_token_ids"] ) prompt_embeds = ( - prompt_inputs["prompt_embeds"] - if prompt_inputs["type"] == "embeds" - else None + prompt_input["prompt_embeds"] if prompt_input["type"] == "embeds" else None ) prompt_len = length_from_prompt_token_ids_or_embeds(prompt_ids, prompt_embeds) self._validate_prompt_len(prompt_len, prompt_type) - if prompt_inputs["type"] == "multimodal": - decoder_mm_positions = prompt_inputs["mm_placeholders"] + if prompt_input["type"] == "multimodal": + decoder_mm_positions = prompt_input["mm_placeholders"] for modality, mm_positions in decoder_mm_positions.items(): for mm_position in mm_positions: embed_length = mm_position.get_num_embeds() @@ -439,10 +435,10 @@ class InputProcessor: def _validate_model_inputs( self, - encoder_inputs: SingletonInputs | None, - decoder_inputs: SingletonInputs, + encoder_input: SingletonInput | None, + decoder_input: SingletonInput, ): - if encoder_inputs is not None: - self._validate_model_input(encoder_inputs, prompt_type="encoder") + if encoder_input is not None: + self._validate_model_input(encoder_input, prompt_type="encoder") - self._validate_model_input(decoder_inputs, prompt_type="decoder") + self._validate_model_input(decoder_input, prompt_type="decoder") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 0d9279331..4b6a7ba44 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.distributed.parallel_state import get_dp_group from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import ProcessorInputs, PromptType +from vllm.inputs import EngineInput, PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry @@ -96,7 +96,7 @@ class LLMEngine: self.model_config.io_processor_plugin, ) - # Convert TokPrompt --> EngineCoreRequest. + # Convert EngineInput --> EngineCoreRequest. self.input_processor = InputProcessor(self.vllm_config, renderer) # Converts EngineCoreOutputs --> RequestOutput. @@ -216,7 +216,7 @@ class LLMEngine: def add_request( self, request_id: str, - prompt: EngineCoreRequest | PromptType | ProcessorInputs, + prompt: EngineCoreRequest | PromptType | EngineInput, params: SamplingParams | PoolingParams, arrival_time: float | None = None, lora_request: LoRARequest | None = None,