diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 85a03efd5..691eff9ac 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -727,18 +727,35 @@ def _find_matches( return mode, matches_to_apply +def _all_items_found( + mm_item_counts: dict[str, int], + mm_found_counts: dict[str, int], +) -> bool: + return all( + item_idx >= mm_item_counts[modality] + for modality, item_idx in mm_found_counts.items() + ) + + def _apply_matches( prompt: _S, mm_prompt_updates: "MultiModalPromptUpdates", tokenizer: AnyTokenizer, ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: prompt_len = len(prompt) + mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} out_seqs = list[str | list[int]]() out_result: MultiModalPromptUpdatesApplyResult = { m: [None] * len(items) for m, items in mm_prompt_updates.items() } + mm_found_counts = { + m: sum(r is not None for r in res) for m, res in out_result.items() + } + if _all_items_found(mm_item_counts, mm_found_counts): + return [prompt], out_result + start_idx = prev_end_idx = 0 while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt found = False @@ -776,6 +793,12 @@ def _apply_matches( # Exclude overlapping matches start_idx = prev_end_idx = match.end_idx + mm_found_counts = { + m: sum(r is not None for r in res) for m, res in out_result.items() + } + if _all_items_found(mm_item_counts, mm_found_counts): + break + if not found: start_idx += 1 @@ -832,12 +855,15 @@ def _iter_placeholders( Note that empty matches are ignored. """ - prompt_len = len(prompt) mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} + item_idx_by_modality = {modality: 0 for modality in mm_prompt_updates} - item_idx_by_modality = defaultdict[str, int](lambda: 0) + if _all_items_found(mm_item_counts, item_idx_by_modality): + return + prompt_len = len(prompt) start_idx = 0 + while start_idx < prompt_len: found = False @@ -875,6 +901,9 @@ def _iter_placeholders( break if found: + if _all_items_found(mm_item_counts, item_idx_by_modality): + return + break # Go back to the outer while loop if not found: