[Misc] Add fully interleaved support for multimodal 'string' content format (#14047)

Signed-off-by: drobyshev.anton <drobyshev.anton@wb.ru>
Co-authored-by: drobyshev.anton <drobyshev.anton@wb.ru>
This commit is contained in:
Anton
2025-07-07 22:43:08 +03:00
committed by GitHub
parent 22dd9c2730
commit e601efcb10
4 changed files with 478 additions and 43 deletions

View File

@@ -4,7 +4,7 @@
import asyncio
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections import Counter, defaultdict, deque
from collections.abc import Awaitable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
@@ -52,6 +52,12 @@ from vllm.utils import deprecate_kwargs, random_uuid
logger = init_logger(__name__)
MODALITY_PLACEHOLDERS_MAP = {
"image": "<##IMAGE##>",
"audio": "<##AUDIO##>",
"video": "<##VIDEO##>",
}
class AudioURL(TypedDict, total=False):
url: Required[str]
@@ -354,6 +360,7 @@ def resolve_mistral_chat_template(
"so it will be ignored.")
return None
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
@@ -633,15 +640,22 @@ class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
# stores model placehodlers list with corresponding
# general MM placeholder:
# {
# "<##IMAGE##>": ["<image>", "<image>", "<image>"],
# "<##AUDIO##>": ["<audio>", "<audio>"]
# }
self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(self, placeholder: Optional[str]):
def _add_placeholder(self, modality: ModalityStr,
placeholder: Optional[str]):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder:
self._placeholder_counts[placeholder] += 1
self._placeholder_storage[mod_placeholder].append(placeholder)
def mm_placeholder_counts(self) -> dict[str, int]:
return dict(self._placeholder_counts)
def mm_placeholder_storage(self) -> dict[str, list]:
return dict(self._placeholder_storage)
@abstractmethod
def parse_image(self, image_url: str) -> None:
@@ -685,7 +699,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
@@ -700,17 +714,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
placeholder = self._tracker.add("image", image_pil)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
@@ -723,7 +737,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
video = self._connector.fetch_video(video_url=video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
self._add_placeholder("video", placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
@@ -741,7 +755,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
@@ -760,20 +774,20 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil)
placeholder = self._tracker.add("image", future)
self._add_placeholder(placeholder)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
@@ -786,7 +800,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
video = self._connector.fetch_video_async(video_url=video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
self._add_placeholder("video", placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
@@ -856,12 +870,40 @@ def load_chat_template(
return _cached_load_chat_template(chat_template, is_literal=is_literal)
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
texts: list[str]) -> str:
for idx, elem in enumerate(texts):
if elem in placeholder_storage:
texts[idx] = placeholder_storage[elem].pop(0)
return "\n".join(texts)
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
text_prompt: str) -> str:
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
texts: list[str],
interleave_strings: bool
) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like
# {
# "<|image|>": 2,
# "<|audio|>": 1
# }
placeholder_counts = Counter(
[v for elem in placeholder_storage.values() for v in elem]
)
if interleave_strings:
text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts)
else:
text_prompt = "\n".join(texts)
# Pass interleaved text further in case the user used image placeholders
# himself, but forgot to disable the 'interleave_strings' flag
# Look through the text prompt to check for missing placeholders
missing_placeholders: list[str] = []
for placeholder in placeholder_counts:
@@ -870,6 +912,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
logger.error(
"Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) "
"when manually placing image placeholders.", interleave_strings
)
logger.debug("Input prompt: %s", text_prompt)
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
@@ -877,8 +926,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
# NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False
return "\n".join(missing_placeholders + [text_prompt])
@@ -988,6 +1037,7 @@ def _parse_chat_message_content_parts(
mm_tracker: BaseMultiModalItemTracker,
*,
wrap_dicts: bool,
interleave_strings: bool,
) -> list[ConversationMessage]:
content = list[_ContentPart]()
@@ -998,6 +1048,7 @@ def _parse_chat_message_content_parts(
part,
mm_parser,
wrap_dicts=wrap_dicts,
interleave_strings=interleave_strings
)
if parse_res:
content.append(parse_res)
@@ -1007,11 +1058,14 @@ def _parse_chat_message_content_parts(
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(list[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
if mm_placeholder_storage:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage,
texts,
interleave_strings)
else:
text_prompt = "\n".join(texts)
return [ConversationMessage(role=role, content=text_prompt)]
@@ -1020,6 +1074,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
interleave_strings: bool,
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
@@ -1049,34 +1104,37 @@ def _parse_chat_message_content_part(
else:
return str_content
modality = None
if part_type == "image_pil":
image_content = cast(Image.Image, content)
mm_parser.parse_image_pil(image_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_url":
modality = "image"
elif part_type == "image_url":
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_embeds":
modality = "image"
elif part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
modality = "image"
elif part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio":
modality = "audio"
elif part_type == "input_audio":
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None
if part_type == "video_url":
modality = "audio"
elif part_type == "video_url":
str_content = cast(str, content)
mm_parser.parse_video(str_content)
return {'type': 'video'} if wrap_dicts else None
modality = "video"
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
raise NotImplementedError(f"Unknown part type: {part_type}")
return {'type': modality} if wrap_dicts else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
)
# No need to validate using Pydantic again
@@ -1088,6 +1146,7 @@ def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
interleave_strings: bool,
) -> list[ConversationMessage]:
role = message["role"]
content = message.get("content")
@@ -1103,6 +1162,7 @@ def _parse_chat_message_content(
content, # type: ignore
mm_tracker,
wrap_dicts=(content_format == "openai"),
interleave_strings=interleave_strings,
)
for result_msg in result:
@@ -1155,6 +1215,11 @@ def parse_chat_messages(
msg,
mm_tracker,
content_format,
interleave_strings=(
content_format == "string"
and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings
)
)
conversation.extend(sub_messages)
@@ -1178,6 +1243,11 @@ def parse_chat_messages_futures(
msg,
mm_tracker,
content_format,
interleave_strings=(
content_format == "string"
and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings
)
)
conversation.extend(sub_messages)