[Models]: Make Multimodal config implicit in ViT implementation (#31972)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-24 20:34:26 +08:00
committed by GitHub
parent 6450b536a6
commit 9ad7f89f55
38 changed files with 118 additions and 470 deletions

View File

@@ -16,7 +16,7 @@ from transformers import (
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.encoder_only_attention import (
@@ -64,6 +64,7 @@ from .vision import (
VisionFeatureSelectStrategy,
VisionFeatureSelectStrategyStr,
get_num_selected_vision_tokens,
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
)
@@ -356,7 +357,6 @@ class SiglipAttention(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
@@ -376,11 +376,7 @@ class SiglipAttention(nn.Module):
self.scale = self.head_dim**-0.5
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
@@ -409,7 +405,6 @@ class SiglipAttention(nn.Module):
self.head_dim,
self.scale,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
else:
self.attn = attn_cls(
@@ -437,17 +432,12 @@ class SiglipMLP(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
use_data_parallel = is_vit_use_data_parallel()
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
@@ -487,7 +477,6 @@ class SiglipEncoderLayer(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
@@ -499,7 +488,6 @@ class SiglipEncoderLayer(nn.Module):
self.self_attn = SiglipAttention(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.self_attn",
attn_cls=attn_cls,
)
@@ -507,7 +495,6 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@@ -535,7 +522,6 @@ class SiglipEncoder(nn.Module):
self,
config: SiglipVisionConfig | SiglipTextConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
@@ -555,7 +541,6 @@ class SiglipEncoder(nn.Module):
SiglipEncoderLayer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
attn_cls=attn_cls,
)
@@ -660,7 +645,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -674,7 +658,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
self.mlp = SiglipMLP(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)
@@ -700,7 +683,6 @@ class SiglipVisionTransformer(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -717,7 +699,6 @@ class SiglipVisionTransformer(nn.Module):
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MMEncoderAttention,
@@ -756,7 +737,6 @@ class SiglipVisionTransformer(nn.Module):
SiglipMultiheadAttentionPoolingHead(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.head",
)
if self.use_head
@@ -870,7 +850,6 @@ class SiglipVisionModel(nn.Module):
self,
config: SiglipVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -883,7 +862,6 @@ class SiglipVisionModel(nn.Module):
self.vision_model = SiglipVisionTransformer(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
@@ -1062,9 +1040,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
config: SiglipConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
if hasattr(config, "num_labels"):
config.num_labels = 0
@@ -1087,7 +1063,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
use_head=None, # Allows potential pooling head
)