diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 620b6b6e2..87d33d1b7 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1221,49 +1221,33 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing num_videos: int, overrides: VideoDummyOptions | None = None, ): - if overrides: - if overrides.num_frames: - if overrides.num_frames > num_frames: - logger.warning( - "video.num_frames override (%d) exceeds model's " - "maximum number of frames (%d), will be ignored", - overrides.num_frames, - num_frames, - ) - num_frames = min(num_frames, overrides.num_frames) - if overrides.width: - if overrides.width > width: - logger.warning( - "video.width override (%d) exceeds model's " - "maximum width (%d), will be ignored", - overrides.width, - width, - ) - width = min(width, overrides.width) - if overrides.height: - if overrides.height > height: - logger.warning( - "video.height override (%d) exceeds model's " - "maximum height (%d), will be ignored", - overrides.height, - height, - ) - height = min(height, overrides.height) - num_frames = max(num_frames, 2) # ernie4.5-vl requires at least 2 frames + # ernie4.5-vl requires at least 2 frames + num_frames = max(num_frames, 2) + if overrides and overrides.num_frames: + overrides.num_frames = max(overrides.num_frames, 2) + + videos = super()._get_dummy_videos( + width=width, + height=height, + num_frames=num_frames, + num_videos=num_videos, + overrides=overrides, + ) + videos = [v.copy() for v in videos] - video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] - for i in range(num_videos): + for video in videos: + video_num_frames = video.shape[0] video_metadata = { "fps": 2.0, - "duration": num_frames / 2.0, - "total_num_frames": num_frames, - "frames_indices": [i for i in range(num_frames)], + "duration": video_num_frames / 2.0, + "total_num_frames": video_num_frames, + "frames_indices": list(range(video_num_frames)), "video_backend": "opencv", "do_sample_frames": False, } - video_item = (video.copy(), video_metadata) - video_items.append(video_item) + video_items.append((video, video_metadata)) + return video_items diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 4722b6e3d..d806562e0 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -1206,49 +1206,32 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): num_videos: int, overrides: VideoDummyOptions | None = None, ) -> list[VideoItem]: - if overrides: - if overrides.num_frames: - if overrides.num_frames > num_frames: - logger.warning( - "video.num_frames override (%d) exceeds model's " - "maximum number of frames (%d), will be ignored", - overrides.num_frames, - num_frames, - ) - num_frames = min(num_frames, overrides.num_frames) - if overrides.width: - if overrides.width > width: - logger.warning( - "video.width override (%d) exceeds model's " - "maximum width (%d), will be ignored", - overrides.width, - width, - ) - width = min(width, overrides.width) - if overrides.height: - if overrides.height > height: - logger.warning( - "video.height override (%d) exceeds model's " - "maximum height (%d), will be ignored", - overrides.height, - height, - ) - height = min(height, overrides.height) + # GLM 4.6V requires at least 2 frames + num_frames = max(num_frames, 2) + if overrides and overrides.num_frames: + overrides.num_frames = max(overrides.num_frames, 2) + + videos = super()._get_dummy_videos( + width=width, + height=height, + num_frames=num_frames, + num_videos=num_videos, + overrides=overrides, + ) + videos = [v.copy() for v in videos] - num_frames = max(num_frames, 2) # GLM 4.6V requires 2 frames - video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video_items = [] - for i in range(num_videos): + for video in videos: + video_num_frames = video.shape[0] video_metadata = { "fps": 2.0, - "duration": num_frames / 2.0, - "total_num_frames": num_frames, - "frames_indices": [i for i in range(num_frames)], + "duration": video_num_frames / 2.0, + "total_num_frames": video_num_frames, + "frames_indices": list(range(video_num_frames)), "video_backend": "opencv", "do_sample_frames": False, } - video_item = (video.copy(), video_metadata) - video_items.append(video_item) + video_items.append((video, video_metadata)) return video_items diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index e684280fe..1e3629eb4 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -8,14 +8,13 @@ # Copyright (c) 2024 H2O.AI # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- -from collections.abc import Mapping, Sequence import torch from transformers import PretrainedConfig from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargsItems +from vllm.multimodal.inputs import BatchedTensorInputs from vllm.multimodal.parse import ( ImageEmbeddingItems, ImageProcessorItems, @@ -25,7 +24,6 @@ from vllm.multimodal.processing.processor import ( MultiModalProcessingInfo, ProcessorInputs, PromptReplacement, - PromptUpdate, TimingContext, ) from vllm.transformers_utils.processors.h2ovl import H2OVLImageProcessor, H2OVLProcessor @@ -86,15 +84,12 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingInfo]): - def _get_prompt_updates( + def _get_prompt_repl_image( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - - out_mm_data = out_mm_kwargs.get_data() + hf_processor: H2OVLProcessor, + out_mm_data: BatchedTensorInputs, + ): if "image_num_patches" in out_mm_data: image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) @@ -130,13 +125,11 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn return hf_processor.get_image_repl(num_patches, num_features=feature_size) - return [ - PromptReplacement( - modality="image", - target="", - replacement=get_replacement_internvl, - ) - ] + return PromptReplacement( + modality="image", + target="", + replacement=get_replacement_internvl, + ) def _cached_apply_hf_processor( self, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 3c33da212..5cb7f462d 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -27,6 +27,7 @@ from vllm.model_executor.models.intern_vit import ( from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( + BatchedTensorInputs, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, @@ -238,11 +239,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): return processed_outputs - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: + def _get_image_fields_config(self, hf_inputs: BatchFeature): image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) num_images = len(image_num_patches) @@ -255,15 +252,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): image_token_id=MultiModalFieldConfig.shared("image", num_images), ) - def _get_prompt_updates( + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return self._get_image_fields_config(hf_inputs) + + def _get_prompt_repl_image( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - - out_mm_data = out_mm_kwargs.get_data() + hf_processor: InternVLProcessor, + out_mm_data: BatchedTensorInputs, + ): if "image_num_patches" in out_mm_data: image_num_patches = out_mm_data["image_num_patches"] assert isinstance(image_num_patches, torch.Tensor) @@ -296,12 +297,23 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): return hf_processor.get_image_repl(num_patches, num_features=feature_size) + return PromptReplacement( + modality="image", + target="", + replacement=get_replacement_internvl, + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + out_mm_data = out_mm_kwargs.get_data() + return [ - PromptReplacement( - modality="image", - target="", - replacement=get_replacement_internvl, - ) + self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data), ] @@ -455,44 +467,35 @@ class InternVLMultiModalProcessor( return processed_outputs + def _get_video_fields_config(self, hf_inputs: BatchFeature): + video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) + num_videos = len(video_num_patches) + + return dict( + pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches + ), + video_num_patches=MultiModalFieldConfig.batched("video"), + video_token_id=MultiModalFieldConfig.shared("video", num_videos), + ) + def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) + fields = self._get_image_fields_config(hf_inputs) if self.info.ctx_video_token: - video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) - num_videos = len(video_num_patches) - video_fields = dict( - pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( - "video", video_num_patches - ), - video_num_patches=MultiModalFieldConfig.batched("video"), - video_token_id=MultiModalFieldConfig.shared("video", num_videos), - ) - else: - video_fields = {} + fields |= self._get_video_fields_config(hf_inputs) - return image_fields | video_fields + return fields - def _get_prompt_updates( + def _get_prompt_repl_video( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - prompt_repl = super()._get_prompt_updates( - mm_items=mm_items, - hf_processor_mm_kwargs=hf_processor_mm_kwargs, - out_mm_kwargs=out_mm_kwargs, - ) - if self.info.ctx_video_token is None: - return prompt_repl - - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - - out_mm_data = out_mm_kwargs.get_data() + hf_processor: InternVLProcessor, + out_mm_data: BatchedTensorInputs, + ): if "video_num_patches" in out_mm_data: video_num_patches = out_mm_data["video_num_patches"] assert isinstance(video_num_patches, torch.Tensor) @@ -507,14 +510,30 @@ class InternVLMultiModalProcessor( return hf_processor.get_video_repl(num_patches) - return [ - *prompt_repl, - PromptReplacement( - modality="video", - target="