Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,8 +13,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
|
||||
TensorType)
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
@@ -23,43 +22,65 @@ from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||
SiluAndMul)
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import MulAndSilu, QuickGELU, SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptIndexTargets,
|
||||
PromptInsertion, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptIndexTargets,
|
||||
PromptInsertion,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsQuant,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@@ -81,16 +102,22 @@ class MolmoImageInputs(TensorSchema):
|
||||
- tp: Token sequence positions
|
||||
- pd: Patch dimension
|
||||
"""
|
||||
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})]
|
||||
|
||||
images: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
|
||||
]
|
||||
# Number of crops may vary per batch and image, so pass it as a list.
|
||||
|
||||
image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
|
||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"})]
|
||||
image_masks: Annotated[
|
||||
Optional[Union[torch.Tensor, list[torch.Tensor]]],
|
||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
||||
]
|
||||
|
||||
feat_is_patch: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})]
|
||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
||||
]
|
||||
# A boolean mask indicating which image features correspond to patch tokens.
|
||||
num_crops: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
@@ -110,8 +137,7 @@ class VisionBackboneConfig:
|
||||
image_norm_eps: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
self.image_default_input_size = tuple(
|
||||
self.image_default_input_size) # type: ignore[assignment]
|
||||
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
|
||||
|
||||
@property
|
||||
def image_num_patch(self):
|
||||
@@ -207,15 +233,13 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attn = MultiHeadAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(self,
|
||||
inputs_q: torch.Tensor,
|
||||
inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if inputs_kv is not None:
|
||||
inputs_k = inputs_kv
|
||||
inputs_v = inputs_kv
|
||||
@@ -242,8 +266,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadDotProductAttention(
|
||||
config, quant_config=quant_config)
|
||||
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
|
||||
self.feed_forward = ViTMLP(config, quant_config)
|
||||
self.attention_norm = nn.LayerNorm(
|
||||
config.image_emb_dim,
|
||||
@@ -269,10 +292,12 @@ class BlockCollection(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList([
|
||||
ResidualAttentionBlock(config, quant_config)
|
||||
for _ in range(config.image_num_layers)
|
||||
])
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(config, quant_config)
|
||||
for _ in range(config.image_num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
hidden_states = []
|
||||
@@ -297,19 +322,18 @@ class VisionTransformer(nn.Module):
|
||||
super().__init__()
|
||||
scale = config.image_emb_dim**-0.5
|
||||
self.patch_num = config.image_num_patch
|
||||
self.class_embedding = nn.Parameter(
|
||||
torch.randn(config.image_emb_dim) * scale)
|
||||
self.class_embedding = nn.Parameter(torch.randn(config.image_emb_dim) * scale)
|
||||
self.num_prefix_tokens: int = NUM_PREFIX_TOKENS
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.randn(config.image_num_pos, config.image_emb_dim) * scale)
|
||||
torch.randn(config.image_num_pos, config.image_emb_dim) * scale
|
||||
)
|
||||
image_patch_size = config.image_patch_size
|
||||
self.patch_embedding = nn.Linear(
|
||||
image_patch_size * image_patch_size * 3,
|
||||
config.image_emb_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.pre_ln = nn.LayerNorm(config.image_emb_dim,
|
||||
eps=config.image_norm_eps)
|
||||
self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
|
||||
self.transformer = BlockCollection(config, quant_config)
|
||||
|
||||
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
||||
@@ -317,8 +341,12 @@ class VisionTransformer(nn.Module):
|
||||
pos_emb = self.positional_embedding[1:]
|
||||
|
||||
pos_emb = pos_emb.reshape(
|
||||
(int(math.sqrt(pos_emb.shape[0])),
|
||||
int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
|
||||
(
|
||||
int(math.sqrt(pos_emb.shape[0])),
|
||||
int(math.sqrt(pos_emb.shape[0])),
|
||||
pos_emb.shape[1],
|
||||
)
|
||||
)
|
||||
|
||||
(patch_num_0, patch_num_1) = patch_num
|
||||
|
||||
@@ -335,13 +363,12 @@ class VisionTransformer(nn.Module):
|
||||
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
|
||||
|
||||
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
|
||||
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]],
|
||||
dim=1).to(x.dtype)
|
||||
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
|
||||
return x
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
patch_num: Optional[int] = None) -> list[torch.Tensor]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, patch_num: Optional[int] = None
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
: param x: (batch_size, num_patch, n_pixels)
|
||||
"""
|
||||
@@ -353,8 +380,8 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
x = torch.cat(
|
||||
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x],
|
||||
dim=1)
|
||||
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
|
||||
)
|
||||
x = self.add_pos_emb(x, patch_num)
|
||||
|
||||
x = self.pre_ln(x)
|
||||
@@ -382,8 +409,7 @@ class MolmoAttention(nn.Module):
|
||||
assert self.total_num_heads % self.tp_size == 0
|
||||
|
||||
self.num_heads = self.total_num_heads // self.tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads \
|
||||
or self.total_num_heads
|
||||
self.total_num_kv_heads = config.num_key_value_heads or self.total_num_heads
|
||||
if self.total_num_kv_heads >= self.tp_size:
|
||||
assert self.total_num_kv_heads % self.tp_size == 0
|
||||
else:
|
||||
@@ -411,10 +437,10 @@ class MolmoAttention(nn.Module):
|
||||
self.q_norm: Optional[nn.Module] = None
|
||||
if config.attention_layer_norm:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.q_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(
|
||||
self.total_num_kv_heads * self.head_dim, eps=config.layer_norm_eps
|
||||
)
|
||||
self.q_norm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# Rotary embeddings.
|
||||
self.rotary_emb = get_rope(
|
||||
@@ -424,13 +450,15 @@ class MolmoAttention(nn.Module):
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
# Attention output projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
@@ -440,16 +468,16 @@ class MolmoAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor,
|
||||
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def _apply_qk_norm(
|
||||
self, q: torch.Tensor, k: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
return q, k
|
||||
@@ -472,10 +500,12 @@ class MolmoAttention(nn.Module):
|
||||
class LanguageModelMLP(nn.Module):
|
||||
"""Molmo's LLM mlp."""
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
input_dim: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
input_dim: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size // 2
|
||||
@@ -547,7 +577,6 @@ class ImageProjectorMLP(nn.Module):
|
||||
|
||||
|
||||
class MolmoDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -557,20 +586,19 @@ class MolmoDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = MolmoAttention(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
self.self_attn = MolmoAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
|
||||
)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
|
||||
|
||||
# LayerNorm
|
||||
assert config.layer_norm_type == "rms"
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -583,21 +611,18 @@ class MolmoDecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -638,16 +663,14 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
(self.image_num_patch[0] + 1) // POOLING_SIZE,
|
||||
(self.image_num_patch[1] + 1) // POOLING_SIZE,
|
||||
)
|
||||
self.image_vit = VisionTransformer(vision_config,
|
||||
quant_config=quant_config)
|
||||
self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
|
||||
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
||||
assert self.num_prefix_tokens in {
|
||||
0, 1
|
||||
}, "Only 0 or 1 prefix tokens are supported"
|
||||
assert self.num_prefix_tokens in {0, 1}, (
|
||||
"Only 0 or 1 prefix tokens are supported"
|
||||
)
|
||||
self.image_pooling_2d = MultiHeadDotProductAttention(
|
||||
vision_config,
|
||||
nlayers=len(self.vit_layers),
|
||||
quant_config=quant_config)
|
||||
vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
|
||||
)
|
||||
self.image_projector = ImageProjectorMLP(
|
||||
config,
|
||||
input_dim=vision_config.image_emb_dim,
|
||||
@@ -671,8 +694,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
"""
|
||||
B, T, N, D = images.shape
|
||||
|
||||
mask = ~torch.all(
|
||||
images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
|
||||
mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
|
||||
|
||||
images = images.view(B * T, N, D)
|
||||
image_features = self.image_vit(images)
|
||||
@@ -707,21 +729,22 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
assert image_masks is not None
|
||||
pad_embed = self.pad_embed[:, None, None, None, :]
|
||||
all_pad = image_masks == 0
|
||||
partial_pad = torch.logical_and(
|
||||
image_masks < 1,
|
||||
torch.logical_not(all_pad)).to(dtype=torch.float32)
|
||||
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
all_pad = all_pad.to(dtype=torch.float32)
|
||||
image_features = image_features + pad_embed[0] * torch.unsqueeze(
|
||||
all_pad, -1)
|
||||
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
|
||||
image_features = image_features + pad_embed[1] * torch.unsqueeze(
|
||||
partial_pad, -1)
|
||||
partial_pad, -1
|
||||
)
|
||||
|
||||
image_features = image_features.to(og_dtype)
|
||||
|
||||
image_features = image_features.reshape(
|
||||
(batch_size, num_image) + self.image_num_patch + (-1, ), )
|
||||
(batch_size, num_image) + self.image_num_patch + (-1,),
|
||||
)
|
||||
|
||||
if (missing_w := self.image_num_patch[0] % POOLING_SIZE):
|
||||
if missing_w := self.image_num_patch[0] % POOLING_SIZE:
|
||||
# Padding for image pooling (see below)
|
||||
image_features = F.pad(
|
||||
image_features,
|
||||
@@ -731,7 +754,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
# image pooling
|
||||
image_features = rearrange(
|
||||
image_features,
|
||||
'b n (h dh) (w dw) c -> (b n h w) (dh dw) c',
|
||||
"b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
|
||||
dh=POOLING_SIZE,
|
||||
dw=POOLING_SIZE,
|
||||
)
|
||||
@@ -747,8 +770,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
# image_features: (batch_size, num_image, num_patch, d_model)
|
||||
return image_features
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("merged_linear", "gate_proj", 0),
|
||||
@@ -758,7 +780,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -777,8 +799,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@@ -786,7 +807,6 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
|
||||
@support_torch_compile
|
||||
class MolmoModel(nn.Module, SupportsQuant):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -804,21 +824,23 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \
|
||||
else MolmoDecoderLayer
|
||||
decoder_layer = (
|
||||
MolmoDecoderNormAfterLayer if config.norm_after else MolmoDecoderLayer
|
||||
)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: decoder_layer(
|
||||
config, cache_config, quant_config, prefix=prefix),
|
||||
config, cache_config, quant_config, prefix=prefix
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
|
||||
assert config.layer_norm_type == "rms"
|
||||
self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -849,18 +871,16 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
@@ -871,8 +891,7 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@@ -939,8 +958,12 @@ def get_patches_grid_size(
|
||||
|
||||
|
||||
def get_candidate_tilings(max_num: int) -> list[tuple[int, int]]:
|
||||
tilings = [(i, j) for i in range(1, max_num + 1)
|
||||
for j in range(1, max_num + 1) if i * j <= max_num]
|
||||
tilings = [
|
||||
(i, j)
|
||||
for i in range(1, max_num + 1)
|
||||
for j in range(1, max_num + 1)
|
||||
if i * j <= max_num
|
||||
]
|
||||
return sorted(tilings, key=lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
@@ -1128,7 +1151,8 @@ class MolmoProcessorWrapper:
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
outputs = self.processor.process( # type: ignore
|
||||
text, images, **kwargs)
|
||||
text, images, **kwargs
|
||||
)
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
@@ -1146,7 +1170,8 @@ class MolmoProcessorWrapper:
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
image_height=image.size[1],
|
||||
) for image in images
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
# For each image: tiling_h * tiling_w + extra
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
@@ -1160,7 +1185,6 @@ class MolmoProcessorWrapper:
|
||||
|
||||
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
|
||||
processor = self.ctx.get_hf_processor(**kwargs)
|
||||
return MolmoProcessorWrapper(processor)
|
||||
@@ -1209,8 +1233,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
largest_feature_pinpoint = ImageSize(width=width, height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
@@ -1219,7 +1242,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
|
||||
class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
@@ -1229,23 +1251,22 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides)
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
@@ -1263,7 +1284,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
processor, # type: ignore
|
||||
dict(tokens=tokens),
|
||||
)
|
||||
prompt_ids, = processed_data.pop("input_ids").tolist()
|
||||
(prompt_ids,) = processed_data.pop("input_ids").tolist()
|
||||
|
||||
return prompt_ids
|
||||
|
||||
@@ -1277,10 +1298,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
image_masks=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops),
|
||||
image_masks=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
feat_is_patch=MultiModalFieldConfig.flat_from_sizes("image", num_crops),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@@ -1303,8 +1322,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
img_end_id = processor.im_end_id
|
||||
|
||||
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
|
||||
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
|
||||
[img_end_id])
|
||||
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
|
||||
|
||||
def get_insertion_molmo(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
@@ -1315,10 +1333,12 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
joint_row = ([img_patch_id] * ((ncols + 1) // pooling_size) +
|
||||
[img_col_id])
|
||||
joint = ([img_start_id] + joint_row *
|
||||
((nrows + 1) // pooling_size) + [img_end_id])
|
||||
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
|
||||
joint = (
|
||||
[img_start_id]
|
||||
+ joint_row * ((nrows + 1) // pooling_size)
|
||||
+ [img_end_id]
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
extra_joint + joint,
|
||||
@@ -1334,11 +1354,14 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(MolmoMultiModalProcessor,
|
||||
info=MolmoProcessingInfo,
|
||||
dummy_inputs=MolmoDummyInputsBuilder)
|
||||
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
SupportsQuant):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
MolmoMultiModalProcessor,
|
||||
info=MolmoProcessingInfo,
|
||||
dummy_inputs=MolmoDummyInputsBuilder,
|
||||
)
|
||||
class MolmoForCausalLM(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
# vision backbone mapping
|
||||
@@ -1370,7 +1393,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["qkv_proj"],
|
||||
"gate_up_proj": ["gate_up_proj"], # language model
|
||||
"merged_linear": ["gate_proj", "up_proj"] # image_projector
|
||||
"merged_linear": ["gate_proj", "up_proj"], # image_projector
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -1391,10 +1414,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
self.lora_config = lora_config
|
||||
|
||||
vision_config = VisionBackboneConfig()
|
||||
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
|
||||
quant_config)
|
||||
self.model = MolmoModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
|
||||
self.model = MolmoModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
self.img_patch_id = None
|
||||
|
||||
if self.config.weight_tying:
|
||||
@@ -1407,11 +1430,13 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.embedding_size
|
||||
or config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.embedding_size or config.vocab_size
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self,
|
||||
@@ -1426,14 +1451,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
return None
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
|
||||
)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
img_patch_id = kwargs.pop("img_patch_id", None)
|
||||
if not isinstance(img_patch_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of img_patch_id. "
|
||||
f"Got type: {type(img_patch_id)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
|
||||
)
|
||||
self.img_patch_id = img_patch_id.flatten().unique().item()
|
||||
|
||||
return MolmoImageInputs(
|
||||
@@ -1454,19 +1481,22 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
# Call the vision backbone on the whole batch at once
|
||||
images_flat = flatten_bn(images, concat=True)
|
||||
image_masks_flat = (None if image_masks is None else flatten_bn(
|
||||
image_masks, concat=True))
|
||||
image_masks_flat = (
|
||||
None if image_masks is None else flatten_bn(image_masks, concat=True)
|
||||
)
|
||||
feat_is_patch_flat = flatten_bn(feat_is_patch, concat=True)
|
||||
|
||||
image_features_flat = self.vision_backbone(
|
||||
images=images_flat.unsqueeze(0),
|
||||
image_masks=(None if image_masks_flat is None else
|
||||
image_masks_flat.unsqueeze(0)),
|
||||
image_masks=(
|
||||
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
|
||||
),
|
||||
).squeeze(0)
|
||||
|
||||
# Only the features corresponding to patch tokens are relevant
|
||||
return [
|
||||
feats[f_is_patch] for feats, f_is_patch in zip(
|
||||
feats[f_is_patch]
|
||||
for feats, f_is_patch in zip(
|
||||
image_features_flat.split(num_crops.tolist()),
|
||||
feat_is_patch_flat.split(num_crops.tolist()),
|
||||
)
|
||||
@@ -1475,8 +1505,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@@ -1491,14 +1520,12 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1507,7 +1534,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
weights = _get_weights_with_merged_embedding(weights)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
@@ -1524,7 +1550,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
|
||||
def _get_weights_with_merged_embedding(
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
embedding_weights = {}
|
||||
for name, weight in weights:
|
||||
|
||||
Reference in New Issue
Block a user