[Model] Multi-input support for LLaVA (#8238)
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
dummy_seq_data_for_clip, get_max_clip_image_tokens,
|
||||
@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
from .utils import (filter_weights, init_vllm_registered_model,
|
||||
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
|
||||
@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
image_feature_size = get_max_llava_image_tokens(ctx)
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = [get_max_llava_image_tokens(ctx)
|
||||
] * len(image_data)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[1] for item in image_data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
pixel_values = pixel_values.squeeze(1)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
# Remove the N dimension until multiple images are supported.
|
||||
image_embeds = image_embeds.squeeze(1)
|
||||
|
||||
return LlavaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
Reference in New Issue
Block a user