[Model] Standardize common vision encoders (#31947)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:33:16 +08:00
committed by GitHub
parent d1b6fe007f
commit 5576227bc1
19 changed files with 253 additions and 173 deletions

View File

@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from itertools import accumulate
from typing import Annotated, Any, Literal
from typing import Annotated, Literal
import numpy as np
import torch
@@ -18,7 +18,7 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
@@ -361,6 +361,7 @@ def _build_hcxvision_hf_processor(
def init_vision_tower_for_hcxvision(
vision_config,
quant_config: QuantizationConfig | None,
multimodal_config: MultiModalConfig | None,
*,
use_nth_layer: int | None = None,
require_post_norm: bool | None = None,
@@ -378,6 +379,7 @@ def init_vision_tower_for_hcxvision(
return CLIPVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -386,6 +388,7 @@ def init_vision_tower_for_hcxvision(
return SiglipVisionModel(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
prefix=prefix,
@@ -597,18 +600,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
**kwargs: Any | None,
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
# init configs
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
# text_config
text_config = config.text_config
if text_config.model_type in ["gpt2", "hyperclovax", "llama"]:
@@ -631,7 +629,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
with no_init_weights(): # weight will be loaded in from_pretrained
self.vision_model = init_vision_tower_for_hcxvision(
vision_config,
quant_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
use_nth_layer=getattr(config, "use_nth_layer", -1),
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"),