[2/N] Initialize MM components in context managers (E-H) (#32641)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 16:12:56 +08:00
committed by GitHub
parent 148117ea2e
commit e1a34c3a5d
12 changed files with 161 additions and 189 deletions

View File

@@ -522,25 +522,27 @@ class Gemma3ForConditionalGeneration(
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = Gemma3MultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma3ForCausalLM"],
)
logit_scale = getattr(config, "logit_scale", 1.0)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Gemma3ForCausalLM"],
)
if hasattr(self.language_model, "logits_processor"):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
logit_scale = getattr(config, "logit_scale", 1.0)
if hasattr(self.language_model, "logits_processor"):
# The logits processor can be unset if we're using
# automatic conversion to pooling model.
self.language_model.logits_processor.scale *= logit_scale
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -579,8 +581,6 @@ class Gemma3ForConditionalGeneration(
self,
image_input: Gemma3ImageInputs,
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]
@@ -592,9 +592,6 @@ class Gemma3ForConditionalGeneration(
return [e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: