[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union, cast)
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
import numpy as np
import torch
@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
@@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache
# in multi-image case
return {"image": 1}
def get_mm_max_tokens_per_item(
@@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
pooling_size = processor.pooling_size
user_str = "User:"
if processor.always_start_with_space:
user_str = " " + user_str
user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
@@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
[img_end_id])
def get_replacement_molmo(item_idx: int):
def get_insertion_molmo(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
@@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint
return PromptReplacementDetails(
full=image_tokens + user_tokens,
features=image_tokens,
)
return image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=user_str,
replacement=get_replacement_molmo,
target="<|endoftext|>",
insertion=get_insertion_molmo,
)
]