[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)

This commit is contained in:
lkchen
2025-03-04 07:43:59 -08:00
committed by GitHub
parent c8525f06fc
commit b3cf368d79
22 changed files with 249 additions and 150 deletions

View File

@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None