[Model] Multi-input support for LLaVA (#8238)

This commit is contained in:
Cyrus Leung
2024-09-07 10:57:24 +08:00
committed by GitHub
parent 41e95c5247
commit 2f707fcb35
10 changed files with 176 additions and 45 deletions

View File

@@ -105,7 +105,7 @@ def input_processor_for_clip(
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:

View File

@@ -209,7 +209,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
image_feature_size = num_blocks * num_patches
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

View File

@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from vllm.attention import AttentionMetadata
@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
image_feature_size = get_max_llava_image_tokens(ctx)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_max_llava_image_tokens(ctx)
elif is_list_of(image_data, Image.Image):
image_feature_size = [get_max_llava_image_tokens(ctx)
] * len(image_data)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return None
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# Remove the N dimension until multiple images are supported.
pixel_values = pixel_values.squeeze(1)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
# Remove the N dimension until multiple images are supported.
image_embeds = image_embeds.squeeze(1)
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
data=flatten_bn(image_embeds, concat=True),
)
raise AssertionError("This line should be unreachable.")

View File

@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

View File

@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w,
input_height=h))
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")

View File

@@ -110,7 +110,7 @@ def input_processor_for_siglip(
if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else: