[Frontend][Multimodal] Allow skipping media data when UUIDs are provided. (#23950)

Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.me>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
Chenheli Hua
2025-09-12 19:16:06 -07:00
committed by GitHub
parent 4fdd6f5cbf
commit 7f2ea7074e
9 changed files with 970 additions and 96 deletions

View File

@@ -73,15 +73,10 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
type: Required[Literal["audio_url"]]
"""The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, dict[str, str]]]
image_embeds: Optional[Union[str, dict[str, str]]]
"""
The image embeddings. It can be either:
- A single base64 string.
@@ -108,11 +103,6 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
type: Required[Literal["video_url"]]
"""The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class PILImage(BaseModel):
@@ -133,7 +123,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
}
"""
image_pil: Required[PILImage]
image_pil: Optional[PILImage]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
@@ -151,7 +141,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
}
"""
image_url: Required[str]
image_url: Optional[str]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
@@ -168,7 +158,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
}
"""
audio_url: Required[str]
audio_url: Optional[str]
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
@@ -180,7 +170,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
}
"""
video_url: Required[str]
video_url: Optional[str]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
@@ -597,7 +587,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._tokenizer = tokenizer
self._items_by_modality = defaultdict[str, list[_T]](list)
self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
@property
@@ -624,14 +614,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self.mm_registry.create_processor(self.model_config)
def add(
self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
self,
modality: ModalityStr,
item: Optional[_T],
uuid: Optional[str] = None,
) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
An optional uuid can be added which serves as a unique identifier of the
media.
media.
"""
input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1
@@ -708,10 +701,15 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
items_by_modality = {}
for modality, items in self._items_by_modality.items():
coros = []
for item in items:
if item is not None:
coros.append(item)
else:
coros.append(asyncio.sleep(0))
items_by_modality[modality] = await asyncio.gather(*coros)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
@@ -760,35 +758,40 @@ class BaseMultiModalContentParser(ABC):
return dict(self._placeholder_storage)
@abstractmethod
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_embeds(
self,
image_embeds: Union[str, dict[str, str]],
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
def parse_audio(
self, audio_url: Optional[str], uuid: Optional[str] = None
) -> None:
raise NotImplementedError
@abstractmethod
def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
) -> None:
raise NotImplementedError
@abstractmethod
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
def parse_video(
self, video_url: Optional[str], uuid: Optional[str] = None
) -> None:
raise NotImplementedError
@@ -803,15 +806,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
image = self._connector.fetch_image(image_url)
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None
) -> None:
image = self._connector.fetch_image(image_url) if image_url else None
placeholder = self._tracker.add("image", image, uuid)
self._add_placeholder("image", placeholder)
def parse_image_embeds(
self,
image_embeds: Union[str, dict[str, str]],
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
) -> None:
if isinstance(image_embeds, dict):
@@ -825,31 +830,49 @@ class MultiModalContentParser(BaseMultiModalContentParser):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding, uuid)
if image_embeds is None:
placeholder = self._tracker.add("image_embeds", None, uuid)
self._add_placeholder("image", placeholder)
def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
) -> None:
placeholder = self._tracker.add("image", image_pil, uuid)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
audio = self._connector.fetch_audio(audio_url)
def parse_audio(
self, audio_url: Optional[str], uuid: Optional[str] = None
) -> None:
audio = self._connector.fetch_audio(audio_url) if audio_url else None
placeholder = self._tracker.add("audio", audio, uuid)
self._add_placeholder("audio", placeholder)
def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
if input_audio:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
if audio_data:
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
else:
# If a UUID is provided, audio data may be empty.
audio_url = None
else:
audio_url = None
return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video(video_url=video_url)
def parse_video(
self, video_url: Optional[str], uuid: Optional[str] = None
) -> None:
video = (
self._connector.fetch_video(video_url=video_url)
if video_url
else None
)
placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder)
@@ -865,18 +888,24 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
image_coro = self._connector.fetch_image_async(image_url)
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None
) -> None:
image_coro = (
self._connector.fetch_image_async(image_url) if image_url else None
)
placeholder = self._tracker.add("image", image_coro, uuid)
self._add_placeholder("image", placeholder)
def parse_image_embeds(
self,
image_embeds: Union[str, dict[str, str]],
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
future: asyncio.Future[Union[str, dict[str, str], None]] = (
asyncio.Future()
)
if isinstance(image_embeds, dict):
embeds = {
@@ -889,35 +918,58 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
embedding = self._connector.fetch_image_embedding(image_embeds)
future.set_result(embedding)
if image_embeds is None:
future.set_result(None)
placeholder = self._tracker.add("image_embeds", future, uuid)
self._add_placeholder("image", placeholder)
def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil)
future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
if image_pil:
future.set_result(image_pil)
else:
future.set_result(None)
placeholder = self._tracker.add("image", future, uuid)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)
def parse_audio(
self, audio_url: Optional[str], uuid: Optional[str] = None
) -> None:
audio_coro = (
self._connector.fetch_audio_async(audio_url) if audio_url else None
)
placeholder = self._tracker.add("audio", audio_coro, uuid)
self._add_placeholder("audio", placeholder)
def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
if input_audio:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
if audio_data:
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
else:
# If a UUID is provided, audio data may be empty.
audio_url = None
else:
audio_url = None
return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video_async(video_url=video_url)
def parse_video(
self, video_url: Optional[str], uuid: Optional[str] = None
) -> None:
video = (
self._connector.fetch_video_async(video_url=video_url)
if video_url
else None
)
placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder)
@@ -1130,8 +1182,9 @@ def _parse_chat_message_content_mm_part(
part, dict
) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
uuid = part.get("uuid", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
@@ -1146,25 +1199,54 @@ def _parse_chat_message_content_mm_part(
# Handle missing 'type' but provided direct URL fields.
# 'type' is required field by pydantic
if part_type is None:
if part.get("image_url") is not None:
if part_type is None or uuid is not None:
if "image_url" in part:
image_params = cast(
CustomChatCompletionContentSimpleImageParam, part
)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
image_url = image_params.get("image_url", None)
if isinstance(image_url, dict):
# Can potentially happen if user provides a uuid
# with url as a dict of {"url": url}
image_url = image_url.get("url", None)
return "image_url", image_url
if "image_pil" in part:
# "image_pil" could be None if UUID is provided.
image_params = cast( # type: ignore
CustomChatCompletionContentPILImageParam, part
)
image_pil = image_params.get("image_pil", None)
return "image_pil", image_pil
if "image_embeds" in part:
# "image_embeds" could be None if UUID is provided.
image_params = cast( # type: ignore
ChatCompletionContentPartImageEmbedsParam, part
)
image_embeds = image_params.get("image_embeds", None)
return "image_embeds", image_embeds
if "audio_url" in part:
audio_params = cast(
CustomChatCompletionContentSimpleAudioParam, part
)
return "audio_url", audio_params.get("audio_url", "")
audio_url = audio_params.get("audio_url", None)
if isinstance(audio_url, dict):
# Can potentially happen if user provides a uuid
# with url as a dict of {"url": url}
audio_url = audio_url.get("url", None)
return "audio_url", audio_url
if part.get("input_audio") is not None:
input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
if "video_url" in part:
video_params = cast(
CustomChatCompletionContentSimpleVideoParam, part
)
return "video_url", video_params.get("video_url", "")
video_url = video_params.get("video_url", None)
if isinstance(video_url, dict):
# Can potentially happen if user provides a uuid
# with url as a dict of {"url": url}
video_url = video_url.get("url", None)
return "video_url", video_url
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
@@ -1173,15 +1255,9 @@ def _parse_chat_message_content_mm_part(
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
PART_TYPES_TO_SKIP_NONE_CONTENT = (
"text",
"refusal",
"image_url",
"image_embeds",
"image_pil",
"audio_url",
"input_audio",
"video_url",
)
@@ -1242,7 +1318,7 @@ def _parse_chat_message_content_part(
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.",
@@ -1266,7 +1342,10 @@ def _parse_chat_message_content_part(
modality = None
if part_type == "image_pil":
image_content = cast(Image.Image, content)
if content is not None:
image_content = cast(Image.Image, content)
else:
image_content = None
mm_parser.parse_image_pil(image_content, uuid)
modality = "image"
elif part_type in ("image_url", "input_image"):
@@ -1274,7 +1353,10 @@ def _parse_chat_message_content_part(
mm_parser.parse_image(str_content, uuid)
modality = "image"
elif part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
if content is not None:
content = cast(Union[str, dict[str, str]], content)
else:
content = None
mm_parser.parse_image_embeds(content, uuid)
modality = "image"
elif part_type == "audio_url":