[Refactor] Define MM data parser in processing info instead of processor itself (#33260)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-29 13:55:17 +08:00
committed by GitHub
parent 07ea184f00
commit 51550179fc
34 changed files with 399 additions and 347 deletions

View File

@@ -19,6 +19,7 @@
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any
import torch
@@ -38,7 +39,6 @@ from vllm.model_executor.layers.pooler import IdentityPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (
ImageItem,
ModalityData,
@@ -89,7 +89,45 @@ def _terratorch_field_factory(input_definition: InputDefinition):
return _terratorch_field_config
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data(
self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self.input_definition),
)
return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchProcessingInfo(BaseProcessingInfo):
@cached_property
def input_definition(self) -> InputDefinition:
pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
return InputDefinition(**pretrained_cfg["input"])
def get_data_parser(self):
return TerratorchMultiModalDataParser(
self.input_definition,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
@@ -123,55 +161,13 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
return self.dummy_data_generator.get_dummy_mm_data()
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data(
self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self.input_definition),
)
return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: MultiModalProcessorOnlyCache | None = None,
) -> None:
pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
self._input_definition = InputDefinition(**pretrained_cfg["input"])
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return TerratorchMultiModalDataParser(self._input_definition)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _terratorch_field_factory(self._input_definition)(hf_inputs)
return _terratorch_field_factory(self.info.input_definition)(hf_inputs)
def _get_prompt_updates(
self,