[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-26 15:00:28 +08:00
committed by GitHub
parent ee484b3f4b
commit 11b556878b
14 changed files with 701 additions and 604 deletions

View File

@@ -8,14 +8,24 @@ from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs import (
ProcessorInputs,
PromptType,
SingletonInputs,
SingletonPrompt,
TextPrompt,
)
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
MultiModalUUIDDict,
)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
@@ -188,7 +198,66 @@ class InputProcessor:
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor()
return mm_processor.data_parser.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str):
prompt = TextPrompt(prompt=prompt)
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
if not mm_data and not mm_uuids:
return
mm_data_parsed = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
mm_uuids_parsed = {
k: [v] if isinstance(v, str) else v
for k, v in mm_uuids.items()
if v is not None
}
# NOTE: Include the keys corresponding to `None`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities:
data_items = cast(
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
)
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
if len(data_items) > 0:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have "
f"same length as multi_modal_data[{modality!r}], but "
f"got {len(uuid_items)} vs {len(data_items)}."
)
for i, item in enumerate(data_items):
if item is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing."
)
else:
if len(uuid_items) == 0:
raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
def _validate_mm_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
@@ -196,55 +265,13 @@ class InputProcessor:
auto-hashed downstream.
"""
def _validate_single_prompt(single_prompt: dict | str) -> None:
if not isinstance(single_prompt, dict):
return
if is_explicit_encoder_decoder_prompt(prompt):
self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
import torch
def _get_len(items: object):
if isinstance(items, dict): # Embedding inputs
return _get_len(next(iter(items.values()))) if items else 1
if isinstance(items, list):
return len(items)
if isinstance(items, torch.Tensor):
# To keep backwards compatibility for single item embedding input
return 1 if getattr(items, "_is_single_item", False) else len(items)
return 1
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = _get_len(items)
uuid_len = _get_len(mm_uuids[modality])
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} "
"must have same length as data: got "
f"{uuid_len} uuids vs {data_len} items."
)
else:
raise ValueError(
f"multi_modal_uuids for modality {modality!r} must "
"be provided if multi_modal_data is provided."
)
# Handle explicit encoder/decoder prompts or singleton prompt
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(cast(dict | str, enc))
if dec is not None:
_validate_single_prompt(cast(dict | str, dec))
if (dec_prompt := prompt["decoder_prompt"]) is not None:
self._validate_singleton_mm_uuids(dec_prompt)
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
self._validate_singleton_mm_uuids(prompt)
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
if lora_request is None:
@@ -379,6 +406,20 @@ class InputProcessor:
# roundtrip serialization/deserialization won't fail.
params.structured_outputs.__post_init__()
def _extract_singleton_mm_data(
self, prompt: SingletonPrompt
) -> MultiModalDataDict | None:
if isinstance(prompt, str):
return None
return prompt.get("multi_modal_data") # type: ignore[return-value]
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
if is_explicit_encoder_decoder_prompt(prompt):
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
else:
return self._extract_singleton_mm_data(prompt)
def _maybe_build_mm_uuids(
self,
request_id: str,
@@ -391,31 +432,18 @@ class InputProcessor:
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
mm_data = self._extract_mm_data(prompt)
if not mm_data:
return None
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
# Hash each item for embedding inputs.
n = (
len(data)
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
else 1
)
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids
mm_items = self._parse_mm_items(
{k: v for k, v in mm_data.items() if v is not None}
)
return {
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items()
}
def _get_mm_identifier(
self,
@@ -494,7 +522,7 @@ class InputProcessor:
else:
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
self._validate_mm_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")