[Model] Standardize common vision encoders (#31947)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:33:16 +08:00
committed by GitHub
parent d1b6fe007f
commit 5576227bc1
19 changed files with 253 additions and 173 deletions

View File

@@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
@@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .ernie45 import Ernie4_5ForCausalLM
from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal
from .siglip import SiglipMLP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module):
return freqs
class SiglipMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
# Special handling for BNB and torchao quantization
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
quantizable = True
else:
# For other quantization, we require the hidden size to be a
# multiple of 64
quantizable = (
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
)
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config if quantizable else None,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(
self,
@@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module):
self.mlp = SiglipMLP(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=f"{prefix}.mlp",
)