[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -113,7 +113,11 @@ from .qwen2_5_vl import (
|
||||
Qwen2_5_VLVideoInputs,
|
||||
Qwen2_5_VLVideoPixelInputs,
|
||||
)
|
||||
from .qwen2_vl import Qwen2VLMultiModalDataParser, Qwen2VLProcessingInfo
|
||||
from .qwen2_vl import (
|
||||
Qwen2VLMultiModalDataParser,
|
||||
Qwen2VLProcessingInfo,
|
||||
_create_qwen2vl_field_factory,
|
||||
)
|
||||
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
@@ -985,28 +989,9 @@ class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo])
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes
|
||||
),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes
|
||||
),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image", keep_on_cpu=True),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes
|
||||
),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes
|
||||
),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video", keep_on_cpu=True),
|
||||
)
|
||||
return _create_qwen2vl_field_factory(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size
|
||||
)(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"""Wrapper around `Terratorch` models"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
|
||||
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
|
||||
from .interfaces_base import attn_type
|
||||
@@ -69,28 +70,21 @@ from .interfaces_base import attn_type
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _terratorch_field_names(pretrained_cfg: dict):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
def _terratorch_field_names(input_definition: InputDefinition):
|
||||
return set(input_definition.data.keys())
|
||||
|
||||
|
||||
def _terratorch_field_factory(
|
||||
pretrained_cfg: dict,
|
||||
) -> Callable[
|
||||
[Mapping[str, torch.Tensor]],
|
||||
Mapping[str, MultiModalFieldConfig],
|
||||
]:
|
||||
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
fields = {}
|
||||
for input_name, input in input_definition.data.items():
|
||||
def _terratorch_field_factory(input_definition: InputDefinition):
|
||||
def _terratorch_field_config(
|
||||
hf_inputs: Mapping[str, torch.Tensor],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
fields = dict[str, MultiModalFieldConfig]()
|
||||
for name, input in input_definition.data.items():
|
||||
modality = "image"
|
||||
if input.type == InputTypeEnum.tensor:
|
||||
fields[input_name] = "image"
|
||||
fields[name] = MultiModalFieldConfig.shared(modality, batch_size=1)
|
||||
|
||||
return {
|
||||
field_name: MultiModalFieldConfig.batched(modality=field_modality)
|
||||
for field_name, field_modality in fields.items()
|
||||
}
|
||||
return fields
|
||||
|
||||
return _terratorch_field_config
|
||||
|
||||
@@ -130,26 +124,31 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||
|
||||
|
||||
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
|
||||
self._pretrained_cfg = pretrained_cfg
|
||||
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.input_definition = input_definition
|
||||
|
||||
def _parse_image_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
|
||||
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="image",
|
||||
required_fields=terratorch_fields,
|
||||
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
|
||||
required_fields=_terratorch_field_names(self.input_definition),
|
||||
fields_factory=_terratorch_field_factory(self.input_definition),
|
||||
)
|
||||
|
||||
return super()._parse_image_data(data)
|
||||
|
||||
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
if "image" not in mm_data:
|
||||
mm_data = {"image": mm_data}
|
||||
|
||||
return super().parse_mm_data(mm_data)
|
||||
|
||||
|
||||
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
def __init__(
|
||||
@@ -159,18 +158,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
*,
|
||||
cache: MultiModalProcessorOnlyCache | None = None,
|
||||
) -> None:
|
||||
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||
self._input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||
|
||||
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return TerratorchMultiModalDataParser(pretrained_cfg=self.pretrained_cfg)
|
||||
return TerratorchMultiModalDataParser(self._input_definition)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
|
||||
return _terratorch_field_factory(self._input_definition)(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -188,23 +189,16 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||
tokenization_kwargs: Mapping[str, object] | None = None,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> MultiModalInputs:
|
||||
if "image" in mm_data:
|
||||
image_data = mm_data["image"]
|
||||
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||
else:
|
||||
image_data = mm_data
|
||||
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
|
||||
|
||||
mm_data = {"image": image_data}
|
||||
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
tokenization_kwargs = tokenization_kwargs or {}
|
||||
mm_hashes = self._hash_mm_items(
|
||||
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
|
||||
)
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
mm_processed_data = BatchFeature(image_data)
|
||||
mm_processed_data = BatchFeature(
|
||||
mm_data.get("image", mm_data), tensor_type="pt"
|
||||
)
|
||||
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
|
||||
mm_processed_data,
|
||||
@@ -272,9 +266,15 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
model_output = self.inference_runner.forward(**kwargs)
|
||||
input_len = length_from_prompt_token_ids_or_embeds(input_ids, inputs_embeds)
|
||||
|
||||
return model_output.output
|
||||
batched_kwargs = {k: v.unsqueeze(0) for k, v in kwargs.items()}
|
||||
model_output = self.inference_runner.forward(**batched_kwargs).output
|
||||
|
||||
# The leading dimension of hidden states needs to equal input length
|
||||
return model_output.expand(
|
||||
input_len, *(-1 for _ in range(model_output.ndim - 1))
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_list = []
|
||||
|
||||
Reference in New Issue
Block a user