[Refactor] Dynamic target and content for prompt updates (#23411)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-25 14:39:58 +08:00
committed by GitHub
parent 49ab23b3cc
commit 712d0f88d8
9 changed files with 463 additions and 456 deletions

View File

@@ -22,10 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
PromptReplacement, PromptUpdate,
PromptUpdateDetails,
find_mm_placeholders,
replace_token_matches)
# yapf: enable
@@ -337,14 +339,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
token_ids, res = super()._apply_token_matches(prompt,
mm_prompt_updates)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
@@ -373,13 +371,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
[newline_4],
)
return token_ids
return token_ids, res
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
@@ -404,8 +401,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
return {
modality: [