[V1] Enable multi-input by default (#15799)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-12 16:52:39 +08:00
committed by GitHub
parent f069f3ea74
commit d9fc8cd9da
21 changed files with 214 additions and 105 deletions

View File

@@ -2667,14 +2667,20 @@ class MultiModalConfig:
usedforsecurity=False).hexdigest()
return hash_str
def get_default_limit_per_prompt(self) -> int:
"""
Return the default number of input items allowed per prompt
for any modality if not specified by the user.
"""
return 999 if envs.VLLM_USE_V1 else 1
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
If not set by the user, this defaults to `1`.
"""
return self.limit_per_prompt.get(modality, 1)
default = self.get_default_limit_per_prompt()
return self.limit_per_prompt.get(modality, default)
# TODO: Add configs to init vision tower or not.

View File

@@ -671,13 +671,13 @@ class EngineArgs:
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalConfig.get_limit_per_prompt
# MultiModalConfig.get_default_limit_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to 1 for '
'each modality.'))
'images and 2 videos per prompt. Defaults to '
'1 (V0) or 999 (V1) for each modality.'))
parser.add_argument(
'--mm-processor-kwargs',
default=None,

View File

@@ -35,7 +35,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@@ -452,8 +452,6 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._items_by_modality = defaultdict[str, list[_T]](list)
@@ -465,6 +463,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
@staticmethod
@cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
@@ -540,12 +542,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
mm_registry = self.mm_registry
model_config = self.model_config
input_modality = modality.replace("_embeds", "")
if mm_registry.has_processor(model_config):
mm_processor = mm_registry.create_processor(model_config)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
else:
mm_config = model_config.multimodal_config
if mm_config is None:
msg = "This model does not support multi-modal inputs"
raise ValueError(msg)
allowed_count = mm_config.get_limit_per_prompt(input_modality)
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self._items_by_modality[modality].append(item)

View File

@@ -126,7 +126,7 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(
data,

View File

@@ -290,7 +290,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMVImageEmbeddingItems(
data,
@@ -302,7 +302,7 @@ class MiniCPMVMultiModalDataParser(MultiModalDataParser):
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return MiniCPMVVideoEmbeddingItems(
data,

View File

@@ -720,7 +720,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
@@ -734,7 +734,7 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def _parse_video_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
return DictEmbeddingItems(
data,

View File

@@ -1034,6 +1034,20 @@ class BaseProcessingInfo:
"""
raise NotImplementedError
def get_allowed_mm_limits(self) -> Mapping[str, int]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits = self.get_supported_mm_limits()
mm_config = self.ctx.get_mm_config()
allowed_limits = dict[str, int]()
for modality, supported_limit in supported_mm_limits.items():
user_limit = mm_config.get_limit_per_prompt(modality)
allowed_limits[modality] = (user_limit if supported_limit is None
else min(user_limit, supported_limit))
return allowed_limits
_I = TypeVar("_I", bound=BaseProcessingInfo)
@@ -1087,14 +1101,24 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
before passing them to :meth:`_get_hf_mm_data`.
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.get_mm_config()
supported_mm_limits = self.info.get_supported_mm_limits()
allowed_mm_limits = self.info.get_allowed_mm_limits()
for modality, items in mm_items.items():
limit = mm_config.get_limit_per_prompt(modality)
if len(items) > limit:
supported_limit = supported_mm_limits.get(modality, 0)
allowed_limit = allowed_mm_limits.get(modality, 0)
num_items = len(items)
if supported_limit is not None and num_items > supported_limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but passed {len(items)} "
f"The model only supports at most {supported_limit} "
f"{modality} items, but you passed {num_items} "
f"{modality} items in the same prompt.")
if num_items > allowed_limit:
raise ValueError(
f"You set or defaulted to {modality}={allowed_limit} "
f"in --limit-mm-per-prompt`, but passed {num_items} "
f"{modality} items in the same prompt.")
return mm_items

View File

@@ -162,23 +162,7 @@ class MultiModalProfiler(Generic[_I]):
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = {
modality: mm_config.get_limit_per_prompt(modality)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits
return self.processing_info.get_allowed_mm_limits()
def _get_dummy_mm_inputs(
self,

View File

@@ -265,8 +265,10 @@ class MultiModalRegistry:
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1
for modality in mm_limits},
{
modality: 1
for modality, limit in mm_limits.items() if limit > 0
},
)
return {

View File

@@ -264,7 +264,7 @@ fetch_video = global_media_connector.fetch_video
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
sampling_rate: float,
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()