[Core] Support image processor (#4197)

This commit is contained in:
Cyrus Leung
2024-06-03 13:56:41 +08:00
committed by GitHub
parent dfbe60dc62
commit 7a64d24aad
29 changed files with 1042 additions and 256 deletions

View File

@@ -0,0 +1,7 @@
from .base import MultiModalData, MultiModalPlugin
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
__all__ = [
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
"MultiModalRegistry"
]

126
vllm/multimodal/base.py Normal file
View File

@@ -0,0 +1,126 @@
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
from torch import nn
logger = init_logger(__name__)
class MultiModalData:
"""
Base class that contains multi-modal data.
To add a new modality, add a new file under ``multimodal`` directory.
In this new file, subclass :class:`~MultiModalData` and
:class:`~MultiModalPlugin`.
Finally, register the new plugin to
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
This enables models to call :meth:`MultiModalRegistry.register_input` for
the new modality.
"""
pass
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
class MultiModalPlugin(ABC, Generic[D]):
"""
Base class that defines data processing logic for a specific modality.
In particular, we adopt a registry pattern to dispatch data processing
according to the model being used (considering that different models may
process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data).
"""
@classmethod
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
return get_model_architecture(model_config)[0]
def __init__(self) -> None:
self._input_processors: Dict[Type["nn.Module"],
MultiModalInputProcessor[D]] = {}
@abstractmethod
def get_data_type(self) -> Type[D]:
"""
Get the modality (subclass of :class:`~MultiModalData`) served by
this plugin.
"""
raise NotImplementedError
@abstractmethod
def _default_input_processor(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
"""
raise NotImplementedError
def register_input_processor(self,
processor: Optional[
MultiModalInputProcessor[D]] = None):
"""
Register an input processor to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided input processor is
applied to preprocess the data. If `None` is provided, then the default
input processor is applied instead.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors:
logger.warning(
"Model class %s already has an input processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors[model_cls] = processor \
or self._default_input_processor
return model_cls
return wrapper
def process_input(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
The model is identified by ``model_config``. ``vlm_config`` is
for compatibility purposes and may be merged into ``model_config``
in the near future.
"""
model_cls = self.get_model_cls(model_config)
processor = self._input_processors.get(model_cls)
if processor is None:
raise KeyError(f"No input processor in {self} is registered for "
f"model class {model_cls.__name__}.")
return processor(data, model_config, vlm_config)

141
vllm/multimodal/image.py Normal file
View File

@@ -0,0 +1,141 @@
from typing import Dict, Tuple, Type, Union
import torch
from PIL import Image
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.sequence import SequenceData
from vllm.transformers_utils.image_processor import cached_get_image_processor
from .base import MultiModalData, MultiModalPlugin
logger = init_logger(__name__)
def _get_dummy_seq_data(seq_len: int,
vlm_config: VisionLanguageConfig) -> SequenceData:
# NOTE: We assume that <image> token is repeated `image_feature_size` times
# and then concatenated with the text prompt
# TODO: Enable other ways of inserting the image into the prompt
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
return SequenceData(token_ids)
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
if vlm_config.image_processor is None:
values_dtype = torch.float16
else:
values_dtype = torch.uint8
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
def get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
"""Standard dummy data factory for image data (to be used in
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
values = _get_dummy_values(vlm_config)
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
fake_mm_data: MultiModalData
if config_input_type == ImageInputType.PIXEL_VALUES:
fake_mm_data = ImagePixelData(values)
elif config_input_type == ImageInputType.IMAGE_FEATURES:
fake_mm_data = ImageFeatureData(values)
else:
raise NotImplementedError
return seq_data, fake_mm_data
class ImagePixelData(MultiModalData):
"""
The pixel data of an image. Can be one of:
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
processor is available to the model.
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
without additional pre-processing.
"""
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
if isinstance(image, Image.Image):
# So that this class can be created inside the Image context manager
image.load()
self.image = image
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData
def _get_hf_image_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
if vlm_config is None or vlm_config.image_processor is None:
return None
return cached_get_image_processor(
vlm_config.image_processor,
trust_remote_code=model_config.trust_remote_code,
revision=vlm_config.image_processor_revision,
)
def _default_input_processor(
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image = data.image
image_processor = self._get_hf_image_processor(model_config,
vlm_config)
if isinstance(image, Image.Image):
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
try:
return image_processor.preprocess(image, return_tensors="pt") \
.to(model_config.dtype).data
except Exception:
logger.error("Failed to process image (%s)", image)
raise
elif isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
return {"pixel_values": pixel_values}
raise TypeError(f"Invalid image type: {type(image)}")
class ImageFeatureData(MultiModalData):
"""
The feature vector of an image, passed directly to the model.
This should be the output of the vision tower.
"""
def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData
def _default_input_processor(
self, data: ImageFeatureData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features}

156
vllm/multimodal/registry.py Normal file
View File

@@ -0,0 +1,156 @@
import functools
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
Tuple, Type, TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
if TYPE_CHECKING:
import torch
from torch import nn
from vllm.sequence import SequenceData
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
Tuple["SequenceData", MultiModalData]]
class MultiModalRegistry:
"""
This registry is used by model runners to dispatch data processing
according to its modality and the target model.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
def __init__(self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
MultiModalDummyFactory] = {}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
if data_type in self._plugins_by_data_type:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type,
plugin)
self._plugins_by_data_type[data_type] = plugin
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
for typ in data_type.mro():
plugin = self._plugins_by_data_type.get(typ)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
def register_dummy_data(self, factory: MultiModalDummyFactory):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The modality and shape of
the dummy data should be an upper bound of what the model would receive
at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""Create dummy data for memory profiling."""
model_cls = MultiModalPlugin.get_model_cls(model_config)
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
if dummy_factory is None:
msg = f"No dummy data defined for model class: {model_cls}"
raise NotImplementedError(msg)
return dummy_factory(seq_len, model_config, vlm_config)
def register_input(
self,
data_type: Type[D],
processor: Optional[MultiModalInputProcessor[D]] = None):
"""
Register an input processor for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_processor(processor)
def register_image_pixel_input(
self,
processor: Optional[
MultiModalInputProcessor[ImagePixelData]] = None):
"""
Register an input processor for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImagePixelData, processor)
def register_image_feature_input(
self,
processor: Optional[
MultiModalInputProcessor[ImageFeatureData]] = None):
"""
Register an input processor for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImageFeatureData, processor)
def process_input(self, data: MultiModalData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
See :meth:`MultiModalPlugin.process_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.process_input(data, model_config, vlm_config)
def create_input_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
"""
return functools.partial(self.process_input,
model_config=model_config,
vlm_config=vlm_config)
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""The global :class:`~MultiModalRegistry` which is used by model runners."""