[Bugfix] Check dimensions of multimodal embeddings in V1 (#15816)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -613,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
) -> list[torch.Tensor]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
@@ -625,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
)
|
||||
image_embeds = self.multi_modal_projector(image_features)
|
||||
|
||||
return image_embeds.split(num_patches.tolist())
|
||||
return [
|
||||
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
Reference in New Issue
Block a user