[Renderer] Move InputPreprocessor into Renderer (1.5/2) (#34598)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-16 15:46:33 +08:00
committed by GitHub
parent bb59c90248
commit ec17bdd894
11 changed files with 209 additions and 136 deletions

View File

@@ -93,14 +93,14 @@ def _build_renderer(
def _preprocess_prompt(
mdoel_config: ModelConfig,
model_config: ModelConfig,
prompt_or_prompts: SingletonPrompt | bytes | Sequence[SingletonPrompt | bytes],
):
return [
(
prompt
if isinstance(prompt, bytes)
else parse_model_prompt(mdoel_config, prompt)
else parse_model_prompt(model_config, prompt)
)
for prompt in prompt_to_seq(prompt_or_prompts)
]

View File

@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch
from typing_extensions import NotRequired, TypedDict
from typing_extensions import NotRequired, TypedDict, assert_never
if TYPE_CHECKING:
from vllm.multimodal.inputs import (
@@ -200,15 +200,22 @@ class TokenInputs(_InputOptions):
prompt_token_ids: list[int]
"""The token IDs of the prompt."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
def token_inputs(
prompt_token_ids: list[int],
*,
prompt: str | None = None,
cache_salt: str | None = None,
) -> TokenInputs:
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
@@ -224,15 +231,22 @@ class EmbedsInputs(_InputOptions):
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
def embeds_inputs(
prompt_embeds: torch.Tensor,
*,
prompt: str | None = None,
cache_salt: str | None = None,
) -> EmbedsInputs:
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
@@ -278,10 +292,12 @@ class EncoderDecoderInputs(TypedDict):
for encoder-decoder models.
"""
encoder: EncoderInputs
type: Literal["enc_dec"]
encoder_prompt: EncoderInputs
"""The inputs for the encoder portion."""
decoder: DecoderInputs
decoder_prompt: DecoderInputs
"""The inputs for the decoder portion."""
@@ -296,3 +312,94 @@ which can be passed to
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
"""The inputs for a single encoder/decoder prompt."""
def _validate_enc_inputs(inputs: SingletonInputs) -> EncoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
raise RuntimeError(
"You should register an encoder-decoder multi-modal processor "
"for encoder-decoder models."
)
return inputs # type: ignore[return-value]
def _validate_dec_inputs(inputs: SingletonInputs) -> DecoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
return inputs
def _prepare_decoder_input_ids_for_generation(
decoder_input_ids: list[int],
decoder_start_token_id: int,
) -> list[int]:
"""
Prepare `decoder_input_ids` for generation with encoder-decoder models,
according to `GenerationMixin._prepare_decoder_input_ids_for_generation()`.
Source:
https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/generation/utils.py
"""
if len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id:
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def build_enc_dec_inputs(
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None,
decoder_start_token_id: int,
) -> EncoderDecoderInputs:
enc_inputs = _validate_enc_inputs(encoder_inputs)
if decoder_inputs is None:
dec_inputs: DecoderInputs = enc_inputs
else:
dec_inputs = _validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal":
from vllm.multimodal.inputs import mm_inputs
enc_inputs_new = token_inputs(
enc_inputs["encoder_prompt_token_ids"],
prompt=enc_inputs.get("encoder_prompt"),
)
dec_inputs_new = mm_inputs(
prompt_token_ids=dec_inputs["prompt_token_ids"],
prompt=dec_inputs.get("prompt"),
mm_kwargs=enc_inputs["mm_kwargs"],
mm_hashes=enc_inputs["mm_hashes"],
mm_placeholders=enc_inputs["mm_placeholders"],
)
elif enc_inputs["type"] == "token":
enc_inputs_new = token_inputs(prompt_token_ids=[])
dec_inputs_new = dec_inputs
else:
assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"],
decoder_start_token_id,
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt
return EncoderDecoderInputs(
type="enc_dec",
encoder_prompt=enc_inputs_new,
decoder_prompt=dec_inputs_new,
)

View File

@@ -7,11 +7,7 @@ from .data import ProcessorInputs, SingletonInputs
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[SingletonInputs | None, SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
if inputs["type"] == "enc_dec":
return inputs["encoder_prompt"], inputs["decoder_prompt"]
return None, inputs

View File

@@ -7,6 +7,7 @@ from typing import Any, overload
from typing_extensions import assert_never
from vllm.config import VllmConfig
from vllm.inputs.data import build_enc_dec_inputs
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (
@@ -67,54 +68,6 @@ class InputPreprocessor:
def get_tokenizer(self) -> TokenizerLike:
return self.renderer.get_tokenizer()
def get_decoder_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder
model. Raises an error if it is not available.
"""
dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None
)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.renderer.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
return dec_start_token_id
def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on:
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
specifically,
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id = self.get_decoder_start_token_id()
if (
len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id
):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
def _tokenize_prompt(
self,
prompt: str,
@@ -332,66 +285,6 @@ class InputPreprocessor:
assert_never(prompt) # type: ignore[arg-type]
def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
raise RuntimeError(
"You should register an encoder-decoder "
"multi-modal processor for encoder-decoder models."
)
return inputs # type: ignore[return-value]
def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
if inputs["type"] == "embeds":
raise ValueError(
"Embedding inputs are not supported for encoder-decoder models"
)
return inputs
def _build_enc_dec_inputs(
self,
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None = None,
) -> EncoderDecoderInputs:
enc_inputs = self._validate_enc_inputs(encoder_inputs)
if decoder_inputs is None:
dec_inputs: DecoderInputs = enc_inputs # type: ignore[assignment]
else:
dec_inputs = self._validate_dec_inputs(decoder_inputs)
enc_inputs_new: EncoderInputs
dec_inputs_new: DecoderInputs
if enc_inputs["type"] == "multimodal":
enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
dec_inputs_new = MultiModalInputs(
type="multimodal",
prompt_token_ids=dec_inputs["prompt_token_ids"],
mm_kwargs=enc_inputs["mm_kwargs"],
mm_hashes=enc_inputs["mm_hashes"],
mm_placeholders=enc_inputs["mm_placeholders"],
)
elif enc_inputs["type"] == "token":
enc_inputs_new = token_inputs(prompt_token_ids=[])
dec_inputs_new = dec_inputs
else:
assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
dec_inputs_new["prompt_token_ids"]
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt
return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
def _process_encoder_decoder_prompt(
self,
prompt: EncoderDecoderDictPrompt,
@@ -417,7 +310,7 @@ class InputPreprocessor:
encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"]
return self._build_enc_dec_inputs(
return build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt,
tokenization_kwargs=tokenization_kwargs,
@@ -431,6 +324,7 @@ class InputPreprocessor:
tokenization_kwargs=tokenization_kwargs,
)
),
decoder_start_token_id=self.renderer.get_dec_start_token_id(),
)
def _process_decoder_only_prompt(

View File

@@ -31,6 +31,7 @@ from vllm.multimodal.inputs import (
MultiModalInputs,
MultiModalKwargsItems,
MultiModalUUIDDict,
mm_inputs,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
@@ -837,8 +838,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputs(
type="multimodal",
return mm_inputs(
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@@ -48,6 +48,7 @@ from vllm.multimodal.inputs import (
MultiModalKwargsItems,
MultiModalUUIDDict,
PlaceholderRange,
mm_inputs,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
@@ -222,8 +223,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
),
)
return MultiModalInputs(
type="multimodal",
return mm_inputs(
prompt_token_ids=[1],
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@@ -33,6 +33,7 @@ from vllm.multimodal.inputs import (
MultiModalInputs,
MultiModalUUIDDict,
PlaceholderRange,
mm_inputs,
)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (
@@ -260,8 +261,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
return MultiModalInputs(
type="multimodal",
return mm_inputs(
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,

View File

@@ -20,7 +20,7 @@ from typing import (
import numpy as np
from PIL.Image import Image
from typing_extensions import TypeVar
from typing_extensions import NotRequired, TypeVar
from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
@@ -1075,6 +1075,9 @@ class MultiModalInputs(_InputOptions):
prompt_token_ids: list[int]
"""The processed token IDs which includes placeholder tokens."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
@@ -1088,6 +1091,31 @@ class MultiModalInputs(_InputOptions):
"""
def mm_inputs(
prompt_token_ids: list[int],
mm_kwargs: MultiModalKwargsOptionalItems,
mm_hashes: MultiModalHashes,
mm_placeholders: MultiModalPlaceholderDict,
*,
prompt: str | None = None,
cache_salt: str | None = None,
) -> MultiModalInputs:
inputs = MultiModalInputs(
type="multimodal",
prompt_token_ids=prompt_token_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
if prompt is not None:
inputs["prompt"] = prompt
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
class MultiModalEncDecInputs(MultiModalInputs):
"""
Represents the outputs of
@@ -1101,3 +1129,31 @@ class MultiModalEncDecInputs(MultiModalInputs):
encoder_prompt_token_ids: list[int]
"""The processed token IDs of the encoder prompt."""
encoder_prompt: NotRequired[str]
"""The prompt text corresponding to the encoder token IDs, if available."""
def mm_enc_dec_inputs(
encoder_inputs: MultiModalInputs,
decoder_prompt_token_ids: list[int],
*,
decoder_prompt: str | None = None,
) -> MultiModalEncDecInputs:
inputs = MultiModalEncDecInputs(
type="multimodal",
prompt_token_ids=decoder_prompt_token_ids,
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
mm_kwargs=encoder_inputs["mm_kwargs"],
mm_hashes=encoder_inputs["mm_hashes"],
mm_placeholders=encoder_inputs["mm_placeholders"],
)
if decoder_prompt is not None:
inputs["prompt"] = decoder_prompt
if "prompt" in encoder_inputs:
inputs["encoder_prompt"] = encoder_inputs["prompt"]
if "cache_salt" in encoder_inputs:
inputs["cache_salt"] = encoder_inputs["cache_salt"]
return inputs

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, TypeVar
@@ -26,7 +26,7 @@ class MediaWithBytes(Generic[_T]):
"""
media: _T
original_bytes: bytes
original_bytes: bytes = field(repr=False)
def __array__(self, *args, **kwargs) -> np.ndarray:
"""Allow np.array(obj) to return np.array(obj.media)."""

View File

@@ -34,6 +34,8 @@ from ..inputs import (
MultiModalKwargsOptionalItems,
MultiModalUUIDDict,
PlaceholderRange,
mm_enc_dec_inputs,
mm_inputs,
)
from ..parse import (
DictEmbeddingItems,
@@ -1803,8 +1805,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, placeholders in mm_placeholders.items()
}
return MultiModalInputs(
type="multimodal",
return mm_inputs(
prompt_token_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes,
@@ -1848,12 +1849,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
else:
decoder_prompt_ids = decoder_prompt_raw
mm_inputs = MultiModalEncDecInputs(
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs,
return mm_enc_dec_inputs(
encoder_inputs,
decoder_prompt_ids,
)
mm_inputs["prompt_token_ids"] = decoder_prompt_ids
return mm_inputs
def apply(
self,

View File

@@ -153,6 +153,27 @@ class BaseRenderer(ABC, Generic[_T]):
return self.tokenizer.eos_token_id
def get_dec_start_token_id(self) -> int:
"""
Obtain the decoder start token id employed by an encoder/decoder model,
raising an error if it is not available.
"""
dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None
)
if dec_start_token_id is None:
logger.warning_once(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id = self.get_bos_token_id()
if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
return dec_start_token_id
@cached_property
def default_cmpl_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor