[Model] Standardize common vision encoders (#31947)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:33:16 +08:00
committed by GitHub
parent d1b6fe007f
commit 5576227bc1
19 changed files with 253 additions and 173 deletions

View File

@@ -17,7 +17,7 @@ from transformers import (
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -353,6 +353,7 @@ class CLIPAttention(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention],
@@ -365,18 +366,24 @@ class CLIPAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and "
f"`num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
@@ -384,17 +391,29 @@ class CLIPAttention(nn.Module):
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
if attn_cls == MMEncoderAttention:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
def forward(
self,
@@ -415,17 +434,26 @@ class CLIPMLP(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
@@ -433,6 +461,7 @@ class CLIPMLP(nn.Module):
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -448,19 +477,27 @@ class CLIPEncoderLayer(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
self.self_attn = CLIPAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp")
self.mlp = CLIPMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -491,6 +528,7 @@ class CLIPEncoder(nn.Module):
self,
config: CLIPTextConfig | CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
@@ -504,11 +542,13 @@ class CLIPEncoder(nn.Module):
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList(
[
CLIPEncoderLayer(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
@@ -618,6 +658,7 @@ class CLIPVisionTransformer(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -637,6 +678,7 @@ class CLIPVisionTransformer(nn.Module):
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
@@ -738,6 +780,7 @@ class CLIPVisionModel(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -748,6 +791,7 @@ class CLIPVisionModel(nn.Module):
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
@@ -817,6 +861,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = CLIPVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)

View File

@@ -19,6 +19,7 @@ import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -608,6 +609,7 @@ class DeepCLIPVisionTransformer(nn.Module):
self,
config: CLIPVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
@@ -626,6 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
self.transformer = CLIPEncoder(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,

View File

@@ -397,6 +397,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)

View File

@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from itertools import accumulate
from typing import Annotated, Any, Literal
from typing import Annotated, Literal
import numpy as np
import torch
@@ -18,7 +18,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
@@ -361,6 +361,7 @@ def _build_hcxvision_hf_processor(
def init_vision_tower_for_hcxvision(
vision_config,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
use_nth_layer: int | None = None,
require_post_norm: bool | None = None,
@@ -378,6 +379,7 @@ def init_vision_tower_for_hcxvision(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -386,6 +388,7 @@ def init_vision_tower_for_hcxvision(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -597,18 +600,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs: Any | None,
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
# init configs
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# text_config
text_config = config.text_config
if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
@@ -631,7 +629,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with no_init_weights(): # weight will be loaded in from_pretrained
self.vision_model = init_vision_tower_for_hcxvision(
vision_config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
use_nth_layer=getattr(config, "use_nth_layer", -1),
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"),

View File

@@ -1226,8 +1226,8 @@ class IsaacVisionEmbedding(nn.Module):
self.transformer = Siglip2VisionTransformer(
vision_cfg,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "0"),
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "0"),
)
self.linear_fc1 = ColumnParallelLinear(
hidden_dim,

View File

@@ -404,6 +404,7 @@ class KeyeSiglipAttention(nn.Module):
self.attn = MMEncoderAttention(
num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
@@ -511,6 +512,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)

View File

@@ -155,6 +155,7 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
@@ -164,7 +165,8 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -468,6 +468,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
require_post_norm: bool | None = None,
prefix: str = "",
@@ -481,6 +482,7 @@ def init_vision_tower_for_llava(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -489,6 +491,7 @@ def init_vision_tower_for_llava(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -497,6 +500,7 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -563,7 +567,8 @@ class LlavaForConditionalGeneration(
if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -243,6 +243,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
@@ -270,7 +271,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -321,6 +321,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
@@ -331,7 +332,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -511,7 +511,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -204,7 +204,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -16,7 +16,7 @@ from transformers import (
from transformers.models.pixtral import PixtralProcessor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -395,6 +395,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
require_post_norm: bool | None = None,
prefix: str = "",
@@ -409,6 +410,7 @@ def init_vision_tower_for_llava(
return PixtralHFVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -472,7 +474,8 @@ class Mistral3ForConditionalGeneration(
if multimodal_config.get_limit_per_prompt("image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)

View File

@@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
@@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .ernie45 import Ernie4_5ForCausalLM
from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal
from .siglip import SiglipMLP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module):
return freqs
class SiglipMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(
self,
@@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)

View File

@@ -29,7 +29,7 @@ from transformers import (
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -96,6 +96,7 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
def _init_img_processor(
hf_config: PretrainedConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "",
) -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
@@ -109,7 +110,8 @@ def _init_img_processor(
img_processor = CLIPVisionModel(
clip_config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
@@ -160,38 +162,15 @@ class Phi3VImageEmbeddingInputs(TensorSchema):
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer_idx: int
self.type_feature: str
self.img_processor: CLIPVisionModel
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
TYPE_FEATURE = self.type_feature
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds)
if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
return patch_feature
if TYPE_FEATURE == "cls_patch":
return img_feature
raise NotImplementedError
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3HDImageEmbedding(nn.Module):
"""Phi3 Image embedding with HD transform."""
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
prefix: str = "",
) -> None:
super().__init__()
@@ -200,7 +179,10 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
self.img_processor = _init_img_processor(
config, quant_config, prefix=f"{prefix}.img_processor"
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.img_processor",
)
image_dim_out = config.img_processor["image_dim_out"]
@@ -223,13 +205,29 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
dim_projection = hidden_size
depth = 2
layers = [nn.Linear(image_dim_out * 4, dim_projection)]
layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)]
for _ in range(1, depth):
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.img_projection = nn.Sequential(*layers)
self.type_feature = config.img_processor.get("type_feature", "patch")
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
type_feature = self.type_feature
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds)
if type_feature == "patch":
patch_feature = img_feature[:, 1:]
return patch_feature
if type_feature == "cls_patch":
return img_feature
raise NotImplementedError(type_feature)
def forward(
self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
) -> torch.FloatTensor:
@@ -582,6 +580,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
@@ -590,14 +589,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config,
self.quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
)

View File

@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
@@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
@@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
@@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module):
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads
assert self.total_num_heads * self.head_dim == config.hidden_size
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
@@ -1096,16 +1110,22 @@ class PixtralHFAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
disable_tp=use_data_parallel,
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.n_heads = divide(config.num_attention_heads, self.tp_size)
def forward(
self,
hidden_states: torch.Tensor,
@@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module):
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(
config, quant_config=quant_config, prefix=f"{prefix}.attention"
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = PixtralHFMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.feed_forward",
)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
@@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
@@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module):
PixtralHFTransformerBlock(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
@@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module):
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable, Iterable, Mapping
from functools import cached_property
from typing import Annotated, Literal
@@ -19,7 +18,7 @@ from transformers import (
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
return image_size // patch_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
# Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
@@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = VocabParallelEmbedding(
self.num_positions, self.embed_dim
)
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
@@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module):
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method is an adapted method for SigLIP (due to SigLIP not having
class embedding unlike other ViTs) that allows the model to interpolate
the pre-trained position encodings such that it can be usable on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1]
num_positions = position_embeddings.shape[1]
num_positions = self.position_embedding.weight.shape[1]
if num_patches == num_positions and height == width:
return position_embeddings
return self.position_embedding(self.position_ids)
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
height = height // self.patch_size
width = width // self.patch_size
# we add a small number to avoid floating point error
# in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = position_embeddings.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(
height / math.sqrt(num_positions),
width / math.sqrt(num_positions),
),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if (
int(height) != patch_pos_embed.shape[-2]
or int(width) != patch_pos_embed.shape[-1]
):
raise ValueError(
"Width or height does not match with "
"the interpolated position embeddings"
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
@@ -377,6 +354,7 @@ class SiglipAttention(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
@@ -389,19 +367,25 @@ class SiglipAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and "
f"`num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
@@ -409,17 +393,29 @@ class SiglipAttention(nn.Module):
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
if attn_cls == MMEncoderAttention:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
def forward(
self,
@@ -439,12 +435,19 @@ class SiglipMLP(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
@@ -454,17 +457,20 @@ class SiglipMLP(nn.Module):
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
@@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module):
self.self_attn = SiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
@@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
@@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
@@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
@@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
@@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module):
self.head = SiglipMultiheadAttentionPoolingHead(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.head",
)
@@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module):
class SiglipVisionModel(nn.Module):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module):
self.quant_config = quant_config
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
@@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)

View File

@@ -11,7 +11,6 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -186,7 +185,6 @@ class Siglip2Attention(nn.Module):
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: AttentionBackendEnum | None = None,
):
super().__init__()
self.config = config
@@ -196,12 +194,11 @@ class Siglip2Attention(nn.Module):
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
f"(got `embed_dim`: {self.embed_dim} and "
f"`num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
@@ -233,6 +230,7 @@ class Siglip2Attention(nn.Module):
self.attn = MMEncoderAttention(
num_heads=self.num_heads_per_partition,
head_size=self.head_dim,
scale=self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)

View File

@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
from transformers.processing_utils import ProcessingKwargs, Unpack
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from vllm.config import VllmConfig
from vllm.config import MultiModalConfig, VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -346,6 +346,7 @@ def _build_tarsier_hf_processor(
def init_vision_tower_for_tarsier(
hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
require_post_norm: bool | None = None,
prefix: str = "",
@@ -377,6 +378,7 @@ def init_vision_tower_for_tarsier(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -385,6 +387,7 @@ def init_vision_tower_for_tarsier(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_to_init,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -414,12 +417,16 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config: TarsierHfConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config # Storing the Tarsier-specific HF config
self.vision_tower = init_vision_tower_for_tarsier(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)