[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,38 +3,32 @@ from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
|
||||
SiglipVisionConfig)
|
||||
from transformers import (BatchFeature, LlavaNextVideoConfig,
|
||||
LlavaNextVideoProcessor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 32
|
||||
_MAX_NUM_VIDEOS = 1
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
|
||||
|
||||
class LlavaNextVideoPixelInputs(TypedDict):
|
||||
@@ -50,144 +44,149 @@ class LlavaNextVideoPixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
def get_llava_next_video_frame_feature_size(
|
||||
hf_config: LlavaNextVideoConfig) -> int:
|
||||
# Support both CLIPVisionConfig and SiglipVisionConfig
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||
class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
|
||||
return int((image_size / patch_size / spatial_pool_stride)**2)
|
||||
def _get_hf_config(self) -> LlavaNextVideoConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaNextVideoProcessor:
|
||||
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
|
||||
|
||||
def _get_max_llm_tokens(ctx: InputContext) -> int:
|
||||
"""
|
||||
Calculated from the maximum video frames under the context length
|
||||
constraints of the language model.
|
||||
"""
|
||||
hf_text_config = ctx.model_config.hf_text_config
|
||||
model_config = ctx.model_config
|
||||
max_tokens = model_config.max_model_len
|
||||
rope_scaling = model_config.rope_scaling
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"video": 1}
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling_factor = hf_text_config.rope_scaling["factor"]
|
||||
else:
|
||||
rope_scaling_factor = 1
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
num_frames = self._get_dummy_num_frames(seq_len)
|
||||
max_video_tokens = self._get_max_video_tokens(num_frames)
|
||||
|
||||
max_tokens *= rope_scaling_factor
|
||||
return {"video": max_video_tokens}
|
||||
|
||||
return max_tokens
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
|
||||
|
||||
def _get_num_frame_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
spatial_pool_stride = hf_config.spatial_pool_stride
|
||||
|
||||
def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
|
||||
# Currently set to 32 frames
|
||||
# TODO: max_tokens = _get_max_llm_tokens(ctx)
|
||||
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
|
||||
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
|
||||
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
|
||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||
|
||||
return pooled_grid_length * pooled_grid_length
|
||||
|
||||
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# TODO: support multiple videos
|
||||
num_videos = mm_counts["video"]
|
||||
if num_videos != _MAX_NUM_VIDEOS:
|
||||
raise NotImplementedError(
|
||||
f"Only {_MAX_NUM_VIDEOS} videos are supported")
|
||||
|
||||
# TODO: support configuring the number of frames
|
||||
frames_per_video = _MAX_FRAMES_PER_VIDEO
|
||||
# num_images = num_videos * frames_per_video
|
||||
|
||||
# fills the sequence with as longer video data as possible
|
||||
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
|
||||
video_feature_size = frames_per_video * tokens_per_frame
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_videos,
|
||||
image_token_id=hf_config.video_token_index,
|
||||
image_feature_size_override=video_feature_size,
|
||||
mm_key="video",
|
||||
def _get_num_video_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
num_frames: int,
|
||||
) -> int:
|
||||
num_frame_tokens = self._get_num_frame_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
pil_frame = dummy_image_for_clip(vision_config, num_images=1)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
seq_data, ranges = dummy_seq_data_for_siglip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
num_videos,
|
||||
image_token_id=hf_config.video_token_index,
|
||||
image_feature_size_override=video_feature_size,
|
||||
mm_key="video",
|
||||
return num_frame_tokens * num_frames
|
||||
|
||||
def _get_max_video_tokens(self, num_frames: int) -> int:
|
||||
return self._get_num_video_tokens(image_width=999999,
|
||||
image_height=999999,
|
||||
num_frames=num_frames)
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
|
||||
if self._get_max_video_tokens(next_num_frames) > max_tokens:
|
||||
break
|
||||
|
||||
num_frames = next_num_frames
|
||||
|
||||
return num_frames
|
||||
|
||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
def _get_video_token(self) -> str:
|
||||
return self._get_hf_processor().video_token
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
video_token_id = hf_config.video_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems))
|
||||
|
||||
if isinstance(videos, VideoEmbeddingItems):
|
||||
num_video_tokens = videos.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = videos.get_frame_size(item_idx)
|
||||
num_video_tokens = self._get_num_video_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
num_frames=videos.get_num_frames(item_idx),
|
||||
)
|
||||
|
||||
return [video_token_id] * num_video_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="video",
|
||||
target=[video_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
video_token = self._get_video_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
|
||||
mm_data = {
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=video_token * num_videos,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
|
||||
np_frame = np.array(pil_frame["image"])
|
||||
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
|
||||
mm_data = {"video": mm_data_per_video}
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava_next_video(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "video" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
if "multi_modal_placeholders" in inputs and "video" in inputs[
|
||||
"multi_modal_placeholders"]:
|
||||
# The inputs already have placeholders.
|
||||
return inputs
|
||||
|
||||
video_data = multi_modal_data["video"]
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(video_data, np.ndarray):
|
||||
# Supports both CLIP and Siglip
|
||||
num_frames = video_data.shape[0]
|
||||
frame_feature_size = \
|
||||
get_llava_next_video_frame_feature_size(hf_config)
|
||||
video_feature_size = num_frames * frame_feature_size
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer,
|
||||
inputs.get("prompt"),
|
||||
inputs["prompt_token_ids"],
|
||||
placeholder_token_id=hf_config.video_token_index,
|
||||
repeat_count=video_feature_size,
|
||||
)
|
||||
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"video": ranges})
|
||||
|
||||
elif is_list_of(video_data, np.ndarray):
|
||||
raise NotImplementedError(
|
||||
"Processing multiple videos is not supported")
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
# adopted from transformers modeling_llava_next_video.py
|
||||
class LlavaNextVideoPooler(nn.Module):
|
||||
@@ -246,11 +245,7 @@ class LlavaNextMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("video")
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"video", get_max_llava_next_video_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
|
||||
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user