2809 lines
94 KiB
Python
2809 lines
94 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from dataclasses import dataclass, fields
|
|
from functools import cached_property, partial
|
|
from itertools import islice
|
|
from typing import Annotated, Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from PIL import ImageOps
|
|
from PIL.Image import Image
|
|
from transformers import (
|
|
BatchFeature,
|
|
PretrainedConfig,
|
|
ProcessorMixin,
|
|
TensorType,
|
|
)
|
|
from transformers.image_utils import ImageInput
|
|
from transformers.tokenization_utils_base import TextInput
|
|
from transformers.video_utils import VideoInput, VideoMetadata
|
|
|
|
from vllm.attention.layer import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
|
from vllm.distributed import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
split_tensor_along_last_dim,
|
|
tensor_model_parallel_all_gather,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.activation import MulAndSilu, SiluAndMul, get_act_fn
|
|
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
VideoItem,
|
|
)
|
|
from vllm.multimodal.parse import (
|
|
ImageProcessorItems,
|
|
ImageSize,
|
|
MultiModalDataItems,
|
|
MultiModalDataParser,
|
|
)
|
|
from vllm.multimodal.processing import (
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
PromptUpdateDetails,
|
|
)
|
|
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils.math_utils import round_down
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .interfaces import (
|
|
MultiModalEmbeddings,
|
|
SupportsLoRA,
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
SupportsQuant,
|
|
)
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
WeightsMapper,
|
|
_merge_multimodal_embeddings,
|
|
extract_layer_index,
|
|
is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
# Special tokens. These should be present in any tokenizer we use
|
|
# because the preprocessor relies on them.
|
|
IMAGE_PROMPT = "<|image|>"
|
|
VIDEO_PROMPT = "<|video|>"
|
|
_MAX_VIDEO_FPS = 8
|
|
|
|
|
|
class Molmo2ImageInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- nc: The total number of crops (dynamic)
|
|
- np: The total number of patches per crop
|
|
- cps: Number of channels * patch_size * patch_size
|
|
- npp: Number of pooled patches (dynamic)
|
|
- pp: pooling_size * pooling_size
|
|
- ni: Number of images
|
|
- nt: Number of image tokens (dynamic)
|
|
"""
|
|
|
|
pixel_values: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")]
|
|
|
|
token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")]
|
|
"""
|
|
An index tensor that maps image features to their corresponding
|
|
patch tokens before pooling.
|
|
"""
|
|
|
|
num_pooled_patches: Annotated[torch.Tensor, TensorShape("ni")]
|
|
|
|
image_tokens: Annotated[torch.BoolTensor, TensorShape("nt")]
|
|
|
|
num_image_tokens: Annotated[torch.Tensor, TensorShape("ni")]
|
|
|
|
|
|
class Molmo2VideoInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- nc: The total number of frames (dynamic)
|
|
- np: The total number of patches per frame
|
|
- cps: Number of channels * patch_size * patch_size
|
|
- npp: Number of pooled patches (dynamic)
|
|
- pp: pooling_size * pooling_size
|
|
- nv: Number of videos
|
|
- nt: Number of video tokens (dynamic)
|
|
"""
|
|
|
|
pixel_values_videos: Annotated[torch.Tensor, TensorShape("nc", "np", "cps")]
|
|
|
|
token_pooling: Annotated[torch.Tensor, TensorShape("npp", "pp")]
|
|
"""
|
|
An index tensor that maps image features to their corresponding
|
|
patch tokens before pooling.
|
|
"""
|
|
|
|
num_pooled_patches: Annotated[torch.Tensor, TensorShape("nv")]
|
|
|
|
video_tokens: Annotated[torch.BoolTensor, TensorShape("nt")]
|
|
|
|
num_video_tokens: Annotated[torch.Tensor, TensorShape("nv")]
|
|
|
|
|
|
@dataclass
|
|
class VitConfig:
|
|
"""Config for a vision transformer"""
|
|
|
|
hidden_size: int = 1152
|
|
intermediate_size: int = 4304
|
|
num_hidden_layers: int = 27
|
|
num_attention_heads: int = 16
|
|
num_key_value_heads: int = 16
|
|
head_dim: int = 72
|
|
hidden_act: str = "gelu_pytorch_tanh"
|
|
layer_norm_eps: float = 1e-6
|
|
image_default_input_size: tuple[int, int] = (378, 378)
|
|
image_patch_size: int = 14
|
|
image_num_pos: int = 577
|
|
|
|
def __post_init__(self):
|
|
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
|
|
|
@property
|
|
def image_num_patch(self):
|
|
h, w = self.image_default_input_size
|
|
return h // self.image_patch_size, w // self.image_patch_size
|
|
|
|
|
|
@dataclass
|
|
class AdapterConfig:
|
|
"""Config for a vit-llm adapter"""
|
|
|
|
vit_layers: tuple[int, int] = (-3, -9)
|
|
pooling_attention_mask: bool = False
|
|
hidden_size: int = 1152
|
|
num_attention_heads: int = 16
|
|
num_key_value_heads: int = 16
|
|
head_dim: int = 72
|
|
hidden_act: str = "silu"
|
|
intermediate_size: int = 18944
|
|
text_hidden_size: int = 3584
|
|
|
|
|
|
@dataclass
|
|
class TextConfig:
|
|
"""Configuration for a text model transformer"""
|
|
|
|
hidden_size: int = 3584
|
|
"""
|
|
The hidden size of the model.
|
|
"""
|
|
|
|
num_attention_heads: int = 28
|
|
"""
|
|
The number of self-attention heads.
|
|
"""
|
|
|
|
num_key_value_heads: int = 4
|
|
"""
|
|
The number of heads to use for keys and values.
|
|
"""
|
|
|
|
head_dim: int = 128
|
|
"""
|
|
The head dimensionality for the attention mechanism.
|
|
"""
|
|
|
|
vocab_size: int = 152064
|
|
"""Vocabulary size of the model."""
|
|
|
|
additional_vocab_size: int = 128
|
|
"""Number of additional tokens to have the input embeddings for"""
|
|
|
|
qkv_bias: bool = True
|
|
"""
|
|
Do QKV projection a bias
|
|
"""
|
|
|
|
num_hidden_layers: int = 48
|
|
"""
|
|
The number of layers/blocks.
|
|
"""
|
|
|
|
intermediate_size: int = 18944
|
|
"""
|
|
The hidden size for the MLP.
|
|
"""
|
|
|
|
hidden_act: str = "silu"
|
|
"""
|
|
The activation function to use within the MLP layers.
|
|
"""
|
|
|
|
max_position_embeddings: int = 4096
|
|
"""
|
|
Max positional embeddings to use in RoPE cache
|
|
"""
|
|
|
|
rope_theta: float = 1000000.0
|
|
"""
|
|
RoPE theta parameter.
|
|
"""
|
|
|
|
use_qk_norm: bool = False
|
|
"""
|
|
Apply layer norm to the keys and queries within the attention mechanism.
|
|
This can help stabilize training.
|
|
"""
|
|
|
|
qk_norm_type: str = "olmo"
|
|
"""
|
|
The type of layer norm to use for the keys and queries.
|
|
Can be "olmo" or "qwen3".
|
|
"""
|
|
|
|
layer_norm_eps: float = 1e-6
|
|
"""
|
|
epsilon for layer norms
|
|
"""
|
|
|
|
norm_after: bool = False
|
|
"""Apply layer norm before and after the attention and MLP blocks."""
|
|
|
|
rope_scaling_layers: tuple[int, ...] | None = None
|
|
"""
|
|
RoPE scaling layers.
|
|
"""
|
|
|
|
|
|
class ViTMLP(nn.Module):
|
|
"""MLP used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
hidden_dim: int,
|
|
hidden_act: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.w1 = ColumnParallelLinear(
|
|
dim,
|
|
hidden_dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.w1",
|
|
)
|
|
# Activation function.
|
|
self.act = get_act_fn(hidden_act)
|
|
self.w2 = RowParallelLinear(
|
|
hidden_dim,
|
|
dim,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.w2",
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x, _ = self.w1(x)
|
|
x = self.act(x)
|
|
x, _ = self.w2(x)
|
|
return x
|
|
|
|
|
|
class ViTMultiHeadDotProductAttention(nn.Module):
|
|
"""Multi-head attention used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_key_value_heads: int,
|
|
head_dim: int,
|
|
use_bias: bool = True,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.total_num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % tp_size == 0
|
|
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.head_dim = head_dim
|
|
|
|
assert self.head_dim == self.hidden_size // self.total_num_heads
|
|
|
|
self.total_num_kv_heads = num_key_value_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
|
|
self.merged_qkv = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.merged_qkv",
|
|
)
|
|
self.wo = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.wo",
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.attn = MMEncoderAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scale,
|
|
num_kv_heads=self.num_kv_heads,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
qkv, _ = self.merged_qkv(inputs)
|
|
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
|
|
output = self.attn(xq, xk, xv)
|
|
|
|
output, _ = self.wo(output)
|
|
|
|
return output
|
|
|
|
|
|
class Molmo2VisionBlock(nn.Module):
|
|
"""Residual attention block used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VitConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.attention = ViTMultiHeadDotProductAttention(
|
|
hidden_size=config.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_key_value_heads=config.num_key_value_heads,
|
|
head_dim=config.head_dim,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attention",
|
|
)
|
|
self.feed_forward = ViTMLP(
|
|
dim=config.hidden_size,
|
|
hidden_dim=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward",
|
|
)
|
|
self.attention_norm = nn.LayerNorm(
|
|
config.hidden_size,
|
|
eps=config.layer_norm_eps,
|
|
)
|
|
self.ffn_norm = nn.LayerNorm(
|
|
config.hidden_size,
|
|
eps=config.layer_norm_eps,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = x + self.attention(self.attention_norm(x))
|
|
x = x + self.feed_forward(self.ffn_norm(x))
|
|
return x
|
|
|
|
|
|
class Molmo2VisionBlockCollection(nn.Module):
|
|
"""Collection of residual attention blocks used in Vision Transformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VitConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.resblocks = nn.ModuleList(
|
|
[
|
|
Molmo2VisionBlock(
|
|
config,
|
|
quant_config,
|
|
prefix=f"{prefix}.resblocks.{layer_idx}",
|
|
)
|
|
for layer_idx in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
|
hidden_states = []
|
|
for r in self.resblocks:
|
|
x = r(x)
|
|
hidden_states.append(x)
|
|
return hidden_states
|
|
|
|
|
|
class Molmo2VisionTransformer(nn.Module):
|
|
"""Vision Transformer used in Vision Backbone."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: VitConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
scale = config.hidden_size**-0.5
|
|
self.num_prefix_tokens: int = 0 # no class embeddings
|
|
self.patch_num = config.image_num_patch
|
|
self.positional_embedding = nn.Parameter(
|
|
torch.randn(config.image_num_pos, config.hidden_size) * scale,
|
|
)
|
|
image_patch_size = config.image_patch_size
|
|
self.patch_embedding = nn.Linear(
|
|
image_patch_size * image_patch_size * 3,
|
|
config.hidden_size,
|
|
bias=True,
|
|
)
|
|
self.transformer = Molmo2VisionBlockCollection(
|
|
config,
|
|
quant_config,
|
|
prefix=f"{prefix}.transformer",
|
|
)
|
|
|
|
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
|
pos_emb = self.positional_embedding
|
|
|
|
pos_emb = pos_emb.reshape(
|
|
(
|
|
int(math.sqrt(pos_emb.shape[0])),
|
|
int(math.sqrt(pos_emb.shape[0])),
|
|
pos_emb.shape[1],
|
|
)
|
|
)
|
|
|
|
(patch_num_0, patch_num_1) = patch_num
|
|
|
|
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
|
|
# from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
|
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
|
|
pos_emb = F.interpolate(
|
|
pos_emb,
|
|
size=(patch_num_0, patch_num_1),
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
antialias=True,
|
|
)
|
|
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
|
|
|
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
|
|
x = x + pos_emb[None, :, :].to(x.dtype)
|
|
return x
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
patch_num: int | None = None,
|
|
) -> list[torch.Tensor]:
|
|
"""
|
|
: param x: (batch_size, num_patch, n_pixels)
|
|
"""
|
|
if patch_num is None:
|
|
patch_num = self.patch_num
|
|
|
|
x = self.patch_embedding(x)
|
|
|
|
x = self.add_pos_emb(x, patch_num)
|
|
|
|
hidden_states = self.transformer(x)
|
|
return hidden_states
|
|
|
|
|
|
class ImagePoolingAttention(nn.Module):
|
|
"""Multi-head attention used for image pooling"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_key_value_heads: int,
|
|
head_dim: int,
|
|
use_bias: bool = True,
|
|
use_pytorch_sdpa: bool = False,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.input_dim = input_dim
|
|
self.hidden_size = hidden_size
|
|
self.total_num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % tp_size == 0
|
|
|
|
self.num_heads = self.total_num_heads // tp_size
|
|
self.head_dim = head_dim
|
|
|
|
assert self.head_dim == self.hidden_size // self.total_num_heads
|
|
|
|
self.total_num_kv_heads = num_key_value_heads
|
|
if self.total_num_kv_heads >= tp_size:
|
|
assert self.total_num_kv_heads % tp_size == 0
|
|
else:
|
|
assert tp_size % self.total_num_kv_heads == 0
|
|
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
|
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.input_dim,
|
|
self.total_num_heads * self.head_dim,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj",
|
|
)
|
|
self.merged_kv = MergedColumnParallelLinear(
|
|
self.input_dim,
|
|
[self.total_num_kv_heads * self.head_dim] * 2,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.merged_kv",
|
|
)
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=use_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
self.scale = self.head_dim**-0.5
|
|
self.use_pytorch_sdpa = use_pytorch_sdpa
|
|
if use_pytorch_sdpa:
|
|
self.attn = None
|
|
else:
|
|
self.attn = MMEncoderAttention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scale,
|
|
num_kv_heads=self.num_kv_heads,
|
|
)
|
|
|
|
def forward_sdpa(
|
|
self,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
attn_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
bsz, q_len, _ = query.size()
|
|
kv_len = key.size(1)
|
|
|
|
query = query.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
|
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
|
|
|
if self.num_heads != self.num_kv_heads:
|
|
key = torch.repeat_interleave(
|
|
key,
|
|
self.num_heads // self.num_kv_heads,
|
|
dim=2,
|
|
)
|
|
value = torch.repeat_interleave(
|
|
value,
|
|
self.num_heads // self.num_kv_heads,
|
|
dim=2,
|
|
)
|
|
|
|
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
|
|
|
out = F.scaled_dot_product_attention(
|
|
query,
|
|
key,
|
|
value,
|
|
attn_mask=attn_mask,
|
|
is_causal=False,
|
|
).transpose(1, 2)
|
|
|
|
return out.reshape(bsz, q_len, -1)
|
|
|
|
def forward(
|
|
self,
|
|
inputs_q: torch.Tensor,
|
|
inputs_kv: torch.Tensor,
|
|
attn_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
xq, _ = self.q_proj(inputs_q)
|
|
kv, _ = self.merged_kv(inputs_kv)
|
|
xk, xv = kv.split([self.kv_size, self.kv_size], dim=-1)
|
|
|
|
if self.use_pytorch_sdpa:
|
|
output = self.forward_sdpa(xq, xk, xv, attn_mask)
|
|
else:
|
|
output = self.attn(xq, xk, xv)
|
|
|
|
output, _ = self.o_proj(output)
|
|
|
|
return output
|
|
|
|
|
|
class ImageProjectorMLP(nn.Module):
|
|
"""MLP used for the image projector"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
hidden_dim: int,
|
|
output_dim: int,
|
|
hidden_act: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.merged_linear = MergedColumnParallelLinear(
|
|
input_dim,
|
|
[hidden_dim] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.merged_linear",
|
|
)
|
|
# Activation function.
|
|
assert hidden_act == "silu"
|
|
self.act_fn = SiluAndMul()
|
|
|
|
# Feed-forward output projection.
|
|
self.down_proj = RowParallelLinear(
|
|
hidden_dim,
|
|
output_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x, _ = self.merged_linear(x)
|
|
x = self.act_fn(x)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class Molmo2VisionBackbone(nn.Module, SupportsQuant):
|
|
packed_modules_mapping = {
|
|
"merged_qkv": ["wq", "wk", "wv"], # vision backbone
|
|
"merged_kv": ["k_proj", "v_proj"], # image_pooling_2d
|
|
"merged_linear": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
vit_config: VitConfig,
|
|
adapter_config: AdapterConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.vit_config = vit_config
|
|
self.adapter_config = adapter_config
|
|
|
|
self.vit_layers = []
|
|
for layer in adapter_config.vit_layers:
|
|
if layer >= 0:
|
|
self.vit_layers.append(layer)
|
|
else:
|
|
self.vit_layers.append(layer + vit_config.num_hidden_layers)
|
|
|
|
last_layer_needed = max(self.vit_layers) + 1
|
|
if last_layer_needed < vit_config.num_hidden_layers:
|
|
vit_config.num_hidden_layers = last_layer_needed
|
|
self.image_vit = Molmo2VisionTransformer(
|
|
vit_config,
|
|
quant_config,
|
|
prefix=f"{prefix}.image_vit",
|
|
)
|
|
|
|
self.num_prefix_tokens: int = self.image_vit.num_prefix_tokens
|
|
|
|
pool_dim = vit_config.hidden_size * len(adapter_config.vit_layers)
|
|
self.image_pooling_2d = ImagePoolingAttention(
|
|
input_dim=pool_dim,
|
|
hidden_size=adapter_config.hidden_size,
|
|
num_heads=adapter_config.num_attention_heads,
|
|
num_key_value_heads=adapter_config.num_key_value_heads,
|
|
head_dim=adapter_config.head_dim,
|
|
use_pytorch_sdpa=adapter_config.pooling_attention_mask,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.image_pooling_2d",
|
|
)
|
|
self.image_projector = ImageProjectorMLP(
|
|
input_dim=adapter_config.hidden_size,
|
|
hidden_dim=adapter_config.intermediate_size,
|
|
output_dim=adapter_config.text_hidden_size,
|
|
hidden_act=adapter_config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.image_projector",
|
|
)
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return self.image_vit.patch_embedding.weight.dtype
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self.image_vit.patch_embedding.weight.device
|
|
|
|
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
: param images: (batch_size, num_crops, num_patch, n_pixels)
|
|
"""
|
|
B, T, N, D = images.shape
|
|
images = images.view(B * T, N, D)
|
|
image_features = self.image_vit(images)
|
|
|
|
features = []
|
|
for layer in self.vit_layers:
|
|
features.append(image_features[layer])
|
|
image_features = torch.cat(features, dim=-1)
|
|
|
|
if self.num_prefix_tokens > 0:
|
|
image_features = image_features[:, 1:]
|
|
image_features = image_features.view(B, T, N, -1)
|
|
return image_features
|
|
|
|
def forward(
|
|
self,
|
|
images: torch.Tensor,
|
|
token_pooling: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# image_features shape:
|
|
# (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
|
|
batch_size, num_image = images.shape[:2]
|
|
images = images.to(device=self.device, dtype=self.dtype)
|
|
image_features = self.encode_image(images)
|
|
|
|
dim = image_features.shape[-1]
|
|
valid = token_pooling >= 0
|
|
valid_token = torch.any(valid, -1)
|
|
|
|
# Use `token_pooling` to arange the features for image pooling
|
|
batch_idx = torch.arange(
|
|
token_pooling.shape[0],
|
|
dtype=torch.long,
|
|
device=token_pooling.device,
|
|
)
|
|
batch_idx = torch.tile(
|
|
batch_idx.view(batch_size, 1, 1),
|
|
[1, token_pooling.shape[1], token_pooling.shape[2]],
|
|
)
|
|
|
|
# Now [batch, num_features, num_pooled_patches, dim]
|
|
to_pool = image_features.reshape(batch_size, -1, dim)[
|
|
batch_idx, torch.clip(token_pooling, 0)
|
|
]
|
|
to_pool = to_pool * valid.to(self.dtype)[:, :, :, None]
|
|
to_pool = to_pool.reshape([-1, token_pooling.shape[-1], dim])
|
|
if self.adapter_config.pooling_attention_mask:
|
|
attn_mask = valid.reshape([-1, 1, 1, valid.shape[-1]])
|
|
denom = valid.view(-1, to_pool.shape[-2]).float().sum(-1)
|
|
denom = torch.where(denom == 0, 1, denom)
|
|
query = to_pool.sum(-2, keepdim=True) / denom[:, None, None].to(
|
|
to_pool.dtype
|
|
)
|
|
else:
|
|
attn_mask = None
|
|
query = to_pool.mean(-2, keepdim=True)
|
|
|
|
pooled_features = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
|
|
pooled_features = pooled_features.reshape(
|
|
[batch_size, -1, pooled_features.shape[-1]]
|
|
)
|
|
|
|
# MLP layer to map the feature.
|
|
pooled_features = self.image_projector(pooled_features)
|
|
return pooled_features.view(-1, pooled_features.shape[-1])[
|
|
valid_token.flatten()
|
|
]
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("merged_qkv", "wq", "q"),
|
|
("merged_qkv", "wk", "k"),
|
|
("merged_qkv", "wv", "v"),
|
|
("merged_kv", "k_proj", 0),
|
|
("merged_kv", "v_proj", 1),
|
|
("merged_linear", "gate_proj", 0),
|
|
("merged_linear", "up_proj", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class Molmo2Attention(nn.Module):
|
|
"""Molmo2's LLM Attention."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: TextConfig,
|
|
rope_parameters: dict[str, Any],
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.total_num_heads = config.num_attention_heads
|
|
|
|
assert self.hidden_size % self.total_num_heads == 0
|
|
assert self.total_num_heads % self.tp_size == 0
|
|
|
|
self.num_heads = self.total_num_heads // self.tp_size
|
|
self.total_num_kv_heads = config.num_key_value_heads
|
|
if self.total_num_kv_heads >= self.tp_size:
|
|
assert self.total_num_kv_heads % self.tp_size == 0
|
|
else:
|
|
assert self.tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
|
self.head_dim = config.head_dim
|
|
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
|
|
# Attention input projection. Projects x -> (q, k, v)
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=config.qkv_bias,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
self.tp_rank: int | None = None
|
|
self.k_norm: nn.Module | None = None
|
|
self.q_norm: nn.Module | None = None
|
|
self.qk_norm_type: str | None = None
|
|
if config.use_qk_norm:
|
|
k_norm_size = (
|
|
self.head_dim
|
|
if config.qk_norm_type == "qwen3"
|
|
else self.total_num_kv_heads * self.head_dim
|
|
)
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.k_norm = RMSNorm(k_norm_size, eps=config.layer_norm_eps)
|
|
q_norm_size = (
|
|
self.head_dim
|
|
if config.qk_norm_type == "qwen3"
|
|
else self.total_num_heads * self.head_dim
|
|
)
|
|
self.q_norm = RMSNorm(q_norm_size, eps=config.layer_norm_eps)
|
|
self.qk_norm_type = config.qk_norm_type
|
|
# Rotary embeddings. Rope scaling is only applied on full attention layers.
|
|
layer_idx = extract_layer_index(prefix)
|
|
if (
|
|
config.rope_scaling_layers is not None
|
|
and layer_idx not in config.rope_scaling_layers
|
|
):
|
|
rope_theta = rope_parameters["rope_theta"]
|
|
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
max_position=self.max_position_embeddings,
|
|
rope_parameters=rope_parameters,
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
# Attention output projection.
|
|
self.o_proj = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def _apply_qk_norm(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if self.tp_size > 1:
|
|
q = tensor_model_parallel_all_gather(q.contiguous())
|
|
k = tensor_model_parallel_all_gather(k.contiguous())
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
if self.tp_size > 1:
|
|
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
|
|
q = splitter(q)[self.tp_rank]
|
|
k = splitter(k)[self.tp_rank]
|
|
return q, k
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
**kwargs: object,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
if (
|
|
self.q_norm is not None
|
|
and self.k_norm is not None
|
|
and self.qk_norm_type == "olmo"
|
|
):
|
|
q, k = self._apply_qk_norm(q, k)
|
|
elif self.q_norm is not None and self.k_norm is not None:
|
|
q_by_head = q.view(
|
|
*q.shape[:-1],
|
|
q.shape[-1] // self.head_dim,
|
|
self.head_dim,
|
|
)
|
|
q_by_head = self.q_norm(q_by_head)
|
|
q = q_by_head.view(q.shape)
|
|
k_by_head = k.view(
|
|
*k.shape[:-1],
|
|
k.shape[-1] // self.head_dim,
|
|
self.head_dim,
|
|
)
|
|
k_by_head = self.k_norm(k_by_head)
|
|
k = k_by_head.view(k.shape)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class LanguageModelMLP(nn.Module):
|
|
"""Molmo2's LLM mlp."""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: QuantizationConfig | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.up_gate_proj = MergedColumnParallelLinear(
|
|
input_dim,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
# Activation function.
|
|
assert hidden_act == "silu"
|
|
self.act_fn = MulAndSilu()
|
|
# Feed-forward output projection.
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
input_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
up_gate, _ = self.up_gate_proj(x)
|
|
x = self.act_fn(up_gate)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class Molmo2DecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: TextConfig,
|
|
rope_parameters: dict[str, Any],
|
|
cache_config: CacheConfig | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
# Attention block.
|
|
self.self_attn = Molmo2Attention(
|
|
config,
|
|
rope_parameters,
|
|
cache_config,
|
|
quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
|
|
# MLP block.
|
|
self.mlp = LanguageModelMLP(
|
|
config.hidden_size,
|
|
config.intermediate_size,
|
|
config.hidden_act,
|
|
quant_config,
|
|
)
|
|
|
|
# LayerNorm
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(
|
|
config.hidden_size,
|
|
eps=config.layer_norm_eps,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
**kwargs: object,
|
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
class Molmo2DecoderNormAfterLayer(Molmo2DecoderLayer):
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor | None,
|
|
**kwargs: object,
|
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
|
# Self Attention
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
residual = hidden_states
|
|
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = hidden_states + residual
|
|
residual = None
|
|
return hidden_states, residual
|
|
|
|
|
|
@support_torch_compile
|
|
class Molmo2TextModel(nn.Module, SupportsQuant):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
|
|
if hasattr(config, "text_config"):
|
|
hf_text_config = config.text_config
|
|
else:
|
|
hf_text_config = config.llm_config
|
|
|
|
kwargs = {}
|
|
for field in fields(TextConfig):
|
|
kwargs[field.name] = getattr(hf_text_config, field.name)
|
|
text_config = TextConfig(**kwargs)
|
|
|
|
self.embedding_size = text_config.vocab_size
|
|
self.embedding_size += text_config.additional_vocab_size or 0
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
self.embedding_size,
|
|
text_config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
decoder_layer = (
|
|
Molmo2DecoderNormAfterLayer
|
|
if text_config.norm_after
|
|
else Molmo2DecoderLayer
|
|
)
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
text_config.num_hidden_layers,
|
|
lambda prefix: decoder_layer(
|
|
text_config,
|
|
hf_text_config.rope_parameters,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
|
|
self.norm = RMSNorm(text_config.hidden_size, eps=text_config.layer_norm_eps)
|
|
|
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"],
|
|
text_config.hidden_size,
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embed_tokens(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
# Apply blocks one-by-one.
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
hidden_states, residual = layer(
|
|
positions,
|
|
hidden_states,
|
|
residual,
|
|
**kwargs,
|
|
)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors(
|
|
{"hidden_states": hidden_states, "residual": residual}
|
|
)
|
|
if residual is not None:
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
else:
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
def get_patches_grid_size(
|
|
*,
|
|
image_h: int,
|
|
image_w: int,
|
|
patch_size: int,
|
|
pool_h: int,
|
|
pool_w: int,
|
|
) -> tuple[int, int]:
|
|
patch_h = image_h // patch_size
|
|
patch_w = image_w // patch_size
|
|
h_pad = round_down(patch_h + pool_h - 1, pool_h) - patch_h
|
|
w_pad = round_down(patch_w + pool_w - 1, pool_w) - patch_w
|
|
nrows = (patch_h + h_pad) // pool_h
|
|
ncols = (patch_w + w_pad) // pool_w
|
|
|
|
return nrows, ncols
|
|
|
|
|
|
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
|
|
tilings = [
|
|
(i, j)
|
|
for i in range(1, max_num + 1)
|
|
for j in range(1, max_num + 1)
|
|
if i * j <= max_num
|
|
]
|
|
return sorted(tilings, key=lambda x: (x[0] * x[1], x[0]))
|
|
|
|
|
|
def select_tiling(
|
|
*,
|
|
height: int,
|
|
width: int,
|
|
patch_size: int,
|
|
max_num_patches: int,
|
|
):
|
|
tilings = get_candidate_tilings(max_num_patches)
|
|
candidate_tilings = np.array(tilings, dtype=np.int32)
|
|
candidate_resolutions = candidate_tilings * patch_size
|
|
|
|
original_size = np.array([height, width], dtype=np.float32)
|
|
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
|
|
required_scale = required_scale_d.min(axis=-1, keepdims=True)
|
|
|
|
if (required_scale < 1).all():
|
|
ix = required_scale.argmax()
|
|
else:
|
|
ix = np.where(required_scale < 1.0, 10e9, required_scale).argmin()
|
|
|
|
return candidate_tilings[ix]
|
|
|
|
|
|
def get_image_size(image: ImageInput) -> ImageSize:
|
|
if isinstance(image, Image):
|
|
return ImageSize(*image.size)
|
|
elif isinstance(image, (np.ndarray, torch.Tensor)):
|
|
assert image.ndim == 3
|
|
h, w, c = image.shape
|
|
assert c in [1, 3]
|
|
return ImageSize(w, h)
|
|
else:
|
|
raise ValueError(f"Unknown image type: {type(image)}")
|
|
|
|
|
|
def exif_tranpose(
|
|
images: ImageInput | None,
|
|
) -> ImageInput | None:
|
|
if images is None:
|
|
return None
|
|
if images is not None and isinstance(images, (list, tuple)):
|
|
images = [
|
|
exif_tranpose(img) if isinstance(img, Image) else img for img in images
|
|
]
|
|
elif images is not None and isinstance(images, Image):
|
|
images = ImageOps.exif_transpose(images)
|
|
return images
|
|
|
|
|
|
def build_flat_image_bool_length(
|
|
image_grids: torch.LongTensor,
|
|
image_patch_id: int,
|
|
low_res_image_start_id: int,
|
|
image_start_id: int,
|
|
image_col_id: int,
|
|
image_end_id: int,
|
|
) -> tuple[torch.LongTensor, torch.LongTensor]:
|
|
device = image_grids.device
|
|
B = image_grids.shape[0]
|
|
|
|
resized_h = image_grids[:, 0]
|
|
resized_w = image_grids[:, 1]
|
|
h = image_grids[:, 2]
|
|
w = image_grids[:, 3]
|
|
|
|
lengths = resized_h * resized_w + h * (w + 1) + 4 # [B]
|
|
total_len = int(lengths.sum().item())
|
|
|
|
flat = torch.empty(total_len, dtype=torch.long, device=device)
|
|
|
|
offset = 0
|
|
for i in range(B):
|
|
resized_h_i, resized_w_i, h_i, w_i = image_grids[i].tolist()
|
|
L_i = int(lengths[i].item())
|
|
|
|
num_low_res_patches = resized_h_i * resized_w_i
|
|
|
|
idx = offset
|
|
|
|
flat[idx] = low_res_image_start_id
|
|
idx += 1
|
|
|
|
if num_low_res_patches > 0:
|
|
flat[idx : idx + num_low_res_patches] = image_patch_id
|
|
idx += num_low_res_patches
|
|
|
|
flat[idx] = image_end_id
|
|
idx += 1
|
|
|
|
flat[idx] = image_start_id
|
|
idx += 1
|
|
|
|
block_len = w_i + 1
|
|
if block_len > 0 and h_i > 0:
|
|
line = torch.empty(block_len, dtype=torch.long, device=device)
|
|
if w_i > 0:
|
|
line[:w_i] = image_patch_id
|
|
line[w_i] = image_col_id
|
|
|
|
block = line.repeat(h_i)
|
|
flat[idx : idx + h_i * block_len] = block
|
|
idx += h_i * block_len
|
|
|
|
flat[idx] = image_end_id
|
|
idx += 1
|
|
|
|
assert idx - offset == L_i
|
|
|
|
offset += L_i
|
|
|
|
return flat, lengths
|
|
|
|
|
|
def build_flat_video_bool_length(
|
|
video_grids: torch.LongTensor,
|
|
image_patch_id: int,
|
|
frame_start_id: int,
|
|
frame_end_id: int,
|
|
) -> tuple[torch.LongTensor, torch.LongTensor]:
|
|
device = video_grids.device
|
|
B = video_grids.shape[0]
|
|
|
|
t = video_grids[:, 0]
|
|
resized_h = video_grids[:, 1]
|
|
resized_w = video_grids[:, 2]
|
|
|
|
P = resized_h * resized_w
|
|
block_len = P + 2
|
|
lengths = t * block_len
|
|
|
|
total_len = int(lengths.sum().item())
|
|
flat = torch.empty(total_len, dtype=torch.long, device=device)
|
|
|
|
offset = 0
|
|
for i in range(B):
|
|
ti = int(t[i].item())
|
|
Pi = int(P[i].item())
|
|
Li = int(lengths[i].item())
|
|
|
|
block = torch.empty(Pi + 2, dtype=torch.long, device=device)
|
|
block[0] = frame_start_id
|
|
if Pi > 0:
|
|
block[1 : 1 + Pi] = image_patch_id
|
|
block[-1] = frame_end_id
|
|
|
|
seq = block.repeat(ti)
|
|
|
|
flat[offset : offset + Li] = seq
|
|
offset += Li
|
|
|
|
return flat, lengths
|
|
|
|
|
|
class Molmo2ProcessorWrapper:
|
|
"""
|
|
Wraps :class:`Molmo2Processor` so that it can be called directly.
|
|
"""
|
|
|
|
def __init__(self, processor: ProcessorMixin, hf_config: PretrainedConfig):
|
|
super().__init__()
|
|
|
|
self.processor = processor
|
|
self.hf_config = hf_config
|
|
|
|
@cached_property
|
|
def vocab(self) -> dict[str, int]:
|
|
return self.processor.tokenizer.vocab # type: ignore
|
|
|
|
@cached_property
|
|
def max_crops(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
max_crops = image_processor.max_crops
|
|
assert isinstance(max_crops, int)
|
|
|
|
return max_crops
|
|
|
|
@cached_property
|
|
def image_pooling_h(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
image_pooling_h = image_processor.pooling_size[0]
|
|
assert isinstance(image_pooling_h, int)
|
|
|
|
return image_pooling_h
|
|
|
|
@cached_property
|
|
def image_pooling_w(self) -> int:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
image_pooling_w = image_processor.pooling_size[1]
|
|
assert isinstance(image_pooling_w, int)
|
|
|
|
return image_pooling_w
|
|
|
|
@cached_property
|
|
def video_pooling_h(self) -> int:
|
|
video_processor = self.processor.video_processor # type: ignore
|
|
|
|
video_pooling_h = video_processor.pooling_size[0]
|
|
assert isinstance(video_pooling_h, int)
|
|
|
|
return video_pooling_h
|
|
|
|
@cached_property
|
|
def video_pooling_w(self) -> int:
|
|
video_processor = self.processor.video_processor # type: ignore
|
|
|
|
video_pooling_w = video_processor.pooling_size[1]
|
|
assert isinstance(video_pooling_w, int)
|
|
|
|
return video_pooling_w
|
|
|
|
@cached_property
|
|
def base_image_input_size(self) -> tuple[int, int]:
|
|
if getattr(self.processor, "image_processor", None) is not None:
|
|
processor = self.processor.image_processor # type: ignore
|
|
else:
|
|
processor = self.processor.video_processor # type: ignore
|
|
|
|
base_image_input_size = (processor.size["height"], processor.size["width"])
|
|
|
|
return base_image_input_size
|
|
|
|
@cached_property
|
|
def image_patch_size(self) -> int:
|
|
if getattr(self.processor, "image_processor", None) is not None:
|
|
processor = self.processor.image_processor # type: ignore
|
|
else:
|
|
processor = self.processor.video_processor # type: ignore
|
|
|
|
image_patch_size = processor.patch_size
|
|
assert isinstance(image_patch_size, int)
|
|
|
|
return image_patch_size
|
|
|
|
@cached_property
|
|
def overlap_margins(self) -> tuple[int, int]:
|
|
image_processor = self.processor.image_processor # type: ignore
|
|
|
|
left_margin, right_margin = image_processor.overlap_margins
|
|
assert isinstance(left_margin, int)
|
|
assert isinstance(right_margin, int)
|
|
|
|
return left_margin, right_margin
|
|
|
|
@cached_property
|
|
def bos_token(self) -> str:
|
|
return self.processor.tokenizer.bos_token or self.processor.tokenizer.eos_token
|
|
|
|
@cached_property
|
|
def image_patch_id(self) -> int:
|
|
return self.hf_config.image_patch_id
|
|
|
|
@cached_property
|
|
def im_col_id(self) -> int:
|
|
return self.hf_config.image_col_id
|
|
|
|
@cached_property
|
|
def im_start_id(self) -> int:
|
|
return self.hf_config.image_start_token_id
|
|
|
|
@cached_property
|
|
def im_end_id(self) -> int:
|
|
return self.hf_config.image_end_token_id
|
|
|
|
@cached_property
|
|
def low_res_im_start_id(self) -> int:
|
|
return self.hf_config.low_res_image_start_token_id
|
|
|
|
@cached_property
|
|
def frame_start_id(self) -> int:
|
|
return self.hf_config.frame_start_token_id
|
|
|
|
@cached_property
|
|
def frame_end_id(self) -> int:
|
|
return self.hf_config.frame_end_token_id
|
|
|
|
@cached_property
|
|
def im_low_res_id(self) -> int:
|
|
return self.hf_config.image_low_res_id
|
|
|
|
@cached_property
|
|
def image_placeholder_id(self) -> int:
|
|
return self.vocab[IMAGE_PROMPT]
|
|
|
|
@cached_property
|
|
def video_placeholder_id(self) -> int:
|
|
return self.vocab[VIDEO_PROMPT]
|
|
|
|
@cached_property
|
|
def image_token_ids(self) -> list[int]:
|
|
return [
|
|
self.image_patch_id,
|
|
self.im_col_id,
|
|
self.im_start_id,
|
|
self.low_res_im_start_id,
|
|
self.frame_start_id,
|
|
self.im_end_id,
|
|
self.frame_end_id,
|
|
self.im_low_res_id,
|
|
]
|
|
|
|
def select_tiling(
|
|
self,
|
|
*,
|
|
image_height: int,
|
|
image_width: int,
|
|
) -> tuple[int, int]:
|
|
max_crops = self.max_crops
|
|
left_margin, right_margin = self.overlap_margins
|
|
base_image_input_size = self.base_image_input_size
|
|
base_image_input_d = self.image_patch_size
|
|
|
|
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
|
crop_window_patches = crop_patches - (right_margin + left_margin)
|
|
crop_window_size = crop_window_patches * base_image_input_d
|
|
tiling_h, tiling_w = select_tiling(
|
|
height=image_height - total_margin_pixels,
|
|
width=image_width - total_margin_pixels,
|
|
patch_size=crop_window_size,
|
|
max_num_patches=max_crops,
|
|
)
|
|
|
|
return tiling_h, tiling_w
|
|
|
|
def get_base_grid_size(self, is_video: bool) -> tuple[int, int]:
|
|
base_image_input_size = self.base_image_input_size
|
|
|
|
return get_patches_grid_size(
|
|
image_h=base_image_input_size[0],
|
|
image_w=base_image_input_size[1],
|
|
patch_size=self.image_patch_size,
|
|
pool_h=self.video_pooling_h if is_video else self.image_pooling_h,
|
|
pool_w=self.video_pooling_w if is_video else self.image_pooling_w,
|
|
)
|
|
|
|
def get_patches_grid_size(
|
|
self,
|
|
*,
|
|
image_height: int,
|
|
image_width: int,
|
|
) -> tuple[int, int]:
|
|
left_margin, right_margin = self.overlap_margins
|
|
base_image_input_size = self.base_image_input_size
|
|
base_image_input_d = self.image_patch_size
|
|
|
|
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
|
crop_window_patches = crop_patches - (right_margin + left_margin)
|
|
crop_window_size = crop_window_patches * base_image_input_d
|
|
|
|
tiling_h, tiling_w = self.select_tiling(
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
)
|
|
|
|
h, w = [
|
|
tiling_h * crop_window_size + total_margin_pixels,
|
|
tiling_w * crop_window_size + total_margin_pixels,
|
|
]
|
|
nrows, ncols = get_patches_grid_size(
|
|
image_h=h,
|
|
image_w=w,
|
|
patch_size=base_image_input_d,
|
|
pool_h=self.image_pooling_h,
|
|
pool_w=self.image_pooling_w,
|
|
)
|
|
|
|
return nrows, ncols
|
|
|
|
def __call__(
|
|
self,
|
|
text: TextInput | list[TextInput] | None = None,
|
|
images: ImageInput | None = None,
|
|
videos: VideoInput | None = None,
|
|
return_tensors: str | TensorType = None,
|
|
**kwargs: object,
|
|
) -> BatchFeature:
|
|
inputs = [text]
|
|
images = exif_tranpose(images)
|
|
if getattr(self.processor, "image_processor", None) is not None:
|
|
inputs.append(images)
|
|
if getattr(self.processor, "video_processor", None) is not None:
|
|
inputs.append(videos)
|
|
outputs = self.processor( # type: ignore
|
|
*inputs,
|
|
return_tensors=return_tensors,
|
|
**kwargs,
|
|
)
|
|
|
|
# revert insert bos token
|
|
if outputs["input_ids"][0, 0] == self.vocab[self.bos_token]:
|
|
outputs["input_ids"] = outputs["input_ids"][:, 1:]
|
|
|
|
if images is None:
|
|
images = []
|
|
if not isinstance(images, list):
|
|
images = [images]
|
|
|
|
if videos is None:
|
|
videos = []
|
|
if not isinstance(videos, list):
|
|
videos = [videos]
|
|
|
|
assert len(videos) in {0, 1}, "At most one video is supported for Molmo2"
|
|
|
|
_attention_mask: torch.Tensor = outputs.pop("attention_mask")
|
|
_token_type_ids: torch.Tensor = outputs.pop("token_type_ids", None)
|
|
|
|
if len(images) > 0:
|
|
# For each image: tiling_h * tiling_w + global view
|
|
num_crops = []
|
|
for image in images:
|
|
image_size = get_image_size(image)
|
|
tiling = self.select_tiling(
|
|
image_height=image_size.height,
|
|
image_width=image_size.width,
|
|
)
|
|
num_crops.append(np.prod(tiling) + 1)
|
|
|
|
assert sum(num_crops) == len(outputs["pixel_values"])
|
|
assert sum(num_crops) == outputs["image_num_crops"].sum().item()
|
|
image_grids: torch.Tensor = outputs.pop("image_grids")
|
|
image_num_pooled_patches: torch.Tensor = image_grids[:, :2].prod(
|
|
dim=1
|
|
) + image_grids[:, 2:].prod(dim=1)
|
|
outputs["image_num_pooled_patches"] = image_num_pooled_patches
|
|
n_patches = outputs["pixel_values"].shape[1]
|
|
outputs["image_num_patches"] = outputs["image_num_crops"] * n_patches
|
|
image_tokens, num_image_tokens = build_flat_image_bool_length(
|
|
image_grids,
|
|
self.image_patch_id,
|
|
self.low_res_im_start_id,
|
|
self.im_start_id,
|
|
self.im_col_id,
|
|
self.im_end_id,
|
|
)
|
|
outputs["image_tokens"] = image_tokens
|
|
outputs["num_image_tokens"] = num_image_tokens
|
|
|
|
if len(videos) > 0:
|
|
video_grids: torch.Tensor = outputs.pop("video_grids")
|
|
assert video_grids[:, 0].sum() == len(outputs["pixel_values_videos"])
|
|
outputs["video_num_crops"] = video_grids[:, 0]
|
|
outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
|
|
n_patches = outputs["pixel_values_videos"].shape[1]
|
|
outputs["video_num_patches"] = outputs["video_num_crops"] * n_patches
|
|
video_tokens, num_video_tokens = build_flat_video_bool_length(
|
|
video_grids,
|
|
self.image_patch_id,
|
|
self.frame_start_id,
|
|
self.frame_end_id,
|
|
)
|
|
outputs["video_tokens"] = video_tokens
|
|
outputs["num_video_tokens"] = num_video_tokens
|
|
|
|
return BatchFeature(outputs)
|
|
|
|
|
|
def get_candidate_target_fps(
|
|
video_fps: int | float,
|
|
sampling_fps: int | float,
|
|
max_fps: int | float = _MAX_VIDEO_FPS,
|
|
) -> list[float]:
|
|
"""
|
|
Return the subset of `video_fps` factors that remain multiples
|
|
of `sampling_fps`.
|
|
|
|
Examples:
|
|
>>> get_candidate_target_fps(video_fps=6, sampling_fps=2)
|
|
[2, 6]
|
|
>>> get_candidate_target_fps(video_fps=5, sampling_fps=1)
|
|
[1, 5]
|
|
>>> get_candidate_target_fps(video_fps=2, sampling_fps=2)
|
|
[2]
|
|
>>> get_candidate_target_fps(video_fps=5, sampling_fps=2)
|
|
Traceback (most recent call last):
|
|
...
|
|
ValueError: sampling_fps=2 must divide video_fps=5 to produce
|
|
consistent frame steps.
|
|
"""
|
|
video_fps = int(video_fps)
|
|
sampling_fps = int(sampling_fps)
|
|
max_fps = int(max_fps)
|
|
|
|
if sampling_fps is None:
|
|
raise ValueError("sampling_fps must be provided")
|
|
if video_fps <= 0 or sampling_fps <= 0:
|
|
raise ValueError(
|
|
"video_fps and sampling_fps must be positive "
|
|
f"(got {video_fps}, {sampling_fps})"
|
|
)
|
|
if video_fps % sampling_fps != 0:
|
|
raise ValueError(
|
|
f"sampling_fps={sampling_fps} must divide video_fps={video_fps}."
|
|
)
|
|
|
|
candidates = []
|
|
for candidate in range(sampling_fps, video_fps + 1, sampling_fps):
|
|
if candidate > max_fps:
|
|
break
|
|
if video_fps % candidate == 0:
|
|
candidates.append(float(candidate))
|
|
|
|
return candidates
|
|
|
|
|
|
def get_target_fps(
|
|
video_fps: float,
|
|
max_frames: int,
|
|
total_frames: int,
|
|
frame_sample_mode: str,
|
|
candidate_target_fps: list[float],
|
|
) -> float | None:
|
|
"""
|
|
Get the target fps that best spans the video and has the most frames sampled
|
|
"""
|
|
num_frames_sampled = 0
|
|
selected_target_fps = None
|
|
for target_fps in candidate_target_fps:
|
|
step_size = max(int(video_fps / target_fps), 1)
|
|
num_frames_sampled_at_fps = int(total_frames / step_size)
|
|
if num_frames_sampled == 0:
|
|
if (
|
|
"uniform" in frame_sample_mode
|
|
and num_frames_sampled_at_fps > max_frames
|
|
):
|
|
break
|
|
selected_target_fps = target_fps
|
|
num_frames_sampled = num_frames_sampled_at_fps
|
|
|
|
else:
|
|
# the candidate sampling fps increases so frame count can't decrease
|
|
assert num_frames_sampled <= num_frames_sampled_at_fps
|
|
if num_frames_sampled_at_fps > max_frames:
|
|
# choose the sampling fps that spans the video
|
|
continue
|
|
|
|
elif num_frames_sampled_at_fps > num_frames_sampled:
|
|
# both are less than max_frames; choose the one with higher
|
|
# density of frames sampled
|
|
selected_target_fps = target_fps
|
|
num_frames_sampled = num_frames_sampled_at_fps
|
|
return selected_target_fps
|
|
|
|
|
|
def get_frame_times_and_chosen_fps(
|
|
selected_target_fps, total_frames, max_frames, video_fps
|
|
):
|
|
if selected_target_fps is None:
|
|
frame_indices = np.linspace(
|
|
0, total_frames, max_frames, endpoint=False, dtype=int
|
|
)
|
|
else:
|
|
step_size = max(int(video_fps / selected_target_fps), 1)
|
|
frame_indices = np.arange(0, total_frames, step_size)
|
|
if len(frame_indices) > max_frames:
|
|
frame_indices = frame_indices[:max_frames]
|
|
return selected_target_fps, frame_indices
|
|
|
|
|
|
class Molmo2ProcessingInfo(BaseProcessingInfo):
|
|
def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper:
|
|
processor = self.ctx.get_hf_processor(**kwargs)
|
|
hf_config = self.ctx.get_hf_config()
|
|
return Molmo2ProcessorWrapper(processor, hf_config)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"image": None, "video": 1}
|
|
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_height: int,
|
|
image_width: int,
|
|
processor: Molmo2ProcessorWrapper | None = None,
|
|
) -> int:
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
hf_processor = processor.processor # type: ignore
|
|
|
|
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
|
|
# start/end tokens + image patch token + col tokens
|
|
if hf_processor.use_single_crop_col_tokens is not None:
|
|
use_col_tokens = hf_processor.use_single_crop_col_tokens
|
|
else:
|
|
use_col_tokens = hf_processor.image_use_col_tokens
|
|
extra = 2 + resize_nrows * (resize_cols + int(use_col_tokens))
|
|
overlap_nrows, overlap_ncols = processor.get_patches_grid_size(
|
|
image_height=image_height,
|
|
image_width=image_width,
|
|
)
|
|
joint = 2 + overlap_nrows * (
|
|
overlap_ncols + int(hf_processor.image_use_col_tokens)
|
|
)
|
|
|
|
return extra + joint
|
|
|
|
def get_num_video_tokens(
|
|
self,
|
|
*,
|
|
num_frames: int,
|
|
processor: Molmo2ProcessorWrapper | None = None,
|
|
) -> int:
|
|
if processor is None:
|
|
processor = self.get_hf_processor()
|
|
|
|
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
|
|
# start/end tokens
|
|
extra = 2 + resize_nrows * (
|
|
resize_cols + int(processor.processor.video_use_col_tokens)
|
|
)
|
|
return num_frames * extra
|
|
|
|
def get_image_size_with_most_features(self) -> ImageSize:
|
|
processor = self.get_hf_processor()
|
|
|
|
left_margin, right_margin = processor.overlap_margins
|
|
base_image_input_size = processor.base_image_input_size
|
|
base_image_input_d = processor.image_patch_size
|
|
|
|
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
|
crop_patches = base_image_input_size[0] // base_image_input_d
|
|
crop_window_patches = crop_patches - (right_margin + left_margin)
|
|
crop_window_size = crop_window_patches * base_image_input_d
|
|
|
|
tilings = get_candidate_tilings(processor.max_crops)
|
|
largest_feature_size, largest_feature_pinpoint = 0, None
|
|
|
|
for hr, wr in tilings:
|
|
height = hr * crop_window_size + total_margin_pixels
|
|
width = wr * crop_window_size + total_margin_pixels
|
|
|
|
feat_size = self.get_num_image_tokens(
|
|
image_height=height, image_width=width, processor=processor
|
|
)
|
|
if feat_size > largest_feature_size:
|
|
largest_feature_size = feat_size
|
|
largest_feature_pinpoint = ImageSize(width=width, height=height)
|
|
|
|
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
|
raise ValueError("Cannot have a largest feature size of 0!")
|
|
|
|
return largest_feature_pinpoint
|
|
|
|
def _get_max_video_frames(self, max_tokens: int) -> int:
|
|
num_tokens_per_frame = self.get_num_video_tokens(num_frames=1)
|
|
max_frames = max_tokens // num_tokens_per_frame
|
|
return max(max_frames, 1)
|
|
|
|
def get_num_frames_with_most_features(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> int:
|
|
video_processor = self.get_hf_processor().processor.video_processor
|
|
num_frames = video_processor.num_frames
|
|
max_videos = mm_counts.get("video", 0)
|
|
max_total_frames = self._get_max_video_frames(seq_len)
|
|
max_frames_per_video = min(
|
|
max_total_frames // max(max_videos, 1),
|
|
num_frames,
|
|
)
|
|
return max(max_frames_per_video, 1)
|
|
|
|
def _sample_frames(
|
|
self,
|
|
total_num_frames: int,
|
|
video_fps: float,
|
|
duration: float,
|
|
frame_sample_mode: str,
|
|
num_frames: int,
|
|
max_fps: int,
|
|
sampling_fps: int,
|
|
) -> np.ndarray:
|
|
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
|
if total_num_frames <= 2:
|
|
indices = np.arange(total_num_frames).astype(int)
|
|
elif duration > (num_frames - 1) / max_fps: # -1 to include the last frame
|
|
# uniform fallback
|
|
indices = np.linspace(
|
|
0,
|
|
total_num_frames - 1,
|
|
num=min(num_frames, total_num_frames),
|
|
endpoint=True,
|
|
).astype(int)
|
|
else:
|
|
float_indices = np.arange(
|
|
0.0,
|
|
stop=total_num_frames - 1,
|
|
step=float(video_fps / max_fps),
|
|
)
|
|
if np.round(float_indices[-1]) != total_num_frames - 1:
|
|
float_indices = np.concatenate(
|
|
[float_indices, [total_num_frames - 1]], axis=0
|
|
)
|
|
indices = np.round(float_indices).astype(int)
|
|
assert indices[-1] < total_num_frames
|
|
assert len(float_indices) <= num_frames
|
|
elif frame_sample_mode == "uniform_last_frame":
|
|
indices = np.linspace(
|
|
0,
|
|
total_num_frames - 1,
|
|
num=min(num_frames, total_num_frames),
|
|
endpoint=True,
|
|
).astype(int)
|
|
elif frame_sample_mode == "fps":
|
|
candidate_target_fps = get_candidate_target_fps(video_fps, sampling_fps)
|
|
selected_target_fps = get_target_fps(
|
|
video_fps,
|
|
num_frames,
|
|
total_num_frames,
|
|
frame_sample_mode,
|
|
candidate_target_fps,
|
|
)
|
|
_, indices = get_frame_times_and_chosen_fps(
|
|
selected_target_fps,
|
|
total_num_frames,
|
|
num_frames,
|
|
video_fps,
|
|
)
|
|
else:
|
|
raise NotImplementedError(frame_sample_mode)
|
|
|
|
return indices
|
|
|
|
def _get_video_second_idx(
|
|
self,
|
|
metadata: dict[str, Any],
|
|
do_sample_frames: bool | None = None,
|
|
) -> list[float]:
|
|
video_processor = self.get_hf_processor().processor.video_processor
|
|
# metadata["fps"] refers to the true fps of the input video.
|
|
video_fps = metadata["fps"]
|
|
frames_indices = metadata.get("frames_indices")
|
|
if do_sample_frames is None:
|
|
do_sample_frames = metadata.get("do_sample_frames", False)
|
|
|
|
if do_sample_frames:
|
|
# Frame-based sampling is applied in HF video processor
|
|
total_num_frames = metadata["total_num_frames"]
|
|
duration = total_num_frames / video_fps
|
|
frame_sample_mode = video_processor.frame_sample_mode
|
|
num_frames = video_processor.num_frames
|
|
max_fps = video_processor.max_fps
|
|
sampling_fps = video_processor.sampling_fps
|
|
frames_indices = self._sample_frames(
|
|
total_num_frames,
|
|
video_fps,
|
|
duration,
|
|
frame_sample_mode,
|
|
num_frames,
|
|
max_fps,
|
|
sampling_fps,
|
|
)
|
|
else:
|
|
# Time-based sampling is done in vllm molmo2 video loader or molmo_utils
|
|
assert frames_indices is not None
|
|
timestamps = [frame_idx / video_fps for frame_idx in frames_indices]
|
|
return timestamps
|
|
|
|
|
|
class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_images = mm_counts.get("image", 0)
|
|
num_videos = mm_counts.get("video", 0)
|
|
|
|
image_placeholder_token = IMAGE_PROMPT
|
|
video_placeholder_token = VIDEO_PROMPT
|
|
|
|
if num_images == 1:
|
|
image_string = image_placeholder_token
|
|
else:
|
|
image_string = "".join(
|
|
[f"Image {i + 1}" + image_placeholder_token for i in range(num_images)]
|
|
)
|
|
|
|
return image_string + video_placeholder_token * num_videos
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
) -> MultiModalDataDict:
|
|
num_images = mm_counts.get("image", 0)
|
|
num_videos = mm_counts.get("video", 0)
|
|
|
|
dummy_images = []
|
|
dummy_videos = []
|
|
|
|
if num_images > 0:
|
|
target_width, target_height = self.info.get_image_size_with_most_features()
|
|
|
|
image_overrides = mm_options.get("image") if mm_options else None
|
|
|
|
dummy_images = self._get_dummy_images(
|
|
width=target_width,
|
|
height=target_height,
|
|
num_images=num_images,
|
|
overrides=image_overrides,
|
|
)
|
|
|
|
if num_videos > 0:
|
|
processor = self.info.get_hf_processor()
|
|
base_image_input_size = processor.base_image_input_size
|
|
target_num_frames = self.info.get_num_frames_with_most_features(
|
|
seq_len, mm_counts
|
|
)
|
|
|
|
video_overrides = mm_options.get("video") if mm_options else None
|
|
|
|
if video_overrides:
|
|
assert isinstance(video_overrides, VideoDummyOptions)
|
|
num_frames_override = video_overrides.num_frames
|
|
if num_frames_override:
|
|
if num_frames_override > target_num_frames:
|
|
logger.warning(
|
|
"video.num_frames override (%d) exceeds model's "
|
|
"maximum number of frames (%d), will be ignored",
|
|
num_frames_override,
|
|
target_num_frames,
|
|
)
|
|
if num_frames_override < 2:
|
|
logger.warning(
|
|
"video.num_frames override (%d) cannot be less "
|
|
"than 2, will be ignored",
|
|
num_frames_override,
|
|
)
|
|
target_num_frames = min(target_num_frames, num_frames_override)
|
|
|
|
dummy_videos = self._get_dummy_videos(
|
|
width=base_image_input_size[1],
|
|
height=base_image_input_size[0],
|
|
num_frames=target_num_frames,
|
|
num_videos=num_videos,
|
|
)
|
|
|
|
return {
|
|
"image": dummy_images,
|
|
"video": dummy_videos,
|
|
}
|
|
|
|
def _get_dummy_videos(
|
|
self,
|
|
*,
|
|
width: int,
|
|
height: int,
|
|
num_frames: int,
|
|
num_videos: int,
|
|
) -> list[VideoItem]:
|
|
video = np.full((num_frames, height, width, 3), 255, dtype=np.uint8)
|
|
video_items = []
|
|
for i in range(num_videos):
|
|
video_metadata = {
|
|
"fps": 2.0,
|
|
"duration": num_frames / 2.0,
|
|
"total_num_frames": num_frames,
|
|
"frames_indices": list(range(num_frames)),
|
|
"video_backend": "decord",
|
|
"do_sample_frames": False,
|
|
"height": height,
|
|
"width": width,
|
|
}
|
|
video_item = (video.copy(), video_metadata)
|
|
video_items.append(video_item)
|
|
return video_items
|
|
|
|
|
|
class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
|
def _apply_hf_processor_tokens_only(
|
|
self,
|
|
prompt_tokens: list[int],
|
|
) -> list[int]:
|
|
processor = self.info.get_hf_processor()
|
|
tokenizer = processor.processor.tokenizer
|
|
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
|
|
|
|
if len(prompt_tokens) > 0 and prompt_tokens[0] != bos_token_id:
|
|
# Prepend the bos token to the prompt tokens
|
|
prompt_tokens = [bos_token_id] + prompt_tokens
|
|
|
|
return prompt_tokens
|
|
|
|
def _get_data_parser(self) -> MultiModalDataParser:
|
|
return MultiModalDataParser(video_needs_metadata=True)
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
mm_data = dict(mm_data)
|
|
processor = self.info.get_hf_processor(**mm_kwargs)
|
|
|
|
if videos := mm_data.pop("videos", []):
|
|
pixel_values_videos_lst = []
|
|
video_token_pooling_lst = []
|
|
video_num_crops_lst = []
|
|
video_num_pooled_patches_lst = []
|
|
video_num_patches_lst = []
|
|
video_tokens_lst = []
|
|
num_video_tokens_lst = []
|
|
|
|
for item in videos:
|
|
video_array, metadata = item
|
|
|
|
# NOTE: metadata.frames_indices indicates
|
|
# the sampled frames indices of pre-sampled videos, which is
|
|
# used to calculate the timestamps. Make sure that
|
|
# do_sample_frames in mm_kwargs is false for presampled videos.
|
|
|
|
# NOTE: a copy of mm_kwargs is created to update do_sample_frames,
|
|
# otherwise mm_hash for the object will be incorrect.
|
|
video_mm_kwargs = dict(**mm_kwargs)
|
|
if "do_sample_frames" not in video_mm_kwargs:
|
|
# molmo_utils already has "do_sample_frames" in
|
|
# mm_kwargs, don't overwrite it.
|
|
video_mm_kwargs["do_sample_frames"] = metadata.get(
|
|
"do_sample_frames", False
|
|
)
|
|
|
|
metadata = VideoMetadata(
|
|
**{k: metadata[k] for k in metadata if k != "do_sample_frames"}
|
|
)
|
|
|
|
video_mm_data = dict()
|
|
video_mm_data["videos"] = [[video_array]]
|
|
video_mm_data["video_metadata"] = [[metadata]]
|
|
|
|
video_outputs = super()._call_hf_processor(
|
|
prompt=VIDEO_PROMPT,
|
|
mm_data=video_mm_data,
|
|
mm_kwargs=video_mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
input_ids = video_outputs.pop("input_ids")
|
|
video_string = processor.processor.tokenizer.batch_decode(input_ids)[0]
|
|
prompt = prompt.replace(
|
|
VIDEO_PROMPT,
|
|
video_string,
|
|
1,
|
|
)
|
|
|
|
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
|
|
video_token_pooling_lst.append(video_outputs["video_token_pooling"])
|
|
video_num_crops_lst.append(video_outputs["video_num_crops"])
|
|
video_num_pooled_patches_lst.append(
|
|
video_outputs["video_num_pooled_patches"]
|
|
)
|
|
video_num_patches_lst.append(video_outputs["video_num_patches"])
|
|
video_tokens_lst.append(video_outputs["video_tokens"])
|
|
num_video_tokens_lst.append(video_outputs["num_video_tokens"])
|
|
|
|
video_outputs = dict(
|
|
pixel_values_videos=torch.cat(pixel_values_videos_lst),
|
|
video_token_pooling=torch.cat(video_token_pooling_lst),
|
|
video_num_crops=torch.cat(video_num_crops_lst),
|
|
video_num_pooled_patches=torch.cat(video_num_pooled_patches_lst),
|
|
video_num_patches=torch.cat(video_num_patches_lst),
|
|
video_tokens=torch.cat(video_tokens_lst),
|
|
num_video_tokens=torch.cat(num_video_tokens_lst),
|
|
)
|
|
else:
|
|
video_outputs = dict()
|
|
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
|
|
bos_token_id = processor.vocab[processor.bos_token]
|
|
input_ids = processed_outputs["input_ids"]
|
|
# add bos token back to prompt start
|
|
if input_ids.numel() > 0 and input_ids[0, 0] != bos_token_id:
|
|
bos_token_id_tensor = torch.tensor(
|
|
[[bos_token_id]], device=input_ids.device, dtype=input_ids.dtype
|
|
)
|
|
processed_outputs["input_ids"] = torch.concat(
|
|
[bos_token_id_tensor, input_ids], dim=1
|
|
)
|
|
combined_outputs = dict(
|
|
processed_outputs,
|
|
**video_outputs,
|
|
)
|
|
return BatchFeature(combined_outputs)
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
image_num_crops = hf_inputs.get("image_num_crops", torch.empty(0))
|
|
image_num_pooled_patches = hf_inputs.get(
|
|
"image_num_pooled_patches", torch.empty(0)
|
|
)
|
|
video_num_crops = hf_inputs.get("video_num_crops", torch.empty(0))
|
|
video_num_pooled_patches = hf_inputs.get(
|
|
"video_num_pooled_patches", torch.empty(0)
|
|
)
|
|
num_image_tokens = hf_inputs.get("num_image_tokens", torch.empty(0))
|
|
num_video_tokens = hf_inputs.get("num_video_tokens", torch.empty(0))
|
|
|
|
return dict(
|
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", image_num_crops
|
|
),
|
|
image_token_pooling=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", image_num_pooled_patches
|
|
),
|
|
image_num_crops=MultiModalFieldConfig.batched("image"),
|
|
image_num_pooled_patches=MultiModalFieldConfig.batched("image"),
|
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
|
image_tokens=MultiModalFieldConfig.flat_from_sizes(
|
|
"image", num_image_tokens
|
|
),
|
|
num_image_tokens=MultiModalFieldConfig.batched("image"),
|
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
|
"video", video_num_crops
|
|
),
|
|
video_token_pooling=MultiModalFieldConfig.flat_from_sizes(
|
|
"video", video_num_pooled_patches
|
|
),
|
|
video_num_crops=MultiModalFieldConfig.batched("video"),
|
|
video_num_pooled_patches=MultiModalFieldConfig.batched("video"),
|
|
video_num_patches=MultiModalFieldConfig.batched("video"),
|
|
video_tokens=MultiModalFieldConfig.flat_from_sizes(
|
|
"video", num_video_tokens
|
|
),
|
|
num_video_tokens=MultiModalFieldConfig.batched("video"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
img_patch_id = processor.image_patch_id
|
|
img_col_id = processor.im_col_id
|
|
img_start_id = processor.im_start_id
|
|
img_end_id = processor.im_end_id
|
|
image_use_col_tokens = processor.processor.image_use_col_tokens
|
|
use_single_crop_col_tokens = processor.processor.use_single_crop_col_tokens
|
|
use_single_crop_start_token = processor.processor.use_single_crop_start_token
|
|
video_use_col_tokens = processor.processor.video_use_col_tokens
|
|
use_frame_special_tokens = processor.processor.use_frame_special_tokens
|
|
|
|
def get_image_replacement_molmo2(item_idx: int) -> list[int]:
|
|
images = mm_items.get_items("image", ImageProcessorItems)
|
|
image = images.get(item_idx)
|
|
image = exif_tranpose(image)
|
|
|
|
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
|
|
if use_single_crop_col_tokens is not None:
|
|
use_col_tokens = use_single_crop_col_tokens
|
|
else:
|
|
use_col_tokens = image_use_col_tokens
|
|
if use_single_crop_start_token:
|
|
start_id = processor.low_res_im_start_id
|
|
else:
|
|
start_id = img_start_id
|
|
extra_row = [img_patch_id] * resize_cols + [img_col_id] * int(
|
|
use_col_tokens
|
|
)
|
|
extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id]
|
|
|
|
image_size = get_image_size(image)
|
|
|
|
nrows, ncols = processor.get_patches_grid_size(
|
|
image_height=image_size.height,
|
|
image_width=image_size.width,
|
|
)
|
|
|
|
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
|
|
image_use_col_tokens
|
|
)
|
|
joint = [img_start_id] + joint_row * nrows + [img_end_id]
|
|
img_token_ids = extra_joint + joint
|
|
|
|
return PromptUpdateDetails.select_token_ids(
|
|
img_token_ids,
|
|
processor.image_token_ids,
|
|
)
|
|
|
|
def get_video_replacement_molmo2(item_idx: int) -> list[int]:
|
|
video, metadata = mm_items["video"][item_idx]
|
|
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
|
|
|
|
timestamps = self.info._get_video_second_idx(metadata, do_sample_frames)
|
|
nrows, ncols = processor.get_base_grid_size(is_video=True)
|
|
|
|
if use_frame_special_tokens:
|
|
start_id = processor.frame_start_id
|
|
end_id = processor.frame_end_id
|
|
else:
|
|
start_id = img_start_id
|
|
end_id = img_end_id
|
|
|
|
img_token_ids = []
|
|
|
|
for frame_idx, frame_time in enumerate(timestamps):
|
|
prev_space = " " if frame_idx > 0 else ""
|
|
frame_prefix = (
|
|
prev_space + f"{frame_time:.1f} "
|
|
) # explicit whitespace before/after image tokens
|
|
|
|
img_token_ids += processor.processor.tokenizer.encode(
|
|
frame_prefix,
|
|
add_special_tokens=False,
|
|
)
|
|
|
|
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
|
|
video_use_col_tokens
|
|
)
|
|
joint = [start_id] + nrows * joint_row + [end_id]
|
|
img_token_ids += joint
|
|
|
|
return PromptUpdateDetails.select_token_ids(
|
|
img_token_ids,
|
|
processor.image_token_ids,
|
|
)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality=modality,
|
|
target=[target],
|
|
replacement=replacement_fn,
|
|
)
|
|
for modality, target, replacement_fn in zip(
|
|
["image", "video"],
|
|
[processor.image_placeholder_id, processor.video_placeholder_id],
|
|
[get_image_replacement_molmo2, get_video_replacement_molmo2],
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
Molmo2MultiModalProcessor,
|
|
info=Molmo2ProcessingInfo,
|
|
dummy_inputs=Molmo2DummyInputsBuilder,
|
|
)
|
|
class Molmo2ForConditionalGeneration(
|
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
|
|
):
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_substr={
|
|
# vision backbone mapping
|
|
"image_pooling_2d.wq": "image_pooling_2d.q_proj",
|
|
"image_pooling_2d.wk": "image_pooling_2d.k_proj",
|
|
"image_pooling_2d.wv": "image_pooling_2d.v_proj",
|
|
"image_pooling_2d.wo": "image_pooling_2d.o_proj",
|
|
"image_projector.w1": "image_projector.gate_proj",
|
|
"image_projector.w3": "image_projector.up_proj",
|
|
"image_projector.w2": "image_projector.down_proj",
|
|
# language backbone mapping
|
|
"att_proj": "qkv_proj",
|
|
"attn_out": "o_proj",
|
|
"q_norm": "q_norm",
|
|
"k_norm": "k_norm",
|
|
"ff_proj": "up_gate_proj",
|
|
"ff_out": "down_proj",
|
|
"attn_norm": "input_layernorm",
|
|
"ff_norm": "post_attention_layernorm",
|
|
},
|
|
orig_to_new_prefix={
|
|
# vision backbone mapping
|
|
"model.vision_backbone.": "vision_backbone.",
|
|
# language backbone mapping
|
|
"model.transformer.blocks.": "model.layers.",
|
|
"model.transformer.ln_f.": "model.norm.",
|
|
},
|
|
)
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["qkv_proj"],
|
|
"up_gate_proj": ["up_gate_proj"], # language model
|
|
"merged_qkv": ["wq", "wk", "wv"], # vision backbone
|
|
"merged_kv": ["k_proj", "v_proj"], # image_pooling_2d
|
|
"merged_linear": ["gate_proj", "up_proj"], # image_projector
|
|
}
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("image"):
|
|
return IMAGE_PROMPT
|
|
if modality.startswith("video"):
|
|
return VIDEO_PROMPT
|
|
|
|
raise ValueError("Only image or video modality is supported")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
|
|
kwargs = {}
|
|
for field in fields(VitConfig):
|
|
kwargs[field.name] = getattr(config.vit_config, field.name)
|
|
vit_config = VitConfig(**kwargs)
|
|
|
|
kwargs = {}
|
|
for field in fields(AdapterConfig):
|
|
kwargs[field.name] = getattr(config.adapter_config, field.name)
|
|
adapter_config = AdapterConfig(**kwargs)
|
|
|
|
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
|
self.vision_backbone = Molmo2VisionBackbone(
|
|
vit_config,
|
|
adapter_config,
|
|
quant_config,
|
|
prefix=maybe_prefix(prefix, "vision_backbone"),
|
|
)
|
|
|
|
with self._mark_language_model(vllm_config):
|
|
self.model = Molmo2TextModel(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"),
|
|
)
|
|
|
|
self.img_patch_id = config.image_patch_id
|
|
|
|
if hasattr(config, "text_config"):
|
|
hf_text_config = config.text_config
|
|
else:
|
|
hf_text_config = config.llm_config
|
|
|
|
self.lm_head = ParallelLMHead(
|
|
hf_text_config.vocab_size,
|
|
hf_text_config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
self.logits_processor = LogitsProcessor(hf_text_config.vocab_size)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
@property
|
|
def dtype(self):
|
|
return next(self.parameters()).dtype
|
|
|
|
def _parse_and_validate_image_input(
|
|
self,
|
|
**kwargs: object,
|
|
) -> Molmo2ImageInputs | None:
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
if pixel_values is None:
|
|
return None
|
|
|
|
token_pooling = kwargs.pop("image_token_pooling", None)
|
|
num_pooled_patches = kwargs.pop("image_num_pooled_patches", None)
|
|
num_patches = kwargs.pop("image_num_patches", None)
|
|
image_tokens = kwargs.pop("image_tokens", None)
|
|
num_image_tokens = kwargs.pop("num_image_tokens", None)
|
|
|
|
accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist()
|
|
patch_offset = 0
|
|
new_token_pooling = token_pooling.clone()
|
|
for i, n in enumerate(num_pooled_patches):
|
|
cur_slice = token_pooling[patch_offset : patch_offset + n]
|
|
index_offset = int(accum_patches[i])
|
|
new_token_pooling[patch_offset : patch_offset + n] = torch.where(
|
|
cur_slice >= 0,
|
|
cur_slice + index_offset,
|
|
cur_slice,
|
|
)
|
|
patch_offset += n
|
|
|
|
return Molmo2ImageInputs(
|
|
pixel_values=pixel_values,
|
|
token_pooling=new_token_pooling,
|
|
num_pooled_patches=num_pooled_patches,
|
|
image_tokens=image_tokens,
|
|
num_image_tokens=num_image_tokens,
|
|
)
|
|
|
|
def _parse_and_validate_video_input(
|
|
self,
|
|
**kwargs: object,
|
|
) -> Molmo2VideoInputs | None:
|
|
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
|
if pixel_values_videos is None:
|
|
return None
|
|
|
|
token_pooling = kwargs.pop("video_token_pooling", None)
|
|
num_pooled_patches = kwargs.pop("video_num_pooled_patches", None)
|
|
num_patches = kwargs.pop("video_num_patches", None)
|
|
video_tokens = kwargs.pop("video_tokens", None)
|
|
num_video_tokens = kwargs.pop("num_video_tokens", None)
|
|
|
|
accum_patches = [0] + num_patches.cumsum(dim=0)[:-1].tolist()
|
|
patch_offset = 0
|
|
new_token_pooling = token_pooling.clone()
|
|
for i, n in enumerate(num_pooled_patches):
|
|
cur_slice = token_pooling[patch_offset : patch_offset + n]
|
|
index_offset = int(accum_patches[i])
|
|
new_token_pooling[patch_offset : patch_offset + n] = torch.where(
|
|
cur_slice >= 0,
|
|
cur_slice + index_offset,
|
|
cur_slice,
|
|
)
|
|
patch_offset += n
|
|
|
|
return Molmo2VideoInputs(
|
|
pixel_values_videos=pixel_values_videos,
|
|
token_pooling=new_token_pooling,
|
|
num_pooled_patches=num_pooled_patches,
|
|
video_tokens=video_tokens,
|
|
num_video_tokens=num_video_tokens,
|
|
)
|
|
|
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
|
modalities = {}
|
|
|
|
for input_key in kwargs:
|
|
if input_key in ("pixel_values",) and "images" not in modalities:
|
|
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
|
|
if input_key in ("pixel_values_videos",) and "videos" not in modalities:
|
|
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
|
|
return modalities
|
|
|
|
def _process_image_input(
|
|
self,
|
|
image_input: Molmo2ImageInputs,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
pixel_values = image_input["pixel_values"]
|
|
token_pooling = image_input["token_pooling"]
|
|
num_pooled_patches = image_input["num_pooled_patches"]
|
|
image_tokens = image_input["image_tokens"]
|
|
num_image_tokens = image_input["num_image_tokens"]
|
|
|
|
image_features_flat = self.vision_backbone(
|
|
images=pixel_values.unsqueeze(0),
|
|
token_pooling=token_pooling.unsqueeze(0),
|
|
)
|
|
|
|
assert len(image_features_flat) == num_pooled_patches.sum()
|
|
image_features_list = image_features_flat.split(
|
|
num_pooled_patches.tolist(), dim=0
|
|
)
|
|
image_tokens_list = image_tokens.split(num_image_tokens.tolist(), dim=0)
|
|
out = []
|
|
for image_features_i, image_tokens_i in zip(
|
|
image_features_list, image_tokens_list
|
|
):
|
|
out_features = self.get_language_model().embed_input_ids(image_tokens_i)
|
|
is_image_patch = image_tokens_i == self.img_patch_id
|
|
out_features[is_image_patch] = image_features_i
|
|
out.append(out_features)
|
|
return tuple(out)
|
|
|
|
def _process_video_input(
|
|
self,
|
|
video_input: Molmo2VideoInputs,
|
|
) -> tuple[torch.Tensor, ...]:
|
|
pixel_values_videos = video_input["pixel_values_videos"]
|
|
token_pooling = video_input["token_pooling"]
|
|
num_pooled_patches = video_input["num_pooled_patches"]
|
|
video_tokens = video_input["video_tokens"]
|
|
num_video_tokens = video_input["num_video_tokens"]
|
|
|
|
image_features_flat = self.vision_backbone(
|
|
images=pixel_values_videos.unsqueeze(0),
|
|
token_pooling=token_pooling.unsqueeze(0),
|
|
)
|
|
|
|
assert len(image_features_flat) == num_pooled_patches.sum()
|
|
image_features_list = image_features_flat.split(
|
|
num_pooled_patches.tolist(), dim=0
|
|
)
|
|
video_tokens_list = video_tokens.split(num_video_tokens.tolist(), dim=0)
|
|
out = []
|
|
for image_features_i, video_tokens_i in zip(
|
|
image_features_list, video_tokens_list
|
|
):
|
|
out_features = self.get_language_model().embed_input_ids(video_tokens_i)
|
|
is_image_patch = video_tokens_i == self.img_patch_id
|
|
out_features[is_image_patch] = image_features_i
|
|
out.append(out_features)
|
|
return tuple(out)
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
|
if not modalities:
|
|
return []
|
|
|
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
|
|
|
for modality in modalities:
|
|
if modality == "images":
|
|
image_input = modalities["images"]
|
|
image_embeddings = self._process_image_input(image_input)
|
|
multimodal_embeddings += image_embeddings
|
|
if modality == "videos":
|
|
video_input = modalities["videos"]
|
|
video_embeddings = self._process_video_input(video_input)
|
|
multimodal_embeddings += video_embeddings
|
|
|
|
return multimodal_embeddings
|
|
|
|
def embed_input_ids(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
|
*,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
handle_oov_mm_token: bool = False,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self._embed_text_input_ids(
|
|
input_ids,
|
|
self.get_language_model().embed_input_ids,
|
|
is_multimodal=is_multimodal,
|
|
handle_oov_mm_token=handle_oov_mm_token,
|
|
)
|
|
|
|
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
|
return inputs_embeds
|
|
|
|
if is_multimodal is None:
|
|
raise ValueError(
|
|
"`embed_input_ids` now requires `is_multimodal` arg, "
|
|
"please update your model runner according to "
|
|
"https://github.com/vllm-project/vllm/pull/16229."
|
|
)
|
|
|
|
inputs_embeds = _merge_multimodal_embeddings(
|
|
inputs_embeds=inputs_embeds,
|
|
multimodal_embeddings=multimodal_embeddings,
|
|
is_multimodal=is_multimodal,
|
|
)
|
|
return inputs_embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor,
|
|
positions: torch.LongTensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor:
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.model(
|
|
input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
**kwargs,
|
|
)
|
|
|
|
return hidden_states
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
loader = AutoWeightsLoader(self)
|
|
weights = _get_weights_with_merged_embedding(weights)
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
"""
|
|
Get the module prefix in multimodal models
|
|
"""
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="model",
|
|
connector="vision_backbone.image_projector",
|
|
tower_model="vision_backbone",
|
|
)
|
|
|
|
|
|
def _get_weights_with_merged_embedding(
|
|
weights: Iterable[tuple[str, torch.Tensor]],
|
|
) -> Iterable[tuple[str, torch.Tensor]]:
|
|
embedding_weights = {}
|
|
for name, weight in weights:
|
|
if "wte.embedding" in name:
|
|
embedding_weights["embedding"] = weight
|
|
elif "wte.new_embedding" in name:
|
|
embedding_weights["new_embedding"] = weight
|
|
else:
|
|
yield (name, weight)
|
|
# this is compatible with most of quantization,
|
|
# because they won't quantize embed_tokens
|
|
if "embedding" not in embedding_weights or "new_embedding" not in embedding_weights:
|
|
raise ValueError(
|
|
"Checkpoint is missing 'wte.embedding' or "
|
|
"'wte.new_embedding' weights required for Molmo2."
|
|
)
|
|
|
|
embedding_weights = torch.cat(
|
|
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
|
|
dim=0,
|
|
)
|
|
yield ("model.embed_tokens.weight", embedding_weights)
|