[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,12 +10,17 @@ from PIL import Image
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
||||
_PlaceholderInfo, find_mm_placeholders,
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
|
||||
PromptReplacement,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import full_groupby
|
||||
@@ -431,7 +436,7 @@ def test_find_replace_tokens(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
@@ -445,13 +450,13 @@ def test_find_replace_tokens(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=5,
|
||||
@@ -459,7 +464,7 @@ def test_find_replace_tokens(
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=7,
|
||||
@@ -472,13 +477,13 @@ def test_find_replace_tokens(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=3,
|
||||
@@ -486,7 +491,7 @@ def test_find_replace_tokens(
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
_PlaceholderInfo(
|
||||
PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
@@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
profiler = processor.profiling_info
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
|
||||
profiler.get_supported_mm_limits = mock_supported_mm_limits
|
||||
processor.info.get_supported_mm_limits = mock_supported_mm_limits
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
@@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||
|
||||
with exc_ctx:
|
||||
profiler.get_mm_limits()
|
||||
profiler.get_dummy_data(model_config.max_model_len)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
image = _rand_img(rng, min_wh=128, max_wh=256)
|
||||
if num_images == 0:
|
||||
@@ -681,9 +678,9 @@ def _test_processing_cache_correctness(
|
||||
hf_overrides=hf_overrides,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
@@ -691,8 +688,9 @@ def _test_processing_cache_correctness(
|
||||
# Ensure that it can fit all of the data
|
||||
cache = ProcessingCache(capacity=1 << 30)
|
||||
|
||||
baseline_processor = processor_factory(ctx, cache=None)
|
||||
cached_processor = processor_factory(ctx, cache=cache)
|
||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||
dummy_inputs = baseline_processor.dummy_inputs
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
@@ -724,7 +722,7 @@ def _test_processing_cache_correctness(
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
|
||||
prompt = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
).prompt_text
|
||||
|
||||
Reference in New Issue
Block a user