[Model] Standardize common vision encoders (#31947)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Literal
|
||||
@@ -19,7 +18,7 @@ from transformers import (
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
@@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = VocabParallelEmbedding(
|
||||
self.num_positions, self.embed_dim
|
||||
)
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
|
||||
@@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
def interpolate_pos_encoding(
|
||||
self, embeddings: torch.Tensor, height: int, width: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This method is an adapted method for SigLIP (due to SigLIP not having
|
||||
class embedding unlike other ViTs) that allows the model to interpolate
|
||||
the pre-trained position encodings such that it can be usable on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
position_embeddings = self.position_embedding.weight.unsqueeze(0)
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = position_embeddings.shape[1]
|
||||
num_positions = self.position_embedding.weight.shape[1]
|
||||
if num_patches == num_positions and height == width:
|
||||
return position_embeddings
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
# we add a small number to avoid floating point error
|
||||
# in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
height, width = height + 0.1, width + 0.1
|
||||
|
||||
patch_pos_embed = position_embeddings.reshape(
|
||||
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(
|
||||
1, sqrt_num_positions, sqrt_num_positions, dim
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(
|
||||
height / math.sqrt(num_positions),
|
||||
width / math.sqrt(num_positions),
|
||||
),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if (
|
||||
int(height) != patch_pos_embed.shape[-2]
|
||||
or int(width) != patch_pos_embed.shape[-1]
|
||||
):
|
||||
raise ValueError(
|
||||
"Width or height does not match with "
|
||||
"the interpolated position embeddings"
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
@@ -377,6 +354,7 @@ class SiglipAttention(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
@@ -389,19 +367,25 @@ class SiglipAttention(nn.Module):
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got "
|
||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and "
|
||||
f"`num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
@@ -409,17 +393,29 @@ class SiglipAttention(nn.Module):
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
if attn_cls == MMEncoderAttention:
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -439,12 +435,19 @@ class SiglipMLP(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
# Special handling for BNB and torchao quantization
|
||||
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
|
||||
quantizable = True
|
||||
@@ -454,17 +457,20 @@ class SiglipMLP(nn.Module):
|
||||
quantizable = (
|
||||
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
|
||||
)
|
||||
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
@@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.self_attn = SiglipAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
@@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module):
|
||||
SiglipEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
@@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.encoder = SiglipEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MMEncoderAttention,
|
||||
@@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.head",
|
||||
)
|
||||
|
||||
@@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module):
|
||||
|
||||
|
||||
class SiglipVisionModel(nn.Module):
|
||||
config_class = SiglipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
@@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user