[Refactor] Define MM data parser in processing info instead of processor itself (#33260)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -38,7 +39,6 @@ from vllm.model_executor.layers.pooler import IdentityPooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import (
|
||||
ImageItem,
|
||||
ModalityData,
|
||||
@@ -89,7 +89,45 @@ def _terratorch_field_factory(input_definition: InputDefinition):
|
||||
return _terratorch_field_config
|
||||
|
||||
|
||||
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.input_definition = input_definition
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields=_terratorch_field_names(self.input_definition),
|
||||
fields_factory=_terratorch_field_factory(self.input_definition),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
if "image" not in mm_data:
|
||||
mm_data = {"image": mm_data}
|
||||
|
||||
return super().parse_mm_data(mm_data)
|
||||
|
||||
|
||||
class TerratorchProcessingInfo(BaseProcessingInfo):
|
||||
@cached_property
|
||||
def input_definition(self) -> InputDefinition:
|
||||
pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
return InputDefinition(**pretrained_cfg["input"])
|
||||
|
||||
def get_data_parser(self):
|
||||
return TerratorchMultiModalDataParser(
|
||||
self.input_definition,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
@@ -123,55 +161,13 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
return self.dummy_data_generator.get_dummy_mm_data()
|
||||
|
||||
|
||||
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.input_definition = input_definition
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields=_terratorch_field_names(self.input_definition),
|
||||
fields_factory=_terratorch_field_factory(self.input_definition),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
if "image" not in mm_data:
|
||||
mm_data = {"image": mm_data}
|
||||
|
||||
return super().parse_mm_data(mm_data)
|
||||
|
||||
|
||||
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
info: TerratorchProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
|
||||
*,
|
||||
cache: MultiModalProcessorOnlyCache | None = None,
|
||||
) -> None:
|
||||
pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
self._input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
|
||||
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return TerratorchMultiModalDataParser(self._input_definition)
|
||||
|
||||
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _terratorch_field_factory(self._input_definition)(hf_inputs)
|
||||
return _terratorch_field_factory(self.info.input_definition)(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user