[VLM] Support multimodal inputs for Florence-2 models (#13320)

This commit is contained in:
Isotr0py
2025-02-27 18:06:41 +08:00
committed by GitHub
parent 788f284b53
commit edf309ebbe
13 changed files with 1075 additions and 114 deletions

View File

@@ -1,52 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import List, Optional, Tuple, Type
from typing import Optional, Type
import pytest
from PIL import Image
from vllm.inputs.data import ExplicitEncoderDecoderPrompt
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
decoder_prompt=None,
mm_processor_kwargs=None)
MODELS = ["microsoft/Florence-2-base"]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER = "facebook/bart-base"
PROMPTS = [
Florence2Prompt(encoder_prompt="<CAPTION>"),
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
Florence2Prompt(encoder_prompt="<OCR>"),
Florence2Prompt(encoder_prompt="<OD>"),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<CAPTION>", # special task token
"cherry_blossom":
"Describe in detail what is shown in the image.",
})
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], ):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
def get_hf_images_prompts(
prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
prompts, images = [], []
for prompt in prompts_:
encoder_prompt = prompt["encoder_prompt"]
prompts.append(
ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt["prompt"],
decoder_prompt=None,
))
images.append(encoder_prompt["multi_modal_data"]["image"])
return prompts, images
hf_output_str = "</s><s>" + output_str + "</s>"
return output_ids, hf_output_str, out_logprobs
def hf_to_vllm_output(hf_output: tuple[list[int], str,
Optional[SampleLogprobs]]):
"""Sanitize hf output to be comparable with vllm output."""
output_ids, output_str, out_logprobs = hf_output
output_str = output_str.replace("</s>", "").replace("<s>", "")
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
return output_ids, output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
prompts: List[ExplicitEncoderDecoderPrompt],
inputs: list[list[ExplicitEncoderDecoderPrompt]],
model: str,
*,
dtype: str,
@@ -56,46 +63,76 @@ def run_test(
distributed_executor_backend: Optional[str] = None,
) -> None:
with vllm_runner(model,
max_num_seqs=8,
tokenizer_name=TOKENIZER,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs)
vllm_outputs_per_case = [
vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs)
for prompts in inputs
]
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
# Florence-2 processors require image inputs
dummy_image = Image.new(mode="RGB", size=(2, 2))
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.lm_head
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs,
images=[dummy_image] * len(prompts),
))
hf_outputs_per_case = [
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs=num_logprobs, images=images)
for prompts, images in hf_inputs
]
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
num_logprobs) -> None:
def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, model: str,
size_factors: list[int], dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [[
ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=prompt,
multi_modal_data={"image": rescale_image_size(image, factor)}),
decoder_prompt=None,
) for factor in size_factors
] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test(
hf_runner,
vllm_runner,
PROMPTS,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,