[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-03 00:39:27 +08:00
committed by GitHub
parent b6087a6bee
commit 8c38ee7007
14 changed files with 609 additions and 555 deletions

View File

@@ -4,31 +4,25 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector,
init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model, maybe_prefix)
@@ -65,218 +59,127 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs]
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_llava_next_num_unpadded_features(
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
def _get_hf_config(self) -> LlavaNextConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding
def _get_hf_processor(self) -> LlavaNextProcessor:
return self.ctx.get_hf_processor(LlavaNextProcessor)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches, num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
def get_max_llava_next_image_tokens(ctx: InputContext):
"""Compute the max feature size for all possible image grid pinpoints."""
return _get_pinpoint_with_largest_features(ctx)[0]
def _get_pinpoint_with_largest_features(
ctx: InputContext) -> Tuple[int, Tuple[int, int]]:
"""Get the grid pinpoint with the largest features & its feature size."""
hf_config = ctx.get_hf_config(LlavaNextConfig)
largest_feature_size = 0
largest_feature_pinpoint = None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = (height, width)
if not largest_feature_size or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size, pinpoint = _get_pinpoint_with_largest_features(ctx)
max_feat_height, max_feat_width = pinpoint
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
mm_data = dummy_image_for_clip(
vision_config,
num_images,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
def _get_max_image_tokens(self) -> int:
largest_feature_size, _ = self._get_pinpoint_with_most_features()
return largest_feature_size
def _get_dummy_image_size(self) -> ImageSize:
_, pinpoint = self._get_pinpoint_with_most_features()
return pinpoint
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
num_patches = self._vision_encoder_info.get_num_patches()
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(image_height, image_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=self._vision_encoder_info.get_image_size(),
)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
(
unpadded_feature_size,
newline_feature_size,
) = self._get_num_unpadded_features(
original_height=image_height,
original_width=image_width,
npatches=num_patches,
num_patch_height=num_patch_height,
num_patch_width=num_patch_width,
)
mm_data = dummy_image_for_siglip(
vision_config,
num_images,
image_width_override=max_feat_width,
image_height_override=max_feat_height,
)
return unpadded_feature_size + newline_feature_size + base_feature_size
return DummyData(seq_data, mm_data, ranges)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_num_unpadded_features(
self,
*,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
"""
Get the grid pinpoint with the most features and
the corresponding feature size.
"""
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint
def input_processor_for_llava_next(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_next_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@@ -507,7 +410,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels(
self,
inputs: LlavaNextImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]