[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -10,12 +10,17 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_mm_placeholders,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
PromptReplacement,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
@@ -431,7 +436,7 @@ def test_find_replace_tokens(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=6,
@@ -445,13 +450,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=5,
@@ -459,7 +464,7 @@ def test_find_replace_tokens(
),
],
"pattern_3": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=7,
@@ -472,13 +477,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=3,
@@ -486,7 +491,7 @@ def test_find_replace_tokens(
),
],
"pattern_3": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=6,
@@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
processor = MULTIMODAL_REGISTRY.create_processor(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info
profiler = MultiModalProfiler(processor)
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
profiler.get_supported_mm_limits = mock_supported_mm_limits
processor.info.get_supported_mm_limits = mock_supported_mm_limits
if is_valid:
exc_ctx = nullcontext()
@@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
profiler.get_mm_limits()
profiler.get_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
processor = MULTIMODAL_REGISTRY.create_processor(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0:
@@ -681,9 +678,9 @@ def _test_processing_cache_correctness(
hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
@@ -691,8 +688,9 @@ def _test_processing_cache_correctness(
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
baseline_processor = processor_factory(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache)
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
rng = np.random.RandomState(0)
@@ -724,7 +722,7 @@ def _test_processing_cache_correctness(
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text