[Frontend][Multimodal] Allow skipping media data when UUIDs are provided. (#23950)
Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.me> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@@ -73,15 +73,10 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
|
||||
type: Required[Literal["audio_url"]]
|
||||
"""The type of the content part."""
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
image_embeds: Required[Union[str, dict[str, str]]]
|
||||
image_embeds: Optional[Union[str, dict[str, str]]]
|
||||
"""
|
||||
The image embeddings. It can be either:
|
||||
- A single base64 string.
|
||||
@@ -108,11 +103,6 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
||||
|
||||
type: Required[Literal["video_url"]]
|
||||
"""The type of the content part."""
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class PILImage(BaseModel):
|
||||
@@ -133,7 +123,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
image_pil: Required[PILImage]
|
||||
image_pil: Optional[PILImage]
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
@@ -151,7 +141,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
image_url: Required[str]
|
||||
image_url: Optional[str]
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
@@ -168,7 +158,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
audio_url: Required[str]
|
||||
audio_url: Optional[str]
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
@@ -180,7 +170,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
video_url: Required[str]
|
||||
video_url: Optional[str]
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
@@ -597,7 +587,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
|
||||
|
||||
@property
|
||||
@@ -624,14 +614,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
return self.mm_registry.create_processor(self.model_config)
|
||||
|
||||
def add(
|
||||
self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
|
||||
self,
|
||||
modality: ModalityStr,
|
||||
item: Optional[_T],
|
||||
uuid: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
|
||||
An optional uuid can be added which serves as a unique identifier of the
|
||||
media.
|
||||
media.
|
||||
"""
|
||||
input_modality = modality.replace("_embeds", "")
|
||||
num_items = len(self._items_by_modality[modality]) + 1
|
||||
@@ -708,10 +701,15 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
items_by_modality = {}
|
||||
for modality, items in self._items_by_modality.items():
|
||||
coros = []
|
||||
for item in items:
|
||||
if item is not None:
|
||||
coros.append(item)
|
||||
else:
|
||||
coros.append(asyncio.sleep(0))
|
||||
items_by_modality[modality] = await asyncio.gather(*coros)
|
||||
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(
|
||||
@@ -760,35 +758,40 @@ class BaseMultiModalContentParser(ABC):
|
||||
return dict(self._placeholder_storage)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
||||
def parse_image(
|
||||
self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_embeds(
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str]],
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
||||
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
||||
def parse_audio(
|
||||
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_input_audio(
|
||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
||||
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
||||
def parse_video(
|
||||
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -803,15 +806,17 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
||||
image = self._connector.fetch_image(image_url)
|
||||
def parse_image(
|
||||
self, image_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
image = self._connector.fetch_image(image_url) if image_url else None
|
||||
|
||||
placeholder = self._tracker.add("image", image, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str]],
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
if isinstance(image_embeds, dict):
|
||||
@@ -825,31 +830,49 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding, uuid)
|
||||
|
||||
if image_embeds is None:
|
||||
placeholder = self._tracker.add("image_embeds", None, uuid)
|
||||
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
||||
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
def parse_audio(
|
||||
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url) if audio_url else None
|
||||
|
||||
placeholder = self._tracker.add("audio", audio, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(
|
||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
||||
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
if input_audio:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
if audio_data:
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
else:
|
||||
# If a UUID is provided, audio data may be empty.
|
||||
audio_url = None
|
||||
else:
|
||||
audio_url = None
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
||||
video = self._connector.fetch_video(video_url=video_url)
|
||||
def parse_video(
|
||||
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
video = (
|
||||
self._connector.fetch_video(video_url=video_url)
|
||||
if video_url
|
||||
else None
|
||||
)
|
||||
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
self._add_placeholder("video", placeholder)
|
||||
@@ -865,18 +888,24 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
||||
image_coro = self._connector.fetch_image_async(image_url)
|
||||
def parse_image(
|
||||
self, image_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
image_coro = (
|
||||
self._connector.fetch_image_async(image_url) if image_url else None
|
||||
)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str]],
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
||||
future: asyncio.Future[Union[str, dict[str, str], None]] = (
|
||||
asyncio.Future()
|
||||
)
|
||||
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
@@ -889,35 +918,58 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
future.set_result(embedding)
|
||||
|
||||
if image_embeds is None:
|
||||
future.set_result(None)
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
||||
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
future: asyncio.Future[Image.Image] = asyncio.Future()
|
||||
future.set_result(image_pil)
|
||||
future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
|
||||
if image_pil:
|
||||
future.set_result(image_pil)
|
||||
else:
|
||||
future.set_result(None)
|
||||
|
||||
placeholder = self._tracker.add("image", future, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
def parse_audio(
|
||||
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio_coro = (
|
||||
self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||
)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(
|
||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
||||
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
if input_audio:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
if audio_data:
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
else:
|
||||
# If a UUID is provided, audio data may be empty.
|
||||
audio_url = None
|
||||
else:
|
||||
audio_url = None
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
||||
video = self._connector.fetch_video_async(video_url=video_url)
|
||||
def parse_video(
|
||||
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
video = (
|
||||
self._connector.fetch_video_async(video_url=video_url)
|
||||
if video_url
|
||||
else None
|
||||
)
|
||||
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
self._add_placeholder("video", placeholder)
|
||||
@@ -1130,8 +1182,9 @@ def _parse_chat_message_content_mm_part(
|
||||
part, dict
|
||||
) # This is needed to avoid mypy errors: part.get() from str
|
||||
part_type = part.get("type", None)
|
||||
uuid = part.get("uuid", None)
|
||||
|
||||
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
|
||||
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
|
||||
content = MM_PARSER_MAP[part_type](part)
|
||||
|
||||
# Special case for 'image_url.detail'
|
||||
@@ -1146,25 +1199,54 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
# Handle missing 'type' but provided direct URL fields.
|
||||
# 'type' is required field by pydantic
|
||||
if part_type is None:
|
||||
if part.get("image_url") is not None:
|
||||
if part_type is None or uuid is not None:
|
||||
if "image_url" in part:
|
||||
image_params = cast(
|
||||
CustomChatCompletionContentSimpleImageParam, part
|
||||
)
|
||||
return "image_url", image_params.get("image_url", "")
|
||||
if part.get("audio_url") is not None:
|
||||
image_url = image_params.get("image_url", None)
|
||||
if isinstance(image_url, dict):
|
||||
# Can potentially happen if user provides a uuid
|
||||
# with url as a dict of {"url": url}
|
||||
image_url = image_url.get("url", None)
|
||||
return "image_url", image_url
|
||||
if "image_pil" in part:
|
||||
# "image_pil" could be None if UUID is provided.
|
||||
image_params = cast( # type: ignore
|
||||
CustomChatCompletionContentPILImageParam, part
|
||||
)
|
||||
image_pil = image_params.get("image_pil", None)
|
||||
return "image_pil", image_pil
|
||||
if "image_embeds" in part:
|
||||
# "image_embeds" could be None if UUID is provided.
|
||||
image_params = cast( # type: ignore
|
||||
ChatCompletionContentPartImageEmbedsParam, part
|
||||
)
|
||||
image_embeds = image_params.get("image_embeds", None)
|
||||
return "image_embeds", image_embeds
|
||||
if "audio_url" in part:
|
||||
audio_params = cast(
|
||||
CustomChatCompletionContentSimpleAudioParam, part
|
||||
)
|
||||
return "audio_url", audio_params.get("audio_url", "")
|
||||
audio_url = audio_params.get("audio_url", None)
|
||||
if isinstance(audio_url, dict):
|
||||
# Can potentially happen if user provides a uuid
|
||||
# with url as a dict of {"url": url}
|
||||
audio_url = audio_url.get("url", None)
|
||||
return "audio_url", audio_url
|
||||
if part.get("input_audio") is not None:
|
||||
input_audio_params = cast(dict[str, str], part)
|
||||
return "input_audio", input_audio_params
|
||||
if part.get("video_url") is not None:
|
||||
if "video_url" in part:
|
||||
video_params = cast(
|
||||
CustomChatCompletionContentSimpleVideoParam, part
|
||||
)
|
||||
return "video_url", video_params.get("video_url", "")
|
||||
video_url = video_params.get("video_url", None)
|
||||
if isinstance(video_url, dict):
|
||||
# Can potentially happen if user provides a uuid
|
||||
# with url as a dict of {"url": url}
|
||||
video_url = video_url.get("url", None)
|
||||
return "video_url", video_url
|
||||
# Raise an error if no 'type' or direct URL is found.
|
||||
raise ValueError("Missing 'type' field in multimodal part.")
|
||||
|
||||
@@ -1173,15 +1255,9 @@ def _parse_chat_message_content_mm_part(
|
||||
return part_type, "unknown part_type content"
|
||||
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
|
||||
PART_TYPES_TO_SKIP_NONE_CONTENT = (
|
||||
"text",
|
||||
"refusal",
|
||||
"image_url",
|
||||
"image_embeds",
|
||||
"image_pil",
|
||||
"audio_url",
|
||||
"input_audio",
|
||||
"video_url",
|
||||
)
|
||||
|
||||
|
||||
@@ -1242,7 +1318,7 @@ def _parse_chat_message_content_part(
|
||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
|
||||
# content is None, log a warning and skip
|
||||
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
|
||||
if part_type in PART_TYPES_TO_SKIP_NONE_CONTENT and content is None:
|
||||
logger.warning(
|
||||
"Skipping multimodal part '%s' (type: '%s') "
|
||||
"with empty / unparsable content.",
|
||||
@@ -1266,7 +1342,10 @@ def _parse_chat_message_content_part(
|
||||
|
||||
modality = None
|
||||
if part_type == "image_pil":
|
||||
image_content = cast(Image.Image, content)
|
||||
if content is not None:
|
||||
image_content = cast(Image.Image, content)
|
||||
else:
|
||||
image_content = None
|
||||
mm_parser.parse_image_pil(image_content, uuid)
|
||||
modality = "image"
|
||||
elif part_type in ("image_url", "input_image"):
|
||||
@@ -1274,7 +1353,10 @@ def _parse_chat_message_content_part(
|
||||
mm_parser.parse_image(str_content, uuid)
|
||||
modality = "image"
|
||||
elif part_type == "image_embeds":
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
if content is not None:
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
else:
|
||||
content = None
|
||||
mm_parser.parse_image_embeds(content, uuid)
|
||||
modality = "image"
|
||||
elif part_type == "audio_url":
|
||||
|
||||
Reference in New Issue
Block a user