[Models]: Make Multimodal config implicit in ViT implementation (#31972)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -74,6 +74,7 @@ from .utils import init_vllm_registered_model, maybe_prefix
|
||||
from .vision import (
|
||||
VisionEncoderInfo,
|
||||
VisionFeatureSelectStrategy,
|
||||
is_vit_use_data_parallel,
|
||||
resolve_visual_encoder_outputs,
|
||||
)
|
||||
|
||||
@@ -1065,17 +1066,12 @@ class PixtralHFMLP(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
|
||||
assert config.intermediate_size is not None
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@@ -1108,7 +1104,6 @@ class PixtralHFAttention(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1120,11 +1115,7 @@ class PixtralHFAttention(nn.Module):
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
assert self.total_num_heads * self.head_dim == config.hidden_size
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
use_data_parallel = is_vit_use_data_parallel()
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
@@ -1189,7 +1180,6 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1199,13 +1189,11 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
self.attention = PixtralHFAttention(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
self.feed_forward = PixtralHFMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@@ -1232,7 +1220,6 @@ class PixtralHFTransformer(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
@@ -1249,7 +1236,6 @@ class PixtralHFTransformer(nn.Module):
|
||||
PixtralHFTransformerBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
@@ -1281,7 +1267,6 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -1302,7 +1287,6 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self.transformer = PixtralHFTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user