[V1] Support interleaved modality items (#15605)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -234,22 +234,11 @@ class Processor:
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
decoder_mm_inputs = decoder_inputs["mm_kwargs"]
|
||||
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# contains the kwargs for all items from all modalities.
|
||||
# This code separates them so that there is one set of kwargs
|
||||
# per item per modality.
|
||||
individual_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item])
|
||||
for modality in decoder_mm_inputs.modalities
|
||||
for item in decoder_mm_inputs.get_items(modality)
|
||||
]
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
# NOTE: interleaved modalities are not supported.
|
||||
(
|
||||
sorted_modalities,
|
||||
sorted_item_modalities,
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
@@ -257,26 +246,26 @@ class Processor:
|
||||
decoder_inputs["mm_hashes"] if self.use_hash else None,
|
||||
)
|
||||
|
||||
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
|
||||
# modalities involved.
|
||||
if len(sorted_modalities) > 1:
|
||||
modality_order_dict = {
|
||||
modality: order
|
||||
for order, modality in enumerate(sorted_modalities)
|
||||
}
|
||||
|
||||
# Sanity check to make sure each multimodal input has only one
|
||||
# modality key.
|
||||
for mm_input in individual_mm_inputs:
|
||||
assert len(mm_input.modalities) == 1
|
||||
|
||||
# Sort MultiModalKwargs to match sorted_mm_positions
|
||||
sorted_mm_inputs = sorted(
|
||||
individual_mm_inputs,
|
||||
key=lambda mm_input: modality_order_dict[list(
|
||||
mm_input.modalities)[0]])
|
||||
# The output of merged multi-modal processor (`decoder_mm_inputs`)
|
||||
# is a single MultiModalKwargs for all items from all modalities.
|
||||
# This code flattens kwargs for individual items in a list and
|
||||
# sorts them by each item's position in the input sequence if there
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
|
||||
]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
sorted_mm_inputs = individual_mm_inputs
|
||||
sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
|
||||
Reference in New Issue
Block a user