diff --git a/tests/renderers/test_completions.py b/tests/renderers/test_completions.py index 03e1a655a..492f539e4 100644 --- a/tests/renderers/test_completions.py +++ b/tests/renderers/test_completions.py @@ -93,14 +93,14 @@ def _build_renderer( def _preprocess_prompt( - mdoel_config: ModelConfig, + model_config: ModelConfig, prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes], ): return [ ( prompt if isinstance(prompt, bytes) - else parse_model_prompt(mdoel_config, prompt) + else parse_model_prompt(model_config, prompt) ) for prompt in prompt_to_seq(prompt_or_prompts) ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 4f1b3b9ca..07ed9f1d0 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Literal, TypeAlias import torch -from typing_extensions import NotRequired, TypedDict +from typing_extensions import NotRequired, TypedDict, assert_never if TYPE_CHECKING: from vllm.multimodal.inputs import ( @@ -200,15 +200,22 @@ class TokenInputs(_InputOptions): prompt_token_ids: list[int] """The token IDs of the prompt.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + def token_inputs( prompt_token_ids: list[int], + *, + prompt: str | None = None, cache_salt: str | None = None, ) -> TokenInputs: """Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) + if prompt is not None: + inputs["prompt"] = prompt if cache_salt is not None: inputs["cache_salt"] = cache_salt @@ -224,15 +231,22 @@ class EmbedsInputs(_InputOptions): prompt_embeds: torch.Tensor """The embeddings of the prompt.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + def embeds_inputs( prompt_embeds: torch.Tensor, + *, + prompt: str | None = None, cache_salt: str | None = None, ) -> EmbedsInputs: """Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional values.""" inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds) + if prompt is not None: + inputs["prompt"] = prompt if cache_salt is not None: inputs["cache_salt"] = cache_salt @@ -278,10 +292,12 @@ class EncoderDecoderInputs(TypedDict): for encoder-decoder models. """ - encoder: EncoderInputs + type: Literal["enc_dec"] + + encoder_prompt: EncoderInputs """The inputs for the encoder portion.""" - decoder: DecoderInputs + decoder_prompt: DecoderInputs """The inputs for the decoder portion.""" @@ -296,3 +312,94 @@ which can be passed to SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs """The inputs for a single encoder/decoder prompt.""" + + +def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs: + if inputs["type"] == "embeds": + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) + + if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs: + raise RuntimeError( + "You should register an encoder-decoder multi-modal processor " + "for encoder-decoder models." + ) + + return inputs # type: ignore[return-value] + + +def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs: + if inputs["type"] == "embeds": + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) + + return inputs + + +def _prepare_decoder_input_ids_for_generation( + decoder_input_ids: list[int], + decoder_start_token_id: int, +) -> list[int]: + """ + Prepare `decoder_input_ids` for generation with encoder-decoder models, + according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`. + + Source: + https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py + """ + if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id: + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + +def build_enc_dec_inputs( + encoder_inputs: SingletonInputs, + decoder_inputs: SingletonInputs | None, + decoder_start_token_id: int, +) -> EncoderDecoderInputs: + enc_inputs = _validate_enc_inputs(encoder_inputs) + + if decoder_inputs is None: + dec_inputs: DecoderInputs = enc_inputs + else: + dec_inputs = _validate_dec_inputs(decoder_inputs) + + enc_inputs_new: EncoderInputs + dec_inputs_new: DecoderInputs + + if enc_inputs["type"] == "multimodal": + from vllm.multimodal.inputs import mm_inputs + + enc_inputs_new = token_inputs( + enc_inputs["encoder_prompt_token_ids"], + prompt=enc_inputs.get("encoder_prompt"), + ) + dec_inputs_new = mm_inputs( + prompt_token_ids=dec_inputs["prompt_token_ids"], + prompt=dec_inputs.get("prompt"), + mm_kwargs=enc_inputs["mm_kwargs"], + mm_hashes=enc_inputs["mm_hashes"], + mm_placeholders=enc_inputs["mm_placeholders"], + ) + elif enc_inputs["type"] == "token": + enc_inputs_new = token_inputs(prompt_token_ids=[]) + dec_inputs_new = dec_inputs + else: + assert_never(enc_inputs) + + dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation( + dec_inputs_new["prompt_token_ids"], + decoder_start_token_id, + ) + + if cache_salt := enc_inputs.get("cache_salt"): + dec_inputs_new["cache_salt"] = cache_salt + + return EncoderDecoderInputs( + type="enc_dec", + encoder_prompt=enc_inputs_new, + decoder_prompt=dec_inputs_new, + ) diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 611a470ba..ab29935ac 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -7,11 +7,7 @@ from .data import ProcessorInputs, SingletonInputs def split_enc_dec_inputs( inputs: ProcessorInputs, ) -> tuple[SingletonInputs | None, SingletonInputs]: - if "encoder" in inputs and "decoder" in inputs: - # NOTE: This passes pyright but not mypy - return ( - inputs["encoder"], # type: ignore[typeddict-item] - inputs["decoder"], # type: ignore[typeddict-item] - ) + if inputs["type"] == "enc_dec": + return inputs["encoder_prompt"], inputs["decoder_prompt"] return None, inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 08a37b6da..95089623e 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -7,6 +7,7 @@ from typing import Any, overload from typing_extensions import assert_never from vllm.config import VllmConfig +from vllm.inputs.data import build_enc_dec_inputs from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import ( @@ -67,54 +68,6 @@ class InputPreprocessor: def get_tokenizer(self) -> TokenizerLike: return self.renderer.get_tokenizer() - def get_decoder_start_token_id(self) -> int: - """ - Obtain the decoder start token id employed by an encoder/decoder - model. Raises an error if it is not available. - """ - dec_start_token_id = getattr( - self.model_config.hf_config, "decoder_start_token_id", None - ) - - if dec_start_token_id is None: - logger.warning_once( - "Falling back on for decoder start token id " - "because decoder start token id is not available." - ) - dec_start_token_id = self.renderer.get_bos_token_id() - - if dec_start_token_id is None: - raise RuntimeError("Cannot find decoder start token id or ") - - return dec_start_token_id - - def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]: - """ - Prepares `decoder_input_ids` for generation with encoder-decoder models. - - Based on: - https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py - specifically, - `GenerationMixin._prepare_decoder_input_ids_for_generation()`. - - Arguments: - - * decoder_input_ids: input token ids to preprocess - - Returns: - - * Processed token list - """ - decoder_start_token_id = self.get_decoder_start_token_id() - - if ( - len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id - ): - decoder_input_ids = [decoder_start_token_id] + decoder_input_ids - - return decoder_input_ids - def _tokenize_prompt( self, prompt: str, @@ -332,66 +285,6 @@ class InputPreprocessor: assert_never(prompt) # type: ignore[arg-type] - def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs: - if inputs["type"] == "embeds": - raise ValueError( - "Embedding inputs are not supported for encoder-decoder models" - ) - - if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs: - raise RuntimeError( - "You should register an encoder-decoder " - "multi-modal processor for encoder-decoder models." - ) - - return inputs # type: ignore[return-value] - - def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs: - if inputs["type"] == "embeds": - raise ValueError( - "Embedding inputs are not supported for encoder-decoder models" - ) - - return inputs - - def _build_enc_dec_inputs( - self, - encoder_inputs: SingletonInputs, - decoder_inputs: SingletonInputs | None = None, - ) -> EncoderDecoderInputs: - enc_inputs = self._validate_enc_inputs(encoder_inputs) - - if decoder_inputs is None: - dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment] - else: - dec_inputs = self._validate_dec_inputs(decoder_inputs) - - enc_inputs_new: EncoderInputs - dec_inputs_new: DecoderInputs - - if enc_inputs["type"] == "multimodal": - enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"]) - dec_inputs_new = MultiModalInputs( - type="multimodal", - prompt_token_ids=dec_inputs["prompt_token_ids"], - mm_kwargs=enc_inputs["mm_kwargs"], - mm_hashes=enc_inputs["mm_hashes"], - mm_placeholders=enc_inputs["mm_placeholders"], - ) - elif enc_inputs["type"] == "token": - enc_inputs_new = token_inputs(prompt_token_ids=[]) - dec_inputs_new = dec_inputs - else: - assert_never(enc_inputs) - - dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids( - dec_inputs_new["prompt_token_ids"] - ) - if cache_salt := enc_inputs.get("cache_salt"): - dec_inputs_new["cache_salt"] = cache_salt - - return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new) - def _process_encoder_decoder_prompt( self, prompt: EncoderDecoderDictPrompt, @@ -417,7 +310,7 @@ class InputPreprocessor: encoder_prompt = prompt["encoder_prompt"] decoder_prompt = prompt["decoder_prompt"] - return self._build_enc_dec_inputs( + return build_enc_dec_inputs( encoder_inputs=self._prompt_to_llm_inputs( encoder_prompt, tokenization_kwargs=tokenization_kwargs, @@ -431,6 +324,7 @@ class InputPreprocessor: tokenization_kwargs=tokenization_kwargs, ) ), + decoder_start_token_id=self.renderer.get_dec_start_token_id(), ) def _process_decoder_only_prompt( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index ecd2c895b..07e8dac85 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -31,6 +31,7 @@ from vllm.multimodal.inputs import ( MultiModalInputs, MultiModalKwargsItems, MultiModalUUIDDict, + mm_inputs, ) from vllm.multimodal.parse import ( ImageEmbeddingItems, @@ -837,8 +838,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): for modality, placeholders in mm_placeholders.items() } - return MultiModalInputs( - type="multimodal", + return mm_inputs( prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 804eccbc4..016cdd742 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -48,6 +48,7 @@ from vllm.multimodal.inputs import ( MultiModalKwargsItems, MultiModalUUIDDict, PlaceholderRange, + mm_inputs, ) from vllm.multimodal.parse import ( DictEmbeddingItems, @@ -222,8 +223,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing ), ) - return MultiModalInputs( - type="multimodal", + return mm_inputs( prompt_token_ids=[1], mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 64dc5bf8b..6fb5827a8 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -33,6 +33,7 @@ from vllm.multimodal.inputs import ( MultiModalInputs, MultiModalUUIDDict, PlaceholderRange, + mm_inputs, ) from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.processing import ( @@ -260,8 +261,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids ) - return MultiModalInputs( - type="multimodal", + return mm_inputs( prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_hashes=mm_hashes, diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 221baba6d..be9f7e652 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -20,7 +20,7 @@ from typing import ( import numpy as np from PIL.Image import Image -from typing_extensions import TypeVar +from typing_extensions import NotRequired, TypeVar from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader @@ -1075,6 +1075,9 @@ class MultiModalInputs(_InputOptions): prompt_token_ids: list[int] """The processed token IDs which includes placeholder tokens.""" + prompt: NotRequired[str] + """The prompt text corresponding to the token IDs, if available.""" + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" @@ -1088,6 +1091,31 @@ class MultiModalInputs(_InputOptions): """ +def mm_inputs( + prompt_token_ids: list[int], + mm_kwargs: MultiModalKwargsOptionalItems, + mm_hashes: MultiModalHashes, + mm_placeholders: MultiModalPlaceholderDict, + *, + prompt: str | None = None, + cache_salt: str | None = None, +) -> MultiModalInputs: + inputs = MultiModalInputs( + type="multimodal", + prompt_token_ids=prompt_token_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + if prompt is not None: + inputs["prompt"] = prompt + if cache_salt is not None: + inputs["cache_salt"] = cache_salt + + return inputs + + class MultiModalEncDecInputs(MultiModalInputs): """ Represents the outputs of @@ -1101,3 +1129,31 @@ class MultiModalEncDecInputs(MultiModalInputs): encoder_prompt_token_ids: list[int] """The processed token IDs of the encoder prompt.""" + + encoder_prompt: NotRequired[str] + """The prompt text corresponding to the encoder token IDs, if available.""" + + +def mm_enc_dec_inputs( + encoder_inputs: MultiModalInputs, + decoder_prompt_token_ids: list[int], + *, + decoder_prompt: str | None = None, +) -> MultiModalEncDecInputs: + inputs = MultiModalEncDecInputs( + type="multimodal", + prompt_token_ids=decoder_prompt_token_ids, + encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], + mm_kwargs=encoder_inputs["mm_kwargs"], + mm_hashes=encoder_inputs["mm_hashes"], + mm_placeholders=encoder_inputs["mm_placeholders"], + ) + + if decoder_prompt is not None: + inputs["prompt"] = decoder_prompt + if "prompt" in encoder_inputs: + inputs["encoder_prompt"] = encoder_inputs["prompt"] + if "cache_salt" in encoder_inputs: + inputs["cache_salt"] = encoder_inputs["cache_salt"] + + return inputs diff --git a/vllm/multimodal/media/base.py b/vllm/multimodal/media/base.py index 909a6eb93..576355255 100644 --- a/vllm/multimodal/media/base.py +++ b/vllm/multimodal/media/base.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Generic, TypeVar @@ -26,7 +26,7 @@ class MediaWithBytes(Generic[_T]): """ media: _T - original_bytes: bytes + original_bytes: bytes = field(repr=False) def __array__(self, *args, **kwargs) -> np.ndarray: """Allow np.array(obj) to return np.array(obj.media).""" diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index e1a164d4e..50b288cd7 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -34,6 +34,8 @@ from ..inputs import ( MultiModalKwargsOptionalItems, MultiModalUUIDDict, PlaceholderRange, + mm_enc_dec_inputs, + mm_inputs, ) from ..parse import ( DictEmbeddingItems, @@ -1803,8 +1805,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): for modality, placeholders in mm_placeholders.items() } - return MultiModalInputs( - type="multimodal", + return mm_inputs( prompt_token_ids=prompt_ids, mm_kwargs=mm_info.kwargs, mm_hashes=mm_info.hashes, @@ -1848,12 +1849,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): else: decoder_prompt_ids = decoder_prompt_raw - mm_inputs = MultiModalEncDecInputs( - encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"], - **encoder_inputs, + return mm_enc_dec_inputs( + encoder_inputs, + decoder_prompt_ids, ) - mm_inputs["prompt_token_ids"] = decoder_prompt_ids - return mm_inputs def apply( self, diff --git a/vllm/renderers/base.py b/vllm/renderers/base.py index bd60450ff..2a1549be0 100644 --- a/vllm/renderers/base.py +++ b/vllm/renderers/base.py @@ -153,6 +153,27 @@ class BaseRenderer(ABC, Generic[_T]): return self.tokenizer.eos_token_id + def get_dec_start_token_id(self) -> int: + """ + Obtain the decoder start token id employed by an encoder/decoder model, + raising an error if it is not available. + """ + dec_start_token_id = getattr( + self.model_config.hf_config, "decoder_start_token_id", None + ) + + if dec_start_token_id is None: + logger.warning_once( + "Falling back on for decoder start token id " + "because decoder start token id is not available." + ) + dec_start_token_id = self.get_bos_token_id() + + if dec_start_token_id is None: + raise RuntimeError("Cannot find decoder start token id or ") + + return dec_start_token_id + @cached_property def default_cmpl_tok_params(self) -> TokenizeParams: mm_processor = self.mm_processor