[VLM] Support caching in merged multi-modal processor (#11396)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from functools import cached_property
|
||||
from types import MethodType
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
@@ -7,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
SiglipVisionConfig)
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
@@ -21,10 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalFieldConfig, MultiModalInputsV2,
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
ProcessorInputs, PromptReplacement,
|
||||
full_groupby_modality)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
@@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
return # Already patched
|
||||
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
orig_preprocess = image_processor.preprocess
|
||||
|
||||
def preprocess(__self, *args, **kwargs):
|
||||
hf_inputs = orig_preprocess(*args, **kwargs)
|
||||
hf_inputs["is_pixtral"] = torch.tensor(True)
|
||||
return hf_inputs
|
||||
|
||||
image_processor.preprocess = MethodType(preprocess, image_processor)
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor(
|
||||
(LlavaProcessor, PixtralProcessor))
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return hf_processor
|
||||
# NOTE: pixel_values=None for MLlavaProcessor
|
||||
pixel_values = processed_outputs.get("pixel_values")
|
||||
if pixel_values is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
if isinstance(self._get_hf_processor(), PixtralProcessor):
|
||||
# Original output: (1, num_images, C, H, W)
|
||||
# New output: (num_images, C, H, W)
|
||||
assert (isinstance(pixel_values, list)
|
||||
and len(pixel_values) == 1
|
||||
and isinstance(pixel_values[0], list)
|
||||
and len(pixel_values[0]) == len(images))
|
||||
|
||||
processed_outputs["pixel_values"] = pixel_values[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
@@ -200,7 +219,7 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
@@ -218,7 +237,6 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@@ -379,7 +397,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@@ -390,33 +407,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
assert isinstance(is_pixtral, torch.Tensor)
|
||||
if is_pixtral.any():
|
||||
images = pixel_values
|
||||
|
||||
def flatten_to_3d_tensors(item):
|
||||
if isinstance(item, torch.Tensor):
|
||||
if item.dim() >= 3:
|
||||
return [t for t in item.view(-1, *item.shape[-3:])]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected tensor dimension: {item.dim()}")
|
||||
elif isinstance(item, list):
|
||||
return [
|
||||
t for subitem in item
|
||||
for t in flatten_to_3d_tensors(subitem)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unexpected type: {type(item)}")
|
||||
|
||||
# Restructure the batched images into a list of lists of images
|
||||
images = flatten_to_3d_tensors(pixel_values)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=images,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
@@ -586,19 +576,71 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
try:
|
||||
from mantis.models.mllava import MLlavaProcessor
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"You need to `pip install "
|
||||
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
|
||||
"to use this model") from exc
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
processor = MLlavaProcessor.from_pretrained(
|
||||
self.ctx.model_config.tokenizer)
|
||||
assert isinstance(processor, ProcessorMixin)
|
||||
return processor
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
mm_items = self._get_mm_items(mm_data)
|
||||
mm_item_counts = mm_items.get_item_counts()
|
||||
mm_kwargs = result["mm_kwargs"]
|
||||
|
||||
# We reimplement the functionality of MLlavaProcessor from
|
||||
# https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
def get_replacement_mantis(item_idx: int):
|
||||
return "".join([
|
||||
f"(image {item_idx+1}: <Image>", # 7 tokens
|
||||
"<image>" * max_image_tokens,
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_repls = self._bind_prompt_replacements([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * max_image_tokens,
|
||||
replacement=get_replacement_mantis,
|
||||
)
|
||||
])
|
||||
|
||||
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
|
||||
result["prompt_token_ids"],
|
||||
mantis_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
unbound_orig_repls = self._get_prompt_replacements(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
|
||||
|
||||
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
assert len(all_placeholders) == mm_item_counts.get("image", 0)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
}
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
|
||||
# To use this model, please use
|
||||
|
||||
Reference in New Issue
Block a user