[Model] Standardize common vision encoders (#31947)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -17,7 +17,7 @@ from transformers import (
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -353,6 +353,7 @@ class CLIPAttention(nn.Module):
|
||||
self,
|
||||
config: CLIPTextConfig | CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
@@ -365,18 +366,24 @@ class CLIPAttention(nn.Module):
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and "
|
||||
f"`num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
@@ -384,17 +391,29 @@ class CLIPAttention(nn.Module):
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
if attn_cls == MMEncoderAttention:
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -415,17 +434,26 @@ class CLIPMLP(nn.Module):
|
||||
self,
|
||||
config: CLIPTextConfig | CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
@@ -433,6 +461,7 @@ class CLIPMLP(nn.Module):
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -448,19 +477,27 @@ class CLIPEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: CLIPTextConfig | CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = CLIPAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp")
|
||||
self.mlp = CLIPMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -491,6 +528,7 @@ class CLIPEncoder(nn.Module):
|
||||
self,
|
||||
config: CLIPTextConfig | CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
@@ -504,11 +542,13 @@ class CLIPEncoder(nn.Module):
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
CLIPEncoderLayer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -618,6 +658,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -637,6 +678,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
self.encoder = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MMEncoderAttention,
|
||||
@@ -738,6 +780,7 @@ class CLIPVisionModel(nn.Module):
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -748,6 +791,7 @@ class CLIPVisionModel(nn.Module):
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
@@ -817,6 +861,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
self.vision_model = CLIPVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch.nn.functional as F
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -608,6 +609,7 @@ class DeepCLIPVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: CLIPVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
@@ -626,6 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
|
||||
self.transformer = CLIPEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MMEncoderAttention,
|
||||
|
||||
@@ -397,6 +397,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
|
||||
self.vision_model = DeepCLIPVisionTransformer(
|
||||
config=clip_vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -18,7 +18,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
||||
from transformers.modeling_utils import no_init_weights
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
@@ -361,6 +361,7 @@ def _build_hcxvision_hf_processor(
|
||||
def init_vision_tower_for_hcxvision(
|
||||
vision_config,
|
||||
quant_config: QuantizationConfig | None,
|
||||
multimodal_config: MultiModalConfig | None,
|
||||
*,
|
||||
use_nth_layer: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -378,6 +379,7 @@ def init_vision_tower_for_hcxvision(
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -386,6 +388,7 @@ def init_vision_tower_for_hcxvision(
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -597,18 +600,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
**kwargs: Any | None,
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
# init configs
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
# text_config
|
||||
text_config = config.text_config
|
||||
if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
|
||||
@@ -631,7 +629,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
with no_init_weights(): # weight will be loaded in from_pretrained
|
||||
self.vision_model = init_vision_tower_for_hcxvision(
|
||||
vision_config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
use_nth_layer=getattr(config, "use_nth_layer", -1),
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
|
||||
@@ -1226,8 +1226,8 @@ class IsaacVisionEmbedding(nn.Module):
|
||||
self.transformer = Siglip2VisionTransformer(
|
||||
vision_cfg,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "0"),
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "0"),
|
||||
)
|
||||
self.linear_fc1 = ColumnParallelLinear(
|
||||
hidden_dim,
|
||||
|
||||
@@ -404,6 +404,7 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
@@ -511,6 +512,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
|
||||
@@ -155,6 +155,7 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@@ -164,7 +165,8 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
|
||||
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -468,6 +468,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
def init_vision_tower_for_llava(
|
||||
hf_config: LlavaLikeConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
multimodal_config: MultiModalConfig | None,
|
||||
*,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
@@ -481,6 +482,7 @@ def init_vision_tower_for_llava(
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -489,6 +491,7 @@ def init_vision_tower_for_llava(
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -497,6 +500,7 @@ def init_vision_tower_for_llava(
|
||||
return PixtralHFVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -563,7 +567,8 @@ class LlavaForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -243,6 +243,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@@ -270,7 +271,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -321,6 +321,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@@ -331,7 +332,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -511,7 +511,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -204,7 +204,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from transformers import (
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
@@ -395,6 +395,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
|
||||
def init_vision_tower_for_llava(
|
||||
hf_config: LlavaLikeConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
multimodal_config: MultiModalConfig | None,
|
||||
*,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
@@ -409,6 +410,7 @@ def init_vision_tower_for_llava(
|
||||
return PixtralHFVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -472,7 +474,8 @@ class Mistral3ForConditionalGeneration(
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
@@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .ernie45 import Ernie4_5ForCausalLM
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal
|
||||
from .siglip import SiglipMLP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module):
|
||||
return freqs
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# Special handling for BNB and torchao quantization
|
||||
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
|
||||
quantizable = True
|
||||
else:
|
||||
# For other quantization, we require the hidden size to be a
|
||||
# multiple of 64
|
||||
quantizable = (
|
||||
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
|
||||
)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
@@ -96,6 +96,7 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
|
||||
def _init_img_processor(
|
||||
hf_config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
multimodal_config: MultiModalConfig | None,
|
||||
prefix: str = "",
|
||||
) -> CLIPVisionModel:
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
@@ -109,7 +110,8 @@ def _init_img_processor(
|
||||
|
||||
img_processor = CLIPVisionModel(
|
||||
clip_config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
prefix=prefix,
|
||||
)
|
||||
@@ -160,38 +162,15 @@ class Phi3VImageEmbeddingInputs(TensorSchema):
|
||||
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
|
||||
|
||||
|
||||
class Phi3ImageEmbeddingBase(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx: int
|
||||
self.type_feature: str
|
||||
self.img_processor: CLIPVisionModel
|
||||
|
||||
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||
TYPE_FEATURE = self.type_feature
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the img_processor
|
||||
img_feature = self.img_processor(img_embeds)
|
||||
|
||||
if TYPE_FEATURE == "patch":
|
||||
patch_feature = img_feature[:, 1:]
|
||||
return patch_feature
|
||||
|
||||
if TYPE_FEATURE == "cls_patch":
|
||||
return img_feature
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
class Phi3HDImageEmbedding(nn.Module):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None,
|
||||
multimodal_config: MultiModalConfig | None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -200,7 +179,10 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
|
||||
|
||||
self.img_processor = _init_img_processor(
|
||||
config, quant_config, prefix=f"{prefix}.img_processor"
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.img_processor",
|
||||
)
|
||||
|
||||
image_dim_out = config.img_processor["image_dim_out"]
|
||||
@@ -223,13 +205,29 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
|
||||
dim_projection = hidden_size
|
||||
depth = 2
|
||||
layers = [nn.Linear(image_dim_out * 4, dim_projection)]
|
||||
layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)]
|
||||
for _ in range(1, depth):
|
||||
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
||||
self.img_projection = nn.Sequential(*layers)
|
||||
|
||||
self.type_feature = config.img_processor.get("type_feature", "patch")
|
||||
|
||||
def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||
type_feature = self.type_feature
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the img_processor
|
||||
img_feature = self.img_processor(img_embeds)
|
||||
|
||||
if type_feature == "patch":
|
||||
patch_feature = img_feature[:, 1:]
|
||||
return patch_feature
|
||||
|
||||
if type_feature == "cls_patch":
|
||||
return img_feature
|
||||
|
||||
raise NotImplementedError(type_feature)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
|
||||
) -> torch.FloatTensor:
|
||||
@@ -582,6 +580,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
@@ -590,14 +589,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant)
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=self.quant_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.embed_tokens"),
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
config,
|
||||
self.quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
|
||||
assert config.intermediate_size is not None
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
@@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
@@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
||||
|
||||
@@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module):
|
||||
self.config = config
|
||||
assert not config.hidden_size % config.num_attention_heads
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_heads = divide(config.num_attention_heads, tp_size)
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
assert self.total_num_heads * self.head_dim == config.hidden_size
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
@@ -1096,16 +1110,22 @@ class PixtralHFAttention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
assert self.total_num_heads * self.head_dim == config.hidden_size
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.n_heads = divide(config.num_attention_heads, self.tp_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.attention = PixtralHFAttention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.attention"
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
self.feed_forward = PixtralHFMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
@@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
@@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module):
|
||||
PixtralHFTransformerBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
@@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.transformer = PixtralHFTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Literal
|
||||
@@ -19,7 +18,7 @@ from transformers import (
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -276,7 +275,7 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/siglip/modeling_siglip.py#L216
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
@@ -295,9 +294,7 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = VocabParallelEmbedding(
|
||||
self.num_positions, self.embed_dim
|
||||
)
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer(
|
||||
"position_ids",
|
||||
torch.arange(self.num_positions, dtype=torch.int64).expand((1, -1)),
|
||||
@@ -307,50 +304,30 @@ class SiglipVisionEmbeddings(nn.Module):
|
||||
def interpolate_pos_encoding(
|
||||
self, embeddings: torch.Tensor, height: int, width: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This method is an adapted method for SigLIP (due to SigLIP not having
|
||||
class embedding unlike other ViTs) that allows the model to interpolate
|
||||
the pre-trained position encodings such that it can be usable on higher
|
||||
resolution images.
|
||||
|
||||
Source:
|
||||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
||||
"""
|
||||
position_embeddings = self.position_embedding.weight.unsqueeze(0)
|
||||
num_patches = embeddings.shape[1]
|
||||
num_positions = position_embeddings.shape[1]
|
||||
num_positions = self.position_embedding.weight.shape[1]
|
||||
if num_patches == num_positions and height == width:
|
||||
return position_embeddings
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
# we add a small number to avoid floating point error
|
||||
# in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
height, width = height + 0.1, width + 0.1
|
||||
|
||||
patch_pos_embed = position_embeddings.reshape(
|
||||
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(
|
||||
1, sqrt_num_positions, sqrt_num_positions, dim
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(
|
||||
height / math.sqrt(num_positions),
|
||||
width / math.sqrt(num_positions),
|
||||
),
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
if (
|
||||
int(height) != patch_pos_embed.shape[-2]
|
||||
or int(width) != patch_pos_embed.shape[-1]
|
||||
):
|
||||
raise ValueError(
|
||||
"Width or height does not match with "
|
||||
"the interpolated position embeddings"
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return patch_pos_embed
|
||||
@@ -377,6 +354,7 @@ class SiglipAttention(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
@@ -389,19 +367,25 @@ class SiglipAttention(nn.Module):
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got "
|
||||
"`embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and "
|
||||
f"`num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
@@ -409,17 +393,29 @@ class SiglipAttention(nn.Module):
|
||||
output_size=self.embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
if attn_cls == MMEncoderAttention:
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.attn = attn_cls(
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -439,12 +435,19 @@ class SiglipMLP(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
# Special handling for BNB and torchao quantization
|
||||
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
|
||||
quantizable = True
|
||||
@@ -454,17 +457,20 @@ class SiglipMLP(nn.Module):
|
||||
quantizable = (
|
||||
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
|
||||
)
|
||||
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc1",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc2",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -479,6 +485,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
@@ -490,6 +497,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.self_attn = SiglipAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -497,6 +505,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -524,6 +533,7 @@ class SiglipEncoder(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
@@ -543,6 +553,7 @@ class SiglipEncoder(nn.Module):
|
||||
SiglipEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -647,6 +658,7 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -658,7 +670,10 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
)
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(
|
||||
config=config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
@@ -683,6 +698,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -698,6 +714,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.encoder = SiglipEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MMEncoderAttention,
|
||||
@@ -726,6 +743,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.head",
|
||||
)
|
||||
|
||||
@@ -812,13 +830,11 @@ class SiglipVisionTransformer(nn.Module):
|
||||
|
||||
|
||||
class SiglipVisionModel(nn.Module):
|
||||
config_class = SiglipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -829,7 +845,8 @@ class SiglipVisionModel(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
@@ -1023,6 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -186,7 +185,6 @@ class Siglip2Attention(nn.Module):
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -196,12 +194,11 @@ class Siglip2Attention(nn.Module):
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
f"(got `embed_dim`: {self.embed_dim} and "
|
||||
f"`num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
@@ -233,6 +230,7 @@ class Siglip2Attention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads_per_partition,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers.models.llava import LlavaProcessor
|
||||
from transformers.processing_utils import ProcessingKwargs, Unpack
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -346,6 +346,7 @@ def _build_tarsier_hf_processor(
|
||||
def init_vision_tower_for_tarsier(
|
||||
hf_config: TarsierHfConfig, # Use the Tarsier specific config protocol
|
||||
quant_config: QuantizationConfig | None,
|
||||
multimodal_config: MultiModalConfig | None,
|
||||
*,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
@@ -377,6 +378,7 @@ def init_vision_tower_for_tarsier(
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_to_init,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -385,6 +387,7 @@ def init_vision_tower_for_tarsier(
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_to_init,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
@@ -414,12 +417,16 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config: TarsierHfConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config # Storing the Tarsier-specific HF config
|
||||
self.vision_tower = init_vision_tower_for_tarsier(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user