[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_patches_flat)
|
||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs
|
||||
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user