[VLM][Core] Support profiling with multiple multi-modal inputs per prompt (#7126)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import functools
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol,
|
||||
Tuple, Type)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
@@ -12,7 +14,7 @@ from .data import LLMInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, MultiModalConfig
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalRegistry
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -65,15 +67,38 @@ class InputContext:
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
DummyDataFactory = Callable[[InputContext, int],
|
||||
Tuple["SequenceData",
|
||||
Optional["MultiModalDataDict"]]]
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
class DummyDataFactory(Protocol):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class _MultiModalCounts(UserDict):
|
||||
"""
|
||||
Wraps `mm_counts` for a more informative error message
|
||||
when attempting to access a plugin that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"There is no multi-modal plugin with the key: {key}. "
|
||||
f"Available keys: {set(self.keys())}")
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
@@ -95,6 +120,7 @@ class InputRegistry:
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
The default dummy data factory represents the longest possible text
|
||||
@@ -133,8 +159,12 @@ class InputRegistry:
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, model_config: "ModelConfig",
|
||||
seq_len: int):
|
||||
def dummy_data_for_profiling(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
@@ -142,6 +172,10 @@ class InputRegistry:
|
||||
|
||||
See also:
|
||||
:ref:`enabling_multimodal_inputs`
|
||||
|
||||
Note:
|
||||
This should be called after
|
||||
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
@@ -149,8 +183,29 @@ class InputRegistry:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
return dummy_factory(InputContext(model_config), seq_len)
|
||||
seq_data, mm_data = dummy_factory(
|
||||
InputContext(model_config),
|
||||
seq_len,
|
||||
_MultiModalCounts(mm_counts),
|
||||
)
|
||||
|
||||
# Having more tokens is over-conservative but otherwise fine
|
||||
num_tokens = seq_data.prompt_token_ids
|
||||
assert len(num_tokens) >= seq_len, (
|
||||
f"Expected at least {seq_len} dummy tokens for profiling, "
|
||||
f"but found {len(num_tokens)} tokens instead.")
|
||||
|
||||
if mm_data is not None:
|
||||
for k, v in mm_data.items():
|
||||
num_items = len(v) if isinstance(v, list) else 1
|
||||
num_expected = mm_counts[k]
|
||||
assert num_items >= num_expected, (
|
||||
f"Expected at least {num_expected} dummy '{k}' instances "
|
||||
f"for profiling, but found {num_items} instances instead.")
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
def _default_input_processor(self, ctx: InputContext,
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
|
||||
Reference in New Issue
Block a user