[Core] Enable inputs_embeds_size separate from hidden_size (#29741)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-11-30 17:31:12 +08:00
committed by GitHub
parent 47539cfd3e
commit 64bc09ba27
10 changed files with 123 additions and 18 deletions

View File

@@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
import math
from collections.abc import Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping
from functools import cached_property
from typing import Annotated, Literal
@@ -976,6 +974,7 @@ class SiglipTextEmbeddings(nn.Module):
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
@@ -1145,6 +1144,41 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids(
self,
input_ids: torch.Tensor,
embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
*,
is_multimodal: torch.Tensor | None,
handle_oov_mm_token: bool,
) -> torch.Tensor:
inputs_embeds = super()._embed_text_input_ids(
input_ids,
embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
# NOTE: inputs_embeds in model runner has size text_config.projection_size
# (instead of text_config.hidden_size) to accommodate image embeddings
inputs_embeds_size = self.text_projection_size
if inputs_embeds.shape[1] < inputs_embeds_size:
inputs_embeds = torch.cat(
[
inputs_embeds,
inputs_embeds.new_empty(
inputs_embeds.shape[0],
inputs_embeds_size - inputs_embeds.shape[1],
),
],
dim=1,
)
elif inputs_embeds.shape[1] > inputs_embeds_size:
# No need to handle this case for now
raise NotImplementedError
return inputs_embeds
def embed_input_ids(
self,
input_ids: torch.Tensor,
@@ -1190,6 +1224,15 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
if not self._is_text_input:
return inputs_embeds
# NOTE: inputs_embeds in model runner has size text_config.projection_size
# (instead of text_config.hidden_size) to accommodate image embeddings
hidden_size = self.text_embed_dim
if inputs_embeds.shape[1] > hidden_size:
inputs_embeds = inputs_embeds[:, :hidden_size]
elif inputs_embeds.shape[1] < hidden_size:
# No need to handle this case for now
raise NotImplementedError
return self.get_text_features(input_ids, positions, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):