[Misc] Add fully interleaved support for multimodal 'string' content format (#14047)
Signed-off-by: drobyshev.anton <drobyshev.anton@wb.ru> Co-authored-by: drobyshev.anton <drobyshev.anton@wb.ru>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
@@ -52,6 +52,12 @@ from vllm.utils import deprecate_kwargs, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MODALITY_PLACEHOLDERS_MAP = {
|
||||
"image": "<##IMAGE##>",
|
||||
"audio": "<##AUDIO##>",
|
||||
"video": "<##VIDEO##>",
|
||||
}
|
||||
|
||||
|
||||
class AudioURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
@@ -354,6 +360,7 @@ def resolve_mistral_chat_template(
|
||||
"so it will be ignored.")
|
||||
return None
|
||||
|
||||
|
||||
@deprecate_kwargs(
|
||||
"trust_remote_code",
|
||||
additional_message="Please use `model_config.trust_remote_code` instead.",
|
||||
@@ -633,15 +640,22 @@ class BaseMultiModalContentParser(ABC):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
|
||||
# stores model placehodlers list with corresponding
|
||||
# general MM placeholder:
|
||||
# {
|
||||
# "<##IMAGE##>": ["<image>", "<image>", "<image>"],
|
||||
# "<##AUDIO##>": ["<audio>", "<audio>"]
|
||||
# }
|
||||
self._placeholder_storage: dict[str, list] = defaultdict(list)
|
||||
|
||||
def _add_placeholder(self, placeholder: Optional[str]):
|
||||
def _add_placeholder(self, modality: ModalityStr,
|
||||
placeholder: Optional[str]):
|
||||
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
self._placeholder_storage[mod_placeholder].append(placeholder)
|
||||
|
||||
def mm_placeholder_counts(self) -> dict[str, int]:
|
||||
return dict(self._placeholder_counts)
|
||||
def mm_placeholder_storage(self) -> dict[str, list]:
|
||||
return dict(self._placeholder_storage)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
@@ -685,7 +699,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
image = self._connector.fetch_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
@@ -700,17 +714,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding)
|
||||
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
@@ -723,7 +737,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
video = self._connector.fetch_video(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
@@ -741,7 +755,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
image_coro = self._connector.fetch_image_async(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
@@ -760,20 +774,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
future.set_result(embedding)
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
future: asyncio.Future[Image.Image] = asyncio.Future()
|
||||
future.set_result(image_pil)
|
||||
|
||||
placeholder = self._tracker.add("image", future)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
@@ -786,7 +800,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
video = self._connector.fetch_video_async(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
@@ -856,12 +870,40 @@ def load_chat_template(
|
||||
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
||||
|
||||
|
||||
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
|
||||
texts: list[str]) -> str:
|
||||
for idx, elem in enumerate(texts):
|
||||
if elem in placeholder_storage:
|
||||
texts[idx] = placeholder_storage[elem].pop(0)
|
||||
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
|
||||
texts: list[str],
|
||||
interleave_strings: bool
|
||||
) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model."""
|
||||
|
||||
# flatten storage to make it looks like
|
||||
# {
|
||||
# "<|image|>": 2,
|
||||
# "<|audio|>": 1
|
||||
# }
|
||||
placeholder_counts = Counter(
|
||||
[v for elem in placeholder_storage.values() for v in elem]
|
||||
)
|
||||
|
||||
if interleave_strings:
|
||||
text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
|
||||
else:
|
||||
text_prompt = "\n".join(texts)
|
||||
|
||||
# Pass interleaved text further in case the user used image placeholders
|
||||
# himself, but forgot to disable the 'interleave_strings' flag
|
||||
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders: list[str] = []
|
||||
for placeholder in placeholder_counts:
|
||||
@@ -870,6 +912,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
|
||||
|
||||
if placeholder_counts[placeholder] < 0:
|
||||
logger.error(
|
||||
"Placeholder count is negative! "
|
||||
"Ensure that the 'interleave_strings' flag is disabled "
|
||||
"(current value: %s) "
|
||||
"when manually placing image placeholders.", interleave_strings
|
||||
)
|
||||
logger.debug("Input prompt: %s", text_prompt)
|
||||
raise ValueError(
|
||||
f"Found more '{placeholder}' placeholders in input prompt than "
|
||||
"actual multimodal data items.")
|
||||
@@ -877,8 +926,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
missing_placeholders.extend([placeholder] *
|
||||
placeholder_counts[placeholder])
|
||||
|
||||
# NOTE: For now we always add missing placeholders at the front of
|
||||
# the prompt. This may change to be customizable in the future.
|
||||
# NOTE: Default behaviour: we always add missing placeholders
|
||||
# at the front of the prompt, if interleave_strings=False
|
||||
return "\n".join(missing_placeholders + [text_prompt])
|
||||
|
||||
|
||||
@@ -988,6 +1037,7 @@ def _parse_chat_message_content_parts(
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
interleave_strings: bool,
|
||||
) -> list[ConversationMessage]:
|
||||
content = list[_ContentPart]()
|
||||
|
||||
@@ -998,6 +1048,7 @@ def _parse_chat_message_content_parts(
|
||||
part,
|
||||
mm_parser,
|
||||
wrap_dicts=wrap_dicts,
|
||||
interleave_strings=interleave_strings
|
||||
)
|
||||
if parse_res:
|
||||
content.append(parse_res)
|
||||
@@ -1007,11 +1058,14 @@ def _parse_chat_message_content_parts(
|
||||
return [ConversationMessage(role=role,
|
||||
content=content)] # type: ignore
|
||||
texts = cast(list[str], content)
|
||||
text_prompt = "\n".join(texts)
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
||||
text_prompt)
|
||||
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
|
||||
if mm_placeholder_storage:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage,
|
||||
texts,
|
||||
interleave_strings)
|
||||
else:
|
||||
text_prompt = "\n".join(texts)
|
||||
|
||||
return [ConversationMessage(role=role, content=text_prompt)]
|
||||
|
||||
|
||||
@@ -1020,6 +1074,7 @@ def _parse_chat_message_content_part(
|
||||
mm_parser: BaseMultiModalContentParser,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
interleave_strings: bool,
|
||||
) -> Optional[_ContentPart]:
|
||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||
structured dictionary pieces for texts and images will be
|
||||
@@ -1049,34 +1104,37 @@ def _parse_chat_message_content_part(
|
||||
else:
|
||||
return str_content
|
||||
|
||||
modality = None
|
||||
if part_type == "image_pil":
|
||||
image_content = cast(Image.Image, content)
|
||||
mm_parser.parse_image_pil(image_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "image_url":
|
||||
modality = "image"
|
||||
elif part_type == "image_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "image_embeds":
|
||||
modality = "image"
|
||||
elif part_type == "image_embeds":
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
mm_parser.parse_image_embeds(content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "audio_url":
|
||||
modality = "image"
|
||||
elif part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_audio(str_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "input_audio":
|
||||
modality = "audio"
|
||||
elif part_type == "input_audio":
|
||||
dict_content = cast(InputAudio, content)
|
||||
mm_parser.parse_input_audio(dict_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "video_url":
|
||||
modality = "audio"
|
||||
elif part_type == "video_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_video(str_content)
|
||||
return {'type': 'video'} if wrap_dicts else None
|
||||
modality = "video"
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
return {'type': modality} if wrap_dicts else (
|
||||
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
|
||||
)
|
||||
|
||||
|
||||
# No need to validate using Pydantic again
|
||||
@@ -1088,6 +1146,7 @@ def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
interleave_strings: bool,
|
||||
) -> list[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
@@ -1103,6 +1162,7 @@ def _parse_chat_message_content(
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
wrap_dicts=(content_format == "openai"),
|
||||
interleave_strings=interleave_strings,
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
@@ -1155,6 +1215,11 @@ def parse_chat_messages(
|
||||
msg,
|
||||
mm_tracker,
|
||||
content_format,
|
||||
interleave_strings=(
|
||||
content_format == "string"
|
||||
and model_config.multimodal_config is not None
|
||||
and model_config.multimodal_config.interleave_mm_strings
|
||||
)
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
@@ -1178,6 +1243,11 @@ def parse_chat_messages_futures(
|
||||
msg,
|
||||
mm_tracker,
|
||||
content_format,
|
||||
interleave_strings=(
|
||||
content_format == "string"
|
||||
and model_config.multimodal_config is not None
|
||||
and model_config.multimodal_config.interleave_mm_strings
|
||||
)
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
Reference in New Issue
Block a user