Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -13,8 +13,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
TensorType)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
@@ -23,43 +22,65 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul)
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptIndexTargets,
PromptInsertion,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsQuant)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
@@ -81,16 +102,22 @@ class MolmoImageInputs(TensorSchema):
- tp: Token sequence positions
- pd: Patch dimension
"""
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})]
images: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
]
# Number of crops may vary per batch and image, so pass it as a list.
image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
TensorShape("bn", "nc", "np", dynamic_dims={"nc"})]
image_masks: Annotated[
Optional[Union[torch.Tensor, list[torch.Tensor]]],
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
]
feat_is_patch: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})]
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# A boolean mask indicating which image features correspond to patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@@ -110,8 +137,7 @@ class VisionBackboneConfig:
image_norm_eps: float = 1e-5
def __post_init__(self):
self.image_default_input_size = tuple(
self.image_default_input_size) # type: ignore[assignment]
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
@property
def image_num_patch(self):
@@ -207,15 +233,13 @@ class MultiHeadDotProductAttention(nn.Module):
)
self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention(self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads)
def forward(self,
inputs_q: torch.Tensor,
inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
self.attn = MultiHeadAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
)
def forward(
self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_kv is not None:
inputs_k = inputs_kv
inputs_v = inputs_kv
@@ -242,8 +266,7 @@ class ResidualAttentionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.attention = MultiHeadDotProductAttention(
config, quant_config=quant_config)
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
self.feed_forward = ViTMLP(config, quant_config)
self.attention_norm = nn.LayerNorm(
config.image_emb_dim,
@@ -269,10 +292,12 @@ class BlockCollection(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(config, quant_config)
for _ in range(config.image_num_layers)
])
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(config, quant_config)
for _ in range(config.image_num_layers)
]
)
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = []
@@ -297,19 +322,18 @@ class VisionTransformer(nn.Module):
super().__init__()
scale = config.image_emb_dim**-0.5
self.patch_num = config.image_num_patch
self.class_embedding = nn.Parameter(
torch.randn(config.image_emb_dim) * scale)
self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale)
self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
self.positional_embedding = nn.Parameter(
torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
torch.randn(config.image_num_pos, config.image_emb_dim) * scale
)
image_patch_size = config.image_patch_size
self.patch_embedding = nn.Linear(
image_patch_size * image_patch_size * 3,
config.image_emb_dim,
bias=False,
)
self.pre_ln = nn.LayerNorm(config.image_emb_dim,
eps=config.image_norm_eps)
self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
self.transformer = BlockCollection(config, quant_config)
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
@@ -317,8 +341,12 @@ class VisionTransformer(nn.Module):
pos_emb = self.positional_embedding[1:]
pos_emb = pos_emb.reshape(
(int(math.sqrt(pos_emb.shape[0])),
int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
(
int(math.sqrt(pos_emb.shape[0])),
int(math.sqrt(pos_emb.shape[0])),
pos_emb.shape[1],
)
)
(patch_num_0, patch_num_1) = patch_num
@@ -335,13 +363,12 @@ class VisionTransformer(nn.Module):
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
dim=1).to(x.dtype)
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
return x
def forward(self,
x: torch.Tensor,
patch_num: Optional[int] = None) -> list[torch.Tensor]:
def forward(
self, x: torch.Tensor, patch_num: Optional[int] = None
) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
@@ -353,8 +380,8 @@ class VisionTransformer(nn.Module):
# class embeddings and positional embeddings
x = torch.cat(
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
dim=1)
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
)
x = self.add_pos_emb(x, patch_num)
x = self.pre_ln(x)
@@ -382,8 +409,7 @@ class MolmoAttention(nn.Module):
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = config.num_key_value_heads \
or self.total_num_heads
self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads
if self.total_num_kv_heads >= self.tp_size:
assert self.total_num_kv_heads % self.tp_size == 0
else:
@@ -411,10 +437,10 @@ class MolmoAttention(nn.Module):
self.q_norm: Optional[nn.Module] = None
if config.attention_layer_norm:
self.tp_rank = get_tensor_model_parallel_rank()
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
eps=config.layer_norm_eps)
self.q_norm = RMSNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.k_norm = RMSNorm(
self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps
)
self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
# Rotary embeddings.
self.rotary_emb = get_rope(
@@ -424,13 +450,15 @@ class MolmoAttention(nn.Module):
base=self.rope_theta,
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
# Attention output projection.
self.o_proj = RowParallelLinear(
@@ -440,16 +468,16 @@ class MolmoAttention(nn.Module):
quant_config=quant_config,
)
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
@@ -472,10 +500,12 @@ class MolmoAttention(nn.Module):
class LanguageModelMLP(nn.Module):
"""Molmo's LLM mlp."""
def __init__(self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2
@@ -547,7 +577,6 @@ class ImageProjectorMLP(nn.Module):
class MolmoDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -557,20 +586,19 @@ class MolmoDecoderLayer(nn.Module):
) -> None:
super().__init__()
# Attention block.
self.self_attn = MolmoAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.self_attn")
self.self_attn = MolmoAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
)
# MLP block.
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
# LayerNorm
assert config.layer_norm_type == "rms"
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def forward(
self,
@@ -583,21 +611,18 @@ class MolmoDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
def forward(
self,
positions: torch.Tensor,
@@ -638,16 +663,14 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
(self.image_num_patch[0] + 1) // POOLING_SIZE,
(self.image_num_patch[1] + 1) // POOLING_SIZE,
)
self.image_vit = VisionTransformer(vision_config,
quant_config=quant_config)
self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
assert self.num_prefix_tokens in {
0, 1
}, "Only 0 or 1 prefix tokens are supported"
assert self.num_prefix_tokens in {0, 1}, (
"Only 0 or 1 prefix tokens are supported"
)
self.image_pooling_2d = MultiHeadDotProductAttention(
vision_config,
nlayers=len(self.vit_layers),
quant_config=quant_config)
vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
)
self.image_projector = ImageProjectorMLP(
config,
input_dim=vision_config.image_emb_dim,
@@ -671,8 +694,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
"""
B, T, N, D = images.shape
mask = ~torch.all(
images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
images = images.view(B * T, N, D)
image_features = self.image_vit(images)
@@ -707,21 +729,22 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
assert image_masks is not None
pad_embed = self.pad_embed[:, None, None, None, :]
all_pad = image_masks == 0
partial_pad = torch.logical_and(
image_masks < 1,
torch.logical_not(all_pad)).to(dtype=torch.float32)
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
dtype=torch.float32
)
all_pad = all_pad.to(dtype=torch.float32)
image_features = image_features + pad_embed[0] * torch.unsqueeze(
all_pad, -1)
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
image_features = image_features + pad_embed[1] * torch.unsqueeze(
partial_pad, -1)
partial_pad, -1
)
image_features = image_features.to(og_dtype)
image_features = image_features.reshape(
(batch_size, num_image) + self.image_num_patch + (-1, ), )
(batch_size, num_image) + self.image_num_patch + (-1,),
)
if (missing_w := self.image_num_patch[0] % POOLING_SIZE):
if missing_w := self.image_num_patch[0] % POOLING_SIZE:
# Padding for image pooling (see below)
image_features = F.pad(
image_features,
@@ -731,7 +754,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
# image pooling
image_features = rearrange(
image_features,
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
"b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
dh=POOLING_SIZE,
dw=POOLING_SIZE,
)
@@ -747,8 +770,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
# image_features: (batch_size, num_image, num_patch, d_model)
return image_features
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("merged_linear", "gate_proj", 0),
@@ -758,7 +780,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -777,8 +799,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@@ -786,7 +807,6 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
@support_torch_compile
class MolmoModel(nn.Module, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -804,21 +824,23 @@ class MolmoModel(nn.Module, SupportsQuant):
quant_config=quant_config,
)
decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
else MolmoDecoderLayer
decoder_layer = (
MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer(
config, cache_config, quant_config, prefix=prefix),
config, cache_config, quant_config, prefix=prefix
),
prefix=f"{prefix}.layers",
)
assert config.layer_norm_type == "rms"
self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -849,18 +871,16 @@ class MolmoModel(nn.Module, SupportsQuant):
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -871,8 +891,7 @@ class MolmoModel(nn.Module, SupportsQuant):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@@ -939,8 +958,12 @@ def get_patches_grid_size(
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
tilings = [(i, j) for i in range(1, max_num + 1)
for j in range(1, max_num + 1) if i * j <= max_num]
tilings = [
(i, j)
for i in range(1, max_num + 1)
for j in range(1, max_num + 1)
if i * j <= max_num
]
return sorted(tilings, key=lambda x: x[0] * x[1])
@@ -1128,7 +1151,8 @@ class MolmoProcessorWrapper:
**kwargs,
) -> BatchFeature:
outputs = self.processor.process( # type: ignore
text, images, **kwargs)
text, images, **kwargs
)
if images is None:
images = []
@@ -1146,7 +1170,8 @@ class MolmoProcessorWrapper:
self.select_tiling(
image_width=image.size[0],
image_height=image.size[1],
) for image in images
)
for image in images
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
@@ -1160,7 +1185,6 @@ class MolmoProcessorWrapper:
class MolmoProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
return MolmoProcessorWrapper(processor)
@@ -1209,8 +1233,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
@@ -1219,7 +1242,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@@ -1229,23 +1251,22 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
@@ -1263,7 +1284,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
processor, # type: ignore
dict(tokens=tokens),
)
prompt_ids, = processed_data.pop("input_ids").tolist()
(prompt_ids,) = processed_data.pop("input_ids").tolist()
return prompt_ids
@@ -1277,10 +1298,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
return dict(
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
@@ -1303,8 +1322,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
img_end_id = processor.im_end_id
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
[img_end_id])
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
def get_insertion_molmo(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
@@ -1315,10 +1333,12 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
image_height=image_size.height,
)
joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) +
[img_col_id])
joint = ([img_start_id] + joint_row *
((nrows + 1) // pooling_size) + [img_end_id])
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
joint = (
[img_start_id]
+ joint_row * ((nrows + 1) // pooling_size)
+ [img_end_id]
)
return PromptUpdateDetails.select_token_id(
extra_joint + joint,
@@ -1334,11 +1354,14 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
]
@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor,
info=MolmoProcessingInfo,
dummy_inputs=MolmoDummyInputsBuilder)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
SupportsQuant):
@MULTIMODAL_REGISTRY.register_processor(
MolmoMultiModalProcessor,
info=MolmoProcessingInfo,
dummy_inputs=MolmoDummyInputsBuilder,
)
class MolmoForCausalLM(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
@@ -1370,7 +1393,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
packed_modules_mapping = {
"qkv_proj": ["qkv_proj"],
"gate_up_proj": ["gate_up_proj"], # language model
"merged_linear": ["gate_proj", "up_proj"] # image_projector
"merged_linear": ["gate_proj", "up_proj"], # image_projector
}
@classmethod
@@ -1391,10 +1414,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
self.lora_config = lora_config
vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
quant_config)
self.model = MolmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
self.model = MolmoModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.img_patch_id = None
if self.config.weight_tying:
@@ -1407,11 +1430,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.embedding_size
or config.vocab_size)
self.logits_processor = LogitsProcessor(
config.embedding_size or config.vocab_size
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.model.make_empty_intermediate_tensors
)
def _parse_and_validate_image_input(
self,
@@ -1426,14 +1451,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return None
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
raise ValueError(
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
)
num_crops = flatten_bn(num_crops, concat=True)
img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
raise ValueError("Incorrect type of img_patch_id. "
f"Got type: {type(img_patch_id)}")
raise ValueError(
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
)
self.img_patch_id = img_patch_id.flatten().unique().item()
return MolmoImageInputs(
@@ -1454,19 +1481,22 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
# Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True)
image_masks_flat = (None if image_masks is None else flatten_bn(
image_masks, concat=True))
image_masks_flat = (
None if image_masks is None else flatten_bn(image_masks, concat=True)
)
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(None if image_masks_flat is None else
image_masks_flat.unsqueeze(0)),
image_masks=(
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
),
).squeeze(0)
# Only the features corresponding to patch tokens are relevant
return [
feats[f_is_patch] for feats, f_is_patch in zip(
feats[f_is_patch]
for feats, f_is_patch in zip(
image_features_flat.split(num_crops.tolist()),
feat_is_patch_flat.split(num_crops.tolist()),
)
@@ -1475,8 +1505,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -1491,14 +1520,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
@@ -1507,7 +1534,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@@ -1524,7 +1550,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _get_weights_with_merged_embedding(
weights: Iterable[tuple[str, torch.Tensor]]
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights: