[Model] Initial support for LLaVA-NeXT (#4199)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-06-10 20:47:15 +08:00
committed by GitHub
parent 0bfa1c4f13
commit 6b29d6fe70
7 changed files with 640 additions and 18 deletions

View File

@@ -1,7 +1,7 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
from torch import nn
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
@@ -51,10 +51,10 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
@@ -151,7 +151,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values")
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
@@ -166,7 +167,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return None
if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features")
raise ValueError("Incorrect type of image features. "
f"Got type: {type(image_features)}")
return LlavaImageFeatureInputs(
type="image_features",
@@ -268,7 +270,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = _merge_vision_embeddings(
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)