[Model] Standardize common vision encoders (#31947)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
|
||||
assert config.intermediate_size is not None
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
@@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=config.intermediate_size,
|
||||
@@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
|
||||
|
||||
@@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module):
|
||||
self.config = config
|
||||
assert not config.hidden_size % config.num_attention_heads
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.n_heads = divide(config.num_attention_heads, tp_size)
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
assert self.total_num_heads * self.head_dim == config.hidden_size
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
if multimodal_config
|
||||
else False
|
||||
)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=config.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
@@ -1096,16 +1110,22 @@ class PixtralHFAttention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
assert self.total_num_heads * self.head_dim == config.hidden_size
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.tp_size = (
|
||||
1 if use_data_parallel else get_tensor_model_parallel_world_size()
|
||||
)
|
||||
self.n_heads = divide(config.num_attention_heads, self.tp_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
|
||||
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.attention = PixtralHFAttention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.attention"
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
self.feed_forward = PixtralHFMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
@@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
@@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module):
|
||||
PixtralHFTransformerBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.layers.{layer_idx}",
|
||||
)
|
||||
for layer_idx in range(num_hidden_layers)
|
||||
@@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
@@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.transformer = PixtralHFTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user