[VLM] Support caching in merged multi-modal processor (#11396)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-28 01:22:48 +08:00
committed by GitHub
parent 5ce4627a7e
commit 101418096f
20 changed files with 1459 additions and 452 deletions

View File

@@ -1,5 +1,4 @@
from functools import cached_property
from types import MethodType
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)
@@ -7,7 +6,7 @@ import torch
import torch.nn as nn
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
@@ -21,10 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext):
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
return # Already patched
image_processor = hf_processor.image_processor # type: ignore
orig_preprocess = image_processor.preprocess
def preprocess(__self, *args, **kwargs):
hf_inputs = orig_preprocess(*args, **kwargs)
hf_inputs["is_pixtral"] = torch.tensor(True)
return hf_inputs
image_processor.preprocess = MethodType(preprocess, image_processor)
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return hf_processor
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)
if isinstance(self._get_hf_processor(), PixtralProcessor):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1
and isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
@@ -200,7 +219,7 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
@@ -218,7 +237,6 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
@@ -379,7 +397,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
@@ -390,33 +407,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
assert isinstance(is_pixtral, torch.Tensor)
if is_pixtral.any():
images = pixel_values
def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")
# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
@@ -586,19 +576,71 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, ProcessorMixin)
return processor
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_llava_image_tokens(self.ctx)
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_kwargs = result["mm_kwargs"]
# We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git
def get_replacement_mantis(item_idx: int):
return "".join([
f"(image {item_idx+1}: <Image>", # 7 tokens
"<image>" * max_image_tokens,
"</Image>)", # 3 tokens
])
mantis_repls = self._bind_prompt_replacements([
PromptReplacement(
modality="image",
target=[image_token_id] * max_image_tokens,
replacement=get_replacement_mantis,
)
])
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_repls,
mm_item_counts,
)
unbound_orig_repls = self._get_prompt_replacements(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
mm_item_counts)
assert len(all_placeholders) == mm_item_counts.get("image", 0)
mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
}
return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)
# To use this model, please use