diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 315ffddde..d9f9814ee 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -12,11 +12,13 @@ from vllm.sampling_params import SamplingParams if TYPE_CHECKING: from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalEncDecInputs, MultiModalInputs, MultiModalUUIDDict, ) else: MultiModalDataDict = object + MultiModalEncDecInputs = object MultiModalInputs = object MultiModalUUIDDict = object @@ -241,7 +243,7 @@ class EncoderDecoderInputs(TypedDict): This specifies the required data for encoder-decoder models. """ - encoder: TokenInputs | MultiModalInputs + encoder: TokenInputs | MultiModalEncDecInputs """The inputs for the encoder portion.""" decoder: TokenInputs | MultiModalInputs diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 5f832afdb..7cb1eb4b4 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -69,6 +69,22 @@ def is_explicit_encoder_decoder_prompt( return isinstance(prompt, dict) and "encoder_prompt" in prompt +def split_enc_dec_prompt( + prompt: PromptType, +) -> tuple[SingletonPrompt, SingletonPrompt | None]: + if isinstance(prompt, str): + return prompt, None + + if "encoder_prompt" in prompt and "decoder_prompt" in prompt: + # NOTE: This passes pyright but not mypy + return ( + prompt["encoder_prompt"], # type: ignore[typeddict-item] + prompt["decoder_prompt"], # type: ignore[typeddict-item] + ) + + return prompt, None + + def split_enc_dec_inputs( inputs: ProcessorInputs, ) -> tuple[SingletonInputs | None, SingletonInputs]: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 6edb26a4a..0a3b0c946 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping -from typing import Any, cast +from typing import Any from typing_extensions import assert_never from vllm.config import ModelConfig, ObservabilityConfig +from vllm.inputs.parse import split_enc_dec_prompt from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import BaseMultiModalProcessorCache @@ -27,7 +28,6 @@ from .data import ( EmbedsInputs, EmbedsPrompt, EncoderDecoderInputs, - ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonInputs, @@ -86,30 +86,15 @@ class InputPreprocessor: return self.tokenizer.eos_token_id - def get_decoder_start_token_id(self) -> int | None: + def get_decoder_start_token_id(self) -> int: """ Obtain the decoder start token id employed by an encoder/decoder - model. Returns None for non-encoder/decoder models or if the - model config is unavailable. + model. Raises an error if it is not available. """ - - if not self.model_config.is_encoder_decoder: - logger.warning_once( - "Using None for decoder start token id because " - "this is not an encoder/decoder model." - ) - return None - - if self.model_config is None or self.model_config.hf_config is None: - logger.warning_once( - "Using None for decoder start token id because " - "model config is not available." - ) - return None - dec_start_token_id = getattr( self.model_config.hf_config, "decoder_start_token_id", None ) + if dec_start_token_id is None: logger.warning_once( "Falling back on for decoder start token " @@ -118,48 +103,12 @@ class InputPreprocessor: ) dec_start_token_id = self.get_bos_token_id() + if dec_start_token_id is None: + raise RuntimeError("Cannot find decoder start token id or ") + return dec_start_token_id - def _get_default_enc_dec_decoder_prompt(self) -> list[int]: - """ - Specifically for encoder/decoder models: - generate a default decoder prompt for when - the user specifies only the encoder prompt. - - Encoder/decoder models utilize the decoder - prompt in different ways; as new models are - added, it is intended that this function - will be extended to produce differing - default decoder prompts, depending on the - model variety. - - Absent a special case, the default behavior - of this method is to mirror the behavior of - the HuggingFace (HF) GenerationMixin for a None - decoder prompt, which is to employ a logit processor - setting to force the first decoded token to be . - Here, this behavior is approximated by having the - "default" decoder prompt be . - - However, it is possible that in the future - other models may have different or more - complex logic for the default decoder prompt. - This motivates having a special helper method - for default decoder prompts. - - Returns: - - * prompt_token_ids - """ - - bos_token_id = self.get_bos_token_id() - assert bos_token_id is not None - return [bos_token_id] - - def _prepare_decoder_input_ids_for_generation( - self, - decoder_input_ids: list[int] | None, - ) -> list[int]: + def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -176,14 +125,7 @@ class InputPreprocessor: * Processed token list """ - decoder_start_token_id = self.get_decoder_start_token_id() - assert decoder_start_token_id is not None - - if decoder_input_ids is None: - # no decoder prompt input -> - # use decoder_start_token_id as decoder_input_ids - decoder_input_ids = self._get_default_enc_dec_decoder_prompt() if ( len(decoder_input_ids) == 0 @@ -428,111 +370,70 @@ class InputPreprocessor: assert_never(parsed) - def _build_enc_dec_llm_inputs( + def _validate_enc_inputs( + self, + inputs: SingletonInputs, + ) -> TokenInputs | MultiModalEncDecInputs: + if inputs["type"] == "embeds": + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) + + if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs: + raise RuntimeError( + "You should register an encoder-decoder " + "multi-modal processor for encoder-decoder models." + ) + + return inputs # type: ignore[return-value] + + def _validate_dec_inputs( + self, + inputs: SingletonInputs, + ) -> TokenInputs | MultiModalInputs: + if inputs["type"] == "embeds": + raise ValueError( + "Embedding inputs are not supported for encoder-decoder models" + ) + + return inputs + + def _build_enc_dec_inputs( self, encoder_inputs: SingletonInputs, - decoder_inputs: SingletonInputs | None, + decoder_inputs: SingletonInputs | None = None, ) -> EncoderDecoderInputs: - if ( - encoder_inputs["type"] == "embeds" - or decoder_inputs - and decoder_inputs["type"] == "embeds" - ): - raise ValueError( - "Embedding inputs are not supported for encoder-decoder models" - ) - - # Needed for mypy - encoder_inputs = cast(TokenInputs | MultiModalInputs, encoder_inputs) - decoder_inputs = cast(TokenInputs | MultiModalInputs | None, decoder_inputs) - if decoder_inputs is None: - if self.model_config.hf_config.model_type == "whisper": - # For Whisper models, the text prompt should go to the decoder. - # If no explicit encoder/decoder inputs, then copy the prompt - # from the encoder to the decoder. The encoder tokens are later - # overridden by the audio features. - dec_token_ids = encoder_inputs["prompt_token_ids"].copy() - else: - dec_token_ids = self._prepare_decoder_input_ids_for_generation(None) - decoder_inputs = token_inputs(dec_token_ids) - else: - if "multi_modal_data" in decoder_inputs: - raise ValueError( - "Multi-modal decoder inputs of encoder-" - "decoder models are not supported yet" - ) + decoder_inputs = encoder_inputs - dec_token_ids = self._prepare_decoder_input_ids_for_generation( - decoder_inputs["prompt_token_ids"] - ) - decoder_inputs["prompt_token_ids"] = dec_token_ids + enc_inputs = self._validate_enc_inputs(encoder_inputs) + dec_inputs = self._validate_dec_inputs(decoder_inputs) - return EncoderDecoderInputs( - encoder=encoder_inputs, - decoder=decoder_inputs, - ) + enc_inputs_new: TokenInputs | MultiModalEncDecInputs + dec_inputs_new: TokenInputs | MultiModalInputs - def _split_enc_dec_mm_inputs( - self, - inputs: SingletonInputs | MultiModalEncDecInputs, - decoder_inputs_to_override: SingletonInputs | None = None, - ) -> tuple[SingletonInputs, SingletonInputs]: - """ - For encoder/decoder models only: - Separate Encoder/Decoder inputs from a MultiModalEncDecInputs - """ - if ( - inputs["type"] == "embeds" - or decoder_inputs_to_override - and decoder_inputs_to_override["type"] == "embeds" - ): - raise ValueError( - "Embedding inputs are not supported for encoder-decoder models" - ) - - # Needed for mypy - inputs = cast( - TokenInputs | MultiModalInputs | MultiModalEncDecInputs, - inputs, - ) - decoder_inputs_to_override = cast( - TokenInputs | MultiModalInputs | None, - decoder_inputs_to_override, - ) - - encoder_inputs: SingletonInputs - decoder_inputs: SingletonInputs - - if inputs["type"] == "multimodal": # Multimodal data inputs - if "encoder_prompt_token_ids" not in inputs: - raise RuntimeError( - "You should register an encoder-decoder " - "multi-modal processor for encoder-decoder " - "models." - ) - inputs = cast(MultiModalEncDecInputs, inputs) - - encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"]) - - decoder_prompt_inputs = decoder_inputs_to_override or inputs - decoder_inputs = MultiModalInputs( + if enc_inputs["type"] == "multimodal": + enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"]) + dec_inputs_new = MultiModalInputs( type="multimodal", - prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], - mm_kwargs=inputs["mm_kwargs"], - mm_hashes=inputs["mm_hashes"], - mm_placeholders=inputs["mm_placeholders"], + prompt_token_ids=dec_inputs["prompt_token_ids"], + mm_kwargs=enc_inputs["mm_kwargs"], + mm_hashes=enc_inputs["mm_hashes"], + mm_placeholders=enc_inputs["mm_placeholders"], ) - if cache_salt := inputs.get("cache_salt"): - decoder_inputs["cache_salt"] = cache_salt - - elif inputs["type"] == "token": # Text-only inputs - encoder_inputs = token_inputs(prompt_token_ids=[]) - decoder_inputs = decoder_inputs_to_override or inputs + elif enc_inputs["type"] == "token": + enc_inputs_new = token_inputs(prompt_token_ids=[]) + dec_inputs_new = dec_inputs else: - assert_never(inputs) # type: ignore[arg-type] + assert_never(enc_inputs) - return encoder_inputs, decoder_inputs + dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids( + dec_inputs_new["prompt_token_ids"] + ) + if cache_salt := enc_inputs.get("cache_salt"): + dec_inputs_new["cache_salt"] = cache_salt + + return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new) def _process_encoder_decoder_prompt( self, @@ -574,54 +475,23 @@ class InputPreprocessor: * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] instance """ - encoder_inputs: SingletonInputs - decoder_inputs: SingletonInputs | None - if is_explicit_encoder_decoder_prompt(prompt): - # `cast` is needed for mypy, but not pyright - prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt) - encoder_inputs = self._prompt_to_llm_inputs( - prompt_["encoder_prompt"], + encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt) + + return self._build_enc_dec_inputs( + encoder_inputs=self._prompt_to_llm_inputs( + encoder_prompt, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, - ) - if (decoder_input := prompt_["decoder_prompt"]) is None: - decoder_inputs = None - else: - decoder_inputs = self._prompt_to_llm_inputs( - decoder_input, tokenization_kwargs=tokenization_kwargs + ), + decoder_inputs=( + None + if decoder_prompt is None + else self._prompt_to_llm_inputs( + decoder_prompt, + tokenization_kwargs=tokenization_kwargs, ) - # For multimodal model, override decoder prompt from processor - # with explicit decoder prompt. - if self.model_config.is_multimodal_model: - encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs( - encoder_inputs, decoder_inputs - ) - else: - # `cast` is needed for mypy, but not pyright - inputs = self._prompt_to_llm_inputs( - cast(SingletonPrompt, prompt), - tokenization_kwargs=tokenization_kwargs, - mm_uuids=mm_uuids, - ) - if self.model_config.is_multimodal_model: - # Encoder-Decoder Multimodal model - encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs(inputs) - else: - encoder_inputs = inputs - decoder_inputs = None - - return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) - - def _build_decoder_only_llm_inputs( - self, - prompt_inputs: DecoderOnlyInputs, - ) -> DecoderOnlyInputs: - if "prompt_token_ids" in prompt_inputs: - prompt_inputs = cast( - TokenInputs | MultiModalInputs, prompt_inputs - ) # Needed for mypy - - return prompt_inputs + ), + ) def _process_decoder_only_prompt( self, @@ -643,15 +513,12 @@ class InputPreprocessor: * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance """ - - prompt_comps = self._prompt_to_llm_inputs( + return self._prompt_to_llm_inputs( prompt, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) - return self._build_decoder_only_llm_inputs(prompt_comps) - def _preprocess( self, prompt: PromptType, @@ -673,10 +540,8 @@ class InputPreprocessor: "Cannot pass encoder-decoder prompt to decoder-only models" ) - # Decoder-only operation - # `cast` is needed for mypy, but not pyright return self._process_decoder_only_prompt( - cast(SingletonPrompt, prompt), + prompt, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, ) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index e50771e99..262def712 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -1083,6 +1083,10 @@ class MultiModalEncDecInputs(MultiModalInputs): Represents the outputs of [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor] ready to be passed to vLLM internals. + + Note: Even text-only encoder-decoder models are currently implemented + as multi-modal models for convenience. + (Example: https://github.com/neuralmagic/bart-plugin) """ encoder_prompt_token_ids: list[int]