[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -8,6 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
@@ -41,6 +42,12 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFlatField,
|
||||
MultiModalSharedField,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
@@ -48,7 +55,9 @@ from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
import transformers
|
||||
else:
|
||||
transformers = LazyLoader("transformers", globals(), "transformers")
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -331,42 +340,113 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def _extract_embeds(tensors: list[torch.Tensor]):
|
||||
if len(tensors) == 0:
|
||||
return tensors
|
||||
# Backward compatibility for single item input
|
||||
class _BatchedSingleItemField(MultiModalSharedField):
|
||||
pass
|
||||
|
||||
if len(tensors) == 1:
|
||||
tensors[0]._is_single_item = True # type: ignore
|
||||
return tensors[0] # To keep backwards compatibility for single item input
|
||||
|
||||
first_shape = tensors[0].shape
|
||||
def _detect_field(
|
||||
tensors: list[torch.Tensor],
|
||||
mm_processor: BaseMultiModalProcessor,
|
||||
):
|
||||
first_item = tensors[0]
|
||||
hidden_size = mm_processor.info.ctx.model_config.get_inputs_embeds_size()
|
||||
|
||||
if (
|
||||
len(tensors) == 1
|
||||
and first_item.ndim == 3
|
||||
and first_item.shape[0] == 1
|
||||
and first_item.shape[-1] == hidden_size
|
||||
):
|
||||
logger.warning(
|
||||
"Batched multi-modal embedding inputs are deprecated for Chat API. "
|
||||
"Please pass a separate content part for each multi-modal item."
|
||||
)
|
||||
return _BatchedSingleItemField(batch_size=1)
|
||||
|
||||
first_shape = first_item.shape
|
||||
if all(t.shape == first_shape for t in tensors):
|
||||
return torch.stack(tensors)
|
||||
return MultiModalBatchedField()
|
||||
|
||||
return tensors
|
||||
size_per_item = [len(tensor) for tensor in tensors]
|
||||
slice_idxs = [0, *accumulate(size_per_item)]
|
||||
slices = [
|
||||
(slice(slice_idxs[i], slice_idxs[i + 1]),) for i in range(len(size_per_item))
|
||||
]
|
||||
return MultiModalFlatField(slices=slices)
|
||||
|
||||
|
||||
def _get_embeds_data(items_by_modality: dict[str, list[Any]], modality: str):
|
||||
embeds_key = f"{modality}_embeds"
|
||||
embeds = items_by_modality[embeds_key]
|
||||
def _merge_embeds(
|
||||
data_items: list[dict[str, "torch.Tensor"]],
|
||||
mm_processor: BaseMultiModalProcessor,
|
||||
):
|
||||
if not data_items:
|
||||
return {}
|
||||
|
||||
if len(embeds) == 0:
|
||||
return embeds
|
||||
if is_list_of(embeds, torch.Tensor):
|
||||
return _extract_embeds(embeds)
|
||||
if is_list_of(embeds, dict):
|
||||
if not embeds:
|
||||
return {}
|
||||
first_keys = set(data_items[0].keys())
|
||||
if any(set(item.keys()) != first_keys for item in data_items[1:]):
|
||||
raise ValueError(
|
||||
"All dictionaries in the list of embeddings must have the same keys."
|
||||
)
|
||||
|
||||
first_keys = set(embeds[0].keys())
|
||||
if any(set(item.keys()) != first_keys for item in embeds[1:]):
|
||||
raise ValueError(
|
||||
"All dictionaries in the list of embeddings must have the same keys."
|
||||
fields = {
|
||||
key: _detect_field([item[key] for item in data_items], mm_processor)
|
||||
for key in first_keys
|
||||
}
|
||||
data_merged = {
|
||||
key: field._reduce_data([item[key] for item in data_items], pin_memory=False)
|
||||
for key, field in fields.items()
|
||||
}
|
||||
|
||||
try:
|
||||
# TODO: Support per-request mm_processor_kwargs
|
||||
parsed_configs = mm_processor._get_mm_fields_config(
|
||||
transformers.BatchFeature(data_merged),
|
||||
{},
|
||||
)
|
||||
parsed_fields = {key: parsed_configs[key].field for key in first_keys}
|
||||
keys_to_update = [
|
||||
key
|
||||
for key in first_keys
|
||||
if (
|
||||
fields[key] != parsed_fields[key]
|
||||
and not isinstance(fields[key], _BatchedSingleItemField)
|
||||
)
|
||||
]
|
||||
|
||||
return {k: _extract_embeds([item[k] for item in embeds]) for k in first_keys}
|
||||
for key in keys_to_update:
|
||||
data_merged[key] = parsed_fields[key]._reduce_data(
|
||||
[item[key] for item in data_items], pin_memory=False
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error when parsing merged embeddings. "
|
||||
"Falling back to auto-detected fields."
|
||||
)
|
||||
|
||||
return embeds
|
||||
return data_merged
|
||||
|
||||
|
||||
def _get_embeds_data(
|
||||
modality: str,
|
||||
data_items: list[Any],
|
||||
mm_processor: BaseMultiModalProcessor,
|
||||
):
|
||||
if len(data_items) == 0:
|
||||
return data_items
|
||||
|
||||
if all(item is None for item in data_items):
|
||||
return data_items
|
||||
|
||||
if is_list_of(data_items, torch.Tensor):
|
||||
embeds_key = f"{modality}_embeds"
|
||||
dict_items = [{embeds_key: item} for item in data_items]
|
||||
return _merge_embeds(dict_items, mm_processor)[embeds_key]
|
||||
|
||||
if is_list_of(data_items, dict):
|
||||
return _merge_embeds(data_items, mm_processor)
|
||||
|
||||
raise NotImplementedError(type(data_items))
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
@@ -381,8 +461,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
self._model_config = model_config
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T | None]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
@@ -411,12 +490,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def mm_processor(self):
|
||||
return self.mm_registry.create_processor(self.model_config)
|
||||
|
||||
def add(
|
||||
self,
|
||||
modality: ModalityStr,
|
||||
item: _T | None,
|
||||
uuid: str | None = None,
|
||||
) -> str | None:
|
||||
def add(self, modality: ModalityStr, item: _T) -> str | None:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
@@ -430,99 +504,80 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self.mm_processor.validate_num_items(input_modality, num_items)
|
||||
|
||||
self._items_by_modality[modality].append(item)
|
||||
self._uuids_by_modality[modality].append(uuid)
|
||||
|
||||
return self.model_cls.get_placeholder_str(modality, num_items)
|
||||
|
||||
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
|
||||
uuids_by_modality = dict(self._uuids_by_modality)
|
||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in uuids_by_modality and "audio_embeds" in uuids_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_uuids = {}
|
||||
if "image_embeds" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||
if "image" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
|
||||
if "audio_embeds" in uuids_by_modality:
|
||||
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
|
||||
if "audio" in uuids_by_modality:
|
||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||
if "video" in uuids_by_modality:
|
||||
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
||||
|
||||
return mm_uuids
|
||||
|
||||
@abstractmethod
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
def _resolve_items(
|
||||
items_by_modality: dict[str, list[tuple[object, str | None]]],
|
||||
mm_processor: BaseMultiModalProcessor,
|
||||
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_data = {}
|
||||
mm_uuids = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
mm_data["image"] = _get_embeds_data(
|
||||
"image",
|
||||
[data for data, uuid in items_by_modality["image_embeds"]],
|
||||
mm_processor,
|
||||
)
|
||||
mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image_embeds"]]
|
||||
if "image" in items_by_modality:
|
||||
mm_data["image"] = [data for data, uuid in items_by_modality["image"]]
|
||||
mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]]
|
||||
if "audio_embeds" in items_by_modality:
|
||||
mm_data["audio"] = _get_embeds_data(
|
||||
"audio",
|
||||
[data for data, uuid in items_by_modality["audio_embeds"]],
|
||||
mm_processor,
|
||||
)
|
||||
mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio_embeds"]]
|
||||
if "audio" in items_by_modality:
|
||||
mm_data["audio"] = [data for data, uuid in items_by_modality["audio"]]
|
||||
mm_uuids["audio"] = [uuid for data, uuid in items_by_modality["audio"]]
|
||||
if "video" in items_by_modality:
|
||||
mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
|
||||
mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
|
||||
|
||||
return mm_data, mm_uuids
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]):
|
||||
def resolve_items(
|
||||
self,
|
||||
) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
return _resolve_items(dict(self._items_by_modality), self.mm_processor)
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
async def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
class AsyncMultiModalItemTracker(
|
||||
BaseMultiModalItemTracker[Awaitable[tuple[object, str | None]]]
|
||||
):
|
||||
async def resolve_items(
|
||||
self,
|
||||
) -> tuple[MultiModalDataDict | None, MultiModalUUIDDict | None]:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
coros_by_modality = {
|
||||
modality: [item or asyncio.sleep(0) for item in items]
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
items_by_modality: dict[str, list[object | None]] = {
|
||||
resolved_items_by_modality = {
|
||||
modality: await asyncio.gather(*coros)
|
||||
for modality, coros in coros_by_modality.items()
|
||||
for modality, coros in self._items_by_modality.items()
|
||||
}
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
|
||||
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
|
||||
|
||||
mm_inputs = {}
|
||||
if "image_embeds" in items_by_modality:
|
||||
mm_inputs["image"] = _get_embeds_data(items_by_modality, "image")
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
if "audio_embeds" in items_by_modality:
|
||||
mm_inputs["audio"] = _get_embeds_data(items_by_modality, "audio")
|
||||
if "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
if "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
|
||||
return mm_inputs
|
||||
return _resolve_items(resolved_items_by_modality, self.mm_processor)
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return AsyncMultiModalContentParser(self)
|
||||
@@ -611,7 +666,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
image = self._connector.fetch_image(image_url) if image_url else None
|
||||
|
||||
placeholder = self._tracker.add("image", image, uuid)
|
||||
placeholder = self._tracker.add("image", (image, uuid))
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(
|
||||
@@ -630,14 +685,14 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
placeholder = self._tracker.add("image_embeds", embeds, uuid)
|
||||
placeholder = self._tracker.add("image_embeds", (embeds, uuid))
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding, uuid)
|
||||
placeholder = self._tracker.add("image_embeds", (embedding, uuid))
|
||||
|
||||
if image_embeds is None:
|
||||
placeholder = self._tracker.add("image_embeds", None, uuid)
|
||||
placeholder = self._tracker.add("image_embeds", (None, uuid))
|
||||
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
@@ -657,25 +712,25 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
k: self._connector.fetch_audio_embedding(v)
|
||||
for k, v in audio_embeds.items()
|
||||
}
|
||||
placeholder = self._tracker.add("audio_embeds", embeds, uuid)
|
||||
placeholder = self._tracker.add("audio_embeds", (embeds, uuid))
|
||||
elif isinstance(audio_embeds, str):
|
||||
embedding = self._connector.fetch_audio_embedding(audio_embeds)
|
||||
placeholder = self._tracker.add("audio_embeds", embedding, uuid)
|
||||
placeholder = self._tracker.add("audio_embeds", (embedding, uuid))
|
||||
else:
|
||||
placeholder = self._tracker.add("audio_embeds", None, uuid)
|
||||
placeholder = self._tracker.add("audio_embeds", (None, uuid))
|
||||
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||
) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil, uuid)
|
||||
placeholder = self._tracker.add("image", (image_pil, uuid))
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url) if audio_url else None
|
||||
|
||||
placeholder = self._tracker.add("audio", audio, uuid)
|
||||
placeholder = self._tracker.add("audio", (audio, uuid))
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(
|
||||
@@ -697,7 +752,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||
video = self._connector.fetch_video(video_url=video_url) if video_url else None
|
||||
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
placeholder = self._tracker.add("video", (video, uuid))
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@@ -719,10 +774,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self._tracker.model_config
|
||||
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
|
||||
async def _image_with_uuid_async(self, image_url: str | None, uuid: str | None):
|
||||
image = (
|
||||
await self._connector.fetch_image_async(image_url) if image_url else None
|
||||
)
|
||||
return image, uuid
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro, uuid)
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
coro = self._image_with_uuid_async(image_url, uuid)
|
||||
|
||||
placeholder = self._tracker.add("image", coro)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(
|
||||
@@ -736,23 +797,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
"You must set `--enable-mm-embeds` to input `image_embeds`"
|
||||
)
|
||||
|
||||
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
|
||||
future = asyncio.Future[
|
||||
tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
|
||||
]()
|
||||
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
future.set_result(embeds)
|
||||
future.set_result((embeds, uuid))
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
future.set_result(embedding)
|
||||
future.set_result((embedding, uuid))
|
||||
|
||||
if image_embeds is None:
|
||||
future.set_result(None)
|
||||
future.set_result((None, uuid))
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future, uuid)
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio_embeds(
|
||||
@@ -766,72 +829,51 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
"You must set `--enable-mm-embeds` to input `audio_embeds`"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
|
||||
"is_str=%s, is_none=%s",
|
||||
type(audio_embeds).__name__,
|
||||
uuid,
|
||||
isinstance(audio_embeds, dict),
|
||||
isinstance(audio_embeds, str),
|
||||
audio_embeds is None,
|
||||
)
|
||||
|
||||
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
|
||||
future = asyncio.Future[
|
||||
tuple[torch.Tensor | dict[str, torch.Tensor] | None, str | None]
|
||||
]()
|
||||
|
||||
if isinstance(audio_embeds, dict):
|
||||
logger.info(
|
||||
"🎵 Processing dict audio_embeds with %d entries",
|
||||
len(audio_embeds),
|
||||
)
|
||||
embeds = {
|
||||
k: self._connector.fetch_audio_embedding(v)
|
||||
for k, v in audio_embeds.items()
|
||||
}
|
||||
future.set_result(embeds)
|
||||
logger.info(
|
||||
"🎵 Successfully loaded %d audio embeddings from dict",
|
||||
len(embeds),
|
||||
)
|
||||
future.set_result((embeds, uuid))
|
||||
|
||||
if isinstance(audio_embeds, str):
|
||||
base64_size = len(audio_embeds)
|
||||
logger.info(
|
||||
"🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
|
||||
base64_size,
|
||||
base64_size / 1024,
|
||||
)
|
||||
embedding = self._connector.fetch_audio_embedding(audio_embeds)
|
||||
future.set_result(embedding)
|
||||
logger.info(
|
||||
"🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
|
||||
embedding.shape,
|
||||
embedding.dtype,
|
||||
)
|
||||
future.set_result((embedding, uuid))
|
||||
|
||||
if audio_embeds is None:
|
||||
logger.info("🎵 Audio embeds is None (UUID-only reference)")
|
||||
future.set_result(None)
|
||||
future.set_result((None, uuid))
|
||||
|
||||
placeholder = self._tracker.add("audio_embeds", future, uuid)
|
||||
placeholder = self._tracker.add("audio_embeds", future)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)
|
||||
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||
self,
|
||||
image_pil: Image.Image | None,
|
||||
uuid: str | None = None,
|
||||
) -> None:
|
||||
future: asyncio.Future[Image.Image | None] = asyncio.Future()
|
||||
future = asyncio.Future[tuple[Image.Image | None, str | None]]()
|
||||
if image_pil:
|
||||
future.set_result(image_pil)
|
||||
future.set_result((image_pil, uuid))
|
||||
else:
|
||||
future.set_result(None)
|
||||
future.set_result((None, uuid))
|
||||
|
||||
placeholder = self._tracker.add("image", future, uuid)
|
||||
placeholder = self._tracker.add("image", future)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||
async def _audio_with_uuid_async(self, audio_url: str | None, uuid: str | None):
|
||||
audio = (
|
||||
await self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||
)
|
||||
return audio, uuid
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
||||
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
|
||||
coro = self._audio_with_uuid_async(audio_url, uuid)
|
||||
|
||||
placeholder = self._tracker.add("audio", coro)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(
|
||||
@@ -850,14 +892,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||
async def _video_with_uuid_async(self, video_url: str | None, uuid: str | None):
|
||||
video = (
|
||||
self._connector.fetch_video_async(video_url=video_url)
|
||||
if video_url
|
||||
else None
|
||||
await self._connector.fetch_video_async(video_url) if video_url else None
|
||||
)
|
||||
return video, uuid
|
||||
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||
coro = self._video_with_uuid_async(video_url, uuid)
|
||||
|
||||
placeholder = self._tracker.add("video", coro)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@@ -1380,7 +1424,9 @@ def parse_chat_messages(
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
mm_data, mm_uuids = mm_tracker.resolve_items()
|
||||
|
||||
return conversation, mm_data, mm_uuids
|
||||
|
||||
|
||||
async def parse_chat_messages_async(
|
||||
@@ -1411,7 +1457,9 @@ async def parse_chat_messages_async(
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
mm_data, mm_uuids = await mm_tracker.resolve_items()
|
||||
|
||||
return conversation, mm_data, mm_uuids
|
||||
|
||||
|
||||
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
||||
|
||||
Reference in New Issue
Block a user