[Core] Rename input data types (#8688)

This commit is contained in:
Cyrus Leung
2024-10-16 18:49:37 +08:00
committed by GitHub
parent 1de76a0e55
commit cee711fdbb
32 changed files with 438 additions and 340 deletions

View File

@@ -1,12 +1,12 @@
import os
import re
from typing import Callable, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import pytest
import torch
from transformers import AutoImageProcessor, AutoTokenizer
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import InputContext, token_inputs
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.utils import rescale_image_size
@@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets,
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
@@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
(16, 2653, 1),
(16, 2653, 2),
])
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
num_crops: int, toks_per_img: int, num_imgs: int):
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
toks_per_img: int, num_imgs: int):
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
@@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
(16, 1921, 1),
(16, 1921, 2),
])
def test_input_processor_override(input_processor_for_phi3v: Callable,
def test_input_processor_override(input_processor_for_phi3v,
image_assets: _ImageAssets, model: str,
num_crops: int, expected_toks_per_img: int,
num_imgs: int):
@@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable,
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
proc_llm_inputs = input_processor_for_phi3v(
ctx=ctx,
llm_inputs=llm_inputs,
num_crops=num_crops,
)
processed_inputs = input_processor_for_phi3v(ctx,
inputs,
num_crops=num_crops)
# Ensure we have the right number of placeholders per num_crops size
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
assert img_tok_count == expected_toks_per_img * num_imgs

View File

@@ -5,7 +5,7 @@ import pytest
import torch
from PIL.Image import Image
from vllm.inputs import InputContext, LLMInputs
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
@@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen,
"""Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = LLMInputs(
inputs = token_inputs(
prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with
# the image prompts
prompt_token_ids=None,
prompt_token_ids=[],
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
)
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
@@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen,
trust_remote_code=True)
prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt)
inputs = LLMInputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
inputs = token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings
with pytest.raises(ValueError):
input_processor_for_qwen(qwen_vl_context, inputs)