[1/N] Initialize MM components in context managers (A-D) (#32632)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 14:12:42 +08:00
committed by GitHub
parent 4753f3bf69
commit b75e85dede
11 changed files with 240 additions and 268 deletions

View File

@@ -549,31 +549,31 @@ class Blip2ForConditionalGeneration(
+ 1 # include class token
)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_model = BlipVisionModel(vision_config, quant_config)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = BlipVisionModel(vision_config, quant_config)
self.query_tokens = nn.Parameter(
torch.zeros(
1, config.num_query_tokens, config.qformer_config.hidden_size
)
)
self.qformer = Blip2QFormerModel(
config.qformer_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.qformer",
)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(
config.qformer_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.qformer",
)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -614,8 +614,6 @@ class Blip2ForConditionalGeneration(
return image_features
def _process_image_pixels(self, inputs: Blip2ImagePixelInputs) -> torch.Tensor:
assert self.vision_model is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_model, pixel_values)
@@ -624,7 +622,6 @@ class Blip2ForConditionalGeneration(
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
query_tokens = self.query_tokens.expand(image_features.shape[0], -1, -1)
@@ -635,9 +632,6 @@ class Blip2ForConditionalGeneration(
return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: