[V1] Support interleaved modality items (#15605)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2025-03-29 06:30:09 -07:00
committed by GitHub
parent 6fa7cd3dbc
commit c67abd614f
6 changed files with 205 additions and 115 deletions

View File

@@ -234,22 +234,11 @@ class Processor:
if decoder_inputs["type"] == "multimodal":
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
individual_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_inputs.modalities
for item in decoder_mm_inputs.get_items(modality)
]
# Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position
# in the input sequence.
# NOTE: interleaved modalities are not supported.
(
sorted_modalities,
sorted_item_modalities,
sorted_mm_positions,
sorted_mm_hashes,
) = merge_and_sort_multimodal_metadata(
@@ -257,26 +246,26 @@ class Processor:
decoder_inputs["mm_hashes"] if self.use_hash else None,
)
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
# modalities involved.
if len(sorted_modalities) > 1:
modality_order_dict = {
modality: order
for order, modality in enumerate(sorted_modalities)
}
# Sanity check to make sure each multimodal input has only one
# modality key.
for mm_input in individual_mm_inputs:
assert len(mm_input.modalities) == 1
# Sort MultiModalKwargs to match sorted_mm_positions
sorted_mm_inputs = sorted(
individual_mm_inputs,
key=lambda mm_input: modality_order_dict[list(
mm_input.modalities)[0]])
# The output of merged multi-modal processor (`decoder_mm_inputs`)
# is a single MultiModalKwargs for all items from all modalities.
# This code flattens kwargs for individual items in a list and
# sorts them by each item's position in the input sequence if there
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
used_indices[modality] += 1
else:
sorted_mm_inputs = individual_mm_inputs
sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
return EngineCoreRequest(
request_id=request_id,