[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)

This commit is contained in:
Cyrus Leung
2024-08-15 01:55:42 +08:00
committed by GitHub
parent 70b746efcf
commit 3f674a49b5
38 changed files with 572 additions and 216 deletions

View File

@@ -1,6 +1,8 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
Tuple, Type)
from torch import nn
from transformers import PretrainedConfig
@@ -12,7 +14,7 @@ from .data import LLMInputs
if TYPE_CHECKING:
from vllm.config import ModelConfig, MultiModalConfig
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry
from vllm.sequence import SequenceData
logger = init_logger(__name__)
@@ -65,15 +67,38 @@ class InputContext:
N = TypeVar("N", bound=Type[nn.Module])
DummyDataFactory = Callable[[InputContext, int],
Tuple["SequenceData",
Optional["MultiModalDataDict"]]]
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
class DummyDataFactory(Protocol):
def __call__(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
...
class _MultiModalCounts(UserDict):
"""
Wraps `mm_counts` for a more informative error message
when attempting to access a plugin that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no multi-modal plugin with the key: {key}. "
f"Available keys: {set(self.keys())}")
raise KeyError(msg) from exc
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
"""Preprocess the inputs to the model."""
@@ -95,6 +120,7 @@ class InputRegistry:
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
The default dummy data factory represents the longest possible text
@@ -133,8 +159,12 @@ class InputRegistry:
return wrapper
def dummy_data_for_profiling(self, model_config: "ModelConfig",
seq_len: int):
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
seq_len: int,
mm_registry: "MultiModalRegistry",
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data for profiling the memory usage of a model.
@@ -142,6 +172,10 @@ class InputRegistry:
See also:
:ref:`enabling_multimodal_inputs`
Note:
This should be called after
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
@@ -149,8 +183,29 @@ class InputRegistry:
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
return dummy_factory(InputContext(model_config), seq_len)
seq_data, mm_data = dummy_factory(
InputContext(model_config),
seq_len,
_MultiModalCounts(mm_counts),
)
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None:
for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1
num_expected = mm_counts[k]
assert num_items >= num_expected, (
f"Expected at least {num_expected} dummy '{k}' instances "
f"for profiling, but found {num_items} instances instead.")
return seq_data, mm_data
def _default_input_processor(self, ctx: InputContext,
inputs: LLMInputs) -> LLMInputs: