[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@@ -7,9 +7,10 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, TypeVar, Union)
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
Union)
import torch
import torch.nn as nn
@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
@@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(