[Core] Rename input data types (#8688)

This commit is contained in:
Cyrus Leung
2024-10-16 18:49:37 +08:00
committed by GitHub
parent 1de76a0e55
commit cee711fdbb
32 changed files with 438 additions and 340 deletions

View File

@@ -5,7 +5,7 @@ from unittest.mock import patch
import pytest
import torch
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
from vllm.inputs.registry import InputRegistry
from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@@ -31,7 +31,7 @@ def use_processor_mock():
"""Patches the internal model input processor with an override callable."""
def custom_processor(ctx: InputContext,
llm_inputs: LLMInputs,
inputs: DecoderOnlyInputs,
*,
num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return
@@ -84,7 +84,7 @@ def test_default_processor_is_a_noop():
dummy_registry = InputRegistry()
ctx = build_model_context(DUMMY_MODEL_ID)
processor = dummy_registry.create_input_processor(ctx.model_config)
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
proc_outputs = processor(inputs=proc_inputs)
assert proc_inputs is proc_outputs
@@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count
@@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs
num_crops_val = processor(
LLMInputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS