# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import partial from itertools import islice from typing import Annotated, Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import ImageOps from PIL.Image import Image from transformers import ( BaseImageProcessor, BaseVideoProcessor, BatchFeature, PretrainedConfig, ProcessorMixin, ) from transformers.image_utils import ImageInput from transformers.video_utils import VideoMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) from vllm.inputs import MultiModalDataDict from vllm.logger import init_logger from vllm.model_executor.layers.activation import MulAndSilu, SiluAndMul, get_act_fn from vllm.model_executor.layers.attention import Attention, MMEncoderAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, ) from vllm.multimodal.parse import ( ImageProcessorItems, ImageSize, MultiModalDataItems, MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails, ) from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import round_down from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant, ) from .utils import ( AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) logger = init_logger(__name__) # Special tokens. These should be present in any tokenizer we use # because the preprocessor relies on them. IMAGE_PROMPT = "<|image|>" VIDEO_PROMPT = "<|video|>" _MAX_VIDEO_FPS = 8 class Molmo2ImageInputs(TensorSchema): """ Dimensions: - nc: The total number of crops (dynamic) - np: The total number of patches per crop - cps: Number of channels * patch_size * patch_size - npp: Number of pooled patches (dynamic) - pp: pooling_size * pooling_size - ni: Number of images - nt: Number of image tokens (dynamic) """ pixel_values: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")] token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")] """ An index tensor that maps image features to their corresponding patch tokens before pooling. """ num_pooled_patches: Annotated[torch.Tensor, TensorShape("ni")] image_tokens: Annotated[torch.BoolTensor, TensorShape("nt")] num_image_tokens: Annotated[torch.Tensor, TensorShape("ni")] class Molmo2VideoInputs(TensorSchema): """ Dimensions: - nc: The total number of frames (dynamic) - np: The total number of patches per frame - cps: Number of channels * patch_size * patch_size - npp: Number of pooled patches (dynamic) - pp: pooling_size * pooling_size - nv: Number of videos - nt: Number of video tokens (dynamic) """ pixel_values_videos: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")] token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")] """ An index tensor that maps image features to their corresponding patch tokens before pooling. """ num_pooled_patches: Annotated[torch.Tensor, TensorShape("nv")] video_tokens: Annotated[torch.BoolTensor, TensorShape("nt")] num_video_tokens: Annotated[torch.Tensor, TensorShape("nv")] @dataclass class VitConfig: """Config for a vision transformer""" hidden_size: int = 1152 intermediate_size: int = 4304 num_hidden_layers: int = 27 num_attention_heads: int = 16 num_key_value_heads: int = 16 head_dim: int = 72 hidden_act: str = "gelu_pytorch_tanh" layer_norm_eps: float = 1e-6 image_default_input_size: tuple[int, int] = (378, 378) image_patch_size: int = 14 image_num_pos: int = 577 def __post_init__(self): self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment] @property def image_num_patch(self): h, w = self.image_default_input_size return h // self.image_patch_size, w // self.image_patch_size @dataclass class AdapterConfig: """Config for a vit-llm adapter""" vit_layers: tuple[int, int] = (-3, -9) pooling_attention_mask: bool = False hidden_size: int = 1152 num_attention_heads: int = 16 num_key_value_heads: int = 16 head_dim: int = 72 hidden_act: str = "silu" intermediate_size: int = 18944 text_hidden_size: int = 3584 @dataclass class TextConfig: """Configuration for a text model transformer""" hidden_size: int = 3584 """ The hidden size of the model. """ num_attention_heads: int = 28 """ The number of self-attention heads. """ num_key_value_heads: int = 4 """ The number of heads to use for keys and values. """ head_dim: int = 128 """ The head dimensionality for the attention mechanism. """ vocab_size: int = 152064 """Vocabulary size of the model.""" additional_vocab_size: int = 128 """Number of additional tokens to have the input embeddings for""" qkv_bias: bool = True """ Do QKV projection a bias """ num_hidden_layers: int = 48 """ The number of layers/blocks. """ intermediate_size: int = 18944 """ The hidden size for the MLP. """ hidden_act: str = "silu" """ The activation function to use within the MLP layers. """ max_position_embeddings: int = 4096 """ Max positional embeddings to use in RoPE cache """ rope_theta: float = 1000000.0 """ RoPE theta parameter. """ use_qk_norm: bool = False """ Apply layer norm to the keys and queries within the attention mechanism. This can help stabilize training. """ qk_norm_type: str = "olmo" """ The type of layer norm to use for the keys and queries. Can be "olmo" or "qwen3". """ layer_norm_eps: float = 1e-6 """ epsilon for layer norms """ norm_after: bool = False """Apply layer norm before and after the attention and MLP blocks.""" rope_scaling_layers: tuple[int, ...] | None = None """ RoPE scaling layers. """ class ViTMLP(nn.Module): """MLP used in Vision Transformer.""" def __init__( self, dim: int, hidden_dim: int, hidden_act: str, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.w1 = ColumnParallelLinear( dim, hidden_dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.w1", ) # Activation function. self.act = get_act_fn(hidden_act) self.w2 = RowParallelLinear( hidden_dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.w2", ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.w1(x) x = self.act(x) x, _ = self.w2(x) return x class ViTMultiHeadDotProductAttention(nn.Module): """Multi-head attention used in Vision Transformer.""" def __init__( self, hidden_size: int, num_heads: int, num_key_value_heads: int, head_dim: int, use_bias: bool = True, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size self.total_num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.head_dim = head_dim assert self.head_dim == self.hidden_size // self.total_num_heads self.total_num_kv_heads = num_key_value_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.merged_qkv = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=use_bias, quant_config=quant_config, prefix=f"{prefix}.merged_qkv", ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=use_bias, quant_config=quant_config, prefix=f"{prefix}.wo", ) self.scale = self.head_dim**-0.5 self.attn = MMEncoderAttention( self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads, prefix=f"{prefix}.attn", ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: qkv, _ = self.merged_qkv(inputs) xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) output = self.attn(xq, xk, xv) output, _ = self.wo(output) return output class Molmo2VisionBlock(nn.Module): """Residual attention block used in Vision Transformer.""" def __init__( self, config: VitConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.attention = ViTMultiHeadDotProductAttention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, head_dim=config.head_dim, quant_config=quant_config, prefix=f"{prefix}.attention", ) self.feed_forward = ViTMLP( dim=config.hidden_size, hidden_dim=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.feed_forward", ) self.attention_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, ) self.ffn_norm = nn.LayerNorm( config.hidden_size, eps=config.layer_norm_eps, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.attention(self.attention_norm(x)) x = x + self.feed_forward(self.ffn_norm(x)) return x class Molmo2VisionBlockCollection(nn.Module): """Collection of residual attention blocks used in Vision Transformer.""" def __init__( self, config: VitConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.resblocks = nn.ModuleList( [ Molmo2VisionBlock( config, quant_config, prefix=f"{prefix}.resblocks.{layer_idx}", ) for layer_idx in range(config.num_hidden_layers) ] ) def forward(self, x: torch.Tensor) -> list[torch.Tensor]: hidden_states = [] for r in self.resblocks: x = r(x) hidden_states.append(x) return hidden_states class Molmo2VisionTransformer(nn.Module): """Vision Transformer used in Vision Backbone.""" def __init__( self, config: VitConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() scale = config.hidden_size**-0.5 self.num_prefix_tokens: int = 0 # no class embeddings self.patch_num = config.image_num_patch self.positional_embedding = nn.Parameter( torch.randn(config.image_num_pos, config.hidden_size) * scale, ) image_patch_size = config.image_patch_size self.patch_embedding = nn.Linear( image_patch_size * image_patch_size * 3, config.hidden_size, bias=True, ) self.transformer = Molmo2VisionBlockCollection( config, quant_config, prefix=f"{prefix}.transformer", ) def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: pos_emb = self.positional_embedding pos_emb = pos_emb.reshape( ( int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1], ) ) (patch_num_0, patch_num_1) = patch_num if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) pos_emb = F.interpolate( pos_emb, size=(patch_num_0, patch_num_1), mode="bicubic", align_corners=False, antialias=True, ) pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) x = x + pos_emb[None, :, :].to(x.dtype) return x def forward( self, x: torch.Tensor, patch_num: int | None = None, ) -> list[torch.Tensor]: """ : param x: (batch_size, num_patch, n_pixels) """ if patch_num is None: patch_num = self.patch_num x = self.patch_embedding(x) x = self.add_pos_emb(x, patch_num) hidden_states = self.transformer(x) return hidden_states class ImagePoolingAttention(nn.Module): """Multi-head attention used for image pooling""" def __init__( self, input_dim: int, hidden_size: int, num_heads: int, num_key_value_heads: int, head_dim: int, use_bias: bool = True, use_pytorch_sdpa: bool = False, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.input_dim = input_dim self.hidden_size = hidden_size self.total_num_heads = num_heads tp_size = get_tensor_model_parallel_world_size() assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.head_dim = head_dim assert self.head_dim == self.hidden_size // self.total_num_heads self.total_num_kv_heads = num_key_value_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.kv_size = self.num_kv_heads * self.head_dim self.q_proj = ColumnParallelLinear( self.input_dim, self.total_num_heads * self.head_dim, bias=use_bias, quant_config=quant_config, prefix=f"{prefix}.q_proj", ) self.merged_kv = MergedColumnParallelLinear( self.input_dim, [self.total_num_kv_heads * self.head_dim] * 2, bias=use_bias, quant_config=quant_config, prefix=f"{prefix}.merged_kv", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=use_bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.scale = self.head_dim**-0.5 self.use_pytorch_sdpa = use_pytorch_sdpa if use_pytorch_sdpa: self.attn = None else: self.attn = MMEncoderAttention( self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads, prefix=f"{prefix}.attn", ) def forward_sdpa( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: bsz, q_len, _ = query.size() kv_len = key.size(1) query = query.view(bsz, q_len, self.num_heads, self.head_dim) key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim) query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention( query, key, value, attn_mask=attn_mask, is_causal=False, enable_gqa=self.num_heads > self.num_kv_heads, ).transpose(1, 2) return out.reshape(bsz, q_len, -1) def forward( self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: xq, _ = self.q_proj(inputs_q) kv, _ = self.merged_kv(inputs_kv) xk, xv = kv.split([self.kv_size, self.kv_size], dim=-1) if self.use_pytorch_sdpa: output = self.forward_sdpa(xq, xk, xv, attn_mask) else: output = self.attn(xq, xk, xv) output, _ = self.o_proj(output) return output class ImageProjectorMLP(nn.Module): """MLP used for the image projector""" def __init__( self, input_dim: int, hidden_dim: int, output_dim: int, hidden_act: str, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.merged_linear = MergedColumnParallelLinear( input_dim, [hidden_dim] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.merged_linear", ) # Activation function. assert hidden_act == "silu" self.act_fn = SiluAndMul() # Feed-forward output projection. self.down_proj = RowParallelLinear( hidden_dim, output_dim, bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj", ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.merged_linear(x) x = self.act_fn(x) x, _ = self.down_proj(x) return x class Molmo2VisionBackbone(nn.Module, SupportsQuant): packed_modules_mapping = { "merged_qkv": ["wq", "wk", "wv"], # vision backbone "merged_kv": ["k_proj", "v_proj"], # image_pooling_2d "merged_linear": ["gate_proj", "up_proj"], } def __init__( self, vit_config: VitConfig, adapter_config: AdapterConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.vit_config = vit_config self.adapter_config = adapter_config self.vit_layers = [] for layer in adapter_config.vit_layers: if layer >= 0: self.vit_layers.append(layer) else: self.vit_layers.append(layer + vit_config.num_hidden_layers) last_layer_needed = max(self.vit_layers) + 1 if last_layer_needed < vit_config.num_hidden_layers: vit_config.num_hidden_layers = last_layer_needed self.image_vit = Molmo2VisionTransformer( vit_config, quant_config, prefix=f"{prefix}.image_vit", ) self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers) self.image_pooling_2d = ImagePoolingAttention( input_dim=pool_dim, hidden_size=adapter_config.hidden_size, num_heads=adapter_config.num_attention_heads, num_key_value_heads=adapter_config.num_key_value_heads, head_dim=adapter_config.head_dim, use_pytorch_sdpa=adapter_config.pooling_attention_mask, quant_config=quant_config, prefix=f"{prefix}.image_pooling_2d", ) self.image_projector = ImageProjectorMLP( input_dim=adapter_config.hidden_size, hidden_dim=adapter_config.intermediate_size, output_dim=adapter_config.text_hidden_size, hidden_act=adapter_config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.image_projector", ) @property def dtype(self) -> torch.dtype: return self.image_vit.patch_embedding.weight.dtype @property def device(self) -> torch.device: return self.image_vit.patch_embedding.weight.device def encode_image(self, images: torch.Tensor) -> torch.Tensor: """ : param images: (batch_size, num_crops, num_patch, n_pixels) """ B, T, N, D = images.shape images = images.view(B * T, N, D) image_features = self.image_vit(images) features = [] for layer in self.vit_layers: features.append(image_features[layer]) image_features = torch.cat(features, dim=-1) if self.num_prefix_tokens > 0: image_features = image_features[:, 1:] image_features = image_features.view(B, T, N, -1) return image_features def forward( self, images: torch.Tensor, token_pooling: torch.Tensor, ) -> torch.Tensor: # image_features shape: # (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) batch_size, num_image = images.shape[:2] images = images.to(device=self.device, dtype=self.dtype) image_features = self.encode_image(images) dim = image_features.shape[-1] valid = token_pooling >= 0 valid_token = torch.any(valid, -1) # Use `token_pooling` to arange the features for image pooling batch_idx = torch.arange( token_pooling.shape[0], dtype=torch.long, device=token_pooling.device, ) batch_idx = torch.tile( batch_idx.view(batch_size, 1, 1), [1, token_pooling.shape[1], token_pooling.shape[2]], ) # Now [batch, num_features, num_pooled_patches, dim] to_pool = image_features.reshape(batch_size, -1, dim)[ batch_idx, torch.clip(token_pooling, 0) ] to_pool = to_pool * valid.to(self.dtype)[:, :, :, None] to_pool = to_pool.reshape([-1, token_pooling.shape[-1], dim]) if self.adapter_config.pooling_attention_mask: attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]]) denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1) denom = torch.where(denom == 0, 1, denom) query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to( to_pool.dtype ) else: attn_mask = None query = to_pool.mean(-2, keepdim=True) pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask) pooled_features = pooled_features.reshape( [batch_size, -1, pooled_features.shape[-1]] ) # MLP layer to map the feature. pooled_features = self.image_projector(pooled_features) return pooled_features.view(-1, pooled_features.shape[-1])[ valid_token.flatten() ] def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("merged_qkv", "wq", "q"), ("merged_qkv", "wk", "k"), ("merged_qkv", "wv", "v"), ("merged_kv", "k_proj", 0), ("merged_kv", "v_proj", 1), ("merged_linear", "gate_proj", 0), ("merged_linear", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class Molmo2Attention(nn.Module): """Molmo2's LLM Attention.""" def __init__( self, config: TextConfig, rope_parameters: dict[str, Any], cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 assert self.total_num_heads % self.tp_size == 0 self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = config.num_key_value_heads if self.total_num_kv_heads >= self.tp_size: assert self.total_num_kv_heads % self.tp_size == 0 else: assert self.tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = config.head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta # Attention input projection. Projects x -> (q, k, v) self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=config.qkv_bias, quant_config=quant_config, ) self.tp_rank: int | None = None self.k_norm: nn.Module | None = None self.q_norm: nn.Module | None = None self.qk_norm_type: str | None = None if config.use_qk_norm: k_norm_size = ( self.head_dim if config.qk_norm_type == "qwen3" else self.total_num_kv_heads * self.head_dim ) self.tp_rank = get_tensor_model_parallel_rank() self.k_norm = RMSNorm(k_norm_size, eps=config.layer_norm_eps) q_norm_size = ( self.head_dim if config.qk_norm_type == "qwen3" else self.total_num_heads * self.head_dim ) self.q_norm = RMSNorm(q_norm_size, eps=config.layer_norm_eps) self.qk_norm_type = config.qk_norm_type # Rotary embeddings. Rope scaling is only applied on full attention layers. layer_idx = extract_layer_index(prefix) if ( config.rope_scaling_layers is not None and layer_idx not in config.rope_scaling_layers ): rope_theta = rope_parameters["rope_theta"] rope_parameters = {"rope_type": "default", "rope_theta": rope_theta} self.rotary_emb = get_rope( self.head_dim, max_position=self.max_position_embeddings, rope_parameters=rope_parameters, ) self.scaling = self.head_dim**-0.5 self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) # Attention output projection. self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, quant_config=quant_config, ) def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: if self.tp_size > 1: q = tensor_model_parallel_all_gather(q.contiguous()) k = tensor_model_parallel_all_gather(k.contiguous()) q = self.q_norm(q) k = self.k_norm(k) if self.tp_size > 1: splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) q = splitter(q)[self.tp_rank] k = splitter(k)[self.tp_rank] return q, k def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, **kwargs: object, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if ( self.q_norm is not None and self.k_norm is not None and self.qk_norm_type == "olmo" ): q, k = self._apply_qk_norm(q, k) elif self.q_norm is not None and self.k_norm is not None: q_by_head = q.view( *q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim, ) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view( *k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim, ) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class LanguageModelMLP(nn.Module): """Molmo2's LLM mlp.""" def __init__( self, input_dim: int, intermediate_size: int, hidden_act: str, quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() self.up_gate_proj = MergedColumnParallelLinear( input_dim, [intermediate_size] * 2, bias=False, quant_config=quant_config, ) # Activation function. assert hidden_act == "silu" self.act_fn = MulAndSilu() # Feed-forward output projection. self.down_proj = RowParallelLinear( intermediate_size, input_dim, bias=False, quant_config=quant_config, ) def forward( self, x: torch.Tensor, ) -> torch.Tensor: up_gate, _ = self.up_gate_proj(x) x = self.act_fn(up_gate) x, _ = self.down_proj(x) return x class Molmo2DecoderLayer(nn.Module): def __init__( self, config: TextConfig, rope_parameters: dict[str, Any], cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> None: super().__init__() # Attention block. self.self_attn = Molmo2Attention( config, rope_parameters, cache_config, quant_config, prefix=f"{prefix}.self_attn", ) # MLP block. self.mlp = LanguageModelMLP( config.hidden_size, config.intermediate_size, config.hidden_act, quant_config, ) # LayerNorm self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.layer_norm_eps, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, **kwargs: object, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, **kwargs, ) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual class Molmo2DecoderNormAfterLayer(Molmo2DecoderLayer): def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, **kwargs: object, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: # Self Attention residual = hidden_states hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, **kwargs, ) hidden_states = self.input_layernorm(hidden_states) hidden_states = hidden_states + residual residual = hidden_states hidden_states = self.mlp(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = hidden_states + residual residual = None return hidden_states, residual @support_torch_compile class Molmo2TextModel(nn.Module, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.config = config if hasattr(config, "text_config"): hf_text_config = config.text_config else: hf_text_config = config.llm_config kwargs = {} for field in fields(TextConfig): kwargs[field.name] = getattr(hf_text_config, field.name) text_config = TextConfig(**kwargs) self.embedding_size = text_config.vocab_size self.embedding_size += text_config.additional_vocab_size or 0 self.embed_tokens = VocabParallelEmbedding( self.embedding_size, text_config.hidden_size, quant_config=quant_config, ) decoder_layer = ( Molmo2DecoderNormAfterLayer if text_config.norm_after else Molmo2DecoderLayer ) self.start_layer, self.end_layer, self.layers = make_layers( text_config.num_hidden_layers, lambda prefix: decoder_layer( text_config, hf_text_config.rope_parameters, cache_config=cache_config, quant_config=quant_config, prefix=prefix, ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(text_config.hidden_size, eps=text_config.layer_norm_eps) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], text_config.hidden_size, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_tokens(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] # Apply blocks one-by-one. for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states, residual = layer( positions, hidden_states, residual, **kwargs, ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) if residual is not None: hidden_states, _ = self.norm(hidden_states, residual) else: hidden_states = self.norm(hidden_states) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if name.endswith(".bias") and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params def get_patches_grid_size( *, image_h: int, image_w: int, patch_size: int, pool_h: int, pool_w: int, ) -> tuple[int, int]: patch_h = image_h // patch_size patch_w = image_w // patch_size h_pad = round_down(patch_h + pool_h - 1, pool_h) - patch_h w_pad = round_down(patch_w + pool_w - 1, pool_w) - patch_w nrows = (patch_h + h_pad) // pool_h ncols = (patch_w + w_pad) // pool_w return nrows, ncols def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]: tilings = [ (i, j) for i in range(1, max_num + 1) for j in range(1, max_num + 1) if i * j <= max_num ] return sorted(tilings, key=lambda x: (x[0] * x[1], x[0])) def select_tiling( *, height: int, width: int, patch_size: int, max_num_patches: int, ): tilings = get_candidate_tilings(max_num_patches) candidate_tilings = np.array(tilings, dtype=np.int32) candidate_resolutions = candidate_tilings * patch_size original_size = np.array([height, width], dtype=np.float32) required_scale_d = candidate_resolutions.astype(np.float32) / original_size required_scale = required_scale_d.min(axis=-1, keepdims=True) if (required_scale < 1).all(): ix = required_scale.argmax() else: ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin() return candidate_tilings[ix] def get_image_size(image: ImageInput) -> ImageSize: if isinstance(image, Image): return ImageSize(*image.size) elif isinstance(image, (np.ndarray, torch.Tensor)): assert image.ndim == 3 h, w, c = image.shape assert c in [1, 3] return ImageSize(w, h) else: raise ValueError(f"Unknown image type: {type(image)}") def exif_transpose( images: ImageInput | None, ) -> ImageInput | None: if images is None: return None if images is not None and isinstance(images, (list, tuple)): images = [ exif_transpose(img) if isinstance(img, Image) else img for img in images ] elif images is not None and isinstance(images, Image): images = ImageOps.exif_transpose(images) return images def build_flat_image_bool_length( image_grids: torch.LongTensor, hf_config: PretrainedConfig, ) -> tuple[torch.LongTensor, torch.LongTensor]: image_patch_id = hf_config.image_patch_id low_res_image_start_id = hf_config.low_res_image_start_token_id image_start_id = hf_config.image_start_token_id image_col_id = hf_config.image_col_id image_end_id = hf_config.image_end_token_id device = image_grids.device B = image_grids.shape[0] resized_h = image_grids[:, 0] resized_w = image_grids[:, 1] h = image_grids[:, 2] w = image_grids[:, 3] lengths = resized_h * resized_w + h * (w + 1) + 4 # [B] total_len = int(lengths.sum().item()) flat = torch.empty(total_len, dtype=torch.long, device=device) offset = 0 for i in range(B): resized_h_i, resized_w_i, h_i, w_i = image_grids[i].tolist() L_i = int(lengths[i].item()) num_low_res_patches = resized_h_i * resized_w_i idx = offset flat[idx] = low_res_image_start_id idx += 1 if num_low_res_patches > 0: flat[idx : idx + num_low_res_patches] = image_patch_id idx += num_low_res_patches flat[idx] = image_end_id idx += 1 flat[idx] = image_start_id idx += 1 block_len = w_i + 1 if block_len > 0 and h_i > 0: line = torch.empty(block_len, dtype=torch.long, device=device) if w_i > 0: line[:w_i] = image_patch_id line[w_i] = image_col_id block = line.repeat(h_i) flat[idx : idx + h_i * block_len] = block idx += h_i * block_len flat[idx] = image_end_id idx += 1 assert idx - offset == L_i offset += L_i return flat, lengths def build_flat_video_bool_length( video_grids: torch.LongTensor, hf_config: PretrainedConfig, ) -> tuple[torch.LongTensor, torch.LongTensor]: image_patch_id = hf_config.image_patch_id frame_start_id = hf_config.frame_start_token_id frame_end_id = hf_config.frame_end_token_id device = video_grids.device B = video_grids.shape[0] t = video_grids[:, 0] resized_h = video_grids[:, 1] resized_w = video_grids[:, 2] P = resized_h * resized_w block_len = P + 2 lengths = t * block_len total_len = int(lengths.sum().item()) flat = torch.empty(total_len, dtype=torch.long, device=device) offset = 0 for i in range(B): ti = int(t[i].item()) Pi = int(P[i].item()) Li = int(lengths[i].item()) block = torch.empty(Pi + 2, dtype=torch.long, device=device) block[0] = frame_start_id if Pi > 0: block[1 : 1 + Pi] = image_patch_id block[-1] = frame_end_id seq = block.repeat(ti) flat[offset : offset + Li] = seq offset += Li return flat, lengths def get_candidate_target_fps( video_fps: int | float, sampling_fps: int | float, max_fps: int | float = _MAX_VIDEO_FPS, ) -> list[float]: """ Return the subset of `video_fps` factors that remain multiples of `sampling_fps`. Examples: >>> get_candidate_target_fps(video_fps=6, sampling_fps=2) [2, 6] >>> get_candidate_target_fps(video_fps=5, sampling_fps=1) [1, 5] >>> get_candidate_target_fps(video_fps=2, sampling_fps=2) [2] >>> get_candidate_target_fps(video_fps=5, sampling_fps=2) Traceback (most recent call last): ... ValueError: sampling_fps=2 must divide video_fps=5 to produce consistent frame steps. """ video_fps = int(video_fps) sampling_fps = int(sampling_fps) max_fps = int(max_fps) if sampling_fps is None: raise ValueError("sampling_fps must be provided") if video_fps <= 0 or sampling_fps <= 0: raise ValueError( "video_fps and sampling_fps must be positive " f"(got {video_fps}, {sampling_fps})" ) if video_fps % sampling_fps != 0: raise ValueError( f"sampling_fps={sampling_fps} must divide video_fps={video_fps}." ) candidates = [] for candidate in range(sampling_fps, video_fps + 1, sampling_fps): if candidate > max_fps: break if video_fps % candidate == 0: candidates.append(float(candidate)) return candidates def get_target_fps( video_fps: float, max_frames: int, total_frames: int, frame_sample_mode: str, candidate_target_fps: list[float], ) -> float | None: """ Get the target fps that best spans the video and has the most frames sampled """ num_frames_sampled = 0 selected_target_fps = None for target_fps in candidate_target_fps: step_size = max(int(video_fps / target_fps), 1) num_frames_sampled_at_fps = int(total_frames / step_size) if num_frames_sampled == 0: if ( "uniform" in frame_sample_mode and num_frames_sampled_at_fps > max_frames ): break selected_target_fps = target_fps num_frames_sampled = num_frames_sampled_at_fps else: # the candidate sampling fps increases so frame count can't decrease assert num_frames_sampled <= num_frames_sampled_at_fps if num_frames_sampled_at_fps > max_frames: # choose the sampling fps that spans the video continue elif num_frames_sampled_at_fps > num_frames_sampled: # both are less than max_frames; choose the one with higher # density of frames sampled selected_target_fps = target_fps num_frames_sampled = num_frames_sampled_at_fps return selected_target_fps def get_frame_times_and_chosen_fps( selected_target_fps, total_frames, max_frames, video_fps ): if selected_target_fps is None: frame_indices = np.linspace( 0, total_frames, max_frames, endpoint=False, dtype=int ) else: step_size = max(int(video_fps / selected_target_fps), 1) frame_indices = np.arange(0, total_frames, step_size) if len(frame_indices) > max_frames: frame_indices = frame_indices[:max_frames] return selected_target_fps, frame_indices class Molmo2ProcessingInfo(BaseProcessingInfo): def get_data_parser(self): return MultiModalDataParser( video_needs_metadata=True, expected_hidden_size=self._get_expected_hidden_size(), ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": 1} def select_tiling( self, *, image_width: int, image_height: int, image_processor: BaseImageProcessor, ) -> tuple[int, int]: max_crops = image_processor.max_crops left_margin, right_margin = image_processor.overlap_margins base_image_input_d = image_processor.patch_size total_margin_pixels = base_image_input_d * (right_margin + left_margin) crop_patches = image_processor.size["height"] // base_image_input_d crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_size = crop_window_patches * base_image_input_d tiling_h, tiling_w = select_tiling( height=image_height - total_margin_pixels, width=image_width - total_margin_pixels, patch_size=crop_window_size, max_num_patches=max_crops, ) return tiling_w, tiling_h def get_base_grid_size( self, image_processor: BaseImageProcessor | BaseVideoProcessor, ) -> tuple[int, int]: nrows, ncols = get_patches_grid_size( image_h=image_processor.size["height"], image_w=image_processor.size["width"], patch_size=image_processor.patch_size, pool_h=image_processor.pooling_size[0], pool_w=image_processor.pooling_size[1], ) return ncols, nrows def get_patches_grid_size( self, *, image_width: int, image_height: int, image_processor: BaseImageProcessor, ) -> tuple[int, int]: left_margin, right_margin = image_processor.overlap_margins base_image_input_d = image_processor.patch_size total_margin_pixels = base_image_input_d * (right_margin + left_margin) crop_patches = image_processor.size["height"] // base_image_input_d crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_size = crop_window_patches * base_image_input_d tiling_w, tiling_h = self.select_tiling( image_height=image_height, image_width=image_width, image_processor=image_processor, ) nrows, ncols = get_patches_grid_size( image_h=tiling_h * crop_window_size + total_margin_pixels, image_w=tiling_w * crop_window_size + total_margin_pixels, patch_size=base_image_input_d, pool_h=image_processor.pooling_size[0], pool_w=image_processor.pooling_size[1], ) return ncols, nrows def get_num_image_tokens( self, *, image_height: int, image_width: int, processor: ProcessorMixin, ) -> int: image_processor = processor.image_processor resize_ncols, resize_nrows = self.get_base_grid_size(image_processor) # start/end tokens + image patch token + col tokens if processor.use_single_crop_col_tokens is not None: use_col_tokens = processor.use_single_crop_col_tokens else: use_col_tokens = processor.image_use_col_tokens extra = 2 + resize_nrows * (resize_ncols + int(use_col_tokens)) overlap_ncols, overlap_nrows = self.get_patches_grid_size( image_height=image_height, image_width=image_width, image_processor=image_processor, ) joint = 2 + overlap_nrows * ( overlap_ncols + int(processor.image_use_col_tokens) ) return extra + joint def get_num_video_tokens( self, *, num_frames: int, processor: ProcessorMixin, ) -> int: video_processor = processor.video_processor resize_ncols, resize_nrows = self.get_base_grid_size(video_processor) # start/end tokens extra = 2 + resize_nrows * (resize_ncols + int(processor.video_use_col_tokens)) return num_frames * extra def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() image_processor = processor.image_processor left_margin, right_margin = image_processor.overlap_margins base_image_input_d = image_processor.patch_size total_margin_pixels = base_image_input_d * (right_margin + left_margin) crop_patches = image_processor.size["height"] // base_image_input_d crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_size = crop_window_patches * base_image_input_d tilings = get_candidate_tilings(image_processor.max_crops) largest_feature_size, largest_feature_pinpoint = 0, None for hr, wr in tilings: height = hr * crop_window_size + total_margin_pixels width = wr * crop_window_size + total_margin_pixels feat_size = self.get_num_image_tokens( image_height=height, image_width=width, processor=processor, ) if feat_size > largest_feature_size: largest_feature_size = feat_size largest_feature_pinpoint = ImageSize(width=width, height=height) if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") return largest_feature_pinpoint def _get_max_video_frames( self, max_tokens: int, processor: ProcessorMixin, ) -> int: num_tokens_per_frame = self.get_num_video_tokens( num_frames=1, processor=processor, ) max_frames = max_tokens // num_tokens_per_frame return max(max_frames, 1) def get_num_frames_with_most_features( self, seq_len: int, mm_counts: Mapping[str, int], ) -> int: processor = self.get_hf_processor() video_processor = processor.video_processor num_frames = video_processor.num_frames max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len, processor) max_frames_per_video = min( max_total_frames // max(max_videos, 1), num_frames, ) return max(max_frames_per_video, 1) def _sample_frames( self, total_num_frames: int, video_fps: float, duration: float, frame_sample_mode: str, num_frames: int, max_fps: int, sampling_fps: int, ) -> np.ndarray: if frame_sample_mode == "uniform_last_frame" and max_fps is not None: if total_num_frames <= 2: indices = np.arange(total_num_frames).astype(int) elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame # uniform fallback indices = np.linspace( 0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True, ).astype(int) else: float_indices = np.arange( 0.0, stop=total_num_frames - 1, step=float(video_fps / max_fps), ) if np.round(float_indices[-1]) != total_num_frames - 1: float_indices = np.concatenate( [float_indices, [total_num_frames - 1]], axis=0 ) indices = np.round(float_indices).astype(int) assert indices[-1] < total_num_frames assert len(float_indices) <= num_frames elif frame_sample_mode == "uniform_last_frame": indices = np.linspace( 0, total_num_frames - 1, num=min(num_frames, total_num_frames), endpoint=True, ).astype(int) elif frame_sample_mode == "fps": candidate_target_fps = get_candidate_target_fps(video_fps, sampling_fps) selected_target_fps = get_target_fps( video_fps, num_frames, total_num_frames, frame_sample_mode, candidate_target_fps, ) _, indices = get_frame_times_and_chosen_fps( selected_target_fps, total_num_frames, num_frames, video_fps, ) else: raise NotImplementedError(frame_sample_mode) return indices def _get_video_second_idx( self, metadata: dict[str, Any], do_sample_frames: bool | None = None, ) -> list[float]: processor = self.get_hf_processor() video_processor = processor.video_processor # metadata["fps"] refers to the true fps of the input video. video_fps = metadata["fps"] frames_indices = metadata.get("frames_indices") if do_sample_frames is None: do_sample_frames = metadata.get("do_sample_frames", False) if do_sample_frames: # Frame-based sampling is applied in HF video processor total_num_frames = metadata["total_num_frames"] duration = total_num_frames / video_fps frame_sample_mode = video_processor.frame_sample_mode num_frames = video_processor.num_frames max_fps = video_processor.max_fps sampling_fps = video_processor.sampling_fps frames_indices = self._sample_frames( total_num_frames, video_fps, duration, frame_sample_mode, num_frames, max_fps, sampling_fps, ) else: # Time-based sampling is done in vllm molmo2 video loader or molmo_utils assert frames_indices is not None timestamps = [frame_idx / video_fps for frame_idx in frames_indices] return timestamps class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) image_placeholder_token = IMAGE_PROMPT video_placeholder_token = VIDEO_PROMPT if num_images == 1: image_string = image_placeholder_token else: image_string = "".join( [f"Image {i + 1}" + image_placeholder_token for i in range(num_images)] ) return image_string + video_placeholder_token * num_videos def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_videos = mm_counts.get("video", 0) dummy_images = [] dummy_videos = [] if num_images > 0: target_width, target_height = self.info.get_image_size_with_most_features() image_overrides = mm_options.get("image") dummy_images = self._get_dummy_images( width=target_width, height=target_height, num_images=num_images, overrides=image_overrides, ) if num_videos > 0: processor = self.info.get_hf_processor() video_size = processor.video_processor.size target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts ) video_overrides = mm_options.get("video") if video_overrides: assert isinstance(video_overrides, VideoDummyOptions) num_frames_override = video_overrides.num_frames if num_frames_override: if num_frames_override > target_num_frames: logger.warning( "video.num_frames override (%d) exceeds model's " "maximum number of frames (%d), will be ignored", num_frames_override, target_num_frames, ) if num_frames_override < 2: logger.warning( "video.num_frames override (%d) cannot be less " "than 2, will be ignored", num_frames_override, ) target_num_frames = min(target_num_frames, num_frames_override) dummy_videos = self._get_dummy_videos( width=video_size["width"], height=video_size["height"], num_frames=target_num_frames, num_videos=num_videos, ) return { "image": dummy_images, "video": dummy_videos, } def _get_dummy_videos( self, *, width: int, height: int, num_frames: int, num_videos: int, overrides: VideoDummyOptions | None = None, ) -> list[VideoItem]: videos = super()._get_dummy_videos( width=width, height=height, num_frames=num_frames, num_videos=num_videos, overrides=overrides, ) videos = [v.copy() for v in videos] video_items = [] for video in videos: video_num_frames = video.shape[0] video_metadata = { "fps": 2.0, "duration": video_num_frames / 2.0, "total_num_frames": video_num_frames, "frames_indices": list(range(video_num_frames)), "video_backend": "decord", "do_sample_frames": False, "height": height, "width": width, } video_items.append((video, video_metadata)) return video_items class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], ) -> list[int]: processor = self.info.get_hf_processor() tokenizer = processor.tokenizer bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id if len(prompt_tokens) == 0 or prompt_tokens[0] != bos_token_id: # Prepend the bos token to the prompt tokens prompt_tokens = [bos_token_id] + prompt_tokens return prompt_tokens def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], mm_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) hf_config = self.info.get_hf_config() hf_processor = self.info.get_hf_processor(**mm_kwargs) def patched_call(text=None, images=None, videos=None, **kwargs) -> BatchFeature: res = hf_processor(text=text, images=images, videos=videos, **kwargs) # Molmo2Processor.insert_bos results in float outputs # if the input text is empty if not text: res["input_ids"] = res["input_ids"].long() return res tokenizer = hf_processor.tokenizer image_processor = hf_processor.image_processor if videos := mm_data.pop("videos", []): bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id pixel_values_videos_lst = [] video_token_pooling_lst = [] video_num_crops_lst = [] video_num_pooled_patches_lst = [] video_num_patches_lst = [] video_tokens_lst = [] num_video_tokens_lst = [] for item in videos: video_array, metadata = item # NOTE: metadata.frames_indices indicates # the sampled frames indices of pre-sampled videos, which is # used to calculate the timestamps. Make sure that # do_sample_frames in mm_kwargs is false for presampled videos. # NOTE: a copy of mm_kwargs is created to update do_sample_frames, # otherwise mm_hash for the object will be incorrect. video_mm_kwargs = dict(**mm_kwargs) if "do_sample_frames" not in video_mm_kwargs: # molmo_utils already has "do_sample_frames" in # mm_kwargs, don't overwrite it. video_mm_kwargs["do_sample_frames"] = metadata.get( "do_sample_frames", False ) metadata = VideoMetadata( **{k: metadata[k] for k in metadata if k != "do_sample_frames"} ) video_mm_data = dict() video_mm_data["videos"] = [[video_array]] video_mm_data["video_metadata"] = [[metadata]] video_outputs = self.info.ctx.call_hf_processor( patched_call, dict(text=VIDEO_PROMPT, **video_mm_data), dict(**video_mm_kwargs, **tok_kwargs), ) input_ids = video_outputs.pop("input_ids") if input_ids[0, 0] == bos_token_id: input_ids = input_ids[:, 1:] video_string = tokenizer.batch_decode(input_ids)[0] prompt = prompt.replace(VIDEO_PROMPT, video_string, 1) video_grids = video_outputs.pop("video_grids") assert video_grids[:, 0].sum() == len( video_outputs["pixel_values_videos"] ) video_outputs["video_num_crops"] = video_grids[:, 0] video_outputs["video_num_pooled_patches"] = video_grids.prod(dim=1) n_patches = video_outputs["pixel_values_videos"].shape[1] video_outputs["video_num_patches"] = ( video_outputs["video_num_crops"] * n_patches ) (video_outputs["video_tokens"], video_outputs["num_video_tokens"]) = ( build_flat_video_bool_length(video_grids, hf_config) ) pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) video_token_pooling_lst.append(video_outputs["video_token_pooling"]) video_num_crops_lst.append(video_outputs["video_num_crops"]) video_num_pooled_patches_lst.append( video_outputs["video_num_pooled_patches"] ) video_num_patches_lst.append(video_outputs["video_num_patches"]) video_tokens_lst.append(video_outputs["video_tokens"]) num_video_tokens_lst.append(video_outputs["num_video_tokens"]) all_video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_token_pooling=torch.cat(video_token_pooling_lst), video_num_crops=torch.cat(video_num_crops_lst), video_num_pooled_patches=torch.cat(video_num_pooled_patches_lst), video_num_patches=torch.cat(video_num_patches_lst), video_tokens=torch.cat(video_tokens_lst), num_video_tokens=torch.cat(num_video_tokens_lst), ) else: all_video_outputs = dict() processed_outputs = self.info.ctx.call_hf_processor( patched_call, dict(text=prompt, **mm_data), dict(**mm_kwargs, **tok_kwargs), ) if (images := mm_data.get("images")) is not None: mm_items = self.info.parse_mm_data({"image": images}, validate=False) parsed_images = mm_items.get_items("image", ImageProcessorItems) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] # For each image: tiling_h * tiling_w + global view tilings = [ self.info.select_tiling( image_width=image_size.width, image_height=image_size.height, image_processor=image_processor, ) for image_size in image_sizes ] num_crops = torch.tensor(tilings).prod(-1) + 1 assert sum(num_crops) == len(processed_outputs["pixel_values"]) assert sum(num_crops) == processed_outputs["image_num_crops"].sum().item() image_grids = processed_outputs.pop("image_grids") image_num_pooled_patches = image_grids[:, :2].prod(dim=1) + image_grids[ :, 2: ].prod(dim=1) processed_outputs["image_num_pooled_patches"] = image_num_pooled_patches n_patches = processed_outputs["pixel_values"].shape[1] processed_outputs["image_num_patches"] = ( processed_outputs["image_num_crops"] * n_patches ) ( processed_outputs["image_tokens"], processed_outputs["num_image_tokens"], ) = build_flat_image_bool_length(image_grids, hf_config) return BatchFeature({**processed_outputs, **all_video_outputs}) def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_num_crops = hf_inputs.get("image_num_crops", torch.empty(0)) image_num_pooled_patches = hf_inputs.get( "image_num_pooled_patches", torch.empty(0) ) video_num_crops = hf_inputs.get("video_num_crops", torch.empty(0)) video_num_pooled_patches = hf_inputs.get( "video_num_pooled_patches", torch.empty(0) ) num_image_tokens = hf_inputs.get("num_image_tokens", torch.empty(0)) num_video_tokens = hf_inputs.get("num_video_tokens", torch.empty(0)) return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes( "image", image_num_crops ), image_token_pooling=MultiModalFieldConfig.flat_from_sizes( "image", image_num_pooled_patches ), image_num_crops=MultiModalFieldConfig.batched("image"), image_num_pooled_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"), image_tokens=MultiModalFieldConfig.flat_from_sizes( "image", num_image_tokens ), num_image_tokens=MultiModalFieldConfig.batched("image"), pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( "video", video_num_crops ), video_token_pooling=MultiModalFieldConfig.flat_from_sizes( "video", video_num_pooled_patches ), video_num_crops=MultiModalFieldConfig.batched("video"), video_num_pooled_patches=MultiModalFieldConfig.batched("video"), video_num_patches=MultiModalFieldConfig.batched("video"), video_tokens=MultiModalFieldConfig.flat_from_sizes( "video", num_video_tokens ), num_video_tokens=MultiModalFieldConfig.batched("video"), ) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() img_patch_id = hf_config.image_patch_id img_col_id = hf_config.image_col_id img_start_id = hf_config.image_start_token_id img_end_id = hf_config.image_end_token_id low_res_im_start_id = hf_config.low_res_image_start_token_id frame_start_id = hf_config.frame_start_token_id frame_end_id = hf_config.frame_end_token_id im_low_res_id = hf_config.image_low_res_id emb_tok_ids = [ img_patch_id, img_col_id, img_start_id, low_res_im_start_id, frame_start_id, img_end_id, frame_end_id, im_low_res_id, ] processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) image_use_col_tokens = processor.image_use_col_tokens use_single_crop_col_tokens = processor.use_single_crop_col_tokens use_single_crop_start_token = processor.use_single_crop_start_token video_use_col_tokens = processor.video_use_col_tokens use_frame_special_tokens = processor.use_frame_special_tokens tokenizer = processor.tokenizer vocab = tokenizer.get_vocab() image_processor = processor.image_processor video_processor = processor.video_processor def get_image_replacement_molmo2(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image = images.get(item_idx) image = exif_transpose(image) resize_ncols, resize_nrows = self.info.get_base_grid_size(image_processor) if use_single_crop_col_tokens is not None: use_col_tokens = use_single_crop_col_tokens else: use_col_tokens = image_use_col_tokens if use_single_crop_start_token: start_id = low_res_im_start_id else: start_id = img_start_id extra_row = [img_patch_id] * resize_ncols + [img_col_id] * int( use_col_tokens ) extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id] image_size = get_image_size(image) ncols, nrows = self.info.get_patches_grid_size( image_height=image_size.height, image_width=image_size.width, image_processor=image_processor, ) joint_row = [img_patch_id] * ncols + [img_col_id] * int( image_use_col_tokens ) joint = [img_start_id] + joint_row * nrows + [img_end_id] img_token_ids = extra_joint + joint return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids) def get_video_replacement_molmo2(item_idx: int): video, metadata = mm_items["video"][item_idx] do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") timestamps = self.info._get_video_second_idx(metadata, do_sample_frames) ncols, nrows = self.info.get_base_grid_size(video_processor) if use_frame_special_tokens: start_id = frame_start_id end_id = frame_end_id else: start_id = img_start_id end_id = img_end_id img_token_ids = [] for frame_idx, frame_time in enumerate(timestamps): prev_space = " " if frame_idx > 0 else "" frame_prefix = ( prev_space + f"{frame_time:.1f} " ) # explicit whitespace before/after image tokens img_token_ids += tokenizer.encode( frame_prefix, add_special_tokens=False, ) joint_row = [img_patch_id] * ncols + [img_col_id] * int( video_use_col_tokens ) joint = [start_id] + nrows * joint_row + [end_id] img_token_ids += joint return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids) return [ PromptReplacement( modality=modality, target=[target], replacement=replacement_fn, ) for modality, target, replacement_fn in zip( ["image", "video"], [vocab[IMAGE_PROMPT], vocab[VIDEO_PROMPT]], [get_image_replacement_molmo2, get_video_replacement_molmo2], ) ] @MULTIMODAL_REGISTRY.register_processor( Molmo2MultiModalProcessor, info=Molmo2ProcessingInfo, dummy_inputs=Molmo2DummyInputsBuilder, ) class Molmo2ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant ): hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ # vision backbone mapping "image_pooling_2d.wq": "image_pooling_2d.q_proj", "image_pooling_2d.wk": "image_pooling_2d.k_proj", "image_pooling_2d.wv": "image_pooling_2d.v_proj", "image_pooling_2d.wo": "image_pooling_2d.o_proj", "image_projector.w1": "image_projector.gate_proj", "image_projector.w3": "image_projector.up_proj", "image_projector.w2": "image_projector.down_proj", # language backbone mapping "att_proj": "qkv_proj", "attn_out": "o_proj", "q_norm": "q_norm", "k_norm": "k_norm", "ff_proj": "up_gate_proj", "ff_out": "down_proj", "attn_norm": "input_layernorm", "ff_norm": "post_attention_layernorm", }, orig_to_new_prefix={ # vision backbone mapping "model.vision_backbone.": "vision_backbone.", # language backbone mapping "model.transformer.blocks.": "model.layers.", "model.transformer.ln_f.": "model.norm.", }, ) packed_modules_mapping = { "qkv_proj": ["qkv_proj"], "up_gate_proj": ["up_gate_proj"], # language model "merged_qkv": ["wq", "wk", "wv"], # vision backbone "merged_kv": ["k_proj", "v_proj"], # image_pooling_2d "merged_linear": ["gate_proj", "up_proj"], # image_projector } @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): return IMAGE_PROMPT if modality.startswith("video"): return VIDEO_PROMPT raise ValueError("Only image or video modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config kwargs = {} for field in fields(VitConfig): kwargs[field.name] = getattr(config.vit_config, field.name) vit_config = VitConfig(**kwargs) kwargs = {} for field in fields(AdapterConfig): kwargs[field.name] = getattr(config.adapter_config, field.name) adapter_config = AdapterConfig(**kwargs) with self._mark_tower_model(vllm_config, {"image", "video"}): self.vision_backbone = Molmo2VisionBackbone( vit_config, adapter_config, quant_config, prefix=maybe_prefix(prefix, "vision_backbone"), ) with self._mark_language_model(vllm_config): self.model = Molmo2TextModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"), ) self.img_patch_id = config.image_patch_id if hasattr(config, "text_config"): hf_text_config = config.text_config else: hf_text_config = config.llm_config self.lm_head = ParallelLMHead( hf_text_config.vocab_size, hf_text_config.hidden_size, quant_config=quant_config, ) self.logits_processor = LogitsProcessor(hf_text_config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) @property def dtype(self): return next(self.parameters()).dtype def _parse_and_validate_image_input( self, **kwargs: object, ) -> Molmo2ImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) if pixel_values is None: return None token_pooling = kwargs.pop("image_token_pooling", None) num_pooled_patches = kwargs.pop("image_num_pooled_patches", None) num_patches = kwargs.pop("image_num_patches", None) image_tokens = kwargs.pop("image_tokens", None) num_image_tokens = kwargs.pop("num_image_tokens", None) accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist() patch_offset = 0 new_token_pooling = token_pooling.clone() for i, n in enumerate(num_pooled_patches): cur_slice = token_pooling[patch_offset : patch_offset + n] index_offset = int(accum_patches[i]) new_token_pooling[patch_offset : patch_offset + n] = torch.where( cur_slice >= 0, cur_slice + index_offset, cur_slice, ) patch_offset += n return Molmo2ImageInputs( pixel_values=pixel_values, token_pooling=new_token_pooling, num_pooled_patches=num_pooled_patches, image_tokens=image_tokens, num_image_tokens=num_image_tokens, ) def _parse_and_validate_video_input( self, **kwargs: object, ) -> Molmo2VideoInputs | None: pixel_values_videos = kwargs.pop("pixel_values_videos", None) if pixel_values_videos is None: return None token_pooling = kwargs.pop("video_token_pooling", None) num_pooled_patches = kwargs.pop("video_num_pooled_patches", None) num_patches = kwargs.pop("video_num_patches", None) video_tokens = kwargs.pop("video_tokens", None) num_video_tokens = kwargs.pop("num_video_tokens", None) accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist() patch_offset = 0 new_token_pooling = token_pooling.clone() for i, n in enumerate(num_pooled_patches): cur_slice = token_pooling[patch_offset : patch_offset + n] index_offset = int(accum_patches[i]) new_token_pooling[patch_offset : patch_offset + n] = torch.where( cur_slice >= 0, cur_slice + index_offset, cur_slice, ) patch_offset += n return Molmo2VideoInputs( pixel_values_videos=pixel_values_videos, token_pooling=new_token_pooling, num_pooled_patches=num_pooled_patches, video_tokens=video_tokens, num_video_tokens=num_video_tokens, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: modalities = {} for input_key in kwargs: if input_key in ("pixel_values",) and "images" not in modalities: modalities["images"] = self._parse_and_validate_image_input(**kwargs) if input_key in ("pixel_values_videos",) and "videos" not in modalities: modalities["videos"] = self._parse_and_validate_video_input(**kwargs) return modalities def _process_image_input( self, image_input: Molmo2ImageInputs, ) -> tuple[torch.Tensor, ...]: pixel_values = image_input["pixel_values"] token_pooling = image_input["token_pooling"] num_pooled_patches = image_input["num_pooled_patches"] image_tokens = image_input["image_tokens"] num_image_tokens = image_input["num_image_tokens"] image_features_flat = self.vision_backbone( images=pixel_values.unsqueeze(0), token_pooling=token_pooling.unsqueeze(0), ) assert len(image_features_flat) == num_pooled_patches.sum() image_features_list = image_features_flat.split( num_pooled_patches.tolist(), dim=0 ) image_tokens_list = image_tokens.split(num_image_tokens.tolist(), dim=0) out = [] for image_features_i, image_tokens_i in zip( image_features_list, image_tokens_list ): out_features = self.get_language_model().embed_input_ids(image_tokens_i) is_image_patch = image_tokens_i == self.img_patch_id out_features[is_image_patch] = image_features_i out.append(out_features) return tuple(out) def _process_video_input( self, video_input: Molmo2VideoInputs, ) -> tuple[torch.Tensor, ...]: pixel_values_videos = video_input["pixel_values_videos"] token_pooling = video_input["token_pooling"] num_pooled_patches = video_input["num_pooled_patches"] video_tokens = video_input["video_tokens"] num_video_tokens = video_input["num_video_tokens"] image_features_flat = self.vision_backbone( images=pixel_values_videos.unsqueeze(0), token_pooling=token_pooling.unsqueeze(0), ) assert len(image_features_flat) == num_pooled_patches.sum() image_features_list = image_features_flat.split( num_pooled_patches.tolist(), dim=0 ) video_tokens_list = video_tokens.split(num_video_tokens.tolist(), dim=0) out = [] for image_features_i, video_tokens_i in zip( image_features_list, video_tokens_list ): out_features = self.get_language_model().embed_input_ids(video_tokens_i) is_image_patch = video_tokens_i == self.img_patch_id out_features[is_image_patch] = image_features_i out.append(out_features) return tuple(out) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: return [] multimodal_embeddings: tuple[torch.Tensor, ...] = () for modality in modalities: if modality == "images": image_input = modalities["images"] image_embeddings = self._process_image_input(image_input) multimodal_embeddings += image_embeddings if modality == "videos": video_input = modalities["videos"] video_embeddings = self._process_video_input(video_input) multimodal_embeddings += video_embeddings return multimodal_embeddings def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, ) -> torch.Tensor: inputs_embeds = self._embed_text_input_ids( input_ids, self.get_language_model().embed_input_ids, is_multimodal=is_multimodal, ) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds if is_multimodal is None: raise ValueError( "`embed_input_ids` now requires `is_multimodal` arg, " "please update your model runner according to " "https://github.com/vllm-project/vllm/pull/16229." ) inputs_embeds = _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, ) return inputs_embeds def forward( self, input_ids: torch.LongTensor, positions: torch.LongTensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor: if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds, **kwargs, ) return hidden_states def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) weights = _get_weights_with_merged_embedding(weights) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="model", connector="vision_backbone.image_projector", tower_model="vision_backbone", ) def _get_weights_with_merged_embedding( weights: Iterable[tuple[str, torch.Tensor]], ) -> Iterable[tuple[str, torch.Tensor]]: embedding_weights = {} for name, weight in weights: if "wte.embedding" in name: embedding_weights["embedding"] = weight elif "wte.new_embedding" in name: embedding_weights["new_embedding"] = weight else: yield (name, weight) # this is compatible with most of quantization, # because they won't quantize embed_tokens if "embedding" not in embedding_weights or "new_embedding" not in embedding_weights: raise ValueError( "Checkpoint is missing 'wte.embedding' or " "'wte.new_embedding' weights required for Molmo2." ) embedding_weights = torch.cat( [embedding_weights["embedding"], embedding_weights["new_embedding"]], dim=0, ) yield ("model.embed_tokens.weight", embedding_weights)