Support multiple image/audio embeddings per requests (#29988)

Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com>
Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
jeremyteboul
2025-12-06 20:34:24 -08:00
committed by GitHub
parent cbedb703cc
commit dce6d229f7
3 changed files with 198 additions and 20 deletions

View File

@@ -694,16 +694,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in uuids_by_modality:
image_embeds_uuids = uuids_by_modality["image_embeds"]
if len(image_embeds_uuids) > 1:
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio_embeds" in uuids_by_modality:
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
if len(audio_embeds_uuids) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
@@ -729,16 +723,16 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
@@ -771,16 +765,16 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
mm_inputs["image"] = (
image_embeds_lst if len(image_embeds_lst) != 1 else image_embeds_lst[0]
)
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
mm_inputs["audio"] = (
audio_embeds_lst if len(audio_embeds_lst) != 1 else audio_embeds_lst[0]
)
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: