[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-04 19:40:53 +08:00
committed by GitHub
parent 300acb8347
commit eed11ebee9
31 changed files with 1104 additions and 973 deletions

View File

@@ -3,38 +3,32 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
SiglipVisionConfig)
from transformers import (BatchFeature, LlavaNextVideoConfig,
LlavaNextVideoProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1
from .vision import BaseVisionLanguageMultiModalProcessor
class LlavaNextVideoPixelInputs(TypedDict):
@@ -50,144 +44,149 @@ class LlavaNextVideoPixelInputs(TypedDict):
"""
def get_llava_next_video_frame_feature_size(
hf_config: LlavaNextVideoConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride
class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
return int((image_size / patch_size / spatial_pool_stride)**2)
def _get_hf_config(self) -> LlavaNextVideoConfig:
return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_hf_processor(self) -> LlavaNextVideoProcessor:
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def _get_max_llm_tokens(ctx: InputContext) -> int:
"""
Calculated from the maximum video frames under the context length
constraints of the language model.
"""
hf_text_config = ctx.model_config.hf_text_config
model_config = ctx.model_config
max_tokens = model_config.max_model_len
rope_scaling = model_config.rope_scaling
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
if rope_scaling:
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
else:
rope_scaling_factor = 1
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
max_tokens *= rope_scaling_factor
return {"video": max_video_tokens}
return max_tokens
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
# Currently set to 32 frames
# TODO: max_tokens = _get_max_llm_tokens(ctx)
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos != _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
frames_per_video = _MAX_FRAMES_PER_VIDEO
# num_images = num_videos * frames_per_video
# fills the sequence with as longer video data as possible
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
def _get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
) -> int:
num_frame_tokens = self._get_num_frame_tokens(
image_width=image_width,
image_height=image_height,
)
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
return num_frame_tokens * num_frames
def _get_max_video_tokens(self, num_frames: int) -> int:
return self._get_num_video_tokens(image_width=999999,
image_height=999999,
num_frames=num_frames)
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
while True:
next_num_frames = num_frames + 1
if self._get_max_video_tokens(next_num_frames) > max_tokens:
break
num_frames = next_num_frames
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_total_frames = self._get_max_video_frames(seq_len)
return max(max_total_frames // max(max_videos, 1), 1)
def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
def _get_video_token(self) -> str:
return self._get_hf_processor().video_token
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
if isinstance(videos, VideoEmbeddingItems):
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
)
return [video_token_id] * num_video_tokens
return [
PromptReplacement(
modality="video",
target=[video_token_id],
replacement=get_replacement,
),
]
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_for_llava_next_video(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray):
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
frame_feature_size = \
get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = num_frames * frame_feature_size
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray):
raise NotImplementedError(
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
# adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module):
@@ -246,11 +245,7 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):