[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-26 15:00:28 +08:00
committed by GitHub
parent ee484b3f4b
commit 11b556878b
14 changed files with 701 additions and 604 deletions

View File

@@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
from functools import cached_property, lru_cache, partial
from itertools import accumulate
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
@@ -41,6 +42,12 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.inputs import (
MultiModalBatchedField,
MultiModalFlatField,
MultiModalSharedField,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.utils import random_uuid
from vllm.utils.collection_utils import is_list_of
@@ -48,7 +55,9 @@ from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
import transformers
else:
transformers = LazyLoader("transformers", globals(), "transformers")
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
@@ -331,42 +340,113 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T = TypeVar("_T")
def _extract_embeds(tensors: list[torch.Tensor]):
if len(tensors) == 0:
return tensors
# Backward compatibility for single item input
class _BatchedSingleItemField(MultiModalSharedField):
pass
if len(tensors) == 1:
tensors[0]._is_single_item = True # type: ignore
return tensors[0] # To keep backwards compatibility for single item input
first_shape = tensors[0].shape
def _detect_field(
tensors: list[torch.Tensor],
mm_processor: BaseMultiModalProcessor,
):
first_item = tensors[0]
hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
if (
len(tensors) == 1
and first_item.ndim == 3
and first_item.shape[0] == 1
and first_item.shape[-1] == hidden_size
):
logger.warning(
"Batched multi-modal embedding inputs are deprecated for Chat API. "
"Please pass a separate content part for each multi-modal item."
)
return _BatchedSingleItemField(batch_size=1)
first_shape = first_item.shape
if all(t.shape == first_shape for t in tensors):
return torch.stack(tensors)
return MultiModalBatchedField()
return tensors
size_per_item = [len(tensor) for tensor in tensors]
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
(slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
]
return MultiModalFlatField(slices=slices)
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
embeds_key = f"{modality}_embeds"
embeds = items_by_modality[embeds_key]
def _merge_embeds(
data_items: list[dict[str, "torch.Tensor"]],
mm_processor: BaseMultiModalProcessor,
):
if not data_items:
return {}
if len(embeds) == 0:
return embeds
if is_list_of(embeds, torch.Tensor):
return _extract_embeds(embeds)
if is_list_of(embeds, dict):
if not embeds:
return {}
first_keys = set(data_items[0].keys())
if any(set(item.keys()) != first_keys for item in data_items[1:]):
raise ValueError(
"All dictionaries in the list of embeddings must have the same keys."
)
first_keys = set(embeds[0].keys())
if any(set(item.keys()) != first_keys for item in embeds[1:]):
raise ValueError(
"All dictionaries in the list of embeddings must have the same keys."
fields = {
key: _detect_field([item[key] for item in data_items], mm_processor)
for key in first_keys
}
data_merged = {
key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
for key, field in fields.items()
}
try:
# TODO: Support per-request mm_processor_kwargs
parsed_configs = mm_processor._get_mm_fields_config(
transformers.BatchFeature(data_merged),
{},
)
parsed_fields = {key: parsed_configs[key].field for key in first_keys}
keys_to_update = [
key
for key in first_keys
if (
fields[key] != parsed_fields[key]
and not isinstance(fields[key], _BatchedSingleItemField)
)
]
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
for key in keys_to_update:
data_merged[key] = parsed_fields[key]._reduce_data(
[item[key] for item in data_items], pin_memory=False
)
except Exception:
logger.exception(
"Error when parsing merged embeddings. "
"Falling back to auto-detected fields."
)
return embeds
return data_merged
def _get_embeds_data(
modality: str,
data_items: list[Any],
mm_processor: BaseMultiModalProcessor,
):
if len(data_items) == 0:
return data_items
if all(item is None for item in data_items):
return data_items
if is_list_of(data_items, torch.Tensor):
embeds_key = f"{modality}_embeds"
dict_items = [{embeds_key: item} for item in data_items]
return _merge_embeds(dict_items, mm_processor)[embeds_key]
if is_list_of(data_items, dict):
return _merge_embeds(data_items, mm_processor)
raise NotImplementedError(type(data_items))
class BaseMultiModalItemTracker(ABC, Generic[_T]):
@@ -381,8 +461,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._items_by_modality = defaultdict[str, list[_T | None]](list)
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
self._items_by_modality = defaultdict[str, list[_T]](list)
@property
def model_config(self) -> ModelConfig:
@@ -411,12 +490,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def mm_processor(self):
return self.mm_registry.create_processor(self.model_config)
def add(
self,
modality: ModalityStr,
item: _T | None,
uuid: str | None = None,
) -> str | None:
def add(self, modality: ModalityStr, item: _T) -> str | None:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
@@ -430,99 +504,80 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item)
self._uuids_by_modality[modality].append(uuid)
return self.model_cls.get_placeholder_str(modality, num_items)
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
if not self._items_by_modality:
return None
uuids_by_modality = dict(self._uuids_by_modality)
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_uuids = {}
if "image_embeds" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio_embeds" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality:
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
return mm_uuids
@abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> MultiModalDataDict | None:
def _resolve_items(
items_by_modality: dict[str, list[tuple[object, str | None]]],
mm_processor: BaseMultiModalProcessor,
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_data = {}
mm_uuids = {}
if "image_embeds" in items_by_modality:
mm_data["image"] = _get_embeds_data(
"image",
[data for data, uuid in items_by_modality["image_embeds"]],
mm_processor,
)
mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
if "image" in items_by_modality:
mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
if "audio_embeds" in items_by_modality:
mm_data["audio"] = _get_embeds_data(
"audio",
[data for data, uuid in items_by_modality["audio_embeds"]],
mm_processor,
)
mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
if "audio" in items_by_modality:
mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
if "video" in items_by_modality:
mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
return mm_data, mm_uuids
class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
def resolve_items(
self,
) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
if not self._items_by_modality:
return None
return None, None
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_inputs = {}
if "image_embeds" in items_by_modality:
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
return _resolve_items(dict(self._items_by_modality), self.mm_processor)
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> MultiModalDataDict | None:
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
):
async def resolve_items(
self,
) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
if not self._items_by_modality:
return None
return None, None
coros_by_modality = {
modality: [item or asyncio.sleep(0) for item in items]
for modality, items in self._items_by_modality.items()
}
items_by_modality: dict[str, list[object | None]] = {
resolved_items_by_modality = {
modality: await asyncio.gather(*coros)
for modality, coros in coros_by_modality.items()
for modality, coros in self._items_by_modality.items()
}
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_inputs = {}
if "image_embeds" in items_by_modality:
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
return _resolve_items(resolved_items_by_modality, self.mm_processor)
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
@@ -611,7 +666,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None
placeholder = self._tracker.add("image", image, uuid)
placeholder = self._tracker.add("image", (image, uuid))
self._add_placeholder("image", placeholder)
def parse_image_embeds(
@@ -630,14 +685,14 @@ class MultiModalContentParser(BaseMultiModalContentParser):
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
placeholder = self._tracker.add("image_embeds", embeds, uuid)
placeholder = self._tracker.add("image_embeds", (embeds, uuid))
if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding, uuid)
placeholder = self._tracker.add("image_embeds", (embedding, uuid))
if image_embeds is None:
placeholder = self._tracker.add("image_embeds", None, uuid)
placeholder = self._tracker.add("image_embeds", (None, uuid))
self._add_placeholder("image", placeholder)
@@ -657,25 +712,25 @@ class MultiModalContentParser(BaseMultiModalContentParser):
k: self._connector.fetch_audio_embedding(v)
for k, v in audio_embeds.items()
}
placeholder = self._tracker.add("audio_embeds", embeds, uuid)
placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
elif isinstance(audio_embeds, str):
embedding = self._connector.fetch_audio_embedding(audio_embeds)
placeholder = self._tracker.add("audio_embeds", embedding, uuid)
placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
else:
placeholder = self._tracker.add("audio_embeds", None, uuid)
placeholder = self._tracker.add("audio_embeds", (None, uuid))
self._add_placeholder("audio", placeholder)
def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
placeholder = self._tracker.add("image", image_pil, uuid)
placeholder = self._tracker.add("image", (image_pil, uuid))
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
audio = self._connector.fetch_audio(audio_url) if audio_url else None
placeholder = self._tracker.add("audio", audio, uuid)
placeholder = self._tracker.add("audio", (audio, uuid))
self._add_placeholder("audio", placeholder)
def parse_input_audio(
@@ -697,7 +752,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
video = self._connector.fetch_video(video_url=video_url) if video_url else None
placeholder = self._tracker.add("video", video, uuid)
placeholder = self._tracker.add("video", (video, uuid))
self._add_placeholder("video", placeholder)
@@ -719,10 +774,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def model_config(self) -> ModelConfig:
return self._tracker.model_config
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
image = (
await self._connector.fetch_image_async(image_url) if image_url else None
)
return image, uuid
placeholder = self._tracker.add("image", image_coro, uuid)
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
coro = self._image_with_uuid_async(image_url, uuid)
placeholder = self._tracker.add("image", coro)
self._add_placeholder("image", placeholder)
def parse_image_embeds(
@@ -736,23 +797,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
"You must set `--enable-mm-embeds` to input `image_embeds`"
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
future = asyncio.Future[
tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
]()
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
future.set_result(embeds)
future.set_result((embeds, uuid))
if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
future.set_result(embedding)
future.set_result((embedding, uuid))
if image_embeds is None:
future.set_result(None)
future.set_result((None, uuid))
placeholder = self._tracker.add("image_embeds", future, uuid)
placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder("image", placeholder)
def parse_audio_embeds(
@@ -766,72 +829,51 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
"You must set `--enable-mm-embeds` to input `audio_embeds`"
)
logger.info(
"🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
"is_str=%s, is_none=%s",
type(audio_embeds).__name__,
uuid,
isinstance(audio_embeds, dict),
isinstance(audio_embeds, str),
audio_embeds is None,
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
future = asyncio.Future[
tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
]()
if isinstance(audio_embeds, dict):
logger.info(
"🎵 Processing dict audio_embeds with %d entries",
len(audio_embeds),
)
embeds = {
k: self._connector.fetch_audio_embedding(v)
for k, v in audio_embeds.items()
}
future.set_result(embeds)
logger.info(
"🎵 Successfully loaded %d audio embeddings from dict",
len(embeds),
)
future.set_result((embeds, uuid))
if isinstance(audio_embeds, str):
base64_size = len(audio_embeds)
logger.info(
"🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
base64_size,
base64_size / 1024,
)
embedding = self._connector.fetch_audio_embedding(audio_embeds)
future.set_result(embedding)
logger.info(
"🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
embedding.shape,
embedding.dtype,
)
future.set_result((embedding, uuid))
if audio_embeds is None:
logger.info("🎵 Audio embeds is None (UUID-only reference)")
future.set_result(None)
future.set_result((None, uuid))
placeholder = self._tracker.add("audio_embeds", future, uuid)
placeholder = self._tracker.add("audio_embeds", future)
self._add_placeholder("audio", placeholder)
logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)
def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None
self,
image_pil: Image.Image | None,
uuid: str | None = None,
) -> None:
future: asyncio.Future[Image.Image | None] = asyncio.Future()
future = asyncio.Future[tuple[Image.Image | None, str | None]]()
if image_pil:
future.set_result(image_pil)
future.set_result((image_pil, uuid))
else:
future.set_result(None)
future.set_result((None, uuid))
placeholder = self._tracker.add("image", future, uuid)
placeholder = self._tracker.add("image", future)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
audio = (
await self._connector.fetch_audio_async(audio_url) if audio_url else None
)
return audio, uuid
placeholder = self._tracker.add("audio", audio_coro, uuid)
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
coro = self._audio_with_uuid_async(audio_url, uuid)
placeholder = self._tracker.add("audio", coro)
self._add_placeholder("audio", placeholder)
def parse_input_audio(
@@ -850,14 +892,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
video = (
self._connector.fetch_video_async(video_url=video_url)
if video_url
else None
await self._connector.fetch_video_async(video_url) if video_url else None
)
return video, uuid
placeholder = self._tracker.add("video", video, uuid)
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
coro = self._video_with_uuid_async(video_url, uuid)
placeholder = self._tracker.add("video", coro)
self._add_placeholder("video", placeholder)
@@ -1380,7 +1424,9 @@ def parse_chat_messages(
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
mm_data, mm_uuids = mm_tracker.resolve_items()
return conversation, mm_data, mm_uuids
async def parse_chat_messages_async(
@@ -1411,7 +1457,9 @@ async def parse_chat_messages_async(
_postprocess_messages(conversation)
return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
mm_data, mm_uuids = await mm_tracker.resolve_items()
return conversation, mm_data, mm_uuids
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):