[Model] Initial support for LLaVA-NeXT (#4199)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn as nn
|
||||
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
|
||||
# transformers' impl.
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
@@ -51,10 +51,10 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = (input_ids == image_token_id)
|
||||
|
||||
@@ -151,7 +151,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values")
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@@ -166,7 +167,8 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
return None
|
||||
|
||||
if not isinstance(image_features, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image features")
|
||||
raise ValueError("Incorrect type of image features. "
|
||||
f"Got type: {type(image_features)}")
|
||||
|
||||
return LlavaImageFeatureInputs(
|
||||
type="image_features",
|
||||
@@ -268,7 +270,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = _merge_vision_embeddings(
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user