[1/N] Initial prototype for multi-modal processor (#10044)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from array import array
|
||||
from typing import Mapping
|
||||
from typing import Callable, Dict, Mapping, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
|
||||
InputRegistry, token_inputs)
|
||||
InputRegistry, ProcessorInputs, token_inputs)
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
|
||||
|
||||
@@ -34,10 +34,9 @@ def use_processor_mock():
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
num_crops=DEFAULT_NUM_CROPS):
|
||||
# For testing purposes, we don't worry about the llm inputs / return
|
||||
# type validation, and just return the value of the kwarg that we
|
||||
# clobber.
|
||||
return num_crops
|
||||
# For testing purposes, we don't worry about the prompt
|
||||
return token_inputs(prompt_token_ids=[],
|
||||
mm_processor_kwargs={"num_crops": num_crops})
|
||||
|
||||
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
|
||||
return_value=custom_processor):
|
||||
@@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
|
||||
return init_kwargs, inference_kwargs, expected_seq_count
|
||||
|
||||
|
||||
def _get_processed_num_crops(
|
||||
processor: Callable[[ProcessorInputs], ProcessorInputs],
|
||||
inference_kwargs: Optional[Dict[str, int]],
|
||||
) -> int:
|
||||
processed_inputs = processor(
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
|
||||
assert "type" in processed_inputs
|
||||
assert processed_inputs["type"] == "token"
|
||||
assert "mm_processor_kwargs" in processed_inputs
|
||||
return processed_inputs["mm_processor_kwargs"]["num_crops"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [
|
||||
(None, None),
|
||||
(NUM_CROPS_OVERRIDE, None),
|
||||
@@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
|
||||
|
||||
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
num_crops_val = processor(
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=inference_kwargs))
|
||||
num_crops_val = _get_processed_num_crops(processor, inference_kwargs)
|
||||
|
||||
assert num_crops_val == expected_seq_count
|
||||
|
||||
|
||||
@@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
|
||||
|
||||
processor = dummy_registry.create_input_processor(ctx.model_config)
|
||||
# Should filter out the inference time kwargs
|
||||
num_crops_val = processor(
|
||||
token_inputs(prompt_token_ids=[],
|
||||
prompt="",
|
||||
mm_processor_kwargs=mm_processor_kwargs))
|
||||
num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs)
|
||||
assert num_crops_val == DEFAULT_NUM_CROPS
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user