[Model] Standardize common vision encoders (#31947)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -38,10 +38,8 @@ from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -77,6 +75,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .ernie45 import Ernie4_5ForCausalLM
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal
|
||||
from .siglip import SiglipMLP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -657,46 +656,6 @@ class SigLIPRotaryEmbedding(nn.Module):
|
||||
return freqs
|
||||
|
||||
|
||||
class SiglipMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
# Special handling for BNB and torchao quantization
|
||||
if quant_config and quant_config.get_name() in ["bitsandbytes", "torchao"]:
|
||||
quantizable = True
|
||||
else:
|
||||
# For other quantization, we require the hidden size to be a
|
||||
# multiple of 64
|
||||
quantizable = (
|
||||
config.hidden_size % 64 == 0 and config.intermediate_size % 64 == 0
|
||||
)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config if quantizable else None,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -720,6 +679,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
self.mlp = SiglipMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user