[Core][VLM] Add precise multi-modal placeholder tracking (#8346)

Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
Peter Salas
2024-11-01 16:21:10 -07:00
committed by GitHub
parent d151fde834
commit 6c0b7f548d
53 changed files with 913 additions and 281 deletions

View File

@@ -5,8 +5,8 @@ from unittest.mock import patch
import pytest
import torch
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
InputRegistry, token_inputs)
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@@ -56,7 +56,7 @@ def use_dummy_data_mock():
num_crops=DEFAULT_NUM_CROPS):
seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
return seq_data, None
return DummyData(seq_data, None)
with patch(
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == expected_seq_count
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
@pytest.mark.parametrize(
@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
# NOTE: seq_len is thrown away here since this will leverage the
# default dummy data factory that we have patched in, whose seq
# len is solely dependent on the value of the mm_processor_kwargs.
seq_data, _ = dummy_registry.dummy_data_for_profiling(
dummy_data = dummy_registry.dummy_data_for_profiling(
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
### Test overrides for the max token count per multimodal instance