[Bugfix] Dictionary MM embeddings for online chat (#30507)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
@@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _extract_embeds(tensors: list[torch.Tensor]):
|
||||
if len(tensors) == 0:
|
||||
return tensors
|
||||
|
||||
if len(tensors) == 1:
|
||||
tensors[0]._is_single_item = True # type: ignore
|
||||
return tensors[0] # To keep backwards compatibility for single item input
|
||||
|
||||
first_shape = tensors[0].shape
|
||||
if all(t.shape == first_shape for t in tensors):
|
||||
return torch.stack(tensors)
|
||||
|
||||
return tensors
|
||||
|
||||
|
||||
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
|
||||
embeds_key = f"{modality}_embeds"
|
||||
embeds = items_by_modality[embeds_key]
|
||||
|
||||
if len(embeds) == 0:
|
||||
return embeds
|
||||
if is_list_of(embeds, torch.Tensor):
|
||||
return _extract_embeds(embeds)
|
||||
if is_list_of(embeds, dict):
|
||||
if not embeds:
|
||||
return {}
|
||||
|
||||
first_keys = set(embeds[0].keys())
|
||||
if any(set(item.keys()) != first_keys for item in embeds[1:]):
|
||||
raise ValueError(
|
||||
"All dictionaries in the list of embeddings must have the same keys."
|
||||
)
|
||||
|
||||
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
@@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_uuids = {}
|
||||
|
||||
uuids_by_modality = dict(self._uuids_by_modality)
|
||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_uuids = {}
|
||||
if "image_embeds" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||
if "image" in uuids_by_modality:
|
||||
@@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||
if "video" in uuids_by_modality:
|
||||
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
||||
|
||||
return mm_uuids
|
||||
|
||||
@abstractmethod
|
||||
@@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
mm_inputs["image"] = (
|
||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||
mm_inputs["audio"] = (
|
||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
@@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
async def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = {}
|
||||
for modality, items in self._items_by_modality.items():
|
||||
coros = []
|
||||
for item in items:
|
||||
if item is not None:
|
||||
coros.append(item)
|
||||
else:
|
||||
coros.append(asyncio.sleep(0))
|
||||
items_by_modality[modality] = await asyncio.gather(*coros)
|
||||
|
||||
coros_by_modality = {
|
||||
modality: [item or asyncio.sleep(0) for item in items]
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
items_by_modality: dict[str, list[object | None]] = {
|
||||
modality: await asyncio.gather(*coros)
|
||||
for modality, coros in coros_by_modality.items()
|
||||
}
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
mm_inputs["image"] = (
|
||||
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
audio_embeds_lst = items_by_modality["audio_embeds"]
|
||||
mm_inputs["audio"] = (
|
||||
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
|
||||
)
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
|
||||
Reference in New Issue
Block a user