[2/N] Initialize MM components in context managers (E-H) (#32641)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 16:12:56 +08:00
committed by GitHub
parent 148117ea2e
commit e1a34c3a5d
12 changed files with 161 additions and 189 deletions

View File

@@ -287,16 +287,20 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self.image_token_id = _IMAGE_TOKEN_ID
self.image_feature_size = config.patch_size**2 * config.num_channels
self.vision_embed_tokens = ColumnParallelLinear(
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_embed_tokens = ColumnParallelLinear(
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
with self._mark_language_model(vllm_config):
self.language_model = PersimmonForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@@ -323,14 +327,10 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat = image_input["image_patches_flat"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(image_patches_flat)
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -361,10 +361,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states
)
return logits
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)