diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index f6ac29877..0d413f115 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -698,6 +698,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | +| `Molmo2ForConditionalGeneration` | Molmo2 | T + I+ / V | `allenai/Molmo2-4B`, `allenai/Molmo2-8B`, `allenai/Molmo2-O-7B` | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + IE+ | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 2d8c6081e..ece830603 100755 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1227,6 +1227,36 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: ) +# Molmo2 +def run_molmo2(questions: list[str], modality: str) -> ModelRequestData: + model_name = "allenai/Molmo2-8B" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + dtype="bfloat16", + limit_mm_per_prompt={modality: 1}, + max_num_batched_tokens=36864, + ) + + if modality == "image": + placeholder = "<|image|>" + elif modality == "video": + placeholder = "<|video|>" + else: + raise ValueError(f"Unsupported modality for molmo2: {modality}") + + prompts = [ + f"{placeholder}<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Nemontron_VL def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData: model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1" @@ -1920,6 +1950,7 @@ model_example_map = { "minimax_vl_01": run_minimax_vl_01, "mistral3": run_mistral3, "molmo": run_molmo, + "molmo2": run_molmo2, "nemotron_vl": run_nemotron_vl, "NVLM_D": run_nvlm_d, "ovis": run_ovis, @@ -1949,6 +1980,7 @@ MODELS_NEED_VIDEO_METADATA = [ "glm4_1v", "glm4_5v", "glm4_5v_fp8", + "molmo2", "qwen3_vl", "qwen3_vl_moe", ] diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 2d7aece52..db213d1ff 100755 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -1301,6 +1301,43 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_molmo2(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "allenai/Molmo2-8B" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + dtype="bfloat16", + limit_mm_per_prompt={"image": len(image_urls)}, + max_num_batched_tokens=36864, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + model_example_map = { "aria": load_aria, "aya_vision": load_aya_vision, @@ -1323,6 +1360,7 @@ model_example_map = { "llava-next": load_llava_next, "llava-onevision": load_llava_onevision, "mistral3": load_mistral3, + "molmo2": load_molmo2, "NVLM_D": load_nvlm_d, "ovis": load_ovis, "ovis2_5": load_ovis2_5, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index caeafac21..308784564 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -123,6 +123,7 @@ MM_DATA_PATCHES = { "glm4v": glm4_1v_patch_mm_data, "glm4v_moe": glm4_1v_patch_mm_data, "glmasr": glmasr_patch_mm_data, + "molmo2": qwen3_vl_patch_mm_data, "qwen3_vl": qwen3_vl_patch_mm_data, "qwen3_vl_moe": qwen3_vl_patch_mm_data, } diff --git a/tests/models/registry.py b/tests/models/registry.py index a506408a0..62f6c92f4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -92,6 +92,11 @@ class _HfExamplesInfo: length that is too large to fit into memory in CI. """ + max_num_batched_tokens: int | None = None + """ + The maximum number of tokens to be processed in a single batch. + """ + revision: str | None = None """ The specific revision (commit hash, tag, or branch) to use for the model. @@ -817,6 +822,14 @@ _MULTIMODAL_EXAMPLE_MODELS = { extras={"olmo": "allenai/Molmo-7B-O-0924"}, trust_remote_code=True, ), + "Molmo2ForConditionalGeneration": _HfExamplesInfo( + "allenai/Molmo2-8B", + extras={"olmo": "allenai/Molmo2-O-7B"}, + min_transformers_version="4.51", + trust_remote_code=True, + # required by current PrefixLM implementation + max_num_batched_tokens=31872, + ), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), "Llama_Nemotron_Nano_VL": _HfExamplesInfo( "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 61e8c601f..3efa504c7 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -140,6 +140,7 @@ def can_initialize( else None, trust_remote_code=model_info.trust_remote_code, max_model_len=model_info.max_model_len, + max_num_batched_tokens=model_info.max_num_batched_tokens, # these tests seem to produce leftover memory gpu_memory_utilization=0.80, load_format="dummy", diff --git a/vllm/config/model.py b/vllm/config/model.py index 166fc950c..df25e900c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1127,6 +1127,7 @@ class ModelConfig: """Whether to use bidirectional attention for mm positions.""" MM_PREFIX_LM_MODELS = ( "gemma3", + "molmo2", "paligemma", ) if not hasattr(self.hf_config, "model_type"): diff --git a/vllm/model_executor/models/molmo2.py b/vllm/model_executor/models/molmo2.py new file mode 100644 index 000000000..33ee1d343 --- /dev/null +++ b/vllm/model_executor/models/molmo2.py @@ -0,0 +1,2793 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass, fields +from functools import cached_property, partial +from itertools import islice +from typing import Annotated, Any + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import ImageOps +from PIL.Image import Image +from transformers import ( + BatchFeature, + PretrainedConfig, + ProcessorMixin, + TensorType, +) +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import TextInput +from transformers.video_utils import VideoInput, VideoMetadata + +from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import MulAndSilu, SiluAndMul, get_act_fn +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + ImageProcessorItems, + ImageSize, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.utils.math_utils import round_down +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsPP, + SupportsQuant, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +# Special tokens. These should be present in any tokenizer we use +# because the preprocessor relies on them. +IMAGE_PROMPT = "<|image|>" +VIDEO_PROMPT = "<|video|>" +_MAX_VIDEO_FPS = 8 + + +class Molmo2ImageInputs(TensorSchema): + """ + Dimensions: + - nc: The total number of crops (dynamic) + - np: The total number of patches per crop + - cps: Number of channels * patch_size * patch_size + - npp: Number of pooled patches (dynamic) + - pp: pooling_size * pooling_size + - ni: Number of images + - nt: Number of image tokens (dynamic) + """ + + pixel_values: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")] + + token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")] + """ + An index tensor that maps image features to their corresponding + patch tokens before pooling. + """ + + num_pooled_patches: Annotated[torch.Tensor, TensorShape("ni")] + + image_tokens: Annotated[torch.BoolTensor, TensorShape("nt")] + + num_image_tokens: Annotated[torch.Tensor, TensorShape("ni")] + + +class Molmo2VideoInputs(TensorSchema): + """ + Dimensions: + - nc: The total number of frames (dynamic) + - np: The total number of patches per frame + - cps: Number of channels * patch_size * patch_size + - npp: Number of pooled patches (dynamic) + - pp: pooling_size * pooling_size + - nv: Number of videos + - nt: Number of video tokens (dynamic) + """ + + pixel_values_videos: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")] + + token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")] + """ + An index tensor that maps image features to their corresponding + patch tokens before pooling. + """ + + num_pooled_patches: Annotated[torch.Tensor, TensorShape("nv")] + + video_tokens: Annotated[torch.BoolTensor, TensorShape("nt")] + + num_video_tokens: Annotated[torch.Tensor, TensorShape("nv")] + + +@dataclass +class VitConfig: + """Config for a vision transformer""" + + hidden_size: int = 1152 + intermediate_size: int = 4304 + num_hidden_layers: int = 27 + num_attention_heads: int = 16 + num_key_value_heads: int = 16 + head_dim: int = 72 + hidden_act: str = "gelu_pytorch_tanh" + layer_norm_eps: float = 1e-6 + image_default_input_size: tuple[int, int] = (378, 378) + image_patch_size: int = 14 + image_num_pos: int = 577 + + def __post_init__(self): + self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] + + @property + def image_num_patch(self): + h, w = self.image_default_input_size + return h // self.image_patch_size, w // self.image_patch_size + + +@dataclass +class AdapterConfig: + """Config for a vit-llm adapter""" + + vit_layers: tuple[int, int] = (-3, -9) + pooling_attention_mask: bool = False + hidden_size: int = 1152 + num_attention_heads: int = 16 + num_key_value_heads: int = 16 + head_dim: int = 72 + hidden_act: str = "silu" + intermediate_size: int = 18944 + text_hidden_size: int = 3584 + + +@dataclass +class TextConfig: + """Configuration for a text model transformer""" + + hidden_size: int = 3584 + """ + The hidden size of the model. + """ + + num_attention_heads: int = 28 + """ + The number of self-attention heads. + """ + + num_key_value_heads: int = 4 + """ + The number of heads to use for keys and values. + """ + + head_dim: int = 128 + """ + The head dimensionality for the attention mechanism. + """ + + vocab_size: int = 152064 + """Vocabulary size of the model.""" + + additional_vocab_size: int = 128 + """Number of additional tokens to have the input embeddings for""" + + qkv_bias: bool = True + """ + Do QKV projection a bias + """ + + num_hidden_layers: int = 48 + """ + The number of layers/blocks. + """ + + intermediate_size: int = 18944 + """ + The hidden size for the MLP. + """ + + hidden_act: str = "silu" + """ + The activation function to use within the MLP layers. + """ + + max_position_embeddings: int = 4096 + """ + Max positional embeddings to use in RoPE cache + """ + + rope_theta: float = 1000000.0 + """ + RoPE theta parameter. + """ + + use_qk_norm: bool = False + """ + Apply layer norm to the keys and queries within the attention mechanism. + This can help stabilize training. + """ + + qk_norm_type: str = "olmo" + """ + The type of layer norm to use for the keys and queries. + Can be "olmo" or "qwen3". + """ + + layer_norm_eps: float = 1e-6 + """ + epsilon for layer norms + """ + + norm_after: bool = False + """Apply layer norm before and after the attention and MLP blocks.""" + + rope_scaling_layers: tuple[int, ...] | None = None + """ + RoPE scaling layers. + """ + + +class ViTMLP(nn.Module): + """MLP used in Vision Transformer.""" + + def __init__( + self, + dim: int, + hidden_dim: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + self.w1 = ColumnParallelLinear( + dim, + hidden_dim, + bias=True, + quant_config=quant_config, + ) + # Activation function. + self.act = get_act_fn(hidden_act) + self.w2 = RowParallelLinear( + hidden_dim, + dim, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.w1(x) + x = self.act(x) + x, _ = self.w2(x) + return x + + +class ViTMultiHeadDotProductAttention(nn.Module): + """Multi-head attention used in Vision Transformer.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_key_value_heads: int, + head_dim: int, + use_bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.total_num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = head_dim + + assert self.head_dim == self.hidden_size // self.total_num_heads + + self.total_num_kv_heads = num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.merged_qkv = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=use_bias, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=use_bias, + quant_config=quant_config, + ) + self.scale = self.head_dim**-0.5 + self.attn = MMEncoderAttention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + prefix=f"{prefix}.attn", + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + qkv, _ = self.merged_qkv(inputs) + xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + output = self.attn(xq, xk, xv) + + output, _ = self.wo(output) + + return output + + +class Molmo2VisionBlock(nn.Module): + """Residual attention block used in Vision Transformer.""" + + def __init__( + self, + config: VitConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ViTMultiHeadDotProductAttention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + quant_config=quant_config, + prefix=f"{prefix}.attention", + ) + self.feed_forward = ViTMLP( + dim=config.hidden_size, + hidden_dim=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.attention_norm = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + self.ffn_norm = nn.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class Molmo2VisionBlockCollection(nn.Module): + """Collection of residual attention blocks used in Vision Transformer.""" + + def __init__( + self, + config: VitConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.resblocks = nn.ModuleList( + [ + Molmo2VisionBlock( + config, + quant_config, + prefix=f"{prefix}.resblocks.{layer_idx}", + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + hidden_states = [] + for r in self.resblocks: + x = r(x) + hidden_states.append(x) + return hidden_states + + +class Molmo2VisionTransformer(nn.Module): + """Vision Transformer used in Vision Backbone.""" + + def __init__( + self, + config: VitConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + scale = config.hidden_size**-0.5 + self.num_prefix_tokens: int = 0 # no class embeddings + self.patch_num = config.image_num_patch + self.positional_embedding = nn.Parameter( + torch.randn(config.image_num_pos, config.hidden_size) * scale, + ) + image_patch_size = config.image_patch_size + self.patch_embedding = nn.Linear( + image_patch_size * image_patch_size * 3, + config.hidden_size, + bias=True, + ) + self.transformer = Molmo2VisionBlockCollection( + config, + quant_config, + prefix=f"{prefix}.transformer", + ) + + def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: + pos_emb = self.positional_embedding + + pos_emb = pos_emb.reshape( + ( + int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), + pos_emb.shape[1], + ) + ) + + (patch_num_0, patch_num_1) = patch_num + + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: + # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) + pos_emb = F.interpolate( + pos_emb, + size=(patch_num_0, patch_num_1), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) + + pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) + x = x + pos_emb[None, :, :].to(x.dtype) + return x + + def forward( + self, + x: torch.Tensor, + patch_num: int | None = None, + ) -> list[torch.Tensor]: + """ + : param x: (batch_size, num_patch, n_pixels) + """ + if patch_num is None: + patch_num = self.patch_num + + x = self.patch_embedding(x) + + x = self.add_pos_emb(x, patch_num) + + hidden_states = self.transformer(x) + return hidden_states + + +class ImagePoolingAttention(nn.Module): + """Multi-head attention used for image pooling""" + + def __init__( + self, + input_dim: int, + hidden_size: int, + num_heads: int, + num_key_value_heads: int, + head_dim: int, + use_bias: bool = True, + use_pytorch_sdpa: bool = False, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + self.input_dim = input_dim + self.hidden_size = hidden_size + self.total_num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = head_dim + + assert self.head_dim == self.hidden_size // self.total_num_heads + + self.total_num_kv_heads = num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.kv_size = self.num_kv_heads * self.head_dim + + self.q_proj = ColumnParallelLinear( + self.input_dim, + self.total_num_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.merged_kv = MergedColumnParallelLinear( + self.input_dim, + [self.total_num_kv_heads * self.head_dim] * 2, + bias=use_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=use_bias, + quant_config=quant_config, + ) + self.scale = self.head_dim**-0.5 + self.use_pytorch_sdpa = use_pytorch_sdpa + if use_pytorch_sdpa: + self.attn = None + else: + self.attn = MMEncoderAttention( + self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + ) + + def forward_sdpa( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + bsz, q_len, _ = query.size() + kv_len = key.size(1) + + query = query.view(bsz, q_len, self.num_heads, self.head_dim) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim) + + if self.num_heads != self.num_kv_heads: + key = torch.repeat_interleave( + key, + self.num_heads // self.num_kv_heads, + dim=2, + ) + value = torch.repeat_interleave( + value, + self.num_heads // self.num_kv_heads, + dim=2, + ) + + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + + out = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + ).transpose(1, 2) + + return out.reshape(bsz, q_len, -1) + + def forward( + self, + inputs_q: torch.Tensor, + inputs_kv: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + xq, _ = self.q_proj(inputs_q) + kv, _ = self.merged_kv(inputs_kv) + xk, xv = kv.split([self.kv_size, self.kv_size], dim=-1) + + if self.use_pytorch_sdpa: + output = self.forward_sdpa(xq, xk, xv, attn_mask) + else: + output = self.attn(xq, xk, xv) + + output, _ = self.o_proj(output) + + return output + + +class ImageProjectorMLP(nn.Module): + """MLP used for the image projector""" + + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + self.merged_linear = MergedColumnParallelLinear( + input_dim, + [hidden_dim] * 2, + bias=False, + quant_config=quant_config, + ) + # Activation function. + assert hidden_act == "silu" + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + hidden_dim, + output_dim, + bias=False, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.merged_linear(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class Molmo2VisionBackbone(nn.Module, SupportsQuant): + packed_modules_mapping = { + "merged_qkv": ["wq", "wk", "wv"], # vision backbone + "merged_kv": ["k_proj", "v_proj"], # image_pooling_2d + "merged_linear": ["gate_proj", "up_proj"], + } + + def __init__( + self, + vit_config: VitConfig, + adapter_config: AdapterConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.vit_config = vit_config + self.adapter_config = adapter_config + + self.vit_layers = [] + for layer in adapter_config.vit_layers: + if layer >= 0: + self.vit_layers.append(layer) + else: + self.vit_layers.append(layer + vit_config.num_hidden_layers) + + last_layer_needed = max(self.vit_layers) + 1 + if last_layer_needed < vit_config.num_hidden_layers: + vit_config.num_hidden_layers = last_layer_needed + self.image_vit = Molmo2VisionTransformer( + vit_config, + quant_config, + prefix=f"{prefix}.image_vit", + ) + + self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens + + pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers) + self.image_pooling_2d = ImagePoolingAttention( + input_dim=pool_dim, + hidden_size=adapter_config.hidden_size, + num_heads=adapter_config.num_attention_heads, + num_key_value_heads=adapter_config.num_key_value_heads, + head_dim=adapter_config.head_dim, + use_pytorch_sdpa=adapter_config.pooling_attention_mask, + quant_config=quant_config, + ) + self.image_projector = ImageProjectorMLP( + input_dim=adapter_config.hidden_size, + hidden_dim=adapter_config.intermediate_size, + output_dim=adapter_config.text_hidden_size, + hidden_act=adapter_config.hidden_act, + quant_config=quant_config, + ) + + @property + def dtype(self) -> torch.dtype: + return self.image_vit.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.image_vit.patch_embedding.weight.device + + def encode_image(self, images: torch.Tensor) -> torch.Tensor: + """ + : param images: (batch_size, num_crops, num_patch, n_pixels) + """ + B, T, N, D = images.shape + images = images.view(B * T, N, D) + image_features = self.image_vit(images) + + features = [] + for layer in self.vit_layers: + features.append(image_features[layer]) + image_features = torch.cat(features, dim=-1) + + if self.num_prefix_tokens > 0: + image_features = image_features[:, 1:] + image_features = image_features.view(B, T, N, -1) + return image_features + + def forward( + self, + images: torch.Tensor, + token_pooling: torch.Tensor, + ) -> torch.Tensor: + # image_features shape: + # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) + batch_size, num_image = images.shape[:2] + images = images.to(device=self.device, dtype=self.dtype) + image_features = self.encode_image(images) + + dim = image_features.shape[-1] + valid = token_pooling >= 0 + valid_token = torch.any(valid, -1) + + # Use `token_pooling` to arange the features for image pooling + batch_idx = torch.arange( + token_pooling.shape[0], + dtype=torch.long, + device=token_pooling.device, + ) + batch_idx = torch.tile( + batch_idx.view(batch_size, 1, 1), + [1, token_pooling.shape[1], token_pooling.shape[2]], + ) + + # Now [batch, num_features, num_pooled_patches, dim] + to_pool = image_features.reshape(batch_size, -1, dim)[ + batch_idx, torch.clip(token_pooling, 0) + ] + to_pool = to_pool * valid.to(self.dtype)[:, :, :, None] + to_pool = to_pool.reshape([-1, token_pooling.shape[-1], dim]) + if self.adapter_config.pooling_attention_mask: + attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]]) + denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1) + denom = torch.where(denom == 0, 1, denom) + query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to( + to_pool.dtype + ) + else: + attn_mask = None + query = to_pool.mean(-2, keepdim=True) + + pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask) + pooled_features = pooled_features.reshape( + [batch_size, -1, pooled_features.shape[-1]] + ) + + # MLP layer to map the feature. + pooled_features = self.image_projector(pooled_features) + return pooled_features.view(-1, pooled_features.shape[-1])[ + valid_token.flatten() + ] + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("merged_qkv", "wq", "q"), + ("merged_qkv", "wk", "k"), + ("merged_qkv", "wv", "v"), + ("merged_kv", "k_proj", 0), + ("merged_kv", "v_proj", 1), + ("merged_linear", "gate_proj", 0), + ("merged_linear", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Molmo2Attention(nn.Module): + """Molmo2's LLM Attention.""" + + def __init__( + self, + config: TextConfig, + rope_parameters: dict[str, Any], + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.tp_rank: int | None = None + self.k_norm: nn.Module | None = None + self.q_norm: nn.Module | None = None + self.qk_norm_type: str | None = None + if config.use_qk_norm: + k_norm_size = ( + self.head_dim + if config.qk_norm_type == "qwen3" + else self.total_num_kv_heads * self.head_dim + ) + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm(k_norm_size, eps=config.layer_norm_eps) + q_norm_size = ( + self.head_dim + if config.qk_norm_type == "qwen3" + else self.total_num_heads * self.head_dim + ) + self.q_norm = RMSNorm(q_norm_size, eps=config.layer_norm_eps) + self.qk_norm_type = config.qk_norm_type + # Rotary embeddings. Rope scaling is only applied on full attention layers. + layer_idx = extract_layer_index(prefix) + if ( + config.rope_scaling_layers is not None + and layer_idx not in config.rope_scaling_layers + ): + rope_theta = rope_parameters["rope_theta"] + rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} + self.rotary_emb = get_rope( + self.head_dim, + max_position=self.max_position_embeddings, + rope_parameters=rope_parameters, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def _apply_qk_norm( + self, + q: torch.Tensor, + k: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if ( + self.q_norm is not None + and self.k_norm is not None + and self.qk_norm_type == "olmo" + ): + q, k = self._apply_qk_norm(q, k) + elif self.q_norm is not None and self.k_norm is not None: + q_by_head = q.view( + *q.shape[:-1], + q.shape[-1] // self.head_dim, + self.head_dim, + ) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view( + *k.shape[:-1], + k.shape[-1] // self.head_dim, + self.head_dim, + ) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + + output, _ = self.o_proj(attn_output) + return output + + +class LanguageModelMLP(nn.Module): + """Molmo2's LLM mlp.""" + + def __init__( + self, + input_dim: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + ) -> None: + super().__init__() + + self.up_gate_proj = MergedColumnParallelLinear( + input_dim, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + # Activation function. + assert hidden_act == "silu" + self.act_fn = MulAndSilu() + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + intermediate_size, + input_dim, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + up_gate, _ = self.up_gate_proj(x) + x = self.act_fn(up_gate) + x, _ = self.down_proj(x) + return x + + +class Molmo2DecoderLayer(nn.Module): + def __init__( + self, + config: TextConfig, + rope_parameters: dict[str, Any], + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + # Attention block. + self.self_attn = Molmo2Attention( + config, + rope_parameters, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn", + ) + + # MLP block. + self.mlp = LanguageModelMLP( + config.hidden_size, + config.intermediate_size, + config.hidden_act, + quant_config, + ) + + # LayerNorm + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs: object, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Molmo2DecoderNormAfterLayer(Molmo2DecoderLayer): + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + **kwargs: object, + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: + # Self Attention + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None + return hidden_states, residual + + +@support_torch_compile +class Molmo2TextModel(nn.Module, SupportsQuant): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + if hasattr(config, "text_config"): + hf_text_config = config.text_config + else: + hf_text_config = config.llm_config + + kwargs = {} + for field in fields(TextConfig): + kwargs[field.name] = getattr(hf_text_config, field.name) + text_config = TextConfig(**kwargs) + + self.embedding_size = text_config.vocab_size + self.embedding_size += text_config.additional_vocab_size or 0 + self.embed_tokens = VocabParallelEmbedding( + self.embedding_size, + text_config.hidden_size, + quant_config=quant_config, + ) + + decoder_layer = ( + Molmo2DecoderNormAfterLayer + if text_config.norm_after + else Molmo2DecoderLayer + ) + self.start_layer, self.end_layer, self.layers = make_layers( + text_config.num_hidden_layers, + lambda prefix: decoder_layer( + text_config, + hf_text_config.rope_parameters, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(text_config.hidden_size, eps=text_config.layer_norm_eps) + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], + text_config.hidden_size, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Apply blocks one-by-one. + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, + hidden_states, + residual, + **kwargs, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def get_patches_grid_size( + *, + image_h: int, + image_w: int, + patch_size: int, + pool_h: int, + pool_w: int, +) -> tuple[int, int]: + patch_h = image_h // patch_size + patch_w = image_w // patch_size + h_pad = round_down(patch_h + pool_h - 1, pool_h) - patch_h + w_pad = round_down(patch_w + pool_w - 1, pool_w) - patch_w + nrows = (patch_h + h_pad) // pool_h + ncols = (patch_w + w_pad) // pool_w + + return nrows, ncols + + +def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: + tilings = [ + (i, j) + for i in range(1, max_num + 1) + for j in range(1, max_num + 1) + if i * j <= max_num + ] + return sorted(tilings, key=lambda x: (x[0] * x[1], x[0])) + + +def select_tiling( + *, + height: int, + width: int, + patch_size: int, + max_num_patches: int, +): + tilings = get_candidate_tilings(max_num_patches) + candidate_tilings = np.array(tilings, dtype=np.int32) + candidate_resolutions = candidate_tilings * patch_size + + original_size = np.array([height, width], dtype=np.float32) + required_scale_d = candidate_resolutions.astype(np.float32) / original_size + required_scale = required_scale_d.min(axis=-1, keepdims=True) + + if (required_scale < 1).all(): + ix = required_scale.argmax() + else: + ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin() + + return candidate_tilings[ix] + + +def get_image_size(image: ImageInput) -> ImageSize: + if isinstance(image, Image): + return ImageSize(*image.size) + elif isinstance(image, (np.ndarray, torch.Tensor)): + assert image.ndim == 3 + h, w, c = image.shape + assert c in [1, 3] + return ImageSize(w, h) + else: + raise ValueError(f"Unknown image type: {type(image)}") + + +def exif_tranpose( + images: ImageInput | None, +) -> ImageInput | None: + if images is None: + return None + if images is not None and isinstance(images, (list, tuple)): + images = [ + exif_tranpose(img) if isinstance(img, Image) else img for img in images + ] + elif images is not None and isinstance(images, Image): + images = ImageOps.exif_transpose(images) + return images + + +def build_flat_image_bool_length( + image_grids: torch.LongTensor, + image_patch_id: int, + low_res_image_start_id: int, + image_start_id: int, + image_col_id: int, + image_end_id: int, +) -> tuple[torch.LongTensor, torch.LongTensor]: + device = image_grids.device + B = image_grids.shape[0] + + resized_h = image_grids[:, 0] + resized_w = image_grids[:, 1] + h = image_grids[:, 2] + w = image_grids[:, 3] + + lengths = resized_h * resized_w + h * (w + 1) + 4 # [B] + total_len = int(lengths.sum().item()) + + flat = torch.empty(total_len, dtype=torch.long, device=device) + + offset = 0 + for i in range(B): + resized_h_i, resized_w_i, h_i, w_i = image_grids[i].tolist() + L_i = int(lengths[i].item()) + + num_low_res_patches = resized_h_i * resized_w_i + + idx = offset + + flat[idx] = low_res_image_start_id + idx += 1 + + if num_low_res_patches > 0: + flat[idx : idx + num_low_res_patches] = image_patch_id + idx += num_low_res_patches + + flat[idx] = image_end_id + idx += 1 + + flat[idx] = image_start_id + idx += 1 + + block_len = w_i + 1 + if block_len > 0 and h_i > 0: + line = torch.empty(block_len, dtype=torch.long, device=device) + if w_i > 0: + line[:w_i] = image_patch_id + line[w_i] = image_col_id + + block = line.repeat(h_i) + flat[idx : idx + h_i * block_len] = block + idx += h_i * block_len + + flat[idx] = image_end_id + idx += 1 + + assert idx - offset == L_i + + offset += L_i + + return flat, lengths + + +def build_flat_video_bool_length( + video_grids: torch.LongTensor, + image_patch_id: int, + frame_start_id: int, + frame_end_id: int, +) -> tuple[torch.LongTensor, torch.LongTensor]: + device = video_grids.device + B = video_grids.shape[0] + + t = video_grids[:, 0] + resized_h = video_grids[:, 1] + resized_w = video_grids[:, 2] + + P = resized_h * resized_w + block_len = P + 2 + lengths = t * block_len + + total_len = int(lengths.sum().item()) + flat = torch.empty(total_len, dtype=torch.long, device=device) + + offset = 0 + for i in range(B): + ti = int(t[i].item()) + Pi = int(P[i].item()) + Li = int(lengths[i].item()) + + block = torch.empty(Pi + 2, dtype=torch.long, device=device) + block[0] = frame_start_id + if Pi > 0: + block[1 : 1 + Pi] = image_patch_id + block[-1] = frame_end_id + + seq = block.repeat(ti) + + flat[offset : offset + Li] = seq + offset += Li + + return flat, lengths + + +class Molmo2ProcessorWrapper: + """ + Wraps :class:`Molmo2Processor` so that it can be called directly. + """ + + def __init__(self, processor: ProcessorMixin, hf_config: PretrainedConfig): + super().__init__() + + self.processor = processor + self.hf_config = hf_config + + @cached_property + def vocab(self) -> dict[str, int]: + return self.processor.tokenizer.vocab # type: ignore + + @cached_property + def max_crops(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + max_crops = image_processor.max_crops + assert isinstance(max_crops, int) + + return max_crops + + @cached_property + def image_pooling_h(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_pooling_h = image_processor.pooling_size[0] + assert isinstance(image_pooling_h, int) + + return image_pooling_h + + @cached_property + def image_pooling_w(self) -> int: + image_processor = self.processor.image_processor # type: ignore + + image_pooling_w = image_processor.pooling_size[1] + assert isinstance(image_pooling_w, int) + + return image_pooling_w + + @cached_property + def video_pooling_h(self) -> int: + video_processor = self.processor.video_processor # type: ignore + + video_pooling_h = video_processor.pooling_size[0] + assert isinstance(video_pooling_h, int) + + return video_pooling_h + + @cached_property + def video_pooling_w(self) -> int: + video_processor = self.processor.video_processor # type: ignore + + video_pooling_w = video_processor.pooling_size[1] + assert isinstance(video_pooling_w, int) + + return video_pooling_w + + @cached_property + def base_image_input_size(self) -> tuple[int, int]: + if getattr(self.processor, "image_processor", None) is not None: + processor = self.processor.image_processor # type: ignore + else: + processor = self.processor.video_processor # type: ignore + + base_image_input_size = (processor.size["height"], processor.size["width"]) + + return base_image_input_size + + @cached_property + def image_patch_size(self) -> int: + if getattr(self.processor, "image_processor", None) is not None: + processor = self.processor.image_processor # type: ignore + else: + processor = self.processor.video_processor # type: ignore + + image_patch_size = processor.patch_size + assert isinstance(image_patch_size, int) + + return image_patch_size + + @cached_property + def overlap_margins(self) -> tuple[int, int]: + image_processor = self.processor.image_processor # type: ignore + + left_margin, right_margin = image_processor.overlap_margins + assert isinstance(left_margin, int) + assert isinstance(right_margin, int) + + return left_margin, right_margin + + @cached_property + def bos_token(self) -> str: + return self.processor.tokenizer.bos_token or self.processor.tokenizer.eos_token + + @cached_property + def image_patch_id(self) -> int: + return self.hf_config.image_patch_id + + @cached_property + def im_col_id(self) -> int: + return self.hf_config.image_col_id + + @cached_property + def im_start_id(self) -> int: + return self.hf_config.image_start_token_id + + @cached_property + def im_end_id(self) -> int: + return self.hf_config.image_end_token_id + + @cached_property + def low_res_im_start_id(self) -> int: + return self.hf_config.low_res_image_start_token_id + + @cached_property + def frame_start_id(self) -> int: + return self.hf_config.frame_start_token_id + + @cached_property + def frame_end_id(self) -> int: + return self.hf_config.frame_end_token_id + + @cached_property + def im_low_res_id(self) -> int: + return self.hf_config.image_low_res_id + + @cached_property + def image_placeholder_id(self) -> int: + return self.vocab[IMAGE_PROMPT] + + @cached_property + def video_placeholder_id(self) -> int: + return self.vocab[VIDEO_PROMPT] + + @cached_property + def image_token_ids(self) -> list[int]: + return [ + self.image_patch_id, + self.im_col_id, + self.im_start_id, + self.low_res_im_start_id, + self.frame_start_id, + self.im_end_id, + self.frame_end_id, + self.im_low_res_id, + ] + + def select_tiling( + self, + *, + image_height: int, + image_width: int, + ) -> tuple[int, int]: + max_crops = self.max_crops + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + tiling_h, tiling_w = select_tiling( + height=image_height - total_margin_pixels, + width=image_width - total_margin_pixels, + patch_size=crop_window_size, + max_num_patches=max_crops, + ) + + return tiling_h, tiling_w + + def get_base_grid_size(self, is_video: bool) -> tuple[int, int]: + base_image_input_size = self.base_image_input_size + + return get_patches_grid_size( + image_h=base_image_input_size[0], + image_w=base_image_input_size[1], + patch_size=self.image_patch_size, + pool_h=self.video_pooling_h if is_video else self.image_pooling_h, + pool_w=self.video_pooling_w if is_video else self.image_pooling_w, + ) + + def get_patches_grid_size( + self, + *, + image_height: int, + image_width: int, + ) -> tuple[int, int]: + left_margin, right_margin = self.overlap_margins + base_image_input_size = self.base_image_input_size + base_image_input_d = self.image_patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + + tiling_h, tiling_w = self.select_tiling( + image_height=image_height, + image_width=image_width, + ) + + h, w = [ + tiling_h * crop_window_size + total_margin_pixels, + tiling_w * crop_window_size + total_margin_pixels, + ] + nrows, ncols = get_patches_grid_size( + image_h=h, + image_w=w, + patch_size=base_image_input_d, + pool_h=self.image_pooling_h, + pool_w=self.image_pooling_w, + ) + + return nrows, ncols + + def __call__( + self, + text: TextInput | list[TextInput] | None = None, + images: ImageInput | None = None, + videos: VideoInput | None = None, + return_tensors: str | TensorType = None, + **kwargs: object, + ) -> BatchFeature: + inputs = [text] + images = exif_tranpose(images) + if getattr(self.processor, "image_processor", None) is not None: + inputs.append(images) + if getattr(self.processor, "video_processor", None) is not None: + inputs.append(videos) + outputs = self.processor( # type: ignore + *inputs, + return_tensors=return_tensors, + **kwargs, + ) + + # revert insert bos token + if outputs["input_ids"][0, 0] == self.vocab[self.bos_token]: + outputs["input_ids"] = outputs["input_ids"][:, 1:] + + if images is None: + images = [] + if not isinstance(images, list): + images = [images] + + if videos is None: + videos = [] + if not isinstance(videos, list): + videos = [videos] + + assert len(videos) in {0, 1}, "At most one video is supported for Molmo2" + + _attention_mask: torch.Tensor = outputs.pop("attention_mask") + _token_type_ids: torch.Tensor = outputs.pop("token_type_ids", None) + + if len(images) > 0: + # For each image: tiling_h * tiling_w + global view + num_crops = [] + for image in images: + image_size = get_image_size(image) + tiling = self.select_tiling( + image_height=image_size.height, + image_width=image_size.width, + ) + num_crops.append(np.prod(tiling) + 1) + + assert sum(num_crops) == len(outputs["pixel_values"]) + assert sum(num_crops) == outputs["image_num_crops"].sum().item() + image_grids: torch.Tensor = outputs.pop("image_grids") + image_num_pooled_patches: torch.Tensor = image_grids[:, :2].prod( + dim=1 + ) + image_grids[:, 2:].prod(dim=1) + outputs["image_num_pooled_patches"] = image_num_pooled_patches + n_patches = outputs["pixel_values"].shape[1] + outputs["image_num_patches"] = outputs["image_num_crops"] * n_patches + image_tokens, num_image_tokens = build_flat_image_bool_length( + image_grids, + self.image_patch_id, + self.low_res_im_start_id, + self.im_start_id, + self.im_col_id, + self.im_end_id, + ) + outputs["image_tokens"] = image_tokens + outputs["num_image_tokens"] = num_image_tokens + + if len(videos) > 0: + video_grids: torch.Tensor = outputs.pop("video_grids") + assert video_grids[:, 0].sum() == len(outputs["pixel_values_videos"]) + outputs["video_num_crops"] = video_grids[:, 0] + outputs["video_num_pooled_patches"] = video_grids.prod(dim=1) + n_patches = outputs["pixel_values_videos"].shape[1] + outputs["video_num_patches"] = outputs["video_num_crops"] * n_patches + video_tokens, num_video_tokens = build_flat_video_bool_length( + video_grids, + self.image_patch_id, + self.frame_start_id, + self.frame_end_id, + ) + outputs["video_tokens"] = video_tokens + outputs["num_video_tokens"] = num_video_tokens + + return BatchFeature(outputs) + + +def get_candidate_target_fps( + video_fps: int | float, + sampling_fps: int | float, + max_fps: int | float = _MAX_VIDEO_FPS, +) -> list[float]: + """ + Return the subset of `video_fps` factors that remain multiples + of `sampling_fps`. + + Examples: + >>> get_candidate_target_fps(video_fps=6, sampling_fps=2) + [2, 6] + >>> get_candidate_target_fps(video_fps=5, sampling_fps=1) + [1, 5] + >>> get_candidate_target_fps(video_fps=2, sampling_fps=2) + [2] + >>> get_candidate_target_fps(video_fps=5, sampling_fps=2) + Traceback (most recent call last): + ... + ValueError: sampling_fps=2 must divide video_fps=5 to produce + consistent frame steps. + """ + video_fps = int(video_fps) + sampling_fps = int(sampling_fps) + max_fps = int(max_fps) + + if sampling_fps is None: + raise ValueError("sampling_fps must be provided") + if video_fps <= 0 or sampling_fps <= 0: + raise ValueError( + "video_fps and sampling_fps must be positive " + f"(got {video_fps}, {sampling_fps})" + ) + if video_fps % sampling_fps != 0: + raise ValueError( + f"sampling_fps={sampling_fps} must divide video_fps={video_fps}." + ) + + candidates = [] + for candidate in range(sampling_fps, video_fps + 1, sampling_fps): + if candidate > max_fps: + break + if video_fps % candidate == 0: + candidates.append(float(candidate)) + + return candidates + + +def get_target_fps( + video_fps: float, + max_frames: int, + total_frames: int, + frame_sample_mode: str, + candidate_target_fps: list[float], +) -> float | None: + """ + Get the target fps that best spans the video and has the most frames sampled + """ + num_frames_sampled = 0 + selected_target_fps = None + for target_fps in candidate_target_fps: + step_size = max(int(video_fps / target_fps), 1) + num_frames_sampled_at_fps = int(total_frames / step_size) + if num_frames_sampled == 0: + if ( + "uniform" in frame_sample_mode + and num_frames_sampled_at_fps > max_frames + ): + break + selected_target_fps = target_fps + num_frames_sampled = num_frames_sampled_at_fps + + else: + # the candidate sampling fps increases so frame count can't decrease + assert num_frames_sampled <= num_frames_sampled_at_fps + if num_frames_sampled_at_fps > max_frames: + # choose the sampling fps that spans the video + continue + + elif num_frames_sampled_at_fps > num_frames_sampled: + # both are less than max_frames; choose the one with higher + # density of frames sampled + selected_target_fps = target_fps + num_frames_sampled = num_frames_sampled_at_fps + return selected_target_fps + + +def get_frame_times_and_chosen_fps( + selected_target_fps, total_frames, max_frames, video_fps +): + if selected_target_fps is None: + frame_indices = np.linspace( + 0, total_frames, max_frames, endpoint=False, dtype=int + ) + else: + step_size = max(int(video_fps / selected_target_fps), 1) + frame_indices = np.arange(0, total_frames, step_size) + if len(frame_indices) > max_frames: + frame_indices = frame_indices[:max_frames] + return selected_target_fps, frame_indices + + +class Molmo2ProcessingInfo(BaseProcessingInfo): + def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper: + processor = self.ctx.get_hf_processor(**kwargs) + hf_config = self.ctx.get_hf_config() + return Molmo2ProcessorWrapper(processor, hf_config) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None, "video": 1} + + def get_num_image_tokens( + self, + *, + image_height: int, + image_width: int, + processor: Molmo2ProcessorWrapper | None = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + hf_processor = processor.processor # type: ignore + + resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False) + # start/end tokens + image patch token + col tokens + if hf_processor.use_single_crop_col_tokens is not None: + use_col_tokens = hf_processor.use_single_crop_col_tokens + else: + use_col_tokens = hf_processor.image_use_col_tokens + extra = 2 + resize_nrows * (resize_cols + int(use_col_tokens)) + overlap_nrows, overlap_ncols = processor.get_patches_grid_size( + image_height=image_height, + image_width=image_width, + ) + joint = 2 + overlap_nrows * ( + overlap_ncols + int(hf_processor.image_use_col_tokens) + ) + + return extra + joint + + def get_num_video_tokens( + self, + *, + num_frames: int, + processor: Molmo2ProcessorWrapper | None = None, + ) -> int: + if processor is None: + processor = self.get_hf_processor() + + resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True) + # start/end tokens + extra = 2 + resize_nrows * ( + resize_cols + int(processor.processor.video_use_col_tokens) + ) + return num_frames * extra + + def get_image_size_with_most_features(self) -> ImageSize: + processor = self.get_hf_processor() + + left_margin, right_margin = processor.overlap_margins + base_image_input_size = processor.base_image_input_size + base_image_input_d = processor.image_patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = base_image_input_size[0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + + tilings = get_candidate_tilings(processor.max_crops) + largest_feature_size, largest_feature_pinpoint = 0, None + + for hr, wr in tilings: + height = hr * crop_window_size + total_margin_pixels + width = wr * crop_window_size + total_margin_pixels + + feat_size = self.get_num_image_tokens( + image_height=height, image_width=width, processor=processor + ) + if feat_size > largest_feature_size: + largest_feature_size = feat_size + largest_feature_pinpoint = ImageSize(width=width, height=height) + + if largest_feature_size == 0 or largest_feature_pinpoint is None: + raise ValueError("Cannot have a largest feature size of 0!") + + return largest_feature_pinpoint + + def _get_max_video_frames(self, max_tokens: int) -> int: + num_tokens_per_frame = self.get_num_video_tokens(num_frames=1) + max_frames = max_tokens // num_tokens_per_frame + return max(max_frames, 1) + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + video_processor = self.get_hf_processor().processor.video_processor + num_frames = video_processor.num_frames + max_videos = mm_counts.get("video", 0) + max_total_frames = self._get_max_video_frames(seq_len) + max_frames_per_video = min( + max_total_frames // max(max_videos, 1), + num_frames, + ) + return max(max_frames_per_video, 1) + + def _sample_frames( + self, + total_num_frames: int, + video_fps: float, + duration: float, + frame_sample_mode: str, + num_frames: int, + max_fps: int, + sampling_fps: int, + ) -> np.ndarray: + if frame_sample_mode == "uniform_last_frame" and max_fps is not None: + if total_num_frames <= 2: + indices = np.arange(total_num_frames).astype(int) + elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame + # uniform fallback + indices = np.linspace( + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, + ).astype(int) + else: + float_indices = np.arange( + 0.0, + stop=total_num_frames - 1, + step=float(video_fps / max_fps), + ) + if np.round(float_indices[-1]) != total_num_frames - 1: + float_indices = np.concatenate( + [float_indices, [total_num_frames - 1]], axis=0 + ) + indices = np.round(float_indices).astype(int) + assert indices[-1] < total_num_frames + assert len(float_indices) <= num_frames + elif frame_sample_mode == "uniform_last_frame": + indices = np.linspace( + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, + ).astype(int) + elif frame_sample_mode == "fps": + candidate_target_fps = get_candidate_target_fps(video_fps, sampling_fps) + selected_target_fps = get_target_fps( + video_fps, + num_frames, + total_num_frames, + frame_sample_mode, + candidate_target_fps, + ) + _, indices = get_frame_times_and_chosen_fps( + selected_target_fps, + total_num_frames, + num_frames, + video_fps, + ) + else: + raise NotImplementedError(frame_sample_mode) + + return indices + + def _get_video_second_idx( + self, + metadata: dict[str, Any], + do_sample_frames: bool | None = None, + ) -> list[float]: + video_processor = self.get_hf_processor().processor.video_processor + # metadata["fps"] refers to the true fps of the input video. + video_fps = metadata["fps"] + frames_indices = metadata.get("frames_indices") + if do_sample_frames is None: + do_sample_frames = metadata.get("do_sample_frames", False) + + if do_sample_frames: + # Frame-based sampling is applied in HF video processor + total_num_frames = metadata["total_num_frames"] + duration = total_num_frames / video_fps + frame_sample_mode = video_processor.frame_sample_mode + num_frames = video_processor.num_frames + max_fps = video_processor.max_fps + sampling_fps = video_processor.sampling_fps + frames_indices = self._sample_frames( + total_num_frames, + video_fps, + duration, + frame_sample_mode, + num_frames, + max_fps, + sampling_fps, + ) + else: + # Time-based sampling is done in vllm molmo2 video loader or molmo_utils + assert frames_indices is not None + timestamps = [frame_idx / video_fps for frame_idx in frames_indices] + return timestamps + + +class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_placeholder_token = IMAGE_PROMPT + video_placeholder_token = VIDEO_PROMPT + + if num_images == 1: + image_string = image_placeholder_token + else: + image_string = "".join( + [f"Image {i + 1}" + image_placeholder_token for i in range(num_images)] + ) + + return image_string + video_placeholder_token * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + dummy_images = [] + dummy_videos = [] + + if num_images > 0: + target_width, target_height = self.info.get_image_size_with_most_features() + + image_overrides = mm_options.get("image") if mm_options else None + + dummy_images = self._get_dummy_images( + width=target_width, + height=target_height, + num_images=num_images, + overrides=image_overrides, + ) + + if num_videos > 0: + processor = self.info.get_hf_processor() + base_image_input_size = processor.base_image_input_size + target_num_frames = self.info.get_num_frames_with_most_features( + seq_len, mm_counts + ) + + video_overrides = mm_options.get("video") if mm_options else None + + if video_overrides: + assert isinstance(video_overrides, VideoDummyOptions) + num_frames_override = video_overrides.num_frames + if num_frames_override: + if num_frames_override > target_num_frames: + logger.warning( + "video.num_frames override (%d) exceeds model's " + "maximum number of frames (%d), will be ignored", + num_frames_override, + target_num_frames, + ) + if num_frames_override < 2: + logger.warning( + "video.num_frames override (%d) cannot be less " + "than 2, will be ignored", + num_frames_override, + ) + target_num_frames = min(target_num_frames, num_frames_override) + + dummy_videos = self._get_dummy_videos( + width=base_image_input_size[1], + height=base_image_input_size[0], + num_frames=target_num_frames, + num_videos=num_videos, + ) + + return { + "image": dummy_images, + "video": dummy_videos, + } + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[VideoItem]: + video = np.full((num_frames, height, width, 3), 255, dtype=np.uint8) + video_items = [] + for i in range(num_videos): + video_metadata = { + "fps": 2.0, + "duration": num_frames / 2.0, + "total_num_frames": num_frames, + "frames_indices": list(range(num_frames)), + "video_backend": "decord", + "do_sample_frames": False, + "height": height, + "width": width, + } + video_item = (video.copy(), video_metadata) + video_items.append(video_item) + return video_items + + +class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + processor = self.info.get_hf_processor() + tokenizer = processor.processor.tokenizer + bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id + + if len(prompt_tokens) > 0 and prompt_tokens[0] != bos_token_id: + # Prepend the bos token to the prompt tokens + prompt_tokens = [bos_token_id] + prompt_tokens + + return prompt_tokens + + def _get_data_parser(self) -> MultiModalDataParser: + return MultiModalDataParser(video_needs_metadata=True) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + processor = self.info.get_hf_processor(**mm_kwargs) + + if videos := mm_data.pop("videos", []): + pixel_values_videos_lst = [] + video_token_pooling_lst = [] + video_num_crops_lst = [] + video_num_pooled_patches_lst = [] + video_num_patches_lst = [] + video_tokens_lst = [] + num_video_tokens_lst = [] + + for item in videos: + video_array, metadata = item + + # NOTE: metadata.frames_indices indicates + # the sampled frames indices of pre-sampled videos, which is + # used to calculate the timestamps. Make sure that + # do_sample_frames in mm_kwargs is false for presampled videos. + + # NOTE: a copy of mm_kwargs is created to update do_sample_frames, + # otherwise mm_hash for the object will be incorrect. + video_mm_kwargs = dict(**mm_kwargs) + if "do_sample_frames" not in video_mm_kwargs: + # molmo_utils already has "do_sample_frames" in + # mm_kwargs, don't overwrite it. + video_mm_kwargs["do_sample_frames"] = metadata.get( + "do_sample_frames", False + ) + + metadata = VideoMetadata( + **{k: metadata[k] for k in metadata if k != "do_sample_frames"} + ) + + video_mm_data = dict() + video_mm_data["videos"] = [[video_array]] + video_mm_data["video_metadata"] = [[metadata]] + + video_outputs = super()._call_hf_processor( + prompt=VIDEO_PROMPT, + mm_data=video_mm_data, + mm_kwargs=video_mm_kwargs, + tok_kwargs=tok_kwargs, + ) + input_ids = video_outputs.pop("input_ids") + video_string = processor.processor.tokenizer.batch_decode(input_ids)[0] + prompt = prompt.replace( + VIDEO_PROMPT, + video_string, + 1, + ) + + pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) + video_token_pooling_lst.append(video_outputs["video_token_pooling"]) + video_num_crops_lst.append(video_outputs["video_num_crops"]) + video_num_pooled_patches_lst.append( + video_outputs["video_num_pooled_patches"] + ) + video_num_patches_lst.append(video_outputs["video_num_patches"]) + video_tokens_lst.append(video_outputs["video_tokens"]) + num_video_tokens_lst.append(video_outputs["num_video_tokens"]) + + video_outputs = dict( + pixel_values_videos=torch.cat(pixel_values_videos_lst), + video_token_pooling=torch.cat(video_token_pooling_lst), + video_num_crops=torch.cat(video_num_crops_lst), + video_num_pooled_patches=torch.cat(video_num_pooled_patches_lst), + video_num_patches=torch.cat(video_num_patches_lst), + video_tokens=torch.cat(video_tokens_lst), + num_video_tokens=torch.cat(num_video_tokens_lst), + ) + else: + video_outputs = dict() + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + bos_token_id = processor.vocab[processor.bos_token] + input_ids = processed_outputs["input_ids"] + # add bos token back to prompt start + if input_ids.numel() > 0 and input_ids[0, 0] != bos_token_id: + bos_token_id_tensor = torch.tensor( + [[bos_token_id]], device=input_ids.device, dtype=input_ids.dtype + ) + processed_outputs["input_ids"] = torch.concat( + [bos_token_id_tensor, input_ids], dim=1 + ) + combined_outputs = dict( + processed_outputs, + **video_outputs, + ) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + image_num_crops = hf_inputs.get("image_num_crops", torch.empty(0)) + image_num_pooled_patches = hf_inputs.get( + "image_num_pooled_patches", torch.empty(0) + ) + video_num_crops = hf_inputs.get("video_num_crops", torch.empty(0)) + video_num_pooled_patches = hf_inputs.get( + "video_num_pooled_patches", torch.empty(0) + ) + num_image_tokens = hf_inputs.get("num_image_tokens", torch.empty(0)) + num_video_tokens = hf_inputs.get("num_video_tokens", torch.empty(0)) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_crops + ), + image_token_pooling=MultiModalFieldConfig.flat_from_sizes( + "image", image_num_pooled_patches + ), + image_num_crops=MultiModalFieldConfig.batched("image"), + image_num_pooled_patches=MultiModalFieldConfig.batched("image"), + image_num_patches=MultiModalFieldConfig.batched("image"), + image_tokens=MultiModalFieldConfig.flat_from_sizes( + "image", num_image_tokens + ), + num_image_tokens=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_crops + ), + video_token_pooling=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_pooled_patches + ), + video_num_crops=MultiModalFieldConfig.batched("video"), + video_num_pooled_patches=MultiModalFieldConfig.batched("video"), + video_num_patches=MultiModalFieldConfig.batched("video"), + video_tokens=MultiModalFieldConfig.flat_from_sizes( + "video", num_video_tokens + ), + num_video_tokens=MultiModalFieldConfig.batched("video"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + img_patch_id = processor.image_patch_id + img_col_id = processor.im_col_id + img_start_id = processor.im_start_id + img_end_id = processor.im_end_id + image_use_col_tokens = processor.processor.image_use_col_tokens + use_single_crop_col_tokens = processor.processor.use_single_crop_col_tokens + use_single_crop_start_token = processor.processor.use_single_crop_start_token + video_use_col_tokens = processor.processor.video_use_col_tokens + use_frame_special_tokens = processor.processor.use_frame_special_tokens + + def get_image_replacement_molmo2(item_idx: int) -> list[int]: + images = mm_items.get_items("image", ImageProcessorItems) + image = images.get(item_idx) + image = exif_tranpose(image) + + resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False) + if use_single_crop_col_tokens is not None: + use_col_tokens = use_single_crop_col_tokens + else: + use_col_tokens = image_use_col_tokens + if use_single_crop_start_token: + start_id = processor.low_res_im_start_id + else: + start_id = img_start_id + extra_row = [img_patch_id] * resize_cols + [img_col_id] * int( + use_col_tokens + ) + extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id] + + image_size = get_image_size(image) + + nrows, ncols = processor.get_patches_grid_size( + image_height=image_size.height, + image_width=image_size.width, + ) + + joint_row = [img_patch_id] * ncols + [img_col_id] * int( + image_use_col_tokens + ) + joint = [img_start_id] + joint_row * nrows + [img_end_id] + img_token_ids = extra_joint + joint + + return PromptUpdateDetails.select_token_ids( + img_token_ids, + processor.image_token_ids, + ) + + def get_video_replacement_molmo2(item_idx: int) -> list[int]: + video, metadata = mm_items["video"][item_idx] + do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") + + timestamps = self.info._get_video_second_idx(metadata, do_sample_frames) + nrows, ncols = processor.get_base_grid_size(is_video=True) + + if use_frame_special_tokens: + start_id = processor.frame_start_id + end_id = processor.frame_end_id + else: + start_id = img_start_id + end_id = img_end_id + + img_token_ids = [] + + for frame_idx, frame_time in enumerate(timestamps): + prev_space = " " if frame_idx > 0 else "" + frame_prefix = ( + prev_space + f"{frame_time:.1f} " + ) # explicit whitespace before/after image tokens + + img_token_ids += processor.processor.tokenizer.encode( + frame_prefix, + add_special_tokens=False, + ) + + joint_row = [img_patch_id] * ncols + [img_col_id] * int( + video_use_col_tokens + ) + joint = [start_id] + nrows * joint_row + [end_id] + img_token_ids += joint + + return PromptUpdateDetails.select_token_ids( + img_token_ids, + processor.image_token_ids, + ) + + return [ + PromptReplacement( + modality=modality, + target=[target], + replacement=replacement_fn, + ) + for modality, target, replacement_fn in zip( + ["image", "video"], + [processor.image_placeholder_id, processor.video_placeholder_id], + [get_image_replacement_molmo2, get_video_replacement_molmo2], + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + Molmo2MultiModalProcessor, + info=Molmo2ProcessingInfo, + dummy_inputs=Molmo2DummyInputsBuilder, +) +class Molmo2ForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant +): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + # vision backbone mapping + "image_pooling_2d.wq.": "image_pooling_2d.q_proj.", + "image_pooling_2d.wk.": "image_pooling_2d.k_proj.", + "image_pooling_2d.wv.": "image_pooling_2d.v_proj.", + "image_pooling_2d.wo.": "image_pooling_2d.o_proj.", + "image_projector.w1.": "image_projector.gate_proj.", + "image_projector.w3.": "image_projector.up_proj.", + "image_projector.w2.": "image_projector.down_proj.", + # language backbone mapping + "att_proj": "qkv_proj", + "attn_out": "o_proj", + "q_norm": "q_norm", + "k_norm": "k_norm", + "ff_proj": "up_gate_proj", + "ff_out": "down_proj", + "attn_norm": "input_layernorm", + "ff_norm": "post_attention_layernorm", + }, + orig_to_new_prefix={ + # vision backbone mapping + "model.vision_backbone.": "vision_backbone.", + # language backbone mapping + "model.transformer.blocks.": "model.layers.", + "model.transformer.ln_f.": "model.norm.", + }, + ) + + packed_modules_mapping = { + "qkv_proj": ["qkv_proj"], + "up_gate_proj": ["up_gate_proj"], # language model + "merged_qkv": ["wq", "wk", "wv"], # vision backbone + "merged_kv": ["k_proj", "v_proj"], # image_pooling_2d + "merged_linear": ["gate_proj", "up_proj"], # image_projector + } + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return IMAGE_PROMPT + if modality.startswith("video"): + return VIDEO_PROMPT + + raise ValueError("Only image or video modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + + kwargs = {} + for field in fields(VitConfig): + kwargs[field.name] = getattr(config.vit_config, field.name) + vit_config = VitConfig(**kwargs) + + kwargs = {} + for field in fields(AdapterConfig): + kwargs[field.name] = getattr(config.adapter_config, field.name) + adapter_config = AdapterConfig(**kwargs) + + self.vision_backbone = Molmo2VisionBackbone( + vit_config, + adapter_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_backbone"), + ) + self.model = Molmo2TextModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + self.img_patch_id = config.image_patch_id + + if hasattr(config, "text_config"): + hf_text_config = config.text_config + else: + hf_text_config = config.llm_config + + self.lm_head = ParallelLMHead( + hf_text_config.vocab_size, + hf_text_config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(hf_text_config.vocab_size) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + @property + def dtype(self): + return next(self.parameters()).dtype + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Molmo2ImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is None: + return None + + token_pooling = kwargs.pop("image_token_pooling", None) + num_pooled_patches = kwargs.pop("image_num_pooled_patches", None) + num_patches = kwargs.pop("image_num_patches", None) + image_tokens = kwargs.pop("image_tokens", None) + num_image_tokens = kwargs.pop("num_image_tokens", None) + + accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist() + patch_offset = 0 + new_token_pooling = token_pooling.clone() + for i, n in enumerate(num_pooled_patches): + cur_slice = token_pooling[patch_offset : patch_offset + n] + index_offset = int(accum_patches[i]) + new_token_pooling[patch_offset : patch_offset + n] = torch.where( + cur_slice >= 0, + cur_slice + index_offset, + cur_slice, + ) + patch_offset += n + + return Molmo2ImageInputs( + pixel_values=pixel_values, + token_pooling=new_token_pooling, + num_pooled_patches=num_pooled_patches, + image_tokens=image_tokens, + num_image_tokens=num_image_tokens, + ) + + def _parse_and_validate_video_input( + self, + **kwargs: object, + ) -> Molmo2VideoInputs | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + if pixel_values_videos is None: + return None + + token_pooling = kwargs.pop("video_token_pooling", None) + num_pooled_patches = kwargs.pop("video_num_pooled_patches", None) + num_patches = kwargs.pop("video_num_patches", None) + video_tokens = kwargs.pop("video_tokens", None) + num_video_tokens = kwargs.pop("num_video_tokens", None) + + accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist() + patch_offset = 0 + new_token_pooling = token_pooling.clone() + for i, n in enumerate(num_pooled_patches): + cur_slice = token_pooling[patch_offset : patch_offset + n] + index_offset = int(accum_patches[i]) + new_token_pooling[patch_offset : patch_offset + n] = torch.where( + cur_slice >= 0, + cur_slice + index_offset, + cur_slice, + ) + patch_offset += n + + return Molmo2VideoInputs( + pixel_values_videos=pixel_values_videos, + token_pooling=new_token_pooling, + num_pooled_patches=num_pooled_patches, + video_tokens=video_tokens, + num_video_tokens=num_video_tokens, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + for input_key in kwargs: + if input_key in ("pixel_values",) and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input(**kwargs) + if input_key in ("pixel_values_videos",) and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input(**kwargs) + return modalities + + def _process_image_input( + self, + image_input: Molmo2ImageInputs, + ) -> tuple[torch.Tensor, ...]: + pixel_values = image_input["pixel_values"] + token_pooling = image_input["token_pooling"] + num_pooled_patches = image_input["num_pooled_patches"] + image_tokens = image_input["image_tokens"] + num_image_tokens = image_input["num_image_tokens"] + + image_features_flat = self.vision_backbone( + images=pixel_values.unsqueeze(0), + token_pooling=token_pooling.unsqueeze(0), + ) + + assert len(image_features_flat) == num_pooled_patches.sum() + image_features_list = image_features_flat.split( + num_pooled_patches.tolist(), dim=0 + ) + image_tokens_list = image_tokens.split(num_image_tokens.tolist(), dim=0) + out = [] + for image_features_i, image_tokens_i in zip( + image_features_list, image_tokens_list + ): + out_features = self.get_language_model().embed_input_ids(image_tokens_i) + is_image_patch = image_tokens_i == self.img_patch_id + out_features[is_image_patch] = image_features_i + out.append(out_features) + return tuple(out) + + def _process_video_input( + self, + video_input: Molmo2VideoInputs, + ) -> tuple[torch.Tensor, ...]: + pixel_values_videos = video_input["pixel_values_videos"] + token_pooling = video_input["token_pooling"] + num_pooled_patches = video_input["num_pooled_patches"] + video_tokens = video_input["video_tokens"] + num_video_tokens = video_input["num_video_tokens"] + + image_features_flat = self.vision_backbone( + images=pixel_values_videos.unsqueeze(0), + token_pooling=token_pooling.unsqueeze(0), + ) + + assert len(image_features_flat) == num_pooled_patches.sum() + image_features_list = image_features_flat.split( + num_pooled_patches.tolist(), dim=0 + ) + video_tokens_list = video_tokens.split(num_video_tokens.tolist(), dim=0) + out = [] + for image_features_i, video_tokens_i in zip( + image_features_list, video_tokens_list + ): + out_features = self.get_language_model().embed_input_ids(video_tokens_i) + is_image_patch = video_tokens_i == self.img_patch_id + out_features[is_image_patch] = image_features_i + out.append(out_features) + return tuple(out) + + def get_language_model(self) -> torch.nn.Module: + return self.model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + image_embeddings = self._process_image_input(image_input) + multimodal_embeddings += image_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self._embed_text_input_ids( + input_ids, + self.get_language_model().embed_input_ids, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + if is_multimodal is None: + raise ValueError( + "`embed_input_ids` now requires `is_multimodal` arg, " + "please update your model runner according to " + "https://github.com/vllm-project/vllm/pull/16229." + ) + + inputs_embeds = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + weights = _get_weights_with_merged_embedding(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="model", + connector="vision_backbone.image_projector", + tower_model="vision_backbone", + ) + + +def _get_weights_with_merged_embedding( + weights: Iterable[tuple[str, torch.Tensor]], +) -> Iterable[tuple[str, torch.Tensor]]: + embedding_weights = {} + for name, weight in weights: + if "wte.embedding" in name: + embedding_weights["embedding"] = weight + elif "wte.new_embedding" in name: + embedding_weights["new_embedding"] = weight + else: + yield (name, weight) + # this is compatible with most of quantization, + # because they won't quantize embed_tokens + if "embedding" not in embedding_weights or "new_embedding" not in embedding_weights: + raise ValueError( + "Checkpoint is missing 'wte.embedding' or " + "'wte.new_embedding' weights required for Molmo2." + ) + + embedding_weights = torch.cat( + [embedding_weights["embedding"], embedding_weights["new_embedding"]], + dim=0, + ) + yield ("model.embed_tokens.weight", embedding_weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 362028ebf..51e7b9133 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -384,6 +384,7 @@ _MULTIMODAL_MODELS = { "Mistral3ForConditionalGeneration", ), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), + "Molmo2ForConditionalGeneration": ("molmo2", "Molmo2ForConditionalGeneration"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index be6c2468f..3ef445f07 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -386,6 +386,21 @@ class PromptUpdateDetails(Generic[_S]): return PromptUpdateDetails(full=seq, is_embed=is_embed) + @staticmethod + def select_token_ids( + seq: _S, + embed_token_ids: list[int], + ) -> "PromptUpdateDetails[_S]": + def is_embed(tokenizer: TokenizerLike | None, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.isin( + torch.tensor(token_ids), + torch.tensor(embed_token_ids), + ) + + return PromptUpdateDetails(full=seq, is_embed=is_embed) + PromptUpdateInfo: TypeAlias = PromptSeq | PromptUpdateDetails """ diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 8204cdfbc..0fee0f539 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,7 +6,7 @@ from abc import abstractmethod from functools import partial from io import BytesIO from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import numpy as np import numpy.typing as npt @@ -439,6 +439,324 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend): return frames, metadata +@VIDEO_LOADER_REGISTRY.register("molmo2") +class Molmo2VideoBackend(VideoLoader): + def get_cv2_video_api(self): + import cv2.videoio_registry as vr + + api_pref = None + for backend in vr.getStreamBufferedBackends(): + if not vr.hasBackend(backend): + continue + if not vr.isBackendBuiltIn(backend): + _, abi, api = vr.getStreamBufferedBackendPluginVersion(backend) + if abi < 1 or (abi == 1 and api < 2): + continue + api_pref = backend + break + return api_pref + + @classmethod + def get_candidate_target_fps( + cls, + video_fps: float, + sampling_fps: float, + max_fps: float = 8.0, + ) -> list[float]: + """ + Return the subset of `video_fps` factors that remain multiples + of `sampling_fps`. + + Examples: + >>> get_candidate_target_fps(video_fps=6, sampling_fps=2) + [2, 6] + >>> get_candidate_target_fps(video_fps=5, sampling_fps=1) + [1, 5] + >>> get_candidate_target_fps(video_fps=2, sampling_fps=2) + [2] + >>> get_candidate_target_fps(video_fps=5, sampling_fps=2) + Traceback (most recent call last): + ... + ValueError: sampling_fps=2 must divide video_fps=5 to produce + consistent frame steps. + """ + video_fps = int(video_fps) + sampling_fps = int(sampling_fps) + max_fps = int(max_fps) + + if sampling_fps is None: + raise ValueError("sampling_fps must be provided") + if video_fps <= 0 or sampling_fps <= 0: + raise ValueError( + "video_fps and sampling_fps must be positive " + f"(got {video_fps}, {sampling_fps})" + ) + if video_fps % sampling_fps != 0: + raise ValueError( + f"sampling_fps={sampling_fps} must divide video_fps={video_fps}." + ) + + candidates = [] + for candidate in range(sampling_fps, video_fps + 1, sampling_fps): + if candidate > max_fps: + break + if video_fps % candidate == 0: + candidates.append(float(candidate)) + + return candidates + + @classmethod + def get_target_fps( + cls, + video_fps: float, + max_frames: int, + total_frames: int, + frame_sample_mode: str, + candidate_target_fps: list[float], + ) -> float | None: + """ + Get the target fps that best spans the videoand has the most frames sampled + """ + num_frames_sampled = 0 + selected_target_fps = None + for target_fps in candidate_target_fps: + step_size = max(int(video_fps / target_fps), 1) + num_frames_sampled_at_fps = int(total_frames / step_size) + if num_frames_sampled == 0: + if ( + "uniform" in frame_sample_mode + and num_frames_sampled_at_fps > max_frames + ): + break + selected_target_fps = target_fps + num_frames_sampled = num_frames_sampled_at_fps + + else: + # the candidate sampling fps increases so frame count can't decrease + assert num_frames_sampled <= num_frames_sampled_at_fps + if num_frames_sampled_at_fps > max_frames: + # choose the sampling fps that spans the video + continue + + elif num_frames_sampled_at_fps > num_frames_sampled: + # both are less than max_frames; choose the one with higher + # density of frames sampled + selected_target_fps = target_fps + num_frames_sampled = num_frames_sampled_at_fps + return selected_target_fps + + @classmethod + def get_frame_times_and_chosen_fps( + cls, + selected_target_fps: float | None, + total_frames: int, + max_frames: int, + video_fps: float, + ) -> tuple[float | None, npt.NDArray]: + if selected_target_fps is None: + frame_indices = np.linspace( + 0, total_frames, max_frames, endpoint=False, dtype=int + ) + else: + step_size = max(int(video_fps / selected_target_fps), 1) + frame_indices = np.arange(0, total_frames, step_size) + if len(frame_indices) > max_frames: + frame_indices = frame_indices[:max_frames] + return selected_target_fps, frame_indices + + @classmethod + def sample_times( + cls, + duration: float, + max_frames: int, + frame_sample_mode: str, + max_fps: int | None, + candidate_target_fps: list[float] | None = None, + **kwargs, + ) -> npt.NDArray: + if frame_sample_mode == "fps": + assert candidate_target_fps is not None + # Try larger and larger FPSs until we hit one that can't span the video + sampling_fps = candidate_target_fps[0] + for candidate_fps in candidate_target_fps[1:]: + if max_frames / candidate_fps < duration: + break + sampling_fps = candidate_fps + times = np.arange(0, max_frames) / sampling_fps + times = times[times < duration] + return times + elif frame_sample_mode == "uniform_last_frame": + if max_fps is not None: + max_duration = ( + max_frames - 1 + ) / max_fps # -1 to include the last frame + if max_duration < duration: + times = np.linspace( + 0, duration, num=max_frames, endpoint=True, dtype=np.float64 + ) + else: + times = np.arange(0.0, stop=duration, step=1 / max_fps) + times = np.concatenate([times, [duration]], axis=0) + assert len(times) <= max_frames + else: + times = np.linspace( + 0, duration, num=max_frames, endpoint=True, dtype=np.float64 + ) + return times + else: + raise NotImplementedError(frame_sample_mode) + + @classmethod + def _sample_frames( + cls, + total_num_frames: int, + video_fps: float, + duration: float, + frame_sample_mode: str, + num_frames: int, + max_fps: int, + sampling_fps: int, + ) -> npt.NDArray: + if frame_sample_mode == "uniform_last_frame" and max_fps is not None: + if total_num_frames <= 2: + indices = np.arange(total_num_frames).astype(int) + elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame + # uniform fallback + indices = np.linspace( + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, + ).astype(int) + else: + float_indices = np.arange( + 0.0, + stop=total_num_frames - 1, + step=float(video_fps / max_fps), + ) + if np.round(float_indices[-1]) != total_num_frames - 1: + float_indices = np.concatenate( + [float_indices, [total_num_frames - 1]], axis=0 + ) + indices = np.round(float_indices).astype(int) + assert indices[-1] < total_num_frames + assert len(float_indices) <= num_frames + elif frame_sample_mode == "uniform_last_frame": + indices = np.linspace( + 0, + total_num_frames - 1, + num=min(num_frames, total_num_frames), + endpoint=True, + ).astype(int) + elif frame_sample_mode == "fps": + candidate_target_fps = cls.get_candidate_target_fps(video_fps, sampling_fps) + selected_target_fps = cls.get_target_fps( + video_fps, + num_frames, + total_num_frames, + frame_sample_mode, + candidate_target_fps, + ) + _, indices = cls.get_frame_times_and_chosen_fps( + selected_target_fps, + total_num_frames, + num_frames, + video_fps, + ) + else: + raise NotImplementedError(frame_sample_mode) + + return indices + + @classmethod + def load_bytes_opencv( + cls, + data: bytes, + frame_sample_mode: str | None = None, + num_frames: int = -1, + max_fps: int = 2, + sampling_fps: int = 2, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: + import cv2 + + backend = cls().get_cv2_video_api() + cap = cv2.VideoCapture(BytesIO(data), backend, []) + if not cap.isOpened(): + raise ValueError("Could not open video stream") + + total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + original_fps = cap.get(cv2.CAP_PROP_FPS) + duration = total_frames_num / original_fps if original_fps > 0 else 0 + + if frame_sample_mode is None: + # Use transformers transformers.video_utils.VideoMetadata format + frame_idx = list(range(0, total_frames_num)) + frame_idx_set = set(frame_idx) + frames, valid_num_frames, valid_frame_indices = cls._read_frames( + cap, frame_idx_set, total_frames_num, max(frame_idx) + ) + do_sample_frames = valid_num_frames == total_frames_num + metadata = { + "total_num_frames": total_frames_num, + "fps": original_fps, + "duration": duration, + "video_backend": "opencv", + "do_sample_frames": do_sample_frames, + } + if not do_sample_frames: + metadata["frames_indices"] = valid_frame_indices + return frames, metadata + + frame_idx = cls._sample_frames( + total_frames_num, + original_fps, + duration, + frame_sample_mode, + num_frames, + max_fps, + sampling_fps, + ).tolist() + + frames, valid_num_frames, valid_frame_indices = cls._read_frames( + cap, + set(frame_idx), + len(frame_idx), + total_frames_num - 1, + ) + + metadata = { + "total_num_frames": total_frames_num, + "fps": original_fps, + "duration": duration, + "video_backend": "opencv", + "frames_indices": valid_frame_indices, + "do_sample_frames": False, + } + + return frames, metadata + + @classmethod + def load_bytes( + cls, + data: bytes, + num_frames: int = -1, + **kwargs, + ) -> tuple[npt.NDArray, dict[str, Any]]: + frame_sample_mode = cast(str | None, kwargs.pop("frame_sample_mode", None)) + max_fps = cast(int, kwargs.pop("max_fps", 2)) + sampling_fps = cast(int, kwargs.pop("sampling_fps", 2)) + out = cls.load_bytes_opencv( + data, + frame_sample_mode, + num_frames, + max_fps, + sampling_fps, + **kwargs, + ) + return out + + class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): def __init__( self,