[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-26 15:00:28 +08:00
committed by GitHub
parent ee484b3f4b
commit 11b556878b
14 changed files with 701 additions and 604 deletions

View File

@@ -113,7 +113,11 @@ from .qwen2_5_vl import (
Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs,
)
from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo
from .qwen2_vl import (
Qwen2VLMultiModalDataParser,
Qwen2VLProcessingInfo,
_create_qwen2vl_field_factory,
)
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import (
AutoWeightsLoader,
@@ -985,28 +989,9 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
)
return _create_qwen2vl_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
def _get_prompt_updates(
self,

View File

@@ -18,7 +18,7 @@
"""Wrapper around `Terratorch` models"""
from collections import OrderedDict
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence
from typing import Any
import torch
@@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.utils import length_from_prompt_token_ids_or_embeds
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
from .interfaces_base import attn_type
@@ -69,28 +70,21 @@ from .interfaces_base import attn_type
logger = init_logger(__name__)
def _terratorch_field_names(pretrained_cfg: dict):
input_definition = InputDefinition(**pretrained_cfg["input"])
def _terratorch_field_names(input_definition: InputDefinition):
return set(input_definition.data.keys())
def _terratorch_field_factory(
pretrained_cfg: dict,
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
input_definition = InputDefinition(**pretrained_cfg["input"])
fields = {}
for input_name, input in input_definition.data.items():
def _terratorch_field_factory(input_definition: InputDefinition):
def _terratorch_field_config(
hf_inputs: Mapping[str, torch.Tensor],
) -> Mapping[str, MultiModalFieldConfig]:
fields = dict[str, MultiModalFieldConfig]()
for name, input in input_definition.data.items():
modality = "image"
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"
fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
return {
field_name: MultiModalFieldConfig.batched(modality=field_modality)
for field_name, field_modality in fields.items()
}
return fields
return _terratorch_field_config
@@ -130,26 +124,31 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data(
self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
data,
modality="image",
required_fields=terratorch_fields,
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self.input_definition),
)
return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
@@ -159,18 +158,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
*,
cache: MultiModalProcessorOnlyCache | None = None,
) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
self._input_definition = InputDefinition(**pretrained_cfg["input"])
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg)
return TerratorchMultiModalDataParser(self._input_definition)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
return _terratorch_field_factory(self._input_definition)(hf_inputs)
def _get_prompt_updates(
self,
@@ -188,23 +189,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else:
image_data = mm_data
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
mm_data = {"image": image_data}
mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data)
mm_processed_data = BatchFeature(
mm_data.get("image", mm_data), tensor_type="pt"
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
@@ -272,9 +266,15 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
):
model_output = self.inference_runner.forward(**kwargs)
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
return model_output.output
batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
model_output = self.inference_runner.forward(**batched_kwargs).output
# The leading dimension of hidden states needs to equal input length
return model_output.expand(
input_len, *(-1 for _ in range(model_output.ndim - 1))
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_list = []