[Models]: Make Multimodal config implicit in ViT implementation (#31972)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -16,7 +16,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
@@ -64,6 +64,7 @@ from .vision import (
|
||||
VisionFeatureSelectStrategy,
|
||||
VisionFeatureSelectStrategyStr,
|
||||
get_num_selected_vision_tokens,
|
||||
is_vit_use_data_parallel,
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
@@ -356,7 +357,6 @@ class SiglipAttention(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
@@ -376,11 +376,7 @@ class SiglipAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=self.embed_dim,
|
||||
head_size=self.head_dim,
|
||||
@@ -409,7 +405,6 @@ class SiglipAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
else:
|
||||
self.attn = attn_cls(
|
||||
@@ -437,17 +432,12 @@ class SiglipMLP(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
|
||||
# Special handling for BNB and torchao quantization
|
||||
@@ -487,7 +477,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
@@ -499,7 +488,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.self_attn = SiglipAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -507,7 +495,6 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
@@ -535,7 +522,6 @@ class SiglipEncoder(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig | SiglipTextConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
@@ -555,7 +541,6 @@ class SiglipEncoder(nn.Module):
|
||||
SiglipEncoderLayer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
attn_cls=attn_cls,
|
||||
)
|
||||
@@ -660,7 +645,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -674,7 +658,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
@@ -700,7 +683,6 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -717,7 +699,6 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.encoder = SiglipEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MMEncoderAttention,
|
||||
@@ -756,7 +737,6 @@ class SiglipVisionTransformer(nn.Module):
|
||||
SiglipMultiheadAttentionPoolingHead(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.head",
|
||||
)
|
||||
if self.use_head
|
||||
@@ -870,7 +850,6 @@ class SiglipVisionModel(nn.Module):
|
||||
self,
|
||||
config: SiglipVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -883,7 +862,6 @@ class SiglipVisionModel(nn.Module):
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
@@ -1062,9 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
config: SiglipConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
if hasattr(config, "num_labels"):
|
||||
config.num_labels = 0
|
||||
@@ -1087,7 +1063,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
use_head=None, # Allows potential pooling head
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user