[Multi Modal] Configurable MM Profiling (#25631)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
|
||||
ImageDummyOptions, VideoDummyOptions)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
@@ -112,12 +114,26 @@ def _test_processing_correctness(
|
||||
|
||||
processing_info = factories.info(ctx)
|
||||
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||
limit_mm_per_prompt = {
|
||||
# Keep integer limits for local data generation
|
||||
limit_mm_per_prompt_ints = {
|
||||
modality: 3 if limit is None else limit
|
||||
for modality, limit in supported_mm_limits.items()
|
||||
}
|
||||
|
||||
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
|
||||
if modality == "video":
|
||||
return VideoDummyOptions(count=count)
|
||||
if modality == "image":
|
||||
return ImageDummyOptions(count=count)
|
||||
if modality == "audio":
|
||||
return AudioDummyOptions(count=count)
|
||||
return BaseDummyOptions(count=count)
|
||||
|
||||
# Assign normalized DummyOptions to the model config
|
||||
model_config.get_multimodal_config().limit_per_prompt = {
|
||||
modality: _to_dummy_options(modality, count)
|
||||
for modality, count in limit_mm_per_prompt_ints.items()
|
||||
}
|
||||
|
||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||
@@ -150,7 +166,7 @@ def _test_processing_correctness(
|
||||
k:
|
||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(limit + 1))]
|
||||
for k, limit in limit_mm_per_prompt.items()
|
||||
for k, limit in limit_mm_per_prompt_ints.items()
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
|
||||
Reference in New Issue
Block a user