[Model] Standardize common vision encoders (#31947)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:33:16 +08:00
committed by GitHub
parent d1b6fe007f
commit 5576227bc1
19 changed files with 253 additions and 173 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable, Iterable, Mapping
from functools import cached_property
from typing import Annotated, Literal
@@ -19,7 +18,7 @@ from transformers import (
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
return image_size // patch_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
# Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
@@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module):
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = VocabParallelEmbedding(
self.num_positions, self.embed_dim
)
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
@@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module):
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method is an adapted method for SigLIP (due to SigLIP not having
class embedding unlike other ViTs) that allows the model to interpolate
the pre-trained position encodings such that it can be usable on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1]
num_positions = position_embeddings.shape[1]
num_positions = self.position_embedding.weight.shape[1]
if num_patches == num_positions and height == width:
return position_embeddings
return self.position_embedding(self.position_ids)
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
dim = embeddings.shape[-1]
height = height // self.patch_size
width = width // self.patch_size
# we add a small number to avoid floating point error
# in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = position_embeddings.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(
height / math.sqrt(num_positions),
width / math.sqrt(num_positions),
),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if (
int(height) != patch_pos_embed.shape[-2]
or int(width) != patch_pos_embed.shape[-1]
):
raise ValueError(
"Width or height does not match with "
"the interpolated position embeddings"
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
@@ -377,6 +354,7 @@ class SiglipAttention(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
@@ -389,19 +367,25 @@ class SiglipAttention(nn.Module):
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim} and "
f"`num_heads`: {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(
@@ -409,17 +393,29 @@ class SiglipAttention(nn.Module):
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
if attn_cls == MMEncoderAttention:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
)
def forward(
self,
@@ -439,12 +435,19 @@ class SiglipMLP(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
@@ -454,17 +457,20 @@ class SiglipMLP(nn.Module):
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel,
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
@@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module):
self.self_attn = SiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
@@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
@@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
@@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
@@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
@@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module):
self.head = SiglipMultiheadAttentionPoolingHead(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.head",
)
@@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module):
class SiglipVisionModel(nn.Module):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module):
self.quant_config = quant_config
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
@@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)