[VLM] Generalized prompt updates for multi-modal processor (#13964)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union, cast)
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
BaseProcessingInfo, PromptInsertion,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import JSONTree, json_map_leaves
|
||||
@@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
# TODO: Investigate different `embed_is_patch` between cache/no-cache
|
||||
# in multi-image case
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
@@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
pooling_size = processor.pooling_size
|
||||
|
||||
user_str = "User:"
|
||||
if processor.always_start_with_space:
|
||||
user_str = " " + user_str
|
||||
|
||||
user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
|
||||
|
||||
img_patch_id = processor.image_patch_id
|
||||
img_col_id = processor.im_col_id
|
||||
img_start_id = processor.im_start_id
|
||||
@@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
|
||||
[img_end_id])
|
||||
|
||||
def get_replacement_molmo(item_idx: int):
|
||||
def get_insertion_molmo(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
@@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
((nrows + 1) // pooling_size) + [img_end_id])
|
||||
|
||||
image_tokens = extra_joint + joint
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=image_tokens + user_tokens,
|
||||
features=image_tokens,
|
||||
)
|
||||
return image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=user_str,
|
||||
replacement=get_replacement_molmo,
|
||||
target="<|endoftext|>",
|
||||
insertion=get_insertion_molmo,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user