[Misc] Clean up processing logic (#37541)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-19 21:30:20 +08:00
committed by GitHub
parent c63ca2b2e6
commit 9515c20868
9 changed files with 477 additions and 679 deletions

View File

@@ -27,6 +27,7 @@ from vllm.model_executor.models.intern_vit import (
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
@@ -238,11 +239,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
def _get_image_fields_config(self, hf_inputs: BatchFeature):
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
@@ -255,15 +252,19 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return self._get_image_fields_config(hf_inputs)
def _get_prompt_repl_image(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
hf_processor: InternVLProcessor,
out_mm_data: BatchedTensorInputs,
):
if "image_num_patches" in out_mm_data:
image_num_patches = out_mm_data["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
@@ -296,12 +297,23 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
return hf_processor.get_image_repl(num_patches, num_features=feature_size)
return PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_internvl,
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
return [
PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_internvl,
)
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
]
@@ -455,44 +467,35 @@ class InternVLMultiModalProcessor(
return processed_outputs
def _get_video_fields_config(self, hf_inputs: BatchFeature):
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
return dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
fields = self._get_image_fields_config(hf_inputs)
if self.info.ctx_video_token:
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
else:
video_fields = {}
fields |= self._get_video_fields_config(hf_inputs)
return image_fields | video_fields
return fields
def _get_prompt_updates(
def _get_prompt_repl_video(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
prompt_repl = super()._get_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
)
if self.info.ctx_video_token is None:
return prompt_repl
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
hf_processor: InternVLProcessor,
out_mm_data: BatchedTensorInputs,
):
if "video_num_patches" in out_mm_data:
video_num_patches = out_mm_data["video_num_patches"]
assert isinstance(video_num_patches, torch.Tensor)
@@ -507,14 +510,30 @@ class InternVLMultiModalProcessor(
return hf_processor.get_video_repl(num_patches)
return [
*prompt_repl,
PromptReplacement(
modality="video",
target="<video>",
replacement=get_video_replacement_internvl,
),
return PromptReplacement(
modality="video",
target="<video>",
replacement=get_video_replacement_internvl,
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
out_mm_data = out_mm_kwargs.get_data()
prompt_repls = [
self._get_prompt_repl_image(mm_items, hf_processor, out_mm_data),
]
if self.info.ctx_video_token is not None:
prompt_repls.append(
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
)
return prompt_repls
@MULTIMODAL_REGISTRY.register_processor(