[Model] Standardize common vision encoders (#31947)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:33:16 +08:00
committed by GitHub
parent d1b6fe007f
commit 5576227bc1
19 changed files with 253 additions and 173 deletions

View File

@@ -28,7 +28,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -1043,11 +1043,18 @@ class PixtralHFMLP(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
super().__init__()
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
assert config.intermediate_size is not None
self.gate_up_proj = MergedColumnParallelLinear(
input_size=config.hidden_size,
@@ -1055,6 +1062,7 @@ class PixtralHFMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
@@ -1062,6 +1070,7 @@ class PixtralHFMLP(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_and_mul = get_act_and_mul_fn(config.hidden_act)
@@ -1077,6 +1086,7 @@ class PixtralHFAttention(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1085,10 +1095,14 @@ class PixtralHFAttention(nn.Module):
self.config = config
assert not config.hidden_size % config.num_attention_heads
self.total_num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.n_heads = divide(config.num_attention_heads, tp_size)
self.head_dim = config.hidden_size // config.num_attention_heads
assert self.total_num_heads * self.head_dim == config.hidden_size
use_data_parallel = (
multimodal_config.mm_encoder_tp_mode == "data"
if multimodal_config
else False
)
self.qkv_proj = QKVParallelLinear(
hidden_size=config.hidden_size,
head_size=self.head_dim,
@@ -1096,16 +1110,22 @@ class PixtralHFAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
assert self.total_num_heads * self.head_dim == config.hidden_size
self.o_proj = RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
disable_tp=use_data_parallel,
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.n_heads = divide(config.num_attention_heads, self.tp_size)
def forward(
self,
hidden_states: torch.Tensor,
@@ -1147,6 +1167,7 @@ class PixtralHFTransformerBlock(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1154,10 +1175,16 @@ class PixtralHFTransformerBlock(nn.Module):
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
self.attention = PixtralHFAttention(
config, quant_config=quant_config, prefix=f"{prefix}.attention"
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.attention",
)
self.feed_forward = PixtralHFMLP(
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.feed_forward",
)
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
@@ -1183,6 +1210,7 @@ class PixtralHFTransformer(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
prefix: str = "",
@@ -1199,6 +1227,7 @@ class PixtralHFTransformer(nn.Module):
PixtralHFTransformerBlock(
config=config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.layers.{layer_idx}",
)
for layer_idx in range(num_hidden_layers)
@@ -1230,6 +1259,7 @@ class PixtralHFVisionModel(nn.Module):
self,
config: PixtralVisionConfig,
quant_config: QuantizationConfig | None = None,
multimodal_config: MultiModalConfig | None = None,
*,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
@@ -1249,7 +1279,8 @@ class PixtralHFVisionModel(nn.Module):
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
self.transformer = PixtralHFTransformer(
config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.transformer",
)