[Model] Support nested structures for TensorSchema (#26212)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,27 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# copied from : https://github.com/huggingface/transformers
|
||||
import ast
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from itertools import accumulate
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import typing
|
||||
Unpack = typing.Unpack
|
||||
else:
|
||||
import typing_extensions
|
||||
Unpack = typing_extensions.Unpack
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from timm.layers import LayerNorm, LayerNorm2d
|
||||
from timm.models.regnet import RegStage
|
||||
from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig
|
||||
@@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@@ -69,28 +60,42 @@ def get_num_combined_frames(
|
||||
return num_canvases + (leftover_frames > 0)
|
||||
|
||||
|
||||
class HCXVisionMultimodalPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values_images: list[torch.Tensor]
|
||||
class HCXVisionImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
|
||||
|
||||
Note that `height` or `width` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
Dimensions:
|
||||
- n: Number of images
|
||||
- g: Number of grids
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
image_sizes_images: list[tuple[Union[int, float]]]
|
||||
"""
|
||||
Shape: `[(height, width), ...]`
|
||||
"""
|
||||
vision_query_lengths_images: list[Union[int, float]]
|
||||
pixel_values_videos: list[tuple[Union[int, float]]]
|
||||
"""
|
||||
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
|
||||
"""
|
||||
vision_query_lengths_videos: list[Union[int, float]]
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values_images: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
|
||||
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
|
||||
|
||||
|
||||
HCXVisionMultimodalInputs = Union[HCXVisionMultimodalPixelInputs]
|
||||
HCXVisionImageInputs = HCXVisionImagePixelInputs
|
||||
|
||||
|
||||
class HCXVisionVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- n: Number of videos
|
||||
- f: Number of frames
|
||||
- g: Number of grids
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
pixel_values_videos: Annotated[
|
||||
list[list[torch.Tensor]],
|
||||
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
|
||||
|
||||
|
||||
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
|
||||
|
||||
|
||||
class HCXVisionProcessingInfo(BaseProcessingInfo):
|
||||
@@ -191,27 +196,9 @@ class HCXVisionMultiModalProcessor(
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
|
||||
def replace_multimodal_token(
|
||||
token_ids: torch.Tensor,
|
||||
target_token: int,
|
||||
repeats: list[int],
|
||||
):
|
||||
output = list[int]()
|
||||
_repeats_idx = 0
|
||||
for token_id in token_ids:
|
||||
if token_id == target_token:
|
||||
output += [token_id.item()] * repeats[_repeats_idx]
|
||||
_repeats_idx += 1
|
||||
else:
|
||||
output += [token_id.item()]
|
||||
|
||||
return torch.tensor(output, device=token_ids.device)
|
||||
|
||||
for video_idx, video_arr in enumerate(mm_data.get("videos", [])):
|
||||
if video_arr.dtype == np.uint8:
|
||||
continue
|
||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||
if video_arr.dtype != np.uint8:
|
||||
mm_data["videos"][video_idx] = video_arr.astype(np.uint8)
|
||||
|
||||
processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||
@@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
|
||||
) # text-only
|
||||
|
||||
if len(mm_data) > 0:
|
||||
images = mm_data.get("images")
|
||||
videos = mm_data.get("videos")
|
||||
|
||||
# batchify input as a single item
|
||||
images = mm_data.get("images", None)
|
||||
batched_images = None if images is None else [images]
|
||||
|
||||
# list of video in single conversation
|
||||
videos = mm_data.get("videos", None)
|
||||
batched_videos = None if videos is None else [videos]
|
||||
|
||||
_processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor=self.info.get_hf_processor(**mm_kwargs),
|
||||
data=dict(
|
||||
text=None,
|
||||
images=batched_images,
|
||||
videos=batched_videos,
|
||||
images=None if images is None else [images],
|
||||
videos=None if videos is None else [videos],
|
||||
),
|
||||
) # mm-only
|
||||
|
||||
@@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
|
||||
_processed_outputs[k] = v[0]
|
||||
|
||||
if images:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=image_token_id,
|
||||
repeats=_processed_outputs[
|
||||
"vision_query_lengths_images"],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
_processed_outputs["image_sizes_images"] = torch.tensor(
|
||||
_processed_outputs["image_sizes_images"])
|
||||
_processed_outputs[
|
||||
"vision_query_lengths_images"] = torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_images"])
|
||||
|
||||
if videos:
|
||||
_num_per_videos = [
|
||||
get_num_combined_frames(len(video)) for video in videos
|
||||
_idx_per_video = [
|
||||
0, *accumulate(
|
||||
get_num_combined_frames(len(video))
|
||||
for video in videos)
|
||||
]
|
||||
_processed_outputs["pixel_values_videos"] = [
|
||||
_processed_outputs["pixel_values_videos"]
|
||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||
for _i in range(len(videos))
|
||||
[_idx_per_video[i]:_idx_per_video[i + 1]]
|
||||
for i in range(len(videos))
|
||||
]
|
||||
_processed_outputs["vision_query_lengths_videos"] = [
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
[sum(_num_per_videos[:_i]):sum(_num_per_videos[:_i + 1])]
|
||||
for _i in range(len(videos))
|
||||
torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
[_idx_per_video[i]:_idx_per_video[i + 1]])
|
||||
for i in range(len(videos))
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
video_token_id = tokenizer.convert_tokens_to_ids(VIDEO_TOKEN)
|
||||
processed_outputs["input_ids"] = torch.stack([
|
||||
replace_multimodal_token(
|
||||
token_ids=_input_ids,
|
||||
target_token=video_token_id,
|
||||
repeats=[
|
||||
sum(lens) for lens in
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
],
|
||||
) for _input_ids in processed_outputs["input_ids"]
|
||||
],
|
||||
dim=0)
|
||||
|
||||
processed_outputs.update(_processed_outputs)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
|
||||
out_item = out_mm_kwargs[modality][item_idx]
|
||||
|
||||
if modality == "image":
|
||||
lens = out_item["vision_query_lengths_images"].data
|
||||
lens = out_item["vision_query_lengths_images"].data.tolist()
|
||||
num_tokens = self.info.get_num_image_tokens(
|
||||
vision_query_length=lens)
|
||||
elif modality == "video":
|
||||
lens = out_item["vision_query_lengths_videos"].data
|
||||
lens = out_item["vision_query_lengths_videos"].data.tolist()
|
||||
num_tokens = self.info.get_num_video_tokens(
|
||||
vision_query_length=lens)
|
||||
else:
|
||||
@@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
# image
|
||||
pixel_values_images=MultiModalFieldConfig.batched("image"),
|
||||
image_sizes_images=MultiModalFieldConfig.batched("image"),
|
||||
vision_query_lengths_images=MultiModalFieldConfig.batched("image"),
|
||||
num_queries_vis_abstractors_images=MultiModalFieldConfig.batched(
|
||||
"image"),
|
||||
num_queries_vis_abstractors_slow_images=MultiModalFieldConfig.
|
||||
batched("image"),
|
||||
first_last_frames_slows_images=MultiModalFieldConfig.batched(
|
||||
"image"),
|
||||
# video
|
||||
pixel_values_videos=MultiModalFieldConfig.batched("video"),
|
||||
image_sizes_videos=MultiModalFieldConfig.batched("video"),
|
||||
vision_query_lengths_videos=MultiModalFieldConfig.batched("video"),
|
||||
num_queries_vis_abstractors_videos=MultiModalFieldConfig.batched(
|
||||
"video"),
|
||||
num_queries_vis_abstractors_slow_videos=MultiModalFieldConfig.
|
||||
batched("video"),
|
||||
first_last_frames_slows_videos=MultiModalFieldConfig.batched(
|
||||
"video"),
|
||||
)
|
||||
|
||||
|
||||
@@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
info=_build_hcxvision_hf_info,
|
||||
dummy_inputs=HCXVisionDummyInputsBuilder)
|
||||
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
@@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
raise ValueError("Only image or video modality is supported")
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[HCXVisionImageInputs]:
|
||||
pixel_values_images = kwargs.pop("pixel_values_images", None)
|
||||
|
||||
if pixel_values_images is None:
|
||||
return None
|
||||
|
||||
image_sizes_images = kwargs.pop("image_sizes_images")
|
||||
|
||||
return HCXVisionImagePixelInputs(
|
||||
pixel_values_images=pixel_values_images,
|
||||
image_sizes_images=image_sizes_images,
|
||||
)
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[HCXVisionVideoInputs]:
|
||||
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||
|
||||
if pixel_values_videos is None:
|
||||
return None
|
||||
|
||||
return HCXVisionVideoPixelInputs(
|
||||
pixel_values_videos=pixel_values_videos, )
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: HCXVisionImageInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
return self.forward_images(
|
||||
pixel_values_images=image_input["pixel_values_images"],
|
||||
image_sizes_images=image_input["image_sizes_images"],
|
||||
)
|
||||
|
||||
def _process_video_input(
|
||||
self,
|
||||
video_input: HCXVisionVideoInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
return self.forward_videos(
|
||||
pixel_values_videos=video_input["pixel_values_videos"], )
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if (input_key == "pixel_values_images"
|
||||
and "images" not in modalities):
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if (input_key == "pixel_values_videos"
|
||||
and "videos" not in modalities):
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
**kwargs: Unpack[HCXVisionMultimodalInputs],
|
||||
**kwargs: object,
|
||||
) -> MultiModalEmbeddings:
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
multimodal_embeddings = list()
|
||||
if kwargs.get("pixel_values_images") is not None:
|
||||
for _pixel_values_images, _image_sizes_images in zip(
|
||||
kwargs["pixel_values_images"],
|
||||
kwargs["image_sizes_images"]):
|
||||
_pixel_values_images = _pixel_values_images.unsqueeze(dim=0)
|
||||
_image_sizes_images = _image_sizes_images.unsqueeze(dim=0)
|
||||
_len_pixel_values_images = [
|
||||
len(pixel_value) for pixel_value in _pixel_values_images
|
||||
]
|
||||
if isinstance(_image_sizes_images, torch.Tensor):
|
||||
_image_sizes_images = _image_sizes_images.detach().cpu(
|
||||
).tolist()
|
||||
_multimodal_embeddings_images = self.forward_images(
|
||||
pixel_values_images=_pixel_values_images,
|
||||
image_sizes_images=_image_sizes_images,
|
||||
len_pixel_values_images=_len_pixel_values_images,
|
||||
)
|
||||
_multimodal_embeddings_images = torch.cat(
|
||||
_multimodal_embeddings_images, dim=0)
|
||||
multimodal_embeddings.append(_multimodal_embeddings_images)
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
if kwargs.get("pixel_values_videos") is not None:
|
||||
for _pixel_values_videos, _vision_query_lengths_videos in zip(
|
||||
kwargs["pixel_values_videos"],
|
||||
kwargs["vision_query_lengths_videos"]):
|
||||
_len_pixel_values_videos = [
|
||||
len(_vision_query_lengths)
|
||||
for _vision_query_lengths in _vision_query_lengths_videos
|
||||
]
|
||||
_c, _w, _h = _pixel_values_videos.shape[-3:]
|
||||
_pixel_values_videos = _pixel_values_videos.reshape(
|
||||
sum(_len_pixel_values_videos), -1, _c, _w,
|
||||
_h).unsqueeze(dim=0)
|
||||
_multimodal_embeddings_videos = self.forward_videos(
|
||||
pixel_values_videos=_pixel_values_videos,
|
||||
len_pixel_values_videos=_len_pixel_values_videos,
|
||||
)
|
||||
_multimodal_embeddings_videos = torch.cat(
|
||||
_multimodal_embeddings_videos, dim=0)
|
||||
multimodal_embeddings.append(_multimodal_embeddings_videos)
|
||||
return multimodal_embeddings
|
||||
|
||||
def forward(
|
||||
@@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def forward_images(
|
||||
self,
|
||||
pixel_values_images: list[list[torch.FloatTensor]],
|
||||
image_sizes_images: list[list[tuple[int, int]]],
|
||||
len_pixel_values_images: list[int],
|
||||
) -> list[list[torch.Tensor]]:
|
||||
if sum(len_pixel_values_images) == 0:
|
||||
return None
|
||||
|
||||
concat_pixel_values_images = torch.cat(list(
|
||||
chain(*pixel_values_images)),
|
||||
dim=0)
|
||||
pixel_values_images: list[torch.Tensor],
|
||||
image_sizes_images: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
|
||||
|
||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||
image_forward_outs = self.vision_model(
|
||||
concat_pixel_values_images)[:, visual_token_idx:]
|
||||
pixel_values_image_flat)[:, visual_token_idx:]
|
||||
|
||||
image_forward_outs = image_forward_outs.to(
|
||||
dtype=self.mm_projector.dtype)
|
||||
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
|
||||
|
||||
split_sizes = [
|
||||
pixel_value.shape[0] for pixel_value in chain(*pixel_values_images)
|
||||
]
|
||||
split_sizes = [len(item) for item in pixel_values_images]
|
||||
image_forward_outs = torch.split(image_forward_outs,
|
||||
split_sizes,
|
||||
dim=0)
|
||||
@@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# newline for anyres postprocessing
|
||||
image_features = anyres_postprocessing(
|
||||
image_forward_outs=image_forward_outs,
|
||||
image_sizes=[
|
||||
image_size for image_sizes in image_sizes_images
|
||||
for image_size in image_sizes
|
||||
],
|
||||
image_sizes=image_sizes_images.tolist(),
|
||||
num_queries_vis_abstractor=self.config.
|
||||
num_queries_vis_abstractor_image,
|
||||
unpad=self.config.unpad,
|
||||
@@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_newline=self.image_newline,
|
||||
possible_resolutions=self.config.possible_resolutions,
|
||||
)
|
||||
return image_features
|
||||
|
||||
return tuple(image_features)
|
||||
|
||||
def forward_videos(
|
||||
self,
|
||||
pixel_values_videos: list[list[torch.FloatTensor]],
|
||||
len_pixel_values_videos: list[int],
|
||||
) -> list[torch.Tensor]:
|
||||
|
||||
len_video_grids = sum(len_pixel_values_videos)
|
||||
if len_video_grids == 0:
|
||||
return None
|
||||
|
||||
# Run Vision Model
|
||||
concat_pixel_values_videos = torch.cat(list(
|
||||
chain(*pixel_values_videos)),
|
||||
dim=0)
|
||||
pixel_values_videos: list[list[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
pixel_values_videos_flat = flatten_bn(
|
||||
[frame for frames in pixel_values_videos for frame in frames],
|
||||
concat=True,
|
||||
)
|
||||
|
||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||
video_forward_outs = self.vision_model(
|
||||
concat_pixel_values_videos)[:, visual_token_idx:]
|
||||
pixel_values_videos_flat)[:, visual_token_idx:]
|
||||
|
||||
video_forward_outs = video_forward_outs.to(
|
||||
dtype=self.mm_projector.dtype)
|
||||
@@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) == 0, f"target_features is not empty!! {target_features}"
|
||||
assert len(video_groups) == len(video_features)
|
||||
|
||||
return video_features
|
||||
feats_per_video = [len(video) for video in pixel_values_videos]
|
||||
idxs_per_video = [0, *accumulate(feats_per_video)]
|
||||
return tuple(
|
||||
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
|
||||
for i in range(len(feats_per_video)))
|
||||
|
||||
def _prepare_multimodal_kwargs(self, **kwargs: object):
|
||||
output = defaultdict(list)
|
||||
@@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
|
||||
|
||||
|
||||
def anyres_postprocessing(
|
||||
image_forward_outs: list[torch.FloatTensor],
|
||||
image_forward_outs: list[torch.Tensor],
|
||||
image_sizes: list[list[int]],
|
||||
possible_resolutions: list[tuple[int, int]],
|
||||
patch_size: int,
|
||||
grid_size: int,
|
||||
image_newline: torch.FloatTensor,
|
||||
image_newline: torch.Tensor,
|
||||
num_queries_vis_abstractor: int = -1,
|
||||
unpad: bool = False,
|
||||
) -> list[torch.FloatTensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
height = width = grid_size // patch_size
|
||||
|
||||
if num_queries_vis_abstractor > 0:
|
||||
@@ -1147,26 +1135,5 @@ def anyres_postprocessing(
|
||||
(image_feature, image_newline[None].to(image_feature.device)),
|
||||
dim=0)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = new_image_features
|
||||
return image_features
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: Union[np.ndarray, PIL.Image.Image],
|
||||
max_side: int = 378,
|
||||
) -> np.ndarray:
|
||||
image_arr = image
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
|
||||
width, height = image.size
|
||||
cur_max_size = max(width, height)
|
||||
if cur_max_size <= max_side:
|
||||
return image_arr
|
||||
|
||||
scale = max_side / cur_max_size
|
||||
width = int(width * scale)
|
||||
height = int(height * scale)
|
||||
image = image.resize((width, height), Image.LANCZOS)
|
||||
image_arr = np.array(image)
|
||||
return image_arr
|
||||
return new_image_features
|
||||
|
||||
Reference in New Issue
Block a user