[4/N] Initialize MM components in context managers (M-P) (#32663)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 22:06:32 +08:00
committed by GitHub
parent bb9172030e
commit fda3f03eb2
24 changed files with 290 additions and 353 deletions

View File

@@ -201,28 +201,33 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=True,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = MiniMaxVL01MultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=True,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.vision_feature_layer = config.vision_feature_layer
self.vocab_size = config.text_config.vocab_size
self.pad_token_id = -1
@@ -233,9 +238,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self.language_model.make_empty_intermediate_tensors
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _image_pixels_to_features(
self,
vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel,
@@ -302,8 +304,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
self,
inputs: MiniMaxVL01ImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
@@ -314,7 +314,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
if isinstance(image_features, torch.Tensor):
@@ -369,14 +368,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
vision_embeddings = self.embed_multimodal(**kwargs)
inputs_embeds = self.embed_input_ids(
input_ids,
vision_embeddings,
is_multimodal=input_ids == self.config.image_token_index,
)
input_ids = None
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds