[Misc] Abstract the logic for reading and writing media content (#11527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -6,7 +6,7 @@ from collections import defaultdict, deque
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
|
||||
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||
Literal, Optional, Tuple, TypeVar, Union, cast)
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
@@ -23,6 +23,8 @@ from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||
InputAudio)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
@@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
async_get_and_parse_video,
|
||||
get_and_parse_audio, get_and_parse_image,
|
||||
get_and_parse_video)
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||
|
||||
self._items: List[_T] = []
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self._model_config
|
||||
|
||||
@property
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
@@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
@staticmethod
|
||||
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||
|
||||
# Merge all the multi-modal items
|
||||
for single_mm_data in items:
|
||||
for mm_key, mm_item in single_mm_data.items():
|
||||
if isinstance(mm_item, list):
|
||||
mm_lists[mm_key].extend(mm_item)
|
||||
else:
|
||||
mm_lists[mm_key].append(mm_item)
|
||||
|
||||
# Unpack any single item lists for models that don't expect multiple.
|
||||
return {
|
||||
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||
for mm_key, mm_list in mm_lists.items()
|
||||
}
|
||||
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
current_count = len(self._items_by_modality[modality]) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._items.append(item)
|
||||
self._items_by_modality[modality].append(item)
|
||||
|
||||
return self._placeholder_str(modality, current_count)
|
||||
|
||||
@@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
return self._combine(self._items) if self._items else None
|
||||
if self._items_by_modality:
|
||||
return dict(self._items_by_modality)
|
||||
|
||||
return None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(
|
||||
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items:
|
||||
items = await asyncio.gather(*self._items)
|
||||
return self._combine(items)
|
||||
if self._items_by_modality:
|
||||
return {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
self._connector = MediaConnector(
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image = get_and_parse_image(image_url,
|
||||
allowed_local_media_path=self._tracker.
|
||||
_model_config.allowed_local_media_path)
|
||||
image = self._connector.fetch_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
input_audio_data = input_audio.get("data","")
|
||||
input_audio_format = input_audio.get("format","")
|
||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
return self.parse_audio(audio_url)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = get_and_parse_video(video_url)
|
||||
video = self._connector.fetch_video(video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
@@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
self._connector = MediaConnector(
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image_coro = async_get_and_parse_image(
|
||||
image_url,
|
||||
allowed_local_media_path=self._tracker._model_config.
|
||||
allowed_local_media_path)
|
||||
image_coro = self._connector.fetch_image_async(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
input_audio_data = input_audio.get("data","")
|
||||
input_audio_format = input_audio.get("format","")
|
||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
return self.parse_audio(audio_url)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = async_get_and_parse_video(video_url)
|
||||
video = self._connector.fetch_video_async(video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
@@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
|
||||
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: Dict[str,
|
||||
Callable[[ChatCompletionContentPartParam],
|
||||
Union[str, Dict[str,str]]]] = {
|
||||
MM_PARSER_MAP: Dict[
|
||||
str,
|
||||
Callable[[ChatCompletionContentPartParam], _ContentPart],
|
||||
] = {
|
||||
"text":
|
||||
lambda part: _TextParser(part).get("text", ""),
|
||||
"image_url":
|
||||
@@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
|
||||
|
||||
|
||||
def _parse_chat_message_content_mm_part(
|
||||
part: ChatCompletionContentPartParam) -> Tuple[str,
|
||||
Union[str, Dict[str, str]]]:
|
||||
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
|
||||
"""
|
||||
Parses a given multi-modal content part based on its type.
|
||||
|
||||
@@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> List[ConversationMessage]:
|
||||
content: List[Union[str, Dict[str, str]]] = []
|
||||
content = list[_ContentPart]()
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
|
||||
@@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
|
||||
mm_parser: BaseMultiModalContentParser,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> Optional[Union[str, Dict[str, str]]]:
|
||||
) -> Optional[_ContentPart]:
|
||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||
structured dictionary pieces for texts and images will be
|
||||
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
|
||||
@@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
|
||||
with multimodal placeholders.
|
||||
"""
|
||||
if isinstance(part, str): # Handle plain text parts
|
||||
text = _TextParser(part)
|
||||
return text
|
||||
return part
|
||||
|
||||
# Handle structured dictionary parts
|
||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||
@@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "input_audio":
|
||||
dict_content = cast(Dict[str, str], content)
|
||||
dict_content = cast(InputAudio, content)
|
||||
mm_parser.parse_input_audio(dict_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user