[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
This commit is contained in:
@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import JSONTree, json_map_leaves
|
||||
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
|
||||
|
||||
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
|
||||
SupportsQuant)
|
||||
@@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
return embeds_in_batch
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs
|
||||
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return [
|
||||
nested_embeds = [
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["feat_is_patch"],
|
||||
@@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
]
|
||||
return flatten_2d_lists(nested_embeds)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user