[Model] Add multi-image input support for LLaVA-Next offline inference (#7230)

This commit is contained in:
zifeitong
2024-08-27 16:09:02 -07:00
committed by GitHub
parent 345be0e244
commit 5340a2dccf
7 changed files with 173 additions and 51 deletions

View File

@@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
_PREFACE = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions.")
_LIMIT_IMAGE_PER_PROMPT = 4
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
f"{_PREFACE} USER: <image>\nWhat's the content of the image? ASSISTANT:",
"[INST] <image>\nWhat's the content of the image? [/INST]",
"cherry_blossom":
f"{_PREFACE} USER: <image>\nWhat is the season? ASSISTANT:",
"[INST] <image>\nWhat is the season? [/INST]",
})
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
models = ["llava-hf/llava-v1.6-mistral-7b-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
@@ -114,19 +112,43 @@ def run_test(
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_model_len=10240,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]
with hf_runner(model, dtype=dtype,
@@ -136,7 +158,7 @@ def run_test(
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
for prompts, images in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
@@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
@@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens,
num_logprobs) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"[INST] <image><image>\nDescribe 2 images. [/INST]",
"[INST] <image><image>\nDescribe 2 images. [/INST]",
"[INST] <image><image><image><image>\nDescribe 4 images. [/INST]",
"[INST] <image>\nWhat is the season? [/INST]"
],
[
[stop_sign, cherry_blossom],
# Images with different sizes and aspect-ratios
[
rescale_image_size(stop_sign, 0.1),
stop_sign,
],
[
stop_sign,
rescale_image_size(stop_sign, 0.25),
cherry_blossom.resize((183, 488)),
cherry_blossom.resize((488, 183))
],
cherry_blossom,
])]
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)