[Model][VLM] Initialize support for Mono-InternVL model (#9528)

This commit is contained in:
Isotr0py
2024-10-23 00:01:46 +08:00
committed by GitHub
parent 9dbcce84a7
commit bb392ea2d2
6 changed files with 253 additions and 27 deletions

View File

@@ -7,7 +7,6 @@ from PIL.Image import Image
from transformers import AutoConfig
from vllm.multimodal.utils import rescale_image_size
from vllm.platforms import current_platform
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
@@ -19,15 +18,20 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"cherry_blossom":
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
})
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in short.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
models = [
"OpenGVLab/InternVL2-1B",
"OpenGVLab/InternVL2-2B",
# NOTE: Mono-InternVL-2B doesn't work with fp16,
# it will result NaN during inference.
# See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9
"OpenGVLab/Mono-InternVL-2B",
# Broken due to outdated implementation of Phi-3
# See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
# "OpenGVLab/InternVL2-4B",
]
target_dtype = "bfloat16"
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
@@ -52,9 +56,15 @@ def generate(
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model.generate(
forward_kwargs = dict(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
)
if getattr(self, "use_visual_token_mask", False):
visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype)
forward_kwargs["visual_token_mask"] = visual_token_mask
outputs = self.language_model.generate(
**forward_kwargs,
**generate_kwargs,
)
@@ -243,11 +253,6 @@ def run_awq_test(
)
target_dtype = "half"
if current_platform.is_cpu():
target_dtype = "bfloat16"
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",