[Misc] Abstract the logic for reading and writing media content (#11527)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-27 19:21:23 +08:00
committed by GitHub
parent 2c9b8ea2b0
commit 7af553ea30
10 changed files with 495 additions and 389 deletions

View File

@@ -6,7 +6,7 @@ from collections import defaultdict, deque
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
Literal, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
@@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
@@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = []
self._items_by_modality = defaultdict[str, list[_T]](list)
@property
def model_config(self) -> ModelConfig:
return self._model_config
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
@@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._items.append(item)
self._items_by_modality[modality].append(item)
return self._placeholder_str(modality, current_count)
@@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
if self._items_by_modality:
return dict(self._items_by_modality)
return None
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
if self._items_by_modality:
return {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None
@@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError
@abstractmethod
@@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio = get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)
video = self._connector.fetch_video(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
@@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__()
self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)
video = self._connector.fetch_video_async(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
@@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str,
Callable[[ChatCompletionContentPartParam],
Union[str, Dict[str,str]]]] = {
MM_PARSER_MAP: Dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
@@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str,
Union[str, Dict[str, str]]]:
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
"""
Parses a given multi-modal content part based on its type.
@@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []
content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser()
@@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]:
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
@@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
return part
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
@@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio":
dict_content = cast(Dict[str, str], content)
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None