[Misc] Clean up processing logic (#37541)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -27,6 +27,7 @@ from vllm.model_executor.models.intern_vit import (
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
@@ -238,11 +239,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
def _get_image_fields_config(self, hf_inputs: BatchFeature):
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
num_images = len(image_num_patches)
|
||||
|
||||
@@ -255,15 +252,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return self._get_image_fields_config(hf_inputs)
|
||||
|
||||
def _get_prompt_repl_image(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
hf_processor: InternVLProcessor,
|
||||
out_mm_data: BatchedTensorInputs,
|
||||
):
|
||||
if "image_num_patches" in out_mm_data:
|
||||
image_num_patches = out_mm_data["image_num_patches"]
|
||||
assert isinstance(image_num_patches, torch.Tensor)
|
||||
@@ -296,12 +297,23 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
|
||||
|
||||
return PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement=get_replacement_internvl,
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement=get_replacement_internvl,
|
||||
)
|
||||
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
|
||||
]
|
||||
|
||||
|
||||
@@ -455,44 +467,35 @@ class InternVLMultiModalProcessor(
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_video_fields_config(self, hf_inputs: BatchFeature):
|
||||
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
||||
num_videos = len(video_num_patches)
|
||||
|
||||
return dict(
|
||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_patches
|
||||
),
|
||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
|
||||
fields = self._get_image_fields_config(hf_inputs)
|
||||
if self.info.ctx_video_token:
|
||||
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
||||
num_videos = len(video_num_patches)
|
||||
video_fields = dict(
|
||||
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_patches
|
||||
),
|
||||
video_num_patches=MultiModalFieldConfig.batched("video"),
|
||||
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
else:
|
||||
video_fields = {}
|
||||
fields |= self._get_video_fields_config(hf_inputs)
|
||||
|
||||
return image_fields | video_fields
|
||||
return fields
|
||||
|
||||
def _get_prompt_updates(
|
||||
def _get_prompt_repl_video(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
prompt_repl = super()._get_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
out_mm_kwargs=out_mm_kwargs,
|
||||
)
|
||||
if self.info.ctx_video_token is None:
|
||||
return prompt_repl
|
||||
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
hf_processor: InternVLProcessor,
|
||||
out_mm_data: BatchedTensorInputs,
|
||||
):
|
||||
if "video_num_patches" in out_mm_data:
|
||||
video_num_patches = out_mm_data["video_num_patches"]
|
||||
assert isinstance(video_num_patches, torch.Tensor)
|
||||
@@ -507,14 +510,30 @@ class InternVLMultiModalProcessor(
|
||||
|
||||
return hf_processor.get_video_repl(num_patches)
|
||||
|
||||
return [
|
||||
*prompt_repl,
|
||||
PromptReplacement(
|
||||
modality="video",
|
||||
target="<video>",
|
||||
replacement=get_video_replacement_internvl,
|
||||
),
|
||||
return PromptReplacement(
|
||||
modality="video",
|
||||
target="<video>",
|
||||
replacement=get_video_replacement_internvl,
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
|
||||
prompt_repls = [
|
||||
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
|
||||
]
|
||||
if self.info.ctx_video_token is not None:
|
||||
prompt_repls.append(
|
||||
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
|
||||
)
|
||||
|
||||
return prompt_repls
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
|
||||
Reference in New Issue
Block a user