[Models]: Make Multimodal config implicit in ViT implementation (#31972)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -11,7 +11,6 @@ from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
||||
@@ -23,7 +22,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .vision import should_torch_compile_mm_vit
|
||||
from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit
|
||||
|
||||
|
||||
class Siglip2VisionEmbeddings(nn.Module):
|
||||
@@ -154,7 +153,6 @@ class Siglip2Attention(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -171,10 +169,7 @@ class Siglip2Attention(nn.Module):
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config is not None
|
||||
and multimodal_config.mm_encoder_tp_mode == "data"
|
||||
)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
assert self.num_heads % tp_size == 0
|
||||
self.num_heads_per_partition = self.num_heads // tp_size
|
||||
@@ -199,7 +194,6 @@ class Siglip2Attention(nn.Module):
|
||||
head_size=self.head_dim,
|
||||
scale=self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -241,16 +235,12 @@ class Siglip2MLP(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
use_data_parallel = (
|
||||
multimodal_config is not None
|
||||
and multimodal_config.mm_encoder_tp_mode == "data"
|
||||
)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
@@ -282,7 +272,6 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -291,14 +280,12 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
self.self_attn = Siglip2Attention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
@@ -344,7 +331,6 @@ class Siglip2Encoder(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -354,7 +340,6 @@ class Siglip2Encoder(nn.Module):
|
||||
Siglip2EncoderLayer(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
@@ -383,7 +368,6 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -397,7 +381,6 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self.encoder = Siglip2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
@@ -438,7 +421,6 @@ class Siglip2Model(torch.nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -446,7 +428,6 @@ class Siglip2Model(torch.nn.Module):
|
||||
self.vision_model = Siglip2VisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user