[Core] Support image processor (#4197)
This commit is contained in:
7
vllm/multimodal/__init__.py
Normal file
7
vllm/multimodal/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
|
||||
__all__ = [
|
||||
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry"
|
||||
]
|
||||
126
vllm/multimodal/base.py
Normal file
126
vllm/multimodal/base.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
|
||||
TypeVar)
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalData:
|
||||
"""
|
||||
Base class that contains multi-modal data.
|
||||
|
||||
To add a new modality, add a new file under ``multimodal`` directory.
|
||||
|
||||
In this new file, subclass :class:`~MultiModalData` and
|
||||
:class:`~MultiModalPlugin`.
|
||||
|
||||
Finally, register the new plugin to
|
||||
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
This enables models to call :meth:`MultiModalRegistry.register_input` for
|
||||
the new modality.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
D = TypeVar("D", bound=MultiModalData)
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
|
||||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
|
||||
Dict[str, "torch.Tensor"]]
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
|
||||
and processors in HuggingFace Transformers."""
|
||||
|
||||
|
||||
class MultiModalPlugin(ABC, Generic[D]):
|
||||
"""
|
||||
Base class that defines data processing logic for a specific modality.
|
||||
|
||||
In particular, we adopt a registry pattern to dispatch data processing
|
||||
according to the model being used (considering that different models may
|
||||
process the same data differently). This registry is in turn used by
|
||||
:class:`~MultiModalRegistry` which acts at a higher level
|
||||
(i.e., the modality of the data).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._input_processors: Dict[Type["nn.Module"],
|
||||
MultiModalInputProcessor[D]] = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_data_type(self) -> Type[D]:
|
||||
"""
|
||||
Get the modality (subclass of :class:`~MultiModalData`) served by
|
||||
this plugin.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _default_input_processor(
|
||||
self, data: D, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
|
||||
"""Return a dictionary to be passed as keyword arguments to
|
||||
:meth:`torch.nn.Module.forward`. This is similar in concept to
|
||||
tokenizers and processors in HuggingFace Transformers.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_input_processor(self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[D]] = None):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
|
||||
When the model receives input data that matches the modality served by
|
||||
this plugin (see :meth:`get_data_type`), the provided input processor is
|
||||
applied to preprocess the data. If `None` is provided, then the default
|
||||
input processor is applied instead.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_processors:
|
||||
logger.warning(
|
||||
"Model class %s already has an input processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._input_processors[model_cls] = processor \
|
||||
or self._default_input_processor
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def process_input(
|
||||
self, data: D, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
|
||||
"""
|
||||
Apply an input processor to a :class:`~MultiModalData` instance passed
|
||||
to the model.
|
||||
|
||||
The model is identified by ``model_config``. ``vlm_config`` is
|
||||
for compatibility purposes and may be merged into ``model_config``
|
||||
in the near future.
|
||||
"""
|
||||
model_cls = self.get_model_cls(model_config)
|
||||
|
||||
processor = self._input_processors.get(model_cls)
|
||||
if processor is None:
|
||||
raise KeyError(f"No input processor in {self} is registered for "
|
||||
f"model class {model_cls.__name__}.")
|
||||
|
||||
return processor(data, model_config, vlm_config)
|
||||
141
vllm/multimodal/image.py
Normal file
141
vllm/multimodal/image.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from typing import Dict, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.transformers_utils.image_processor import cached_get_image_processor
|
||||
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_dummy_seq_data(seq_len: int,
|
||||
vlm_config: VisionLanguageConfig) -> SequenceData:
|
||||
# NOTE: We assume that <image> token is repeated `image_feature_size` times
|
||||
# and then concatenated with the text prompt
|
||||
# TODO: Enable other ways of inserting the image into the prompt
|
||||
|
||||
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
|
||||
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
|
||||
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
|
||||
if vlm_config.image_processor is None:
|
||||
values_dtype = torch.float16
|
||||
else:
|
||||
values_dtype = torch.uint8
|
||||
|
||||
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
|
||||
|
||||
|
||||
def get_dummy_image_data(
|
||||
seq_len: int,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Tuple[SequenceData, MultiModalData]:
|
||||
"""Standard dummy data factory for image data (to be used in
|
||||
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
|
||||
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
|
||||
values = _get_dummy_values(vlm_config)
|
||||
|
||||
config_input_type = vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
fake_mm_data: MultiModalData
|
||||
if config_input_type == ImageInputType.PIXEL_VALUES:
|
||||
fake_mm_data = ImagePixelData(values)
|
||||
elif config_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
fake_mm_data = ImageFeatureData(values)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return seq_data, fake_mm_data
|
||||
|
||||
|
||||
class ImagePixelData(MultiModalData):
|
||||
"""
|
||||
The pixel data of an image. Can be one of:
|
||||
|
||||
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
|
||||
processor is available to the model.
|
||||
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
|
||||
without additional pre-processing.
|
||||
"""
|
||||
|
||||
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
|
||||
if isinstance(image, Image.Image):
|
||||
# So that this class can be created inside the Image context manager
|
||||
image.load()
|
||||
|
||||
self.image = image
|
||||
|
||||
|
||||
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
|
||||
|
||||
def get_data_type(self) -> Type[ImagePixelData]:
|
||||
return ImagePixelData
|
||||
|
||||
def _get_hf_image_processor(self, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
if vlm_config is None or vlm_config.image_processor is None:
|
||||
return None
|
||||
|
||||
return cached_get_image_processor(
|
||||
vlm_config.image_processor,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=vlm_config.image_processor_revision,
|
||||
)
|
||||
|
||||
def _default_input_processor(
|
||||
self, data: ImagePixelData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
image_processor = self._get_hf_image_processor(model_config,
|
||||
vlm_config)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available"
|
||||
"to process the image object")
|
||||
try:
|
||||
return image_processor.preprocess(image, return_tensors="pt") \
|
||||
.to(model_config.dtype).data
|
||||
except Exception:
|
||||
logger.error("Failed to process image (%s)", image)
|
||||
raise
|
||||
elif isinstance(image, torch.Tensor):
|
||||
pixel_values = image.to(model_config.dtype)
|
||||
|
||||
return {"pixel_values": pixel_values}
|
||||
|
||||
raise TypeError(f"Invalid image type: {type(image)}")
|
||||
|
||||
|
||||
class ImageFeatureData(MultiModalData):
|
||||
"""
|
||||
The feature vector of an image, passed directly to the model.
|
||||
|
||||
This should be the output of the vision tower.
|
||||
"""
|
||||
|
||||
def __init__(self, image_features: torch.Tensor) -> None:
|
||||
self.image_features = image_features
|
||||
|
||||
|
||||
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
|
||||
|
||||
def get_data_type(self) -> Type[ImageFeatureData]:
|
||||
return ImageFeatureData
|
||||
|
||||
def _default_input_processor(
|
||||
self, data: ImageFeatureData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
|
||||
image_features = data.image_features.to(model_config.dtype)
|
||||
|
||||
return {"image_features": image_features}
|
||||
156
vllm/multimodal/registry.py
Normal file
156
vllm/multimodal/registry.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import functools
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
|
||||
Tuple, Type, TypeVar)
|
||||
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MultiModalData, MultiModalPlugin
|
||||
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
|
||||
ImagePixelPlugin)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
D = TypeVar("D", bound=MultiModalData)
|
||||
N = TypeVar("N", bound=Type["nn.Module"])
|
||||
|
||||
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
|
||||
Dict[str, "torch.Tensor"]]
|
||||
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
|
||||
Tuple["SequenceData", MultiModalData]]
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
This registry is used by model runners to dispatch data processing
|
||||
according to its modality and the target model.
|
||||
"""
|
||||
|
||||
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
|
||||
) -> None:
|
||||
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
|
||||
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
|
||||
MultiModalDummyFactory] = {}
|
||||
|
||||
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
|
||||
data_type = plugin.get_data_type()
|
||||
|
||||
if data_type in self._plugins_by_data_type:
|
||||
logger.warning(
|
||||
"A plugin is already registered for data type %s, "
|
||||
"and will be overwritten by the new plugin %s.", data_type,
|
||||
plugin)
|
||||
|
||||
self._plugins_by_data_type[data_type] = plugin
|
||||
|
||||
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
|
||||
for typ in data_type.mro():
|
||||
plugin = self._plugins_by_data_type.get(typ)
|
||||
if plugin is not None:
|
||||
return plugin
|
||||
|
||||
msg = f"Unknown multi-modal data type: {data_type}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def register_dummy_data(self, factory: MultiModalDummyFactory):
|
||||
"""
|
||||
Register a dummy data factory to a model class.
|
||||
|
||||
During memory profiling, the provided function is invoked to create
|
||||
dummy data to be inputted into the model. The modality and shape of
|
||||
the dummy data should be an upper bound of what the model would receive
|
||||
at inference time.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_factories_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""Create dummy data for memory profiling."""
|
||||
model_cls = MultiModalPlugin.get_model_cls(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
|
||||
if dummy_factory is None:
|
||||
msg = f"No dummy data defined for model class: {model_cls}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return dummy_factory(seq_len, model_config, vlm_config)
|
||||
|
||||
def register_input(
|
||||
self,
|
||||
data_type: Type[D],
|
||||
processor: Optional[MultiModalInputProcessor[D]] = None):
|
||||
"""
|
||||
Register an input processor for a specific modality to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self._get_plugin_for_data_type(data_type) \
|
||||
.register_input_processor(processor)
|
||||
|
||||
def register_image_pixel_input(
|
||||
self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[ImagePixelData]] = None):
|
||||
"""
|
||||
Register an input processor for image pixel data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self.register_input(ImagePixelData, processor)
|
||||
|
||||
def register_image_feature_input(
|
||||
self,
|
||||
processor: Optional[
|
||||
MultiModalInputProcessor[ImageFeatureData]] = None):
|
||||
"""
|
||||
Register an input processor for image feature data to a model class.
|
||||
|
||||
See :meth:`MultiModalPlugin.register_input_processor` for more details.
|
||||
"""
|
||||
return self.register_input(ImageFeatureData, processor)
|
||||
|
||||
def process_input(self, data: MultiModalData, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""
|
||||
Apply an input processor to a :class:`~MultiModalData` instance passed
|
||||
to the model.
|
||||
|
||||
See :meth:`MultiModalPlugin.process_input` for more details.
|
||||
"""
|
||||
return self._get_plugin_for_data_type(type(data)) \
|
||||
.process_input(data, model_config, vlm_config)
|
||||
|
||||
def create_input_processor(self, model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig):
|
||||
"""
|
||||
Create an input processor (see :meth:`process_input`) for a
|
||||
specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input,
|
||||
model_config=model_config,
|
||||
vlm_config=vlm_config)
|
||||
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
"""The global :class:`~MultiModalRegistry` which is used by model runners."""
|
||||
Reference in New Issue
Block a user