[Bugfix] Check dimensions of multimodal embeddings in V1 (#15816)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-01 00:01:35 +08:00
committed by GitHub
parent e5ef4fa99a
commit 09e974d483
14 changed files with 98 additions and 37 deletions

View File

@@ -39,7 +39,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
@@ -66,10 +65,13 @@ class FuyuImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
@@ -322,16 +324,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs(
type="image_patches",
@@ -351,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(
@@ -358,13 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
#return vision_embeddings
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,