[3/N] Support and implement merged input processor for LLaVA model (#10676)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -1,17 +1,19 @@
|
||||
from functools import cached_property
|
||||
from types import MethodType
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||
PretrainedConfig, SiglipVisionConfig)
|
||||
from PIL.Image import Image
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -19,21 +21,20 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalProcessingMetadata,
|
||||
MultiModalProcessor, PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||
input_processor_for_clip)
|
||||
get_max_clip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
dummy_seq_data_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
input_processor_for_pixtral_hf)
|
||||
get_max_pixtral_hf_image_tokens)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
get_max_siglip_image_tokens)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
@@ -113,102 +114,86 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
|
||||
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_siglip(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_pixtral_hf(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = [get_max_llava_image_tokens(ctx)
|
||||
] * len(image_data)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[1] for item in image_data]
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return input_processor_for_siglip(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
# We ignore image_feature_size_override since we have non-uniform
|
||||
# image sizes for Pixtral
|
||||
return input_processor_for_pixtral_hf(
|
||||
model_config,
|
||||
vision_config,
|
||||
inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
hf_processor = ctx.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
return MultiModalKwargs(
|
||||
**hf_inputs,
|
||||
is_pixtral=torch.tensor(is_pixtral),
|
||||
)
|
||||
|
||||
|
||||
def create_metadata_for_llava(
|
||||
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_repl_count(
|
||||
mm_items: list[Image],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
return get_max_llava_image_tokens(ctx)
|
||||
|
||||
return {
|
||||
"image":
|
||||
ModalityProcessingMetadata(prompt_repls=[
|
||||
PromptReplacement(target=[image_token_id],
|
||||
repl_unit=[image_token_id],
|
||||
repl_count=get_repl_count),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
class LlavaProcessor(MultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
return # Already patched
|
||||
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
orig_preprocess = image_processor.preprocess
|
||||
|
||||
def preprocess(__self, *args, **kwargs):
|
||||
hf_inputs = orig_preprocess(*args, **kwargs)
|
||||
hf_inputs["is_pixtral"] = torch.tensor(True)
|
||||
return hf_inputs
|
||||
|
||||
image_processor.preprocess = MethodType(preprocess, image_processor)
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_dummy_mm_kwargs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
@@ -291,10 +276,11 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_llava(ctx),
|
||||
))
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@@ -367,38 +353,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_sizes(self, images: List[torch.Tensor],
|
||||
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
if not isinstance(sizes, list):
|
||||
sizes = [sizes]
|
||||
|
||||
total_images = sum(size.numel() // 2 for size in sizes)
|
||||
if total_images != len(images):
|
||||
raise ValueError("Mismatch in number of images. "
|
||||
f"Expected {total_images}, got {len(images)}")
|
||||
img_idx = 0
|
||||
for size in sizes:
|
||||
# Flatten the size tensor to a list of (height, width) pairs
|
||||
size = size.view(-1, 2).tolist()
|
||||
for expected_h, expected_w in size:
|
||||
if img_idx >= len(images):
|
||||
raise ValueError("Ran out of images before sizes. "
|
||||
f"{img_idx} >= {len(images)}")
|
||||
img = images[img_idx]
|
||||
if img.shape[-2:] != (expected_h, expected_w):
|
||||
raise ValueError(
|
||||
"Image size mismatch. Expected "
|
||||
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
|
||||
if img.shape[-3] != 3:
|
||||
raise ValueError("Image channel mismatch. Expected 3, "
|
||||
f"got {img.shape[-3]}")
|
||||
img_idx += 1
|
||||
return images
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
@@ -409,9 +367,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Case for models like PixtralHF that have dynamic image sizes
|
||||
# so we need to produce a list of tensors
|
||||
if image_sizes is not None:
|
||||
assert isinstance(is_pixtral, torch.Tensor)
|
||||
if is_pixtral.any():
|
||||
images = pixel_values
|
||||
|
||||
def flatten_to_3d_tensors(item):
|
||||
@@ -434,7 +391,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_sizes(images, image_sizes),
|
||||
data=images,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
|
||||
Reference in New Issue
Block a user