[2/N] Initialize MM components in context managers (E-H) (#32641)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 16:12:56 +08:00
committed by GitHub
parent 148117ea2e
commit e1a34c3a5d
12 changed files with 161 additions and 189 deletions

View File

@@ -15,7 +15,6 @@ from einops import rearrange
from timm.layers import LayerNorm, LayerNorm2d
from timm.models.regnet import RegStage
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
from transformers.modeling_utils import no_init_weights
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
@@ -625,8 +624,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config, vision_config
)
# init models & parameters
with no_init_weights(): # weight will be loaded in from_pretrained
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_model = init_vision_tower_for_hcxvision(
vision_config,
quant_config=quant_config,
@@ -635,20 +633,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mm_projector = self._init_mm_projector(config, text_config, vision_config)
self.mm_projector = self._init_mm_projector(
config, text_config, vision_config
)
self.lm_head_vocab_size = getattr(
text_config, "padded_vocab_size", text_config.vocab_size
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.anyres:
self.image_newline = nn.Parameter(
torch.empty(text_config.hidden_size, dtype=self.dtype)
)
if config.anyres:
self.image_newline = nn.Parameter(
torch.empty(text_config.hidden_size, dtype=self.dtype)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.config = config
@@ -726,9 +724,6 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(
self,
**kwargs: object,