Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,19 +7,27 @@ from typing import Annotated, Final, Literal, Optional, Protocol, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (BatchFeature, LlavaOnevisionConfig,
|
||||
LlavaOnevisionProcessor)
|
||||
from transformers import BatchFeature, LlavaOnevisionConfig, LlavaOnevisionProcessor
|
||||
from transformers.models.llava_onevision.modeling_llava_onevision import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
get_anyres_image_grid_shape,
|
||||
unpad_image,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
VideoEmbeddingItems,
|
||||
VideoProcessorItems,
|
||||
)
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -27,11 +35,19 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
|
||||
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
|
||||
LlavaNextProcessingInfo)
|
||||
from .llava_next import (
|
||||
BaseLlavaNextMultiModalProcessor,
|
||||
LlavaNextLikeConfig,
|
||||
LlavaNextProcessingInfo,
|
||||
)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
@@ -50,6 +66,7 @@ class LlavaOnevisionVideoPixelInputs(TensorSchema):
|
||||
may be different for each video, in which case the data is passed as a
|
||||
list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
|
||||
pixel_values_videos: Annotated[
|
||||
@@ -70,6 +87,7 @@ class LlavaOnevisionImagePixelInputs(TensorSchema):
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
pixel_values: Annotated[
|
||||
@@ -87,6 +105,7 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
|
||||
- ifs: Image feature size
|
||||
- hs: Hidden size (must match language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
|
||||
data: Annotated[
|
||||
@@ -95,11 +114,13 @@ class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
|
||||
]
|
||||
|
||||
|
||||
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
|
||||
LlavaOnevisionImageEmbeddingInputs]
|
||||
LlavaOnevisionImageInputs = Union[
|
||||
LlavaOnevisionImagePixelInputs, LlavaOnevisionImageEmbeddingInputs
|
||||
]
|
||||
|
||||
LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
|
||||
LlavaOnevisionVideoPixelInputs]
|
||||
LlavaOnevisionMultiInputs = Union[
|
||||
LlavaOnevisionImageInputs, LlavaOnevisionVideoPixelInputs
|
||||
]
|
||||
|
||||
|
||||
class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
|
||||
@@ -107,7 +128,6 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
|
||||
|
||||
|
||||
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
def get_hf_config(self) -> LlavaOnevisionLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
|
||||
@@ -136,12 +156,14 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = int(
|
||||
round(original_height * (current_width / original_width), 7))
|
||||
round(original_height * (current_width / original_width), 7)
|
||||
)
|
||||
padding = (current_height - new_height) // 2
|
||||
current_height = current_height - (2 * padding)
|
||||
else:
|
||||
new_width = int(
|
||||
round(original_width * (current_height / original_height), 7))
|
||||
round(original_width * (current_height / original_height), 7)
|
||||
)
|
||||
padding = (current_width - new_width) // 2
|
||||
current_width = current_width - (2 * padding)
|
||||
|
||||
@@ -218,8 +240,9 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
max_frames_per_video = min(
|
||||
max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
|
||||
)
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
@@ -233,14 +256,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
return self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
||||
)
|
||||
|
||||
|
||||
class LlavaOnevisionDummyInputsBuilder(
|
||||
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
|
||||
|
||||
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]
|
||||
):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
@@ -260,35 +282,34 @@ class LlavaOnevisionDummyInputsBuilder(
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len,
|
||||
mm_counts)
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts
|
||||
)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
"video": self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProcessor(
|
||||
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
|
||||
|
||||
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]
|
||||
):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@@ -405,7 +426,8 @@ class LlavaOnevisionMultiModalProcessor(
|
||||
|
||||
def get_video_replacement(item_idx: int):
|
||||
videos = mm_items.get_items(
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems))
|
||||
"video", (VideoEmbeddingItems, VideoProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(videos, VideoEmbeddingItems):
|
||||
num_video_tokens = videos.get_feature_size(item_idx)
|
||||
@@ -430,17 +452,20 @@ class LlavaOnevisionMultiModalProcessor(
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: LlavaOnevisionConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
)
|
||||
self.act = get_act_fn(config.projector_hidden_act)
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias)
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=config.multimodal_projector_bias,
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.linear_1(image_features)
|
||||
@@ -452,10 +477,9 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaOnevisionMultiModalProcessor,
|
||||
info=LlavaOnevisionProcessingInfo,
|
||||
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
dummy_inputs=LlavaOnevisionDummyInputsBuilder,
|
||||
)
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers v4.52
|
||||
@@ -464,7 +488,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -489,21 +514,23 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
self.language_model.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[LlavaOnevisionImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
@@ -513,12 +540,14 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(image_sizes, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
|
||||
)
|
||||
|
||||
return LlavaOnevisionImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@@ -526,13 +555,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||
resolve_bindings={
|
||||
"h": self.config.vision_config.image_size,
|
||||
"w": self.config.vision_config.image_size
|
||||
})
|
||||
"w": self.config.vision_config.image_size,
|
||||
},
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeds. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
|
||||
)
|
||||
|
||||
return LlavaOnevisionImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
@@ -542,12 +573,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self,
|
||||
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[LlavaOnevisionVideoPixelInputs]:
|
||||
"""
|
||||
A legal video input should have the following dimensions:
|
||||
{
|
||||
"pixel_values_videos" :
|
||||
"pixel_values_videos" :
|
||||
list[b, Tensor(nb_frames, nb_channels, height, width)]
|
||||
}
|
||||
"""
|
||||
@@ -556,16 +587,19 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel_values_videos. "
|
||||
f"Got type: {type(pixel_values_videos)}"
|
||||
)
|
||||
|
||||
return LlavaOnevisionVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=flatten_bn(pixel_values_videos),
|
||||
resolve_bindings={
|
||||
"h": self.config.vision_config.image_size,
|
||||
"w": self.config.vision_config.image_size
|
||||
})
|
||||
"w": self.config.vision_config.image_size,
|
||||
},
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
@@ -573,14 +607,20 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values", "image_embeds"
|
||||
) and "image" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key in ("pixel_values_videos", "video_embeds"
|
||||
) and "video" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"video"] = self._parse_and_validate_video_input(**kwargs)
|
||||
if (
|
||||
input_key in ("pixel_values", "image_embeds")
|
||||
and "image" not in mm_input_by_modality
|
||||
):
|
||||
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
|
||||
**kwargs
|
||||
)
|
||||
if (
|
||||
input_key in ("pixel_values_videos", "video_embeds")
|
||||
and "video" not in mm_input_by_modality
|
||||
):
|
||||
mm_input_by_modality["video"] = self._parse_and_validate_video_input(
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return mm_input_by_modality
|
||||
|
||||
@@ -597,25 +637,29 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
def _merge_image_patch_embeddings(self,
|
||||
image_size: torch.Tensor,
|
||||
patch_embeddings: torch.Tensor,
|
||||
*,
|
||||
image_newline=None,
|
||||
vision_aspect_ratio="anyres_max_9",
|
||||
strategy: str) -> torch.Tensor:
|
||||
def _merge_image_patch_embeddings(
|
||||
self,
|
||||
image_size: torch.Tensor,
|
||||
patch_embeddings: torch.Tensor,
|
||||
*,
|
||||
image_newline=None,
|
||||
vision_aspect_ratio="anyres_max_9",
|
||||
strategy: str,
|
||||
) -> torch.Tensor:
|
||||
if strategy == "flat":
|
||||
return patch_embeddings.flatten(0, 1)
|
||||
|
||||
if strategy.startswith("spatial"):
|
||||
height = width = self.config.vision_config.image_size \
|
||||
height = width = (
|
||||
self.config.vision_config.image_size
|
||||
// self.config.vision_config.patch_size
|
||||
)
|
||||
|
||||
base_patch_embeds = patch_embeddings[0]
|
||||
if height * width != base_patch_embeds.shape[0]:
|
||||
raise ValueError(
|
||||
"The number of patches is not consistent with the "
|
||||
"image size.")
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
if patch_embeddings.shape[0] > 1:
|
||||
other_patch_embeds = patch_embeddings[1:]
|
||||
@@ -632,53 +676,66 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
num_patches = num_patch_height * num_patch_width
|
||||
|
||||
# Image patches might be padded for batch processing
|
||||
other_patch_embeds = other_patch_embeds[:num_patches] \
|
||||
.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
other_patch_embeds = other_patch_embeds[:num_patches].view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
|
||||
if "unpad" in strategy:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.permute(4, 0, 2, 1, 3).contiguous() \
|
||||
.flatten(1, 2).flatten(2, 3)
|
||||
other_patch_embeds = unpad_image(other_patch_embeds,
|
||||
(orig_height, orig_width))
|
||||
other_patch_embeds = (
|
||||
other_patch_embeds.permute(4, 0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.flatten(1, 2)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
other_patch_embeds = unpad_image(
|
||||
other_patch_embeds, (orig_height, orig_width)
|
||||
)
|
||||
max_num_patches = int(
|
||||
vision_aspect_ratio.removeprefix("anyres_max_"))
|
||||
vision_aspect_ratio.removeprefix("anyres_max_")
|
||||
)
|
||||
channels, curr_height, curr_width = other_patch_embeds.shape
|
||||
ratio = math.sqrt(curr_height * curr_width /
|
||||
(max_num_patches * height**2))
|
||||
ratio = math.sqrt(
|
||||
curr_height * curr_width / (max_num_patches * height**2)
|
||||
)
|
||||
if ratio > 1.1:
|
||||
other_patch_embeds = other_patch_embeds[None]
|
||||
other_patch_embeds = nn.functional.interpolate(
|
||||
other_patch_embeds, [
|
||||
int(curr_height // ratio),
|
||||
int(curr_width // ratio)
|
||||
],
|
||||
mode="bilinear")[0]
|
||||
other_patch_embeds,
|
||||
[int(curr_height // ratio), int(curr_width // ratio)],
|
||||
mode="bilinear",
|
||||
)[0]
|
||||
if image_newline is not None:
|
||||
other_patch_embeds = torch.cat(
|
||||
(
|
||||
other_patch_embeds,
|
||||
image_newline[:, None, None] \
|
||||
.expand(*other_patch_embeds.shape[:-1], 1) \
|
||||
image_newline[:, None, None]
|
||||
.expand(*other_patch_embeds.shape[:-1], 1)
|
||||
.to(other_patch_embeds.device),
|
||||
),
|
||||
dim=-1)
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.flatten(1, 2).transpose(0, 1)
|
||||
dim=-1,
|
||||
)
|
||||
other_patch_embeds = other_patch_embeds.flatten(1, 2).transpose(
|
||||
0, 1
|
||||
)
|
||||
else:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.permute(0, 2, 1, 3, 4).contiguous() \
|
||||
other_patch_embeds = (
|
||||
other_patch_embeds.permute(0, 2, 1, 3, 4)
|
||||
.contiguous()
|
||||
.flatten(0, 3)
|
||||
)
|
||||
|
||||
merged_patch_embeddings = torch.cat(
|
||||
(base_patch_embeds, other_patch_embeds), dim=0)
|
||||
(base_patch_embeds, other_patch_embeds), dim=0
|
||||
)
|
||||
else:
|
||||
if "unpad" in strategy:
|
||||
merged_patch_embeddings = torch.cat(
|
||||
(base_patch_embeds,
|
||||
self.image_newline[None] \
|
||||
.to(base_patch_embeds.device)
|
||||
), dim=0)
|
||||
(
|
||||
base_patch_embeds,
|
||||
self.image_newline[None].to(base_patch_embeds.device),
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
merged_patch_embeddings = base_patch_embeds
|
||||
|
||||
@@ -698,21 +755,27 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
self.vision_tower, stacked_pixel_values
|
||||
)
|
||||
stacked_patch_embeddings = self.multi_modal_projector(
|
||||
stacked_image_features)
|
||||
stacked_image_features
|
||||
)
|
||||
|
||||
return stacked_patch_embeddings.view(
|
||||
b, num_patches, *stacked_patch_embeddings.shape[1:])
|
||||
b, num_patches, *stacked_patch_embeddings.shape[1:]
|
||||
)
|
||||
|
||||
num_patches_per_batch = [v.shape[0] for v in pixel_values]
|
||||
stacked_pixel_values = torch.cat(pixel_values)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
self.vision_tower, stacked_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
self.multi_modal_projector(image_features) for image_features in
|
||||
torch.split(stacked_image_features, num_patches_per_batch)
|
||||
self.multi_modal_projector(image_features)
|
||||
for image_features in torch.split(
|
||||
stacked_image_features, num_patches_per_batch
|
||||
)
|
||||
]
|
||||
|
||||
def _process_image_input(
|
||||
@@ -729,15 +792,17 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
batch_size = len(image_input["pixel_values"])
|
||||
vision_config = self.config.vision_config
|
||||
default_height = default_width = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_height, default_width]
|
||||
for _ in range(batch_size)])
|
||||
image_sizes = torch.as_tensor(
|
||||
[[default_height, default_width] for _ in range(batch_size)]
|
||||
)
|
||||
|
||||
return [
|
||||
self._merge_image_patch_embeddings(
|
||||
image_sizes[i],
|
||||
patch_features_batch,
|
||||
image_newline=self.image_newline,
|
||||
strategy="spatial_unpad")
|
||||
strategy="spatial_unpad",
|
||||
)
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
@@ -763,36 +828,39 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
if isinstance(video_pixels, torch.Tensor):
|
||||
total_videos, frames, c, h, w = video_pixels.shape
|
||||
video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
|
||||
w)
|
||||
video_pixels_flat = video_pixels.view(total_videos * frames, c, h, w)
|
||||
|
||||
embeddings_flat = self._video_pixels_to_features(
|
||||
self.vision_tower, video_pixels_flat)
|
||||
self.vision_tower, video_pixels_flat
|
||||
)
|
||||
|
||||
embeddings_flat = embeddings_flat.reshape(
|
||||
total_videos, frames * embeddings_flat.shape[1], -1)
|
||||
total_videos, frames * embeddings_flat.shape[1], -1
|
||||
)
|
||||
|
||||
image_newline = self.image_newline[None, None, :].expand(
|
||||
total_videos, -1, -1)
|
||||
total_videos, -1, -1
|
||||
)
|
||||
return torch.cat((embeddings_flat, image_newline), dim=1)
|
||||
|
||||
frames_per_video = [len(video) for video in video_pixels]
|
||||
video_pixels_flat = torch.cat(video_pixels)
|
||||
|
||||
embeddings_flat = self._video_pixels_to_features(
|
||||
self.vision_tower, video_pixels_flat)
|
||||
self.vision_tower, video_pixels_flat
|
||||
)
|
||||
|
||||
image_newline = self.image_newline[None, None, :]
|
||||
|
||||
return [
|
||||
torch.cat(
|
||||
(
|
||||
embeds.reshape(1, num_frame * embeddings_flat.shape[1],
|
||||
-1),
|
||||
embeds.reshape(1, num_frame * embeddings_flat.shape[1], -1),
|
||||
image_newline,
|
||||
),
|
||||
dim=1,
|
||||
) for num_frame, embeds in zip(
|
||||
)
|
||||
for num_frame, embeds in zip(
|
||||
frames_per_video,
|
||||
torch.split(embeddings_flat, frames_per_video),
|
||||
)
|
||||
@@ -808,9 +876,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# TODO support other pooling types config
|
||||
height, width = image_features.shape[2:]
|
||||
scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
|
||||
image_feature = nn.functional.interpolate(image_features,
|
||||
size=scaled_shape,
|
||||
mode='bilinear')
|
||||
image_feature = nn.functional.interpolate(
|
||||
image_features, size=scaled_shape, mode="bilinear"
|
||||
)
|
||||
image_feature = image_feature.permute(0, 2, 3, 1)
|
||||
image_feature = image_feature.view(batch_frames, -1, dim)
|
||||
return image_feature
|
||||
@@ -818,10 +886,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||
**kwargs)
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
return []
|
||||
return None
|
||||
@@ -860,10 +926,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -873,7 +938,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user