Files
vllm/vllm/model_executor/models/molmo2.py
2026-03-11 03:07:20 -07:00

2636 lines
88 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import partial
from itertools import islice
from typing import Annotated, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import ImageOps
from PIL.Image import Image
from transformers import (
BaseImageProcessor,
BaseVideoProcessor,
BatchFeature,
PretrainedConfig,
ProcessorMixin,
)
from transformers.image_utils import ImageInput
from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import MulAndSilu, SiluAndMul, get_act_fn
from vllm.model_executor.layers.attention import Attention, MMEncoderAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import round_down
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
_merge_multimodal_embeddings,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
# Special tokens. These should be present in any tokenizer we use
# because the preprocessor relies on them.
IMAGE_PROMPT = "<|image|>"
VIDEO_PROMPT = "<|video|>"
_MAX_VIDEO_FPS = 8
class Molmo2ImageInputs(TensorSchema):
"""
Dimensions:
- nc: The total number of crops (dynamic)
- np: The total number of patches per crop
- cps: Number of channels * patch_size * patch_size
- npp: Number of pooled patches (dynamic)
- pp: pooling_size * pooling_size
- ni: Number of images
- nt: Number of image tokens (dynamic)
"""
pixel_values: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")]
token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")]
"""
An index tensor that maps image features to their corresponding
patch tokens before pooling.
"""
num_pooled_patches: Annotated[torch.Tensor, TensorShape("ni")]
image_tokens: Annotated[torch.BoolTensor, TensorShape("nt")]
num_image_tokens: Annotated[torch.Tensor, TensorShape("ni")]
class Molmo2VideoInputs(TensorSchema):
"""
Dimensions:
- nc: The total number of frames (dynamic)
- np: The total number of patches per frame
- cps: Number of channels * patch_size * patch_size
- npp: Number of pooled patches (dynamic)
- pp: pooling_size * pooling_size
- nv: Number of videos
- nt: Number of video tokens (dynamic)
"""
pixel_values_videos: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")]
token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")]
"""
An index tensor that maps image features to their corresponding
patch tokens before pooling.
"""
num_pooled_patches: Annotated[torch.Tensor, TensorShape("nv")]
video_tokens: Annotated[torch.BoolTensor, TensorShape("nt")]
num_video_tokens: Annotated[torch.Tensor, TensorShape("nv")]
@dataclass
class VitConfig:
"""Config for a vision transformer"""
hidden_size: int = 1152
intermediate_size: int = 4304
num_hidden_layers: int = 27
num_attention_heads: int = 16
num_key_value_heads: int = 16
head_dim: int = 72
hidden_act: str = "gelu_pytorch_tanh"
layer_norm_eps: float = 1e-6
image_default_input_size: tuple[int, int] = (378, 378)
image_patch_size: int = 14
image_num_pos: int = 577
def __post_init__(self):
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
@dataclass
class AdapterConfig:
"""Config for a vit-llm adapter"""
vit_layers: tuple[int, int] = (-3, -9)
pooling_attention_mask: bool = False
hidden_size: int = 1152
num_attention_heads: int = 16
num_key_value_heads: int = 16
head_dim: int = 72
hidden_act: str = "silu"
intermediate_size: int = 18944
text_hidden_size: int = 3584
@dataclass
class TextConfig:
"""Configuration for a text model transformer"""
hidden_size: int = 3584
"""
The hidden size of the model.
"""
num_attention_heads: int = 28
"""
The number of self-attention heads.
"""
num_key_value_heads: int = 4
"""
The number of heads to use for keys and values.
"""
head_dim: int = 128
"""
The head dimensionality for the attention mechanism.
"""
vocab_size: int = 152064
"""Vocabulary size of the model."""
additional_vocab_size: int = 128
"""Number of additional tokens to have the input embeddings for"""
qkv_bias: bool = True
"""
Do QKV projection a bias
"""
num_hidden_layers: int = 48
"""
The number of layers/blocks.
"""
intermediate_size: int = 18944
"""
The hidden size for the MLP.
"""
hidden_act: str = "silu"
"""
The activation function to use within the MLP layers.
"""
max_position_embeddings: int = 4096
"""
Max positional embeddings to use in RoPE cache
"""
rope_theta: float = 1000000.0
"""
RoPE theta parameter.
"""
use_qk_norm: bool = False
"""
Apply layer norm to the keys and queries within the attention mechanism.
This can help stabilize training.
"""
qk_norm_type: str = "olmo"
"""
The type of layer norm to use for the keys and queries.
Can be "olmo" or "qwen3".
"""
layer_norm_eps: float = 1e-6
"""
epsilon for layer norms
"""
norm_after: bool = False
"""Apply layer norm before and after the attention and MLP blocks."""
rope_scaling_layers: tuple[int, ...] | None = None
"""
RoPE scaling layers.
"""
class ViTMLP(nn.Module):
"""MLP used in Vision Transformer."""
def __init__(
self,
dim: int,
hidden_dim: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.w1 = ColumnParallelLinear(
dim,
hidden_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w1",
)
# Activation function.
self.act = get_act_fn(hidden_act)
self.w2 = RowParallelLinear(
hidden_dim,
dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.w1(x)
x = self.act(x)
x, _ = self.w2(x)
return x
class ViTMultiHeadDotProductAttention(nn.Module):
"""Multi-head attention used in Vision Transformer."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_key_value_heads: int,
head_dim: int,
use_bias: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.total_num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = head_dim
assert self.head_dim == self.hidden_size // self.total_num_heads
self.total_num_kv_heads = num_key_value_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.merged_qkv = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.merged_qkv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.scale = self.head_dim**-0.5
self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
qkv, _ = self.merged_qkv(inputs)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
output = self.attn(xq, xk, xv)
output, _ = self.wo(output)
return output
class Molmo2VisionBlock(nn.Module):
"""Residual attention block used in Vision Transformer."""
def __init__(
self,
config: VitConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.attention = ViTMultiHeadDotProductAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
head_dim=config.head_dim,
quant_config=quant_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = ViTMLP(
dim=config.hidden_size,
hidden_dim=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
self.attention_norm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
self.ffn_norm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attention(self.attention_norm(x))
x = x + self.feed_forward(self.ffn_norm(x))
return x
class Molmo2VisionBlockCollection(nn.Module):
"""Collection of residual attention blocks used in Vision Transformer."""
def __init__(
self,
config: VitConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.resblocks = nn.ModuleList(
[
Molmo2VisionBlock(
config,
quant_config,
prefix=f"{prefix}.resblocks.{layer_idx}",
)
for layer_idx in range(config.num_hidden_layers)
]
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = []
for r in self.resblocks:
x = r(x)
hidden_states.append(x)
return hidden_states
class Molmo2VisionTransformer(nn.Module):
"""Vision Transformer used in Vision Backbone."""
def __init__(
self,
config: VitConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
scale = config.hidden_size**-0.5
self.num_prefix_tokens: int = 0 # no class embeddings
self.patch_num = config.image_num_patch
self.positional_embedding = nn.Parameter(
torch.randn(config.image_num_pos, config.hidden_size) * scale,
)
image_patch_size = config.image_patch_size
self.patch_embedding = nn.Linear(
image_patch_size * image_patch_size * 3,
config.hidden_size,
bias=True,
)
self.transformer = Molmo2VisionBlockCollection(
config,
quant_config,
prefix=f"{prefix}.transformer",
)
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
pos_emb = self.positional_embedding
pos_emb = pos_emb.reshape(
(
int(math.sqrt(pos_emb.shape[0])),
int(math.sqrt(pos_emb.shape[0])),
pos_emb.shape[1],
)
)
(patch_num_0, patch_num_1) = patch_num
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
# from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
pos_emb = F.interpolate(
pos_emb,
size=(patch_num_0, patch_num_1),
mode="bicubic",
align_corners=False,
antialias=True,
)
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
x = x + pos_emb[None, :, :].to(x.dtype)
return x
def forward(
self,
x: torch.Tensor,
patch_num: int | None = None,
) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
if patch_num is None:
patch_num = self.patch_num
x = self.patch_embedding(x)
x = self.add_pos_emb(x, patch_num)
hidden_states = self.transformer(x)
return hidden_states
class ImagePoolingAttention(nn.Module):
"""Multi-head attention used for image pooling"""
def __init__(
self,
input_dim: int,
hidden_size: int,
num_heads: int,
num_key_value_heads: int,
head_dim: int,
use_bias: bool = True,
use_pytorch_sdpa: bool = False,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.input_dim = input_dim
self.hidden_size = hidden_size
self.total_num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.head_dim = head_dim
assert self.head_dim == self.hidden_size // self.total_num_heads
self.total_num_kv_heads = num_key_value_heads
if self.total_num_kv_heads >= tp_size:
assert self.total_num_kv_heads % tp_size == 0
else:
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.kv_size = self.num_kv_heads * self.head_dim
self.q_proj = ColumnParallelLinear(
self.input_dim,
self.total_num_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.merged_kv = MergedColumnParallelLinear(
self.input_dim,
[self.total_num_kv_heads * self.head_dim] * 2,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.merged_kv",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.scale = self.head_dim**-0.5
self.use_pytorch_sdpa = use_pytorch_sdpa
if use_pytorch_sdpa:
self.attn = None
else:
self.attn = MMEncoderAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
)
def forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
bsz, q_len, _ = query.size()
kv_len = key.size(1)
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_mask,
is_causal=False,
enable_gqa=self.num_heads > self.num_kv_heads,
).transpose(1, 2)
return out.reshape(bsz, q_len, -1)
def forward(
self,
inputs_q: torch.Tensor,
inputs_kv: torch.Tensor,
attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
xq, _ = self.q_proj(inputs_q)
kv, _ = self.merged_kv(inputs_kv)
xk, xv = kv.split([self.kv_size, self.kv_size], dim=-1)
if self.use_pytorch_sdpa:
output = self.forward_sdpa(xq, xk, xv, attn_mask)
else:
output = self.attn(xq, xk, xv)
output, _ = self.o_proj(output)
return output
class ImageProjectorMLP(nn.Module):
"""MLP used for the image projector"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.merged_linear = MergedColumnParallelLinear(
input_dim,
[hidden_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.merged_linear",
)
# Activation function.
assert hidden_act == "silu"
self.act_fn = SiluAndMul()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
hidden_dim,
output_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.merged_linear(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class Molmo2VisionBackbone(nn.Module, SupportsQuant):
packed_modules_mapping = {
"merged_qkv": ["wq", "wk", "wv"], # vision backbone
"merged_kv": ["k_proj", "v_proj"], # image_pooling_2d
"merged_linear": ["gate_proj", "up_proj"],
}
def __init__(
self,
vit_config: VitConfig,
adapter_config: AdapterConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.vit_config = vit_config
self.adapter_config = adapter_config
self.vit_layers = []
for layer in adapter_config.vit_layers:
if layer >= 0:
self.vit_layers.append(layer)
else:
self.vit_layers.append(layer + vit_config.num_hidden_layers)
last_layer_needed = max(self.vit_layers) + 1
if last_layer_needed < vit_config.num_hidden_layers:
vit_config.num_hidden_layers = last_layer_needed
self.image_vit = Molmo2VisionTransformer(
vit_config,
quant_config,
prefix=f"{prefix}.image_vit",
)
self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
self.image_pooling_2d = ImagePoolingAttention(
input_dim=pool_dim,
hidden_size=adapter_config.hidden_size,
num_heads=adapter_config.num_attention_heads,
num_key_value_heads=adapter_config.num_key_value_heads,
head_dim=adapter_config.head_dim,
use_pytorch_sdpa=adapter_config.pooling_attention_mask,
quant_config=quant_config,
prefix=f"{prefix}.image_pooling_2d",
)
self.image_projector = ImageProjectorMLP(
input_dim=adapter_config.hidden_size,
hidden_dim=adapter_config.intermediate_size,
output_dim=adapter_config.text_hidden_size,
hidden_act=adapter_config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.image_projector",
)
@property
def dtype(self) -> torch.dtype:
return self.image_vit.patch_embedding.weight.dtype
@property
def device(self) -> torch.device:
return self.image_vit.patch_embedding.weight.device
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""
: param images: (batch_size, num_crops, num_patch, n_pixels)
"""
B, T, N, D = images.shape
images = images.view(B * T, N, D)
image_features = self.image_vit(images)
features = []
for layer in self.vit_layers:
features.append(image_features[layer])
image_features = torch.cat(features, dim=-1)
if self.num_prefix_tokens > 0:
image_features = image_features[:, 1:]
image_features = image_features.view(B, T, N, -1)
return image_features
def forward(
self,
images: torch.Tensor,
token_pooling: torch.Tensor,
) -> torch.Tensor:
# image_features shape:
# (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
batch_size, num_image = images.shape[:2]
images = images.to(device=self.device, dtype=self.dtype)
image_features = self.encode_image(images)
dim = image_features.shape[-1]
valid = token_pooling >= 0
valid_token = torch.any(valid, -1)
# Use `token_pooling` to arange the features for image pooling
batch_idx = torch.arange(
token_pooling.shape[0],
dtype=torch.long,
device=token_pooling.device,
)
batch_idx = torch.tile(
batch_idx.view(batch_size, 1, 1),
[1, token_pooling.shape[1], token_pooling.shape[2]],
)
# Now [batch, num_features, num_pooled_patches, dim]
to_pool = image_features.reshape(batch_size, -1, dim)[
batch_idx, torch.clip(token_pooling, 0)
]
to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
to_pool = to_pool.reshape([-1, token_pooling.shape[-1], dim])
if self.adapter_config.pooling_attention_mask:
attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]])
denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1)
denom = torch.where(denom == 0, 1, denom)
query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(
to_pool.dtype
)
else:
attn_mask = None
query = to_pool.mean(-2, keepdim=True)
pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
pooled_features = pooled_features.reshape(
[batch_size, -1, pooled_features.shape[-1]]
)
# MLP layer to map the feature.
pooled_features = self.image_projector(pooled_features)
return pooled_features.view(-1, pooled_features.shape[-1])[
valid_token.flatten()
]
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("merged_qkv", "wq", "q"),
("merged_qkv", "wk", "k"),
("merged_qkv", "wv", "v"),
("merged_kv", "k_proj", 0),
("merged_kv", "v_proj", 1),
("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Molmo2Attention(nn.Module):
"""Molmo2's LLM Attention."""
def __init__(
self,
config: TextConfig,
rope_parameters: dict[str, Any],
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.hidden_size % self.total_num_heads == 0
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= self.tp_size:
assert self.total_num_kv_heads % self.tp_size == 0
else:
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = config.head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.qkv_bias,
quant_config=quant_config,
)
self.tp_rank: int | None = None
self.k_norm: nn.Module | None = None
self.q_norm: nn.Module | None = None
self.qk_norm_type: str | None = None
if config.use_qk_norm:
k_norm_size = (
self.head_dim
if config.qk_norm_type == "qwen3"
else self.total_num_kv_heads * self.head_dim
)
self.tp_rank = get_tensor_model_parallel_rank()
self.k_norm = RMSNorm(k_norm_size, eps=config.layer_norm_eps)
q_norm_size = (
self.head_dim
if config.qk_norm_type == "qwen3"
else self.total_num_heads * self.head_dim
)
self.q_norm = RMSNorm(q_norm_size, eps=config.layer_norm_eps)
self.qk_norm_type = config.qk_norm_type
# Rotary embeddings. Rope scaling is only applied on full attention layers.
layer_idx = extract_layer_index(prefix)
if (
config.rope_scaling_layers is not None
and layer_idx not in config.rope_scaling_layers
):
rope_theta = rope_parameters["rope_theta"]
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
# Attention output projection.
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
def _apply_qk_norm(
self,
q: torch.Tensor,
k: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
**kwargs: object,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if (
self.q_norm is not None
and self.k_norm is not None
and self.qk_norm_type == "olmo"
):
q, k = self._apply_qk_norm(q, k)
elif self.q_norm is not None and self.k_norm is not None:
q_by_head = q.view(
*q.shape[:-1],
q.shape[-1] // self.head_dim,
self.head_dim,
)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(
*k.shape[:-1],
k.shape[-1] // self.head_dim,
self.head_dim,
)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class LanguageModelMLP(nn.Module):
"""Molmo2's LLM mlp."""
def __init__(
self,
input_dim: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.up_gate_proj = MergedColumnParallelLinear(
input_dim,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
)
# Activation function.
assert hidden_act == "silu"
self.act_fn = MulAndSilu()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
intermediate_size,
input_dim,
bias=False,
quant_config=quant_config,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
up_gate, _ = self.up_gate_proj(x)
x = self.act_fn(up_gate)
x, _ = self.down_proj(x)
return x
class Molmo2DecoderLayer(nn.Module):
def __init__(
self,
config: TextConfig,
rope_parameters: dict[str, Any],
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
# Attention block.
self.self_attn = Molmo2Attention(
config,
rope_parameters,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn",
)
# MLP block.
self.mlp = LanguageModelMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
quant_config,
)
# LayerNorm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs: object,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
**kwargs,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Molmo2DecoderNormAfterLayer(Molmo2DecoderLayer):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs: object,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Self Attention
residual = hidden_states
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
**kwargs,
)
hidden_states = self.input_layernorm(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + residual
residual = None
return hidden_states, residual
@support_torch_compile
class Molmo2TextModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
if hasattr(config, "text_config"):
hf_text_config = config.text_config
else:
hf_text_config = config.llm_config
kwargs = {}
for field in fields(TextConfig):
kwargs[field.name] = getattr(hf_text_config, field.name)
text_config = TextConfig(**kwargs)
self.embedding_size = text_config.vocab_size
self.embedding_size += text_config.additional_vocab_size or 0
self.embed_tokens = VocabParallelEmbedding(
self.embedding_size,
text_config.hidden_size,
quant_config=quant_config,
)
decoder_layer = (
Molmo2DecoderNormAfterLayer
if text_config.norm_after
else Molmo2DecoderLayer
)
self.start_layer, self.end_layer, self.layers = make_layers(
text_config.num_hidden_layers,
lambda prefix: decoder_layer(
text_config,
hf_text_config.rope_parameters,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(text_config.hidden_size, eps=text_config.layer_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"],
text_config.hidden_size,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# Apply blocks one-by-one.
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
positions,
hidden_states,
residual,
**kwargs,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def get_patches_grid_size(
*,
image_h: int,
image_w: int,
patch_size: int,
pool_h: int,
pool_w: int,
) -> tuple[int, int]:
patch_h = image_h // patch_size
patch_w = image_w // patch_size
h_pad = round_down(patch_h + pool_h - 1, pool_h) - patch_h
w_pad = round_down(patch_w + pool_w - 1, pool_w) - patch_w
nrows = (patch_h + h_pad) // pool_h
ncols = (patch_w + w_pad) // pool_w
return nrows, ncols
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
tilings = [
(i, j)
for i in range(1, max_num + 1)
for j in range(1, max_num + 1)
if i * j <= max_num
]
return sorted(tilings, key=lambda x: (x[0] * x[1], x[0]))
def select_tiling(
*,
height: int,
width: int,
patch_size: int,
max_num_patches: int,
):
tilings = get_candidate_tilings(max_num_patches)
candidate_tilings = np.array(tilings, dtype=np.int32)
candidate_resolutions = candidate_tilings * patch_size
original_size = np.array([height, width], dtype=np.float32)
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
required_scale = required_scale_d.min(axis=-1, keepdims=True)
if (required_scale < 1).all():
ix = required_scale.argmax()
else:
ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()
return candidate_tilings[ix]
def get_image_size(image: ImageInput) -> ImageSize:
if isinstance(image, Image):
return ImageSize(*image.size)
elif isinstance(image, (np.ndarray, torch.Tensor)):
assert image.ndim == 3
h, w, c = image.shape
assert c in [1, 3]
return ImageSize(w, h)
else:
raise ValueError(f"Unknown image type: {type(image)}")
def exif_transpose(
images: ImageInput | None,
) -> ImageInput | None:
if images is None:
return None
if images is not None and isinstance(images, (list, tuple)):
images = [
exif_transpose(img) if isinstance(img, Image) else img for img in images
]
elif images is not None and isinstance(images, Image):
images = ImageOps.exif_transpose(images)
return images
def build_flat_image_bool_length(
image_grids: torch.LongTensor,
hf_config: PretrainedConfig,
) -> tuple[torch.LongTensor, torch.LongTensor]:
image_patch_id = hf_config.image_patch_id
low_res_image_start_id = hf_config.low_res_image_start_token_id
image_start_id = hf_config.image_start_token_id
image_col_id = hf_config.image_col_id
image_end_id = hf_config.image_end_token_id
device = image_grids.device
B = image_grids.shape[0]
resized_h = image_grids[:, 0]
resized_w = image_grids[:, 1]
h = image_grids[:, 2]
w = image_grids[:, 3]
lengths = resized_h * resized_w + h * (w + 1) + 4 # [B]
total_len = int(lengths.sum().item())
flat = torch.empty(total_len, dtype=torch.long, device=device)
offset = 0
for i in range(B):
resized_h_i, resized_w_i, h_i, w_i = image_grids[i].tolist()
L_i = int(lengths[i].item())
num_low_res_patches = resized_h_i * resized_w_i
idx = offset
flat[idx] = low_res_image_start_id
idx += 1
if num_low_res_patches > 0:
flat[idx : idx + num_low_res_patches] = image_patch_id
idx += num_low_res_patches
flat[idx] = image_end_id
idx += 1
flat[idx] = image_start_id
idx += 1
block_len = w_i + 1
if block_len > 0 and h_i > 0:
line = torch.empty(block_len, dtype=torch.long, device=device)
if w_i > 0:
line[:w_i] = image_patch_id
line[w_i] = image_col_id
block = line.repeat(h_i)
flat[idx : idx + h_i * block_len] = block
idx += h_i * block_len
flat[idx] = image_end_id
idx += 1
assert idx - offset == L_i
offset += L_i
return flat, lengths
def build_flat_video_bool_length(
video_grids: torch.LongTensor,
hf_config: PretrainedConfig,
) -> tuple[torch.LongTensor, torch.LongTensor]:
image_patch_id = hf_config.image_patch_id
frame_start_id = hf_config.frame_start_token_id
frame_end_id = hf_config.frame_end_token_id
device = video_grids.device
B = video_grids.shape[0]
t = video_grids[:, 0]
resized_h = video_grids[:, 1]
resized_w = video_grids[:, 2]
P = resized_h * resized_w
block_len = P + 2
lengths = t * block_len
total_len = int(lengths.sum().item())
flat = torch.empty(total_len, dtype=torch.long, device=device)
offset = 0
for i in range(B):
ti = int(t[i].item())
Pi = int(P[i].item())
Li = int(lengths[i].item())
block = torch.empty(Pi + 2, dtype=torch.long, device=device)
block[0] = frame_start_id
if Pi > 0:
block[1 : 1 + Pi] = image_patch_id
block[-1] = frame_end_id
seq = block.repeat(ti)
flat[offset : offset + Li] = seq
offset += Li
return flat, lengths
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
max_fps: int | float = _MAX_VIDEO_FPS,
) -> list[float]:
"""
Return the subset of `video_fps` factors that remain multiples
of `sampling_fps`.
Examples:
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
[2, 6]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
[1, 5]
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
[2]
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
Traceback (most recent call last):
...
ValueError: sampling_fps=2 must divide video_fps=5 to produce
consistent frame steps.
"""
video_fps = int(video_fps)
sampling_fps = int(sampling_fps)
max_fps = int(max_fps)
if sampling_fps is None:
raise ValueError("sampling_fps must be provided")
if video_fps <= 0 or sampling_fps <= 0:
raise ValueError(
"video_fps and sampling_fps must be positive "
f"(got {video_fps}, {sampling_fps})"
)
if video_fps % sampling_fps != 0:
raise ValueError(
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
)
candidates = []
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
if candidate > max_fps:
break
if video_fps % candidate == 0:
candidates.append(float(candidate))
return candidates
def get_target_fps(
video_fps: float,
max_frames: int,
total_frames: int,
frame_sample_mode: str,
candidate_target_fps: list[float],
) -> float | None:
"""
Get the target fps that best spans the video and has the most frames sampled
"""
num_frames_sampled = 0
selected_target_fps = None
for target_fps in candidate_target_fps:
step_size = max(int(video_fps / target_fps), 1)
num_frames_sampled_at_fps = int(total_frames / step_size)
if num_frames_sampled == 0:
if (
"uniform" in frame_sample_mode
and num_frames_sampled_at_fps > max_frames
):
break
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
else:
# the candidate sampling fps increases so frame count can't decrease
assert num_frames_sampled <= num_frames_sampled_at_fps
if num_frames_sampled_at_fps > max_frames:
# choose the sampling fps that spans the video
continue
elif num_frames_sampled_at_fps > num_frames_sampled:
# both are less than max_frames; choose the one with higher
# density of frames sampled
selected_target_fps = target_fps
num_frames_sampled = num_frames_sampled_at_fps
return selected_target_fps
def get_frame_times_and_chosen_fps(
selected_target_fps, total_frames, max_frames, video_fps
):
if selected_target_fps is None:
frame_indices = np.linspace(
0, total_frames, max_frames, endpoint=False, dtype=int
)
else:
step_size = max(int(video_fps / selected_target_fps), 1)
frame_indices = np.arange(0, total_frames, step_size)
if len(frame_indices) > max_frames:
frame_indices = frame_indices[:max_frames]
return selected_target_fps, frame_indices
class Molmo2ProcessingInfo(BaseProcessingInfo):
def get_data_parser(self):
return MultiModalDataParser(
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": 1}
def select_tiling(
self,
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
max_crops = image_processor.max_crops
left_margin, right_margin = image_processor.overlap_margins
base_image_input_d = image_processor.patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = image_processor.size["height"] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = select_tiling(
height=image_height - total_margin_pixels,
width=image_width - total_margin_pixels,
patch_size=crop_window_size,
max_num_patches=max_crops,
)
return tiling_w, tiling_h
def get_base_grid_size(
self,
image_processor: BaseImageProcessor | BaseVideoProcessor,
) -> tuple[int, int]:
nrows, ncols = get_patches_grid_size(
image_h=image_processor.size["height"],
image_w=image_processor.size["width"],
patch_size=image_processor.patch_size,
pool_h=image_processor.pooling_size[0],
pool_w=image_processor.pooling_size[1],
)
return ncols, nrows
def get_patches_grid_size(
self,
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
left_margin, right_margin = image_processor.overlap_margins
base_image_input_d = image_processor.patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = image_processor.size["height"] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_w, tiling_h = self.select_tiling(
image_height=image_height,
image_width=image_width,
image_processor=image_processor,
)
nrows, ncols = get_patches_grid_size(
image_h=tiling_h * crop_window_size + total_margin_pixels,
image_w=tiling_w * crop_window_size + total_margin_pixels,
patch_size=base_image_input_d,
pool_h=image_processor.pooling_size[0],
pool_w=image_processor.pooling_size[1],
)
return ncols, nrows
def get_num_image_tokens(
self,
*,
image_height: int,
image_width: int,
processor: ProcessorMixin,
) -> int:
image_processor = processor.image_processor
resize_ncols, resize_nrows = self.get_base_grid_size(image_processor)
# start/end tokens + image patch token + col tokens
if processor.use_single_crop_col_tokens is not None:
use_col_tokens = processor.use_single_crop_col_tokens
else:
use_col_tokens = processor.image_use_col_tokens
extra = 2 + resize_nrows * (resize_ncols + int(use_col_tokens))
overlap_ncols, overlap_nrows = self.get_patches_grid_size(
image_height=image_height,
image_width=image_width,
image_processor=image_processor,
)
joint = 2 + overlap_nrows * (
overlap_ncols + int(processor.image_use_col_tokens)
)
return extra + joint
def get_num_video_tokens(
self,
*,
num_frames: int,
processor: ProcessorMixin,
) -> int:
video_processor = processor.video_processor
resize_ncols, resize_nrows = self.get_base_grid_size(video_processor)
# start/end tokens
extra = 2 + resize_nrows * (resize_ncols + int(processor.video_use_col_tokens))
return num_frames * extra
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor = processor.image_processor
left_margin, right_margin = image_processor.overlap_margins
base_image_input_d = image_processor.patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = image_processor.size["height"] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tilings = get_candidate_tilings(image_processor.max_crops)
largest_feature_size, largest_feature_pinpoint = 0, None
for hr, wr in tilings:
height = hr * crop_window_size + total_margin_pixels
width = wr * crop_window_size + total_margin_pixels
feat_size = self.get_num_image_tokens(
image_height=height,
image_width=width,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def _get_max_video_frames(
self,
max_tokens: int,
processor: ProcessorMixin,
) -> int:
num_tokens_per_frame = self.get_num_video_tokens(
num_frames=1,
processor=processor,
)
max_frames = max_tokens // num_tokens_per_frame
return max(max_frames, 1)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
processor = self.get_hf_processor()
video_processor = processor.video_processor
num_frames = video_processor.num_frames
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len, processor)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
num_frames,
)
return max(max_frames_per_video, 1)
def _sample_frames(
self,
total_num_frames: int,
video_fps: float,
duration: float,
frame_sample_mode: str,
num_frames: int,
max_fps: int,
sampling_fps: int,
) -> np.ndarray:
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
if total_num_frames <= 2:
indices = np.arange(total_num_frames).astype(int)
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
# uniform fallback
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
else:
float_indices = np.arange(
0.0,
stop=total_num_frames - 1,
step=float(video_fps / max_fps),
)
if np.round(float_indices[-1]) != total_num_frames - 1:
float_indices = np.concatenate(
[float_indices, [total_num_frames - 1]], axis=0
)
indices = np.round(float_indices).astype(int)
assert indices[-1] < total_num_frames
assert len(float_indices) <= num_frames
elif frame_sample_mode == "uniform_last_frame":
indices = np.linspace(
0,
total_num_frames - 1,
num=min(num_frames, total_num_frames),
endpoint=True,
).astype(int)
elif frame_sample_mode == "fps":
candidate_target_fps = get_candidate_target_fps(video_fps, sampling_fps)
selected_target_fps = get_target_fps(
video_fps,
num_frames,
total_num_frames,
frame_sample_mode,
candidate_target_fps,
)
_, indices = get_frame_times_and_chosen_fps(
selected_target_fps,
total_num_frames,
num_frames,
video_fps,
)
else:
raise NotImplementedError(frame_sample_mode)
return indices
def _get_video_second_idx(
self,
metadata: dict[str, Any],
do_sample_frames: bool | None = None,
) -> list[float]:
processor = self.get_hf_processor()
video_processor = processor.video_processor
# metadata["fps"] refers to the true fps of the input video.
video_fps = metadata["fps"]
frames_indices = metadata.get("frames_indices")
if do_sample_frames is None:
do_sample_frames = metadata.get("do_sample_frames", False)
if do_sample_frames:
# Frame-based sampling is applied in HF video processor
total_num_frames = metadata["total_num_frames"]
duration = total_num_frames / video_fps
frame_sample_mode = video_processor.frame_sample_mode
num_frames = video_processor.num_frames
max_fps = video_processor.max_fps
sampling_fps = video_processor.sampling_fps
frames_indices = self._sample_frames(
total_num_frames,
video_fps,
duration,
frame_sample_mode,
num_frames,
max_fps,
sampling_fps,
)
else:
# Time-based sampling is done in vllm molmo2 video loader or molmo_utils
assert frames_indices is not None
timestamps = [frame_idx / video_fps for frame_idx in frames_indices]
return timestamps
class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_placeholder_token = IMAGE_PROMPT
video_placeholder_token = VIDEO_PROMPT
if num_images == 1:
image_string = image_placeholder_token
else:
image_string = "".join(
[f"Image {i + 1}" + image_placeholder_token for i in range(num_images)]
)
return image_string + video_placeholder_token * num_videos
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
dummy_images = []
dummy_videos = []
if num_images > 0:
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image")
dummy_images = self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
if num_videos > 0:
processor = self.info.get_hf_processor()
video_size = processor.video_processor.size
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
video_overrides = mm_options.get("video")
if video_overrides:
assert isinstance(video_overrides, VideoDummyOptions)
num_frames_override = video_overrides.num_frames
if num_frames_override:
if num_frames_override > target_num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
num_frames_override,
target_num_frames,
)
if num_frames_override < 2:
logger.warning(
"video.num_frames override (%d) cannot be less "
"than 2, will be ignored",
num_frames_override,
)
target_num_frames = min(target_num_frames, num_frames_override)
dummy_videos = self._get_dummy_videos(
width=video_size["width"],
height=video_size["height"],
num_frames=target_num_frames,
num_videos=num_videos,
)
return {
"image": dummy_images,
"video": dummy_videos,
}
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[VideoItem]:
video = np.full((num_frames, height, width, 3), 255, dtype=np.uint8)
video_items = []
for i in range(num_videos):
video_metadata = {
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": list(range(num_frames)),
"video_backend": "decord",
"do_sample_frames": False,
"height": height,
"width": width,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
return video_items
class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
processor = self.info.get_hf_processor()
tokenizer = processor.tokenizer
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
if len(prompt_tokens) == 0 or prompt_tokens[0] != bos_token_id:
# Prepend the bos token to the prompt tokens
prompt_tokens = [bos_token_id] + prompt_tokens
return prompt_tokens
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
hf_config = self.info.get_hf_config()
hf_processor = self.info.get_hf_processor(**mm_kwargs)
def patched_call(text=None, images=None, videos=None, **kwargs) -> BatchFeature:
res = hf_processor(text=text, images=images, videos=videos, **kwargs)
# Molmo2Processor.insert_bos results in float outputs
# if the input text is empty
if not text:
res["input_ids"] = res["input_ids"].long()
return res
tokenizer = hf_processor.tokenizer
image_processor = hf_processor.image_processor
if videos := mm_data.pop("videos", []):
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
pixel_values_videos_lst = []
video_token_pooling_lst = []
video_num_crops_lst = []
video_num_pooled_patches_lst = []
video_num_patches_lst = []
video_tokens_lst = []
num_video_tokens_lst = []
for item in videos:
video_array, metadata = item
# NOTE: metadata.frames_indices indicates
# the sampled frames indices of pre-sampled videos, which is
# used to calculate the timestamps. Make sure that
# do_sample_frames in mm_kwargs is false for presampled videos.
# NOTE: a copy of mm_kwargs is created to update do_sample_frames,
# otherwise mm_hash for the object will be incorrect.
video_mm_kwargs = dict(**mm_kwargs)
if "do_sample_frames" not in video_mm_kwargs:
# molmo_utils already has "do_sample_frames" in
# mm_kwargs, don't overwrite it.
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", False
)
metadata = VideoMetadata(
**{k: metadata[k] for k in metadata if k != "do_sample_frames"}
)
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
video_outputs = self.info.ctx.call_hf_processor(
patched_call,
dict(text=VIDEO_PROMPT, **video_mm_data),
dict(**video_mm_kwargs, **tok_kwargs),
)
input_ids = video_outputs.pop("input_ids")
if input_ids[0, 0] == bos_token_id:
input_ids = input_ids[:, 1:]
video_string = tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(VIDEO_PROMPT, video_string, 1)
video_grids = video_outputs.pop("video_grids")
assert video_grids[:, 0].sum() == len(
video_outputs["pixel_values_videos"]
)
video_outputs["video_num_crops"] = video_grids[:, 0]
video_outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
n_patches = video_outputs["pixel_values_videos"].shape[1]
video_outputs["video_num_patches"] = (
video_outputs["video_num_crops"] * n_patches
)
(video_outputs["video_tokens"], video_outputs["num_video_tokens"]) = (
build_flat_video_bool_length(video_grids, hf_config)
)
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
video_token_pooling_lst.append(video_outputs["video_token_pooling"])
video_num_crops_lst.append(video_outputs["video_num_crops"])
video_num_pooled_patches_lst.append(
video_outputs["video_num_pooled_patches"]
)
video_num_patches_lst.append(video_outputs["video_num_patches"])
video_tokens_lst.append(video_outputs["video_tokens"])
num_video_tokens_lst.append(video_outputs["num_video_tokens"])
all_video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_token_pooling=torch.cat(video_token_pooling_lst),
video_num_crops=torch.cat(video_num_crops_lst),
video_num_pooled_patches=torch.cat(video_num_pooled_patches_lst),
video_num_patches=torch.cat(video_num_patches_lst),
video_tokens=torch.cat(video_tokens_lst),
num_video_tokens=torch.cat(num_video_tokens_lst),
)
else:
all_video_outputs = dict()
processed_outputs = self.info.ctx.call_hf_processor(
patched_call,
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
if (images := mm_data.get("images")) is not None:
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
# For each image: tiling_h * tiling_w + global view
tilings = [
self.info.select_tiling(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
for image_size in image_sizes
]
num_crops = torch.tensor(tilings).prod(-1) + 1
assert sum(num_crops) == len(processed_outputs["pixel_values"])
assert sum(num_crops) == processed_outputs["image_num_crops"].sum().item()
image_grids = processed_outputs.pop("image_grids")
image_num_pooled_patches = image_grids[:, :2].prod(dim=1) + image_grids[
:, 2:
].prod(dim=1)
processed_outputs["image_num_pooled_patches"] = image_num_pooled_patches
n_patches = processed_outputs["pixel_values"].shape[1]
processed_outputs["image_num_patches"] = (
processed_outputs["image_num_crops"] * n_patches
)
(
processed_outputs["image_tokens"],
processed_outputs["num_image_tokens"],
) = build_flat_image_bool_length(image_grids, hf_config)
return BatchFeature({**processed_outputs, **all_video_outputs})
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_crops = hf_inputs.get("image_num_crops", torch.empty(0))
image_num_pooled_patches = hf_inputs.get(
"image_num_pooled_patches", torch.empty(0)
)
video_num_crops = hf_inputs.get("video_num_crops", torch.empty(0))
video_num_pooled_patches = hf_inputs.get(
"video_num_pooled_patches", torch.empty(0)
)
num_image_tokens = hf_inputs.get("num_image_tokens", torch.empty(0))
num_video_tokens = hf_inputs.get("num_video_tokens", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_crops
),
image_token_pooling=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_pooled_patches
),
image_num_crops=MultiModalFieldConfig.batched("image"),
image_num_pooled_patches=MultiModalFieldConfig.batched("image"),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_tokens=MultiModalFieldConfig.flat_from_sizes(
"image", num_image_tokens
),
num_image_tokens=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_crops
),
video_token_pooling=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_pooled_patches
),
video_num_crops=MultiModalFieldConfig.batched("video"),
video_num_pooled_patches=MultiModalFieldConfig.batched("video"),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_tokens=MultiModalFieldConfig.flat_from_sizes(
"video", num_video_tokens
),
num_video_tokens=MultiModalFieldConfig.batched("video"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
img_patch_id = hf_config.image_patch_id
img_col_id = hf_config.image_col_id
img_start_id = hf_config.image_start_token_id
img_end_id = hf_config.image_end_token_id
low_res_im_start_id = hf_config.low_res_image_start_token_id
frame_start_id = hf_config.frame_start_token_id
frame_end_id = hf_config.frame_end_token_id
im_low_res_id = hf_config.image_low_res_id
emb_tok_ids = [
img_patch_id,
img_col_id,
img_start_id,
low_res_im_start_id,
frame_start_id,
img_end_id,
frame_end_id,
im_low_res_id,
]
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_use_col_tokens = processor.image_use_col_tokens
use_single_crop_col_tokens = processor.use_single_crop_col_tokens
use_single_crop_start_token = processor.use_single_crop_start_token
video_use_col_tokens = processor.video_use_col_tokens
use_frame_special_tokens = processor.use_frame_special_tokens
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_processor = processor.image_processor
video_processor = processor.video_processor
def get_image_replacement_molmo2(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image = images.get(item_idx)
image = exif_transpose(image)
resize_ncols, resize_nrows = self.info.get_base_grid_size(image_processor)
if use_single_crop_col_tokens is not None:
use_col_tokens = use_single_crop_col_tokens
else:
use_col_tokens = image_use_col_tokens
if use_single_crop_start_token:
start_id = low_res_im_start_id
else:
start_id = img_start_id
extra_row = [img_patch_id] * resize_ncols + [img_col_id] * int(
use_col_tokens
)
extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id]
image_size = get_image_size(image)
ncols, nrows = self.info.get_patches_grid_size(
image_height=image_size.height,
image_width=image_size.width,
image_processor=image_processor,
)
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
image_use_col_tokens
)
joint = [img_start_id] + joint_row * nrows + [img_end_id]
img_token_ids = extra_joint + joint
return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids)
def get_video_replacement_molmo2(item_idx: int):
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
timestamps = self.info._get_video_second_idx(metadata, do_sample_frames)
ncols, nrows = self.info.get_base_grid_size(video_processor)
if use_frame_special_tokens:
start_id = frame_start_id
end_id = frame_end_id
else:
start_id = img_start_id
end_id = img_end_id
img_token_ids = []
for frame_idx, frame_time in enumerate(timestamps):
prev_space = " " if frame_idx > 0 else ""
frame_prefix = (
prev_space + f"{frame_time:.1f} "
) # explicit whitespace before/after image tokens
img_token_ids += tokenizer.encode(
frame_prefix,
add_special_tokens=False,
)
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
video_use_col_tokens
)
joint = [start_id] + nrows * joint_row + [end_id]
img_token_ids += joint
return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids)
return [
PromptReplacement(
modality=modality,
target=[target],
replacement=replacement_fn,
)
for modality, target, replacement_fn in zip(
["image", "video"],
[vocab[IMAGE_PROMPT], vocab[VIDEO_PROMPT]],
[get_image_replacement_molmo2, get_video_replacement_molmo2],
)
]
@MULTIMODAL_REGISTRY.register_processor(
Molmo2MultiModalProcessor,
info=Molmo2ProcessingInfo,
dummy_inputs=Molmo2DummyInputsBuilder,
)
class Molmo2ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_pooling_2d.wq": "image_pooling_2d.q_proj",
"image_pooling_2d.wk": "image_pooling_2d.k_proj",
"image_pooling_2d.wv": "image_pooling_2d.v_proj",
"image_pooling_2d.wo": "image_pooling_2d.o_proj",
"image_projector.w1": "image_projector.gate_proj",
"image_projector.w3": "image_projector.up_proj",
"image_projector.w2": "image_projector.down_proj",
# language backbone mapping
"att_proj": "qkv_proj",
"attn_out": "o_proj",
"q_norm": "q_norm",
"k_norm": "k_norm",
"ff_proj": "up_gate_proj",
"ff_out": "down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
},
)
packed_modules_mapping = {
"qkv_proj": ["qkv_proj"],
"up_gate_proj": ["up_gate_proj"], # language model
"merged_qkv": ["wq", "wk", "wv"], # vision backbone
"merged_kv": ["k_proj", "v_proj"], # image_pooling_2d
"merged_linear": ["gate_proj", "up_proj"], # image_projector
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return IMAGE_PROMPT
if modality.startswith("video"):
return VIDEO_PROMPT
raise ValueError("Only image or video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
kwargs = {}
for field in fields(VitConfig):
kwargs[field.name] = getattr(config.vit_config, field.name)
vit_config = VitConfig(**kwargs)
kwargs = {}
for field in fields(AdapterConfig):
kwargs[field.name] = getattr(config.adapter_config, field.name)
adapter_config = AdapterConfig(**kwargs)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_backbone = Molmo2VisionBackbone(
vit_config,
adapter_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_backbone"),
)
with self._mark_language_model(vllm_config):
self.model = Molmo2TextModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"),
)
self.img_patch_id = config.image_patch_id
if hasattr(config, "text_config"):
hf_text_config = config.text_config
else:
hf_text_config = config.llm_config
self.lm_head = ParallelLMHead(
hf_text_config.vocab_size,
hf_text_config.hidden_size,
quant_config=quant_config,
)
self.logits_processor = LogitsProcessor(hf_text_config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
@property
def dtype(self):
return next(self.parameters()).dtype
def _parse_and_validate_image_input(
self,
**kwargs: object,
) -> Molmo2ImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is None:
return None
token_pooling = kwargs.pop("image_token_pooling", None)
num_pooled_patches = kwargs.pop("image_num_pooled_patches", None)
num_patches = kwargs.pop("image_num_patches", None)
image_tokens = kwargs.pop("image_tokens", None)
num_image_tokens = kwargs.pop("num_image_tokens", None)
accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist()
patch_offset = 0
new_token_pooling = token_pooling.clone()
for i, n in enumerate(num_pooled_patches):
cur_slice = token_pooling[patch_offset : patch_offset + n]
index_offset = int(accum_patches[i])
new_token_pooling[patch_offset : patch_offset + n] = torch.where(
cur_slice >= 0,
cur_slice + index_offset,
cur_slice,
)
patch_offset += n
return Molmo2ImageInputs(
pixel_values=pixel_values,
token_pooling=new_token_pooling,
num_pooled_patches=num_pooled_patches,
image_tokens=image_tokens,
num_image_tokens=num_image_tokens,
)
def _parse_and_validate_video_input(
self,
**kwargs: object,
) -> Molmo2VideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
if pixel_values_videos is None:
return None
token_pooling = kwargs.pop("video_token_pooling", None)
num_pooled_patches = kwargs.pop("video_num_pooled_patches", None)
num_patches = kwargs.pop("video_num_patches", None)
video_tokens = kwargs.pop("video_tokens", None)
num_video_tokens = kwargs.pop("num_video_tokens", None)
accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist()
patch_offset = 0
new_token_pooling = token_pooling.clone()
for i, n in enumerate(num_pooled_patches):
cur_slice = token_pooling[patch_offset : patch_offset + n]
index_offset = int(accum_patches[i])
new_token_pooling[patch_offset : patch_offset + n] = torch.where(
cur_slice >= 0,
cur_slice + index_offset,
cur_slice,
)
patch_offset += n
return Molmo2VideoInputs(
pixel_values_videos=pixel_values_videos,
token_pooling=new_token_pooling,
num_pooled_patches=num_pooled_patches,
video_tokens=video_tokens,
num_video_tokens=num_video_tokens,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
for input_key in kwargs:
if input_key in ("pixel_values",) and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos",) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
return modalities
def _process_image_input(
self,
image_input: Molmo2ImageInputs,
) -> tuple[torch.Tensor, ...]:
pixel_values = image_input["pixel_values"]
token_pooling = image_input["token_pooling"]
num_pooled_patches = image_input["num_pooled_patches"]
image_tokens = image_input["image_tokens"]
num_image_tokens = image_input["num_image_tokens"]
image_features_flat = self.vision_backbone(
images=pixel_values.unsqueeze(0),
token_pooling=token_pooling.unsqueeze(0),
)
assert len(image_features_flat) == num_pooled_patches.sum()
image_features_list = image_features_flat.split(
num_pooled_patches.tolist(), dim=0
)
image_tokens_list = image_tokens.split(num_image_tokens.tolist(), dim=0)
out = []
for image_features_i, image_tokens_i in zip(
image_features_list, image_tokens_list
):
out_features = self.get_language_model().embed_input_ids(image_tokens_i)
is_image_patch = image_tokens_i == self.img_patch_id
out_features[is_image_patch] = image_features_i
out.append(out_features)
return tuple(out)
def _process_video_input(
self,
video_input: Molmo2VideoInputs,
) -> tuple[torch.Tensor, ...]:
pixel_values_videos = video_input["pixel_values_videos"]
token_pooling = video_input["token_pooling"]
num_pooled_patches = video_input["num_pooled_patches"]
video_tokens = video_input["video_tokens"]
num_video_tokens = video_input["num_video_tokens"]
image_features_flat = self.vision_backbone(
images=pixel_values_videos.unsqueeze(0),
token_pooling=token_pooling.unsqueeze(0),
)
assert len(image_features_flat) == num_pooled_patches.sum()
image_features_list = image_features_flat.split(
num_pooled_patches.tolist(), dim=0
)
video_tokens_list = video_tokens.split(num_video_tokens.tolist(), dim=0)
out = []
for image_features_i, video_tokens_i in zip(
image_features_list, video_tokens_list
):
out_features = self.get_language_model().embed_input_ids(video_tokens_i)
is_image_patch = video_tokens_i == self.img_patch_id
out_features[is_image_patch] = image_features_i
out.append(out_features)
return tuple(out)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += image_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
if is_multimodal is None:
raise ValueError(
"`embed_input_ids` now requires `is_multimodal` arg, "
"please update your model runner according to "
"https://github.com/vllm-project/vllm/pull/16229."
)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
return inputs_embeds
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
**kwargs,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="model",
connector="vision_backbone.image_projector",
tower_model="vision_backbone",
)
def _get_weights_with_merged_embedding(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights:
if "wte.embedding" in name:
embedding_weights["embedding"] = weight
elif "wte.new_embedding" in name:
embedding_weights["new_embedding"] = weight
else:
yield (name, weight)
# this is compatible with most of quantization,
# because they won't quantize embed_tokens
if "embedding" not in embedding_weights or "new_embedding" not in embedding_weights:
raise ValueError(
"Checkpoint is missing 'wte.embedding' or "
"'wte.new_embedding' weights required for Molmo2."
)
embedding_weights = torch.cat(
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
dim=0,
)
yield ("model.embed_tokens.weight", embedding_weights)