[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-04 19:40:53 +08:00
committed by GitHub
parent 300acb8347
commit eed11ebee9
31 changed files with 1104 additions and 973 deletions

View File

@@ -3,47 +3,36 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig,
SiglipVisionConfig)
from transformers import (BatchFeature, LlavaOnevisionConfig,
LlavaOnevisionProcessor)
from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .llava_next import LlavaNextMultiModalProcessor
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
@@ -92,286 +81,251 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs]
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches,
scale_height,
scale_width):
current_height = patches * scale_height
current_width = patches * scale_width
class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(width * (current_height / height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
def _get_hf_config(self) -> LlavaOnevisionConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig)
unpadded_features = current_height * current_width
newline_features = current_height
def _get_hf_processor(self) -> LlavaOnevisionProcessor:
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
ratio = math.sqrt(current_height * current_width / (9 * patches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
return (unpadded_features, newline_features)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
def get_llava_onevision_image_feature_size(
hf_config: LlavaOnevisionConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
return {
"image": max_image_tokens,
"video": max_video_tokens,
}
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.batched("video"),
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
def _get_num_unpadded_features(
self,
*,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(original_height *
(current_width / original_width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(original_width *
(current_height / original_height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
return (unpadded_features, newline_features)
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
) -> int:
num_frame_tokens = self._get_num_frame_tokens(
image_width=image_width,
image_height=image_height,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
return num_frame_tokens * num_frames + 1 # Newline token
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
def _get_max_video_tokens(self, num_frames: int) -> int:
return self._get_num_video_tokens(image_width=999999,
image_height=999999,
num_frames=num_frames)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_onevision_image_unppaded_feature_size(
input_height, input_width, num_patches, num_patch_height,
num_patch_width)
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
return unpadded_feature_size + newline_feature_size + base_feature_size
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
def get_max_llava_onevision_image_tokens(ctx: InputContext):
return get_llava_onevision_image_feature_size(
ctx.get_hf_config(LlavaOnevisionConfig),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
num_frames = next_num_frames
return num_frames
def get_llava_onevision_video_frame_feature_size(
hf_config: LlavaOnevisionConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr(
hf_config, "spatial_pool_stride") else 2
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
height = width = image_size // patch_size
return math.ceil(height / spatial_pool_stride) * math.ceil(
width / spatial_pool_stride)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def get_llava_onevision_video_tokens(ctx: InputContext,
num_frames: int) -> int:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
def _get_video_token(self) -> str:
return self._get_hf_processor().video_token
# TODO: support configuring (not supported by HF right now)
num_token_image_newline = 1
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
assert isinstance(videos, list)
return video_feature_size
if not videos:
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
video_token = self._get_video_token()
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
num_videos = mm_counts["video"]
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_image(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_onevision_image_feature_size(
hf_config,
input_height=height,
input_width=width,
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
text_image_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_onevision_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
pixel_values_videos = []
for video in videos:
item_processor_data = dict(prompt=video_token, videos=video)
item_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)
pixel_values_videos.append(
item_outputs.pop("pixel_values_videos")[0])
combined_outputs = dict(
**text_image_outputs,
pixel_values_videos=pixel_values_videos,
)
return BatchFeature(combined_outputs)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
image_repls = super()._get_prompt_replacements(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
)
hf_config = self._get_hf_config()
video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
if isinstance(videos, VideoEmbeddingItems):
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
)
return [video_token_id] * num_video_tokens
return image_repls + [
PromptReplacement(
modality="video",
target=[video_token_id],
replacement=get_video_replacement,
),
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
image_token = self._get_image_token()
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_video(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
video_feature_size = []
for video in video_data:
num_frames = video.shape[0]
video_feature_size.append(
get_llava_onevision_video_tokens(ctx, num_frames))
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
else:
raise TypeError(f"Invalid video type: {type(video_data)}")
msg = f"Unsupported video type: {type(video_data)}"
raise NotImplementedError(msg)
def input_processor_for_llava_onevision(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or ("video" not in multi_modal_data
and "image" not in multi_modal_data):
return inputs
if "image" in multi_modal_data:
return input_processor_when_multimodal_input_image(ctx, inputs)
if "video" in multi_modal_data:
return input_processor_when_multimodal_input_video(ctx, inputs)
msg = "Unsupported multi data type"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module):
@@ -394,14 +348,7 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_llava_onevision_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):