[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-08-06 16:55:31 +08:00
committed by GitHub
parent 9118217f58
commit 1f26efbb3a
14 changed files with 453 additions and 267 deletions

View File

@@ -2,12 +2,12 @@
within a vision language model."""
import math
from typing import Optional, Tuple
from typing import Iterable, Optional, Tuple
import torch
from PIL import Image
from torch import nn
from transformers import SiglipConfig, SiglipVisionConfig
from transformers import SiglipVisionConfig
from transformers.models.siglip.modeling_siglip import SiglipAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention
@@ -22,13 +22,15 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
# Since interpolation is applied, the image size need not be divisible
# assert image_size % patch_size == 0
return image_size // patch_size
@@ -454,7 +456,7 @@ class SiglipEncoderLayer(nn.Module):
def __init__(
self,
config: SiglipConfig,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -474,7 +476,7 @@ class SiglipEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
) -> Tuple[torch.Tensor]:
) -> Tuple[torch.Tensor, None]:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
@@ -493,22 +495,27 @@ class SiglipEncoder(nn.Module):
def __init__(
self,
config: SiglipConfig,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
SiglipEncoderLayer(
config,
quant_config=quant_config,
) for _ in range(config.num_hidden_layers)
SiglipEncoderLayer(config, quant_config=quant_config)
for _ in range(num_hidden_layers)
])
def forward(
self,
inputs_embeds: torch.Tensor,
) -> Tuple:
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states, _ = encoder_layer(hidden_states)
@@ -553,6 +560,7 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
self.config = config
@@ -562,6 +570,7 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
@@ -600,11 +609,13 @@ class SiglipVisionModel(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
self.vision_model = SiglipVisionTransformer(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
def get_input_embeddings(self) -> nn.Module:
@@ -619,3 +630,19 @@ class SiglipVisionModel(nn.Module):
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)