[VLM] Abstract out multi-modal data parsing in merged processor (#11620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-30 23:01:35 +08:00
committed by GitHub
parent b12e87f942
commit 8d9b6721e7
15 changed files with 559 additions and 311 deletions

View File

@@ -33,7 +33,7 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
@@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config):
def mm_input_mapper_for_glmv(
ctx: InputContext,
data: MultiModalData[object],
data: ModalityData[object],
) -> Dict:
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(

View File

@@ -20,11 +20,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
@@ -179,7 +181,9 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
@@ -591,8 +595,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]
# We reimplement the functionality of MLlavaProcessor from

View File

@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement,
MultiModalDataItems, ProcessorInputs,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
@@ -381,7 +382,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
@@ -389,12 +392,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
num_images = mm_items.get_count("image", strict=False)
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:len(mm_items.images)]
) for image_token in image_tokens[:num_images]
]
def _apply_prompt_replacements(

View File

@@ -20,8 +20,8 @@
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
@@ -38,10 +38,12 @@ from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@@ -99,15 +101,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_hf_mm_data(mm_items)
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,

View File

@@ -25,7 +25,6 @@ from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -55,15 +54,16 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
@@ -719,61 +719,81 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
class Qwen2VLMultiModalDataItems(MultiModalDataItems):
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = Qwen2VLMultiModalDataItems()
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data)
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
self.modality = modality
return multi_data
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
self._slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(grid_thw))
]
def get_item_counts(self) -> Mapping[str, int]:
return {
m: (
len(items[f"{m}_grid_thw"]) # type: ignore
if isinstance(items, dict) else len(items))
for m, items in self.items()
}
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def has_embedding_inputs(self) -> bool:
return any(
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])
def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]
out[k] = v
return out
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "video")
class Qwen2MultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="image")
return super()._parse_image_data(data)
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="video")
return super()._parse_video_data(data)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def _get_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
def _get_hf_processor(
self,
@@ -796,35 +816,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return hf_processor
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, dict):
# Pass through embedding inputs (dict)
passthrough_data.update(v)
elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,

View File

@@ -3,8 +3,8 @@
import math
from functools import cached_property, lru_cache
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
@@ -24,10 +24,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement)
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
@@ -85,15 +87,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_hf_mm_data(mm_items)
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,