[Multi Modal] Configurable MM Profiling (#25631)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,15 +4,45 @@
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDummyOptions:
|
||||
"""Base options for generating dummy data during profiling."""
|
||||
count: int = Field(999, ge=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class VideoDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy video data during profiling."""
|
||||
num_frames: Optional[int] = Field(None, gt=0)
|
||||
width: Optional[int] = Field(None, gt=0)
|
||||
height: Optional[int] = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class ImageDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy image data during profiling."""
|
||||
width: Optional[int] = Field(None, gt=0)
|
||||
height: Optional[int] = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class AudioDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy audio data during profiling."""
|
||||
length: Optional[int] = Field(None, gt=0)
|
||||
|
||||
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
MMCacheType = Literal["shm", "lru"]
|
||||
DummyOptions = Union[BaseDummyOptions, VideoDummyOptions, ImageDummyOptions,
|
||||
AudioDummyOptions]
|
||||
|
||||
|
||||
@config
|
||||
@@ -20,12 +50,22 @@ MMCacheType = Literal["shm", "lru"]
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
limit_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||
"""The maximum number of input items allowed per prompt for each modality.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict)
|
||||
"""The maximum number of input items and options allowed per
|
||||
prompt for each modality.
|
||||
Defaults to 999 for each modality.
|
||||
|
||||
For example, to allow up to 16 images and 2 videos per prompt:
|
||||
`{"image": 16, "video": 2}`"""
|
||||
Legacy format (count only):
|
||||
{"image": 16, "video": 2}
|
||||
|
||||
Configurable format (with options):
|
||||
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
|
||||
"image": {"count": 5, "width": 512, "height": 512}}
|
||||
|
||||
Mixed format (combining both):
|
||||
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
|
||||
"height": 512}}
|
||||
"""
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
@@ -84,6 +124,27 @@ class MultiModalConfig:
|
||||
from each video to be pruned.
|
||||
"""
|
||||
|
||||
@field_validator("limit_per_prompt", mode="before")
|
||||
@classmethod
|
||||
def _validate_limit_per_prompt(
|
||||
cls, value: dict[str, Union[int,
|
||||
dict[str,
|
||||
int]]]) -> dict[str, DummyOptions]:
|
||||
for k, v in value.items():
|
||||
# Handle legacy format where only count is specified
|
||||
if isinstance(v, int):
|
||||
v = {"count": v}
|
||||
# Convert to the appropriate DummyOptions subclass
|
||||
if k == "video":
|
||||
value[k] = VideoDummyOptions(**v)
|
||||
elif k == "image":
|
||||
value[k] = ImageDummyOptions(**v)
|
||||
elif k == "audio":
|
||||
value[k] = AudioDummyOptions(**v)
|
||||
else:
|
||||
value[k] = BaseDummyOptions(**v)
|
||||
return value
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@@ -106,12 +167,22 @@ class MultiModalConfig:
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality.
|
||||
for the given modality (backward compatible).
|
||||
"""
|
||||
return self.limit_per_prompt.get(
|
||||
modality,
|
||||
999 if envs.VLLM_USE_V1 else 1,
|
||||
)
|
||||
limit_data = self.limit_per_prompt.get(modality)
|
||||
|
||||
if limit_data is None:
|
||||
# Unspecified modality is set to 999 by default
|
||||
return 999
|
||||
return limit_data.count
|
||||
|
||||
def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]:
|
||||
"""
|
||||
Get the configurable dummy data options for a modality.
|
||||
Returns None if no options are configured for this modality.
|
||||
"""
|
||||
# All values are now DummyOptions after normalization
|
||||
return self.limit_per_prompt.get(modality)
|
||||
|
||||
def merge_mm_processor_kwargs(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user