From 5576227bc132996d9e4ca279a8a8cf9e7b8a8ca4 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 8 Jan 2026 18:33:16 +0800 Subject: [PATCH] [Model] Standardize common vision encoders (#31947) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/clip.py | 69 ++++++++-- vllm/model_executor/models/deepencoder.py | 3 + vllm/model_executor/models/deepseek_ocr.py | 1 + .../models/hyperclovax_vision.py | 19 ++- vllm/model_executor/models/isaac.py | 2 +- vllm/model_executor/models/keye.py | 2 + vllm/model_executor/models/lightonocr.py | 4 +- vllm/model_executor/models/llava.py | 9 +- vllm/model_executor/models/llava_next.py | 4 +- .../model_executor/models/llava_next_video.py | 4 +- vllm/model_executor/models/llava_onevision.py | 3 +- vllm/model_executor/models/minimax_vl_01.py | 3 +- vllm/model_executor/models/mistral3.py | 7 +- vllm/model_executor/models/paddleocr_vl.py | 44 +----- vllm/model_executor/models/phi3v.py | 62 ++++----- vllm/model_executor/models/pixtral.py | 45 ++++++- vllm/model_executor/models/siglip.py | 126 ++++++++++-------- vllm/model_executor/models/siglip2navit.py | 8 +- vllm/model_executor/models/tarsier.py | 11 +- 19 files changed, 253 insertions(+), 173 deletions(-) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 8e77b36e6..6ec700a1c 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -17,7 +17,7 @@ from transformers import ( from vllm.attention.layer import Attention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -353,6 +353,7 @@ class CLIPAttention(nn.Module): self, config: CLIPTextConfig | CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", attn_cls: type[Attention] | type[MMEncoderAttention], @@ -365,18 +366,24 @@ class CLIPAttention(nn.Module): self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - "embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and " + f"`num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, total_num_heads=self.num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( @@ -384,17 +391,29 @@ class CLIPAttention(nn.Module): output_size=self.embed_dim, quant_config=quant_config, prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, ) - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = attn_cls( - self.num_heads_per_partition, - self.head_dim, - self.scale, - prefix=f"{prefix}.attn", - ) + if attn_cls == MMEncoderAttention: + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, + ) + else: + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -415,17 +434,26 @@ class CLIPMLP(nn.Module): self, config: CLIPTextConfig | CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", ) -> None: super().__init__() + self.config = config + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) self.fc2 = RowParallelLinear( config.intermediate_size, @@ -433,6 +461,7 @@ class CLIPMLP(nn.Module): bias=True, quant_config=quant_config, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -448,19 +477,27 @@ class CLIPEncoderLayer(nn.Module): self, config: CLIPTextConfig | CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", attn_cls: type[Attention] | type[MMEncoderAttention], ) -> None: super().__init__() + self.self_attn = CLIPAttention( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", attn_cls=attn_cls, ) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.mlp = CLIPMLP( + config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.mlp", + ) self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -491,6 +528,7 @@ class CLIPEncoder(nn.Module): self, config: CLIPTextConfig | CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, num_hidden_layers_override: int | None = None, *, prefix: str = "", @@ -504,11 +542,13 @@ class CLIPEncoder(nn.Module): num_hidden_layers = config.num_hidden_layers else: num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList( [ CLIPEncoderLayer( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", attn_cls=attn_cls, ) @@ -618,6 +658,7 @@ class CLIPVisionTransformer(nn.Module): self, config: CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, @@ -637,6 +678,7 @@ class CLIPVisionTransformer(nn.Module): self.encoder = CLIPEncoder( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", attn_cls=MMEncoderAttention, @@ -738,6 +780,7 @@ class CLIPVisionModel(nn.Module): self, config: CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, @@ -748,6 +791,7 @@ class CLIPVisionModel(nn.Module): self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_override, require_post_norm=require_post_norm, prefix=f"{prefix}.vision_model", @@ -817,6 +861,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): self.vision_model = CLIPVisionTransformer( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_model"), ) diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py index 045445d23..6b9d09e88 100644 --- a/vllm/model_executor/models/deepencoder.py +++ b/vllm/model_executor/models/deepencoder.py @@ -19,6 +19,7 @@ import torch.nn.functional as F from transformers import CLIPVisionConfig from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention +from vllm.config import MultiModalConfig from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -608,6 +609,7 @@ class DeepCLIPVisionTransformer(nn.Module): self, config: CLIPVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, prefix: str = "", @@ -626,6 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module): self.transformer = CLIPEncoder( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", attn_cls=MMEncoderAttention, diff --git a/vllm/model_executor/models/deepseek_ocr.py b/vllm/model_executor/models/deepseek_ocr.py index 146c673dd..87afec0d3 100644 --- a/vllm/model_executor/models/deepseek_ocr.py +++ b/vllm/model_executor/models/deepseek_ocr.py @@ -397,6 +397,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports self.vision_model = DeepCLIPVisionTransformer( config=clip_vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_model"), ) diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 3a083870e..f5226baba 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial from itertools import accumulate -from typing import Annotated, Any, Literal +from typing import Annotated, Literal import numpy as np import torch @@ -18,7 +18,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers.modeling_utils import no_init_weights from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache @@ -361,6 +361,7 @@ def _build_hcxvision_hf_processor( def init_vision_tower_for_hcxvision( vision_config, quant_config: QuantizationConfig | None, + multimodal_config: MultiModalConfig | None, *, use_nth_layer: int | None = None, require_post_norm: bool | None = None, @@ -378,6 +379,7 @@ def init_vision_tower_for_hcxvision( return CLIPVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, prefix=prefix, @@ -386,6 +388,7 @@ def init_vision_tower_for_hcxvision( return SiglipVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, prefix=prefix, @@ -597,18 +600,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): "gate_up_proj": ["gate_proj", "up_proj"], } - def __init__( - self, - *, - vllm_config: VllmConfig, - prefix: str = "", - **kwargs: Any | None, - ) -> None: + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() # init configs config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config # text_config text_config = config.text_config if text_config.model_type in ["gpt2", "hyperclovax", "llama"]: @@ -631,7 +629,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): with no_init_weights(): # weight will be loaded in from_pretrained self.vision_model = init_vision_tower_for_hcxvision( vision_config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, use_nth_layer=getattr(config, "use_nth_layer", -1), require_post_norm=False, prefix=maybe_prefix(prefix, "vision_model"), diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index d3bdb1370..c95a57faf 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -1226,8 +1226,8 @@ class IsaacVisionEmbedding(nn.Module): self.transformer = Siglip2VisionTransformer( vision_cfg, quant_config=quant_config, - prefix=maybe_prefix(prefix, "0"), multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "0"), ) self.linear_fc1 = ColumnParallelLinear( hidden_dim, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index fcf88953b..bd148059e 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -404,6 +404,7 @@ class KeyeSiglipAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_heads, head_size=self.head_dim, + scale=self.scale, num_kv_heads=self.num_kv_heads, prefix=f"{prefix}.attn", multimodal_config=multimodal_config, @@ -511,6 +512,7 @@ class KeyeSiglipEncoderLayer(nn.Module): self.mlp = SiglipMLP( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", ) diff --git a/vllm/model_executor/models/lightonocr.py b/vllm/model_executor/models/lightonocr.py index 353ee7806..27ec12a8f 100644 --- a/vllm/model_executor/models/lightonocr.py +++ b/vllm/model_executor/models/lightonocr.py @@ -155,6 +155,7 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: nn.Module.__init__(self) + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config @@ -164,7 +165,8 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration): self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 386c5216e..ba54623d9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -468,6 +468,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: def init_vision_tower_for_llava( hf_config: LlavaLikeConfig, quant_config: QuantizationConfig | None, + multimodal_config: MultiModalConfig | None, *, require_post_norm: bool | None = None, prefix: str = "", @@ -481,6 +482,7 @@ def init_vision_tower_for_llava( return CLIPVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, prefix=prefix, @@ -489,6 +491,7 @@ def init_vision_tower_for_llava( return SiglipVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, prefix=prefix, @@ -497,6 +500,7 @@ def init_vision_tower_for_llava( return PixtralHFVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, prefix=prefix, @@ -563,7 +567,8 @@ class LlavaForConditionalGeneration( if multimodal_config.get_limit_per_prompt("image"): self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 526846d0d..21a9c2f28 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -243,6 +243,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config @@ -270,7 +271,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index cd55cfec6..b146a144e 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -321,6 +321,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config @@ -331,7 +332,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp # Initialize the vision tower only up to the required feature layer self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 5aa8de7dc..a89f456ef 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -511,7 +511,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp # Initialize the vision tower only up to the required feature layer self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index e48045495..b4a496dcb 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -204,7 +204,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index e9161e69e..d3617fd4b 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -16,7 +16,7 @@ from transformers import ( from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -395,6 +395,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: def init_vision_tower_for_llava( hf_config: LlavaLikeConfig, quant_config: QuantizationConfig | None, + multimodal_config: MultiModalConfig | None, *, require_post_norm: bool | None = None, prefix: str = "", @@ -409,6 +410,7 @@ def init_vision_tower_for_llava( return PixtralHFVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, require_post_norm=require_post_norm, prefix=prefix, @@ -472,7 +474,8 @@ class Mistral3ForConditionalGeneration( if multimodal_config.get_limit_per_prompt("image"): self.vision_tower = init_vision_tower_for_llava( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), ) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 56565266c..dc70c5a85 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) @@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape from .ernie45 import Ernie4_5ForCausalLM from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal +from .siglip import SiglipMLP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module): return freqs -class SiglipMLP(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ) -> None: - super().__init__() - - self.config = config - self.activation_fn = get_act_fn(config.hidden_act) - # Special handling for BNB and torchao quantization - if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]: - quantizable = True - else: - # For other quantization, we require the hidden size to be a - # multiple of 64 - quantizable = ( - config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0 - ) - self.fc1 = ColumnParallelLinear( - config.hidden_size, - config.intermediate_size, - quant_config=quant_config if quantizable else None, - prefix=f"{prefix}.fc1", - ) - self.fc2 = RowParallelLinear( - config.intermediate_size, - config.hidden_size, - quant_config=quant_config if quantizable else None, - prefix=f"{prefix}.fc2", - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states, _ = self.fc2(hidden_states) - return hidden_states - - class SiglipEncoderLayer(nn.Module): def __init__( self, @@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module): self.mlp = SiglipMLP( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", ) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 900b0eade..75823ec58 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,7 +29,7 @@ from transformers import ( ) from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -96,6 +96,7 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig( def _init_img_processor( hf_config: PretrainedConfig, quant_config: QuantizationConfig | None, + multimodal_config: MultiModalConfig | None, prefix: str = "", ) -> CLIPVisionModel: clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG @@ -109,7 +110,8 @@ def _init_img_processor( img_processor = CLIPVisionModel( clip_config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers, prefix=prefix, ) @@ -160,38 +162,15 @@ class Phi3VImageEmbeddingInputs(TensorSchema): Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs -class Phi3ImageEmbeddingBase(nn.Module): - def __init__(self) -> None: - super().__init__() - self.layer_idx: int - self.type_feature: str - self.img_processor: CLIPVisionModel - - def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: - TYPE_FEATURE = self.type_feature - - # NOTE: we skip the step to select the vision feature layer since - # this is already done inside the img_processor - img_feature = self.img_processor(img_embeds) - - if TYPE_FEATURE == "patch": - patch_feature = img_feature[:, 1:] - return patch_feature - - if TYPE_FEATURE == "cls_patch": - return img_feature - - raise NotImplementedError - - # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py -class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): +class Phi3HDImageEmbedding(nn.Module): """Phi3 Image embedding with HD transform.""" def __init__( self, config: PretrainedConfig, quant_config: QuantizationConfig | None, + multimodal_config: MultiModalConfig | None, prefix: str = "", ) -> None: super().__init__() @@ -200,7 +179,10 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size self.img_processor = _init_img_processor( - config, quant_config, prefix=f"{prefix}.img_processor" + config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.img_processor", ) image_dim_out = config.img_processor["image_dim_out"] @@ -223,13 +205,29 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): dim_projection = hidden_size depth = 2 - layers = [nn.Linear(image_dim_out * 4, dim_projection)] + layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)] for _ in range(1, depth): layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) self.img_projection = nn.Sequential(*layers) self.type_feature = config.img_processor.get("type_feature", "patch") + def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor: + type_feature = self.type_feature + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the img_processor + img_feature = self.img_processor(img_embeds) + + if type_feature == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + if type_feature == "cls_patch": + return img_feature + + raise NotImplementedError(type_feature) + def forward( self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor ) -> torch.FloatTensor: @@ -582,6 +580,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -590,14 +589,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - quant_config=self.quant_config, + quant_config=quant_config, prefix=maybe_prefix(prefix, "model.embed_tokens"), ) # TODO: Optionally initializes this for supporting input embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding( config, - self.quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), ) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 555e6ea4b..fdb940d76 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import ( from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module): self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", ) -> None: super().__init__() + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) + assert config.intermediate_size is not None self.gate_up_proj = MergedColumnParallelLinear( input_size=config.hidden_size, @@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + disable_tp=use_data_parallel, ) self.down_proj = RowParallelLinear( input_size=config.intermediate_size, @@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.down_proj", + disable_tp=use_data_parallel, ) self.act_and_mul = get_act_and_mul_fn(config.hidden_act) @@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module): self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", ) -> None: @@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module): self.config = config assert not config.hidden_size % config.num_attention_heads self.total_num_heads = config.num_attention_heads - tp_size = get_tensor_model_parallel_world_size() - self.n_heads = divide(config.num_attention_heads, tp_size) self.head_dim = config.hidden_size // config.num_attention_heads + assert self.total_num_heads * self.head_dim == config.hidden_size + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, @@ -1096,16 +1110,22 @@ class PixtralHFAttention(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, ) - assert self.total_num_heads * self.head_dim == config.hidden_size self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", + disable_tp=use_data_parallel, ) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.n_heads = divide(config.num_attention_heads, self.tp_size) + def forward( self, hidden_states: torch.Tensor, @@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module): self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", ) -> None: @@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module): self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5) self.attention = PixtralHFAttention( - config, quant_config=quant_config, prefix=f"{prefix}.attention" + config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.attention", ) self.feed_forward = PixtralHFMLP( - config, quant_config=quant_config, prefix=f"{prefix}.feed_forward" + config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.feed_forward", ) self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5) @@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module): self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, prefix: str = "", @@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module): PixtralHFTransformerBlock( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", ) for layer_idx in range(num_hidden_layers) @@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module): self, config: PixtralVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, @@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module): self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5) self.transformer = PixtralHFTransformer( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.transformer", ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 799afc7ca..85772c11a 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from collections.abc import Callable, Iterable, Mapping from functools import cached_property from typing import Annotated, Literal @@ -19,7 +18,7 @@ from transformers import ( from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions +from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.conv import Conv2dLayer @@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): return image_size // patch_size -# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa +# Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216 class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() @@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module): self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches - self.position_embedding = VocabParallelEmbedding( - self.num_positions, self.embed_dim - ) + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.register_buffer( "position_ids", torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)), @@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module): def interpolate_pos_encoding( self, embeddings: torch.Tensor, height: int, width: int ) -> torch.Tensor: - """ - This method is an adapted method for SigLIP (due to SigLIP not having - class embedding unlike other ViTs) that allows the model to interpolate - the pre-trained position encodings such that it can be usable on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - position_embeddings = self.position_embedding.weight.unsqueeze(0) num_patches = embeddings.shape[1] - num_positions = position_embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[1] if num_patches == num_positions and height == width: - return position_embeddings + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] - height = height // self.patch_size - width = width // self.patch_size - # we add a small number to avoid floating point error - # in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - height, width = height + 0.1, width + 0.1 - patch_pos_embed = position_embeddings.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim ) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( patch_pos_embed, - scale_factor=( - height / math.sqrt(num_positions), - width / math.sqrt(num_positions), - ), + size=(new_height, new_width), mode="bicubic", align_corners=False, ) - if ( - int(height) != patch_pos_embed.shape[-2] - or int(width) != patch_pos_embed.shape[-1] - ): - raise ValueError( - "Width or height does not match with " - "the interpolated position embeddings" - ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed @@ -377,6 +354,7 @@ class SiglipAttention(nn.Module): self, config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], @@ -389,19 +367,25 @@ class SiglipAttention(nn.Module): self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( - f"embed_dim must be divisible by num_heads (got " - "`embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." + f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and " + f"`num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout + + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, total_num_heads=self.num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + disable_tp=use_data_parallel, ) self.out_proj = RowParallelLinear( @@ -409,17 +393,29 @@ class SiglipAttention(nn.Module): output_size=self.embed_dim, quant_config=quant_config, prefix=f"{prefix}.out_proj", + disable_tp=use_data_parallel, ) - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn = attn_cls( - self.num_heads_per_partition, - self.head_dim, - self.scale, - prefix=f"{prefix}.attn", - ) + if attn_cls == MMEncoderAttention: + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + multimodal_config=multimodal_config, + ) + else: + self.attn = attn_cls( + self.num_heads_per_partition, + self.head_dim, + self.scale, + prefix=f"{prefix}.attn", + ) def forward( self, @@ -439,12 +435,19 @@ class SiglipMLP(nn.Module): self, config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", ) -> None: super().__init__() self.config = config + use_data_parallel = ( + multimodal_config.mm_encoder_tp_mode == "data" + if multimodal_config + else False + ) self.activation_fn = get_act_fn(config.hidden_act) + # Special handling for BNB and torchao quantization if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]: quantizable = True @@ -454,17 +457,20 @@ class SiglipMLP(nn.Module): quantizable = ( config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0 ) + self.fc1 = ColumnParallelLinear( config.hidden_size, config.intermediate_size, quant_config=quant_config if quantizable else None, prefix=f"{prefix}.fc1", + disable_tp=use_data_parallel, ) self.fc2 = RowParallelLinear( config.intermediate_size, config.hidden_size, quant_config=quant_config if quantizable else None, prefix=f"{prefix}.fc2", + disable_tp=use_data_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module): self, config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, prefix: str = "", attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention], @@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module): self.self_attn = SiglipAttention( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.self_attn", attn_cls=attn_cls, ) @@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module): self.mlp = SiglipMLP( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.mlp", ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module): self, config: SiglipVisionConfig | SiglipTextConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, num_hidden_layers_override: int | None = None, *, prefix: str = "", @@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module): SiglipEncoderLayer( config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.layers.{layer_idx}", attn_cls=attn_cls, ) @@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): self, config: SiglipVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, prefix: str = "", ) -> None: super().__init__() @@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP( - config=config, quant_config=quant_config, prefix=f"{prefix}.mlp" + config=config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=f"{prefix}.mlp", ) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: @@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module): self, config: SiglipVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, @@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module): self.encoder = SiglipEncoder( config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", attn_cls=MMEncoderAttention, @@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module): self.head = SiglipMultiheadAttentionPoolingHead( config=config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=f"{prefix}.head", ) @@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module): class SiglipVisionModel(nn.Module): - config_class = SiglipVisionConfig - main_input_name = "pixel_values" - def __init__( self, config: SiglipVisionConfig, quant_config: QuantizationConfig | None = None, + multimodal_config: MultiModalConfig | None = None, *, num_hidden_layers_override: int | None = None, require_post_norm: bool | None = None, @@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module): self.quant_config = quant_config self.vision_model = SiglipVisionTransformer( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_override, require_post_norm=require_post_norm, prefix=f"{prefix}.vision_model", @@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): self.vision_model = SiglipVisionTransformer( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, prefix=maybe_prefix(prefix, "vision_model"), ) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 15d0ff30e..ff5ef6c15 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -11,7 +11,6 @@ from torch.nn import functional as F from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import MultiModalConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size @@ -186,7 +185,6 @@ class Siglip2Attention(nn.Module): multimodal_config: MultiModalConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: AttentionBackendEnum | None = None, ): super().__init__() self.config = config @@ -196,12 +194,11 @@ class Siglip2Attention(nn.Module): if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." + f"(got `embed_dim`: {self.embed_dim} and " + f"`num_heads`: {self.num_heads})." ) self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.is_causal = False use_data_parallel = ( multimodal_config.mm_encoder_tp_mode == "data" @@ -233,6 +230,7 @@ class Siglip2Attention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_heads_per_partition, head_size=self.head_dim, + scale=self.scale, prefix=f"{prefix}.attn", multimodal_config=multimodal_config, ) diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 7e82a4d72..dcfd43272 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput -from vllm.config import VllmConfig +from vllm.config import MultiModalConfig, VllmConfig from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig @@ -346,6 +346,7 @@ def _build_tarsier_hf_processor( def init_vision_tower_for_tarsier( hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol quant_config: QuantizationConfig | None, + multimodal_config: MultiModalConfig | None, *, require_post_norm: bool | None = None, prefix: str = "", @@ -377,6 +378,7 @@ def init_vision_tower_for_tarsier( return CLIPVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_to_init, require_post_norm=require_post_norm, prefix=prefix, @@ -385,6 +387,7 @@ def init_vision_tower_for_tarsier( return SiglipVisionModel( vision_config, quant_config=quant_config, + multimodal_config=multimodal_config, num_hidden_layers_override=num_hidden_layers_to_init, require_post_norm=require_post_norm, prefix=prefix, @@ -414,12 +417,16 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config: TarsierHfConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config # Storing the Tarsier-specific HF config self.vision_tower = init_vision_tower_for_tarsier( config, - quant_config, + quant_config=quant_config, + multimodal_config=multimodal_config, require_post_norm=False, prefix=maybe_prefix(prefix, "vision_tower"), )