[Bugfix] Dictionary MM embeddings for online chat (#30507)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-12-13 15:48:56 +08:00
committed by GitHub
parent fdc135d768
commit b09806e28f
3 changed files with 193 additions and 44 deletions

View File

@@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
from collections.abc import Awaitable, Callable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
import jinja2
import jinja2.ext
@@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import random_uuid
from vllm.utils.collection_utils import is_list_of
from vllm.utils.func_utils import supports_kw
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
@@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T = TypeVar("_T")
def _extract_embeds(tensors: list[torch.Tensor]):
if len(tensors) == 0:
return tensors
if len(tensors) == 1:
tensors[0]._is_single_item = True # type: ignore
return tensors[0] # To keep backwards compatibility for single item input
first_shape = tensors[0].shape
if all(t.shape == first_shape for t in tensors):
return torch.stack(tensors)
return tensors
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
embeds_key = f"{modality}_embeds"
embeds = items_by_modality[embeds_key]
if len(embeds) == 0:
return embeds
if is_list_of(embeds, torch.Tensor):
return _extract_embeds(embeds)
if is_list_of(embeds, dict):
if not embeds:
return {}
first_keys = set(embeds[0].keys())
if any(set(item.keys()) != first_keys for item in embeds[1:]):
raise ValueError(
"All dictionaries in the list of embeddings must have the same keys."
)
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
return embeds
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
@@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
if not self._items_by_modality:
return None
mm_uuids = {}
uuids_by_modality = dict(self._uuids_by_modality)
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_uuids = {}
if "image_embeds" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
@@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality:
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
return mm_uuids
@abstractmethod
@@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> MultiModalDataDict | None:
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_inputs = {}
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
@@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> MultiModalDataDict | None:
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = {}
for modality, items in self._items_by_modality.items():
coros = []
for item in items:
if item is not None:
coros.append(item)
else:
coros.append(asyncio.sleep(0))
items_by_modality[modality] = await asyncio.gather(*coros)
coros_by_modality = {
modality: [item or asyncio.sleep(0) for item in items]
for modality, items in self._items_by_modality.items()
}
items_by_modality: dict[str, list[object | None]] = {
modality: await asyncio.gather(*coros)
for modality, coros in coros_by_modality.items()
}
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
mm_inputs = {}
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":