[VLM] Generalized prompt updates for multi-modal processor (#13964)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
||||
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
||||
TypedDict, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
@@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor(
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
hf_config = self.info.get_hf_config()
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
@@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_mm_repls = self._bind_and_group_repls([
|
||||
mantis_mm_repls = self._bind_and_group_updates([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
@@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
)
|
||||
])
|
||||
|
||||
prompt_ids, prompt, _ = self._apply_prompt_replacements(
|
||||
prompt_ids, prompt, _ = self._apply_prompt_updates(
|
||||
result["prompt_token_ids"],
|
||||
mantis_mm_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
unbound_orig_repls = self._get_prompt_replacements(
|
||||
unbound_orig_repls = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
|
||||
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
|
||||
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
orig_repls,
|
||||
|
||||
Reference in New Issue
Block a user