[Core][VLM] Add precise multi-modal placeholder tracking (#8346)
Signed-off-by: Peter Salas <peter@fixie.ai>
This commit is contained in:
@@ -5,8 +5,8 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
|
||||
from vllm.inputs.registry import InputRegistry
|
||||
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
|
||||
InputRegistry, token_inputs)
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
@@ -56,7 +56,7 @@ def use_dummy_data_mock():
|
||||
num_crops=DEFAULT_NUM_CROPS):
|
||||
seq_data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
|
||||
return seq_data, None
|
||||
return DummyData(seq_data, None)
|
||||
|
||||
with patch(
|
||||
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
|
||||
@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
|
||||
# NOTE: seq_len is thrown away here since this will leverage the
|
||||
# default dummy data factory that we have patched in, whose seq
|
||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||
seq_data, _ = dummy_registry.dummy_data_for_profiling(
|
||||
dummy_data = dummy_registry.dummy_data_for_profiling(
|
||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||
assert len(seq_data.prompt_token_ids) == expected_seq_count
|
||||
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
|
||||
# NOTE: seq_len is thrown away here since this will leverage the
|
||||
# default dummy data factory that we have patched in, whose seq
|
||||
# len is solely dependent on the value of the mm_processor_kwargs.
|
||||
seq_data, _ = dummy_registry.dummy_data_for_profiling(
|
||||
dummy_data = dummy_registry.dummy_data_for_profiling(
|
||||
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
|
||||
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
|
||||
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
### Test overrides for the max token count per multimodal instance
|
||||
|
||||
Reference in New Issue
Block a user