[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -8,14 +8,24 @@ from typing import Any, Literal, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt, split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -188,7 +198,66 @@ class InputProcessor:
|
||||
self._validate_sampling_params(params)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
|
||||
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
mm_processor = self.input_preprocessor._get_mm_processor()
|
||||
return mm_processor.data_parser.parse_mm_data(mm_data)
|
||||
|
||||
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
|
||||
if isinstance(prompt, str):
|
||||
prompt = TextPrompt(prompt=prompt)
|
||||
|
||||
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
|
||||
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
|
||||
if not mm_data and not mm_uuids:
|
||||
return
|
||||
|
||||
mm_data_parsed = self._parse_mm_items(
|
||||
{k: v for k, v in mm_data.items() if v is not None}
|
||||
)
|
||||
mm_uuids_parsed = {
|
||||
k: [v] if isinstance(v, str) else v
|
||||
for k, v in mm_uuids.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# NOTE: Include the keys corresponding to `None`
|
||||
modalities = mm_data.keys() | mm_uuids.keys()
|
||||
|
||||
for modality in modalities:
|
||||
data_items = cast(
|
||||
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
|
||||
)
|
||||
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
|
||||
|
||||
if len(data_items) > 0:
|
||||
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
|
||||
raise ValueError(
|
||||
f"If given, multi_modal_uuids[{modality!r}] must have "
|
||||
f"same length as multi_modal_data[{modality!r}], but "
|
||||
f"got {len(uuid_items)} vs {len(data_items)}."
|
||||
)
|
||||
|
||||
for i, item in enumerate(data_items):
|
||||
if item is None:
|
||||
if not uuid_items:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
if uuid_items[i] is None:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}][{i}] is missing."
|
||||
)
|
||||
else:
|
||||
if len(uuid_items) == 0:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
def _validate_mm_uuids(self, prompt: PromptType) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
@@ -196,55 +265,13 @@ class InputProcessor:
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
def _validate_single_prompt(single_prompt: dict | str) -> None:
|
||||
if not isinstance(single_prompt, dict):
|
||||
return
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
self._validate_singleton_mm_uuids(prompt["encoder_prompt"])
|
||||
|
||||
mm_data = single_prompt.get("multi_modal_data")
|
||||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||||
if not mm_data or not mm_uuids:
|
||||
return
|
||||
|
||||
import torch
|
||||
|
||||
def _get_len(items: object):
|
||||
if isinstance(items, dict): # Embedding inputs
|
||||
return _get_len(next(iter(items.values()))) if items else 1
|
||||
|
||||
if isinstance(items, list):
|
||||
return len(items)
|
||||
if isinstance(items, torch.Tensor):
|
||||
# To keep backwards compatibility for single item embedding input
|
||||
return 1 if getattr(items, "_is_single_item", False) else len(items)
|
||||
|
||||
return 1
|
||||
|
||||
for modality, items in mm_data.items():
|
||||
if modality in mm_uuids:
|
||||
data_len = _get_len(items)
|
||||
uuid_len = _get_len(mm_uuids[modality])
|
||||
if uuid_len != data_len:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality {modality!r} "
|
||||
"must have same length as data: got "
|
||||
f"{uuid_len} uuids vs {data_len} items."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality {modality!r} must "
|
||||
"be provided if multi_modal_data is provided."
|
||||
)
|
||||
|
||||
# Handle explicit encoder/decoder prompts or singleton prompt
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc = prompt.get("encoder_prompt")
|
||||
dec = prompt.get("decoder_prompt")
|
||||
if enc is not None:
|
||||
_validate_single_prompt(cast(dict | str, enc))
|
||||
if dec is not None:
|
||||
_validate_single_prompt(cast(dict | str, dec))
|
||||
if (dec_prompt := prompt["decoder_prompt"]) is not None:
|
||||
self._validate_singleton_mm_uuids(dec_prompt)
|
||||
else:
|
||||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||||
self._validate_singleton_mm_uuids(prompt)
|
||||
|
||||
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
|
||||
if lora_request is None:
|
||||
@@ -379,6 +406,20 @@ class InputProcessor:
|
||||
# roundtrip serialization/deserialization won't fail.
|
||||
params.structured_outputs.__post_init__()
|
||||
|
||||
def _extract_singleton_mm_data(
|
||||
self, prompt: SingletonPrompt
|
||||
) -> MultiModalDataDict | None:
|
||||
if isinstance(prompt, str):
|
||||
return None
|
||||
|
||||
return prompt.get("multi_modal_data") # type: ignore[return-value]
|
||||
|
||||
def _extract_mm_data(self, prompt: PromptType) -> MultiModalDataDict | None:
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
return self._extract_singleton_mm_data(prompt["encoder_prompt"])
|
||||
else:
|
||||
return self._extract_singleton_mm_data(prompt)
|
||||
|
||||
def _maybe_build_mm_uuids(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -391,31 +432,18 @@ class InputProcessor:
|
||||
Returns a dictionary of modality -> list[str] of overrides, or None if
|
||||
disabled or no multimodal data is present.
|
||||
"""
|
||||
|
||||
def _extract_mm_data(p: PromptType):
|
||||
if isinstance(p, dict) and "encoder_prompt" in p:
|
||||
enc = p.get("encoder_prompt")
|
||||
if isinstance(enc, dict):
|
||||
return enc.get("multi_modal_data")
|
||||
return None
|
||||
if isinstance(p, dict):
|
||||
return p.get("multi_modal_data")
|
||||
return None
|
||||
|
||||
mm_data = _extract_mm_data(prompt)
|
||||
mm_data = self._extract_mm_data(prompt)
|
||||
if not mm_data:
|
||||
return None
|
||||
|
||||
mm_uuids: dict[str, list[str | None] | str] = {}
|
||||
for modality, data in mm_data.items():
|
||||
# Hash each item for embedding inputs.
|
||||
n = (
|
||||
len(data)
|
||||
if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
|
||||
else 1
|
||||
)
|
||||
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
|
||||
return mm_uuids
|
||||
mm_items = self._parse_mm_items(
|
||||
{k: v for k, v in mm_data.items() if v is not None}
|
||||
)
|
||||
|
||||
return {
|
||||
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
|
||||
for modality, data_count in mm_items.get_all_counts().items()
|
||||
}
|
||||
|
||||
def _get_mm_identifier(
|
||||
self,
|
||||
@@ -494,7 +522,7 @@ class InputProcessor:
|
||||
else:
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_multi_modal_uuids(prompt)
|
||||
self._validate_mm_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_uuids = cast(
|
||||
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
|
||||
|
||||
Reference in New Issue
Block a user