[VLM] Add max-count checking in data parser for single image models (#11661)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2025-01-01 14:15:21 +08:00
committed by GitHub
parent 4db72e57f6
commit 365801fedd
6 changed files with 48 additions and 11 deletions

View File

@@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext):
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)

View File

@@ -31,6 +31,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)

View File

@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -54,7 +54,7 @@ MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
data: torch.Tensor
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
@@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict):
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `data`.
This is used to restore the first two dimensions of `flat_data`.
"""
@@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor)
@@ -304,7 +307,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return FuyuImagePatchInputs(
type="image_patches",
data=self._validate_pixel_values(
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
)
@@ -313,12 +316,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
image_patches = image_input["data"]
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
return vision_embeddings.split(patches_per_image, dim=0)
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)