[Refactor] Dynamic target and content for prompt updates (#23411)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -22,10 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BoundPromptUpdate,
|
||||
BaseProcessingInfo,
|
||||
MultiModalPromptUpdates,
|
||||
MultiModalPromptUpdatesApplyResult,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptTargetMatch,
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
@@ -337,14 +339,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
def _apply_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
token_ids = super()._apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
|
||||
token_ids, res = super()._apply_token_matches(prompt,
|
||||
mm_prompt_updates)
|
||||
|
||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||
# Since our replacement can insert "\n\n" next to "\n"
|
||||
@@ -373,13 +371,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
[newline_4],
|
||||
)
|
||||
|
||||
return token_ids
|
||||
return token_ids, res
|
||||
|
||||
def _find_mm_placeholders(
|
||||
self,
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
@@ -404,8 +401,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
||||
mm_item_counts)
|
||||
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
|
||||
Reference in New Issue
Block a user