Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -22,13 +22,19 @@ from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement, PromptUpdate)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -36,8 +42,12 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .utils import (
AutoWeightsLoader,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
@@ -48,8 +58,8 @@ VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
# Based on combine_frames_into_images in
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
def get_num_combined_frames(
num_frames: int,
max_grid_shape: tuple[int, int] = (3, 3),
num_frames: int,
max_grid_shape: tuple[int, int] = (3, 3),
) -> int:
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
@@ -69,10 +79,11 @@ class HCXVisionImagePixelInputs(TensorSchema):
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values_images: Annotated[
list[torch.Tensor],
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
list[torch.Tensor], TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})
]
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
@@ -89,17 +100,18 @@ class HCXVisionVideoPixelInputs(TensorSchema):
- h: Height
- w: Width
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
list[list[torch.Tensor]],
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"}),
]
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
class HCXVisionProcessingInfo(BaseProcessingInfo):
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
@@ -140,15 +152,14 @@ class HCXVisionProcessingInfo(BaseProcessingInfo):
)
class HCXVisionDummyInputsBuilder(
BaseDummyInputsBuilder[HCXVisionProcessingInfo]):
class HCXVisionDummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionProcessingInfo]):
def get_dummy_text(
self,
mm_counts: Mapping[str, int],
) -> str:
dummy_text = IMAGE_TOKEN * mm_counts.get(
"image", 0) + VIDEO_TOKEN * mm_counts.get("video", 0)
"image", 0
) + VIDEO_TOKEN * mm_counts.get("video", 0)
return dummy_text
def get_dummy_mm_data(
@@ -160,35 +171,30 @@ class HCXVisionDummyInputsBuilder(
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_width, target_height = self.info.get_image_size_with_most_features()
target_num_frames = 32
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
return {
"image":
self._get_dummy_images(
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
"video":
self._get_dummy_videos(
"video": self._get_dummy_videos(
width=target_width - 1,
height=target_height - 1,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides,
)
),
}
class HCXVisionMultiModalProcessor(
BaseMultiModalProcessor[HCXVisionProcessingInfo]):
class HCXVisionMultiModalProcessor(BaseMultiModalProcessor[HCXVisionProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
@@ -230,26 +236,31 @@ class HCXVisionMultiModalProcessor(
if images:
_processed_outputs["image_sizes_images"] = torch.tensor(
_processed_outputs["image_sizes_images"])
_processed_outputs[
"vision_query_lengths_images"] = torch.tensor(
_processed_outputs["vision_query_lengths_images"])
_processed_outputs["image_sizes_images"]
)
_processed_outputs["vision_query_lengths_images"] = torch.tensor(
_processed_outputs["vision_query_lengths_images"]
)
if videos:
_idx_per_video = [
0, *accumulate(
get_num_combined_frames(len(video))
for video in videos)
0,
*accumulate(
get_num_combined_frames(len(video)) for video in videos
),
]
_processed_outputs["pixel_values_videos"] = [
_processed_outputs["pixel_values_videos"]
[_idx_per_video[i]:_idx_per_video[i + 1]]
_processed_outputs["pixel_values_videos"][
_idx_per_video[i] : _idx_per_video[i + 1]
]
for i in range(len(videos))
]
_processed_outputs["vision_query_lengths_videos"] = [
torch.tensor(
_processed_outputs["vision_query_lengths_videos"]
[_idx_per_video[i]:_idx_per_video[i + 1]])
_processed_outputs["vision_query_lengths_videos"][
_idx_per_video[i] : _idx_per_video[i + 1]
]
)
for i in range(len(videos))
]
@@ -287,12 +298,10 @@ class HCXVisionMultiModalProcessor(
if modality == "image":
lens = out_item["vision_query_lengths_images"].data.tolist()
num_tokens = self.info.get_num_image_tokens(
vision_query_length=lens)
num_tokens = self.info.get_num_image_tokens(vision_query_length=lens)
elif modality == "video":
lens = out_item["vision_query_lengths_videos"].data.tolist()
num_tokens = self.info.get_num_video_tokens(
vision_query_length=lens)
num_tokens = self.info.get_num_video_tokens(vision_query_length=lens)
else:
raise NotImplementedError(modality)
@@ -309,7 +318,8 @@ class HCXVisionMultiModalProcessor(
modality=modality,
out_mm_kwargs=out_mm_kwargs,
),
) for modality in ("image", "video")
)
for modality in ("image", "video")
]
def _get_mm_fields_config(
@@ -327,7 +337,8 @@ class HCXVisionMultiModalProcessor(
def _build_hcxvision_hf_info(
ctx: InputProcessingContext, ) -> HCXVisionProcessingInfo:
ctx: InputProcessingContext,
) -> HCXVisionProcessingInfo:
return HCXVisionProcessingInfo(ctx)
@@ -385,7 +396,6 @@ def init_vision_tower_for_hcxvision(
class HCXVisionMlp(nn.Module):
def __init__(
self,
mm_projector_type,
@@ -407,8 +417,9 @@ class HCXVisionMlp(nn.Module):
self.act = act_layer()
self.fc2 = nn.Linear(2 * hidden_features, out_features)
else:
raise NotImplementedError("{} is not implemented".format(
self.mm_projector_type))
raise NotImplementedError(
"{} is not implemented".format(self.mm_projector_type)
)
def forward(self, x):
x = self.fc1(x)
@@ -420,7 +431,7 @@ class HCXVisionMlp(nn.Module):
class HCXVisionCAbstractor(nn.Module):
"""
This module is based on C-Abstractor, whose license is under apache-2.0.
You can check the original code at
You can check the original code at
https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
and we made necessary modifications.
"""
@@ -442,7 +453,8 @@ class HCXVisionCAbstractor(nn.Module):
# Positional embedding
if pos_emb:
self.pos_emb = torch.nn.Parameter(
torch.zeros(1, num_input_tokens, encoder_hidden_size))
torch.zeros(1, num_input_tokens, encoder_hidden_size)
)
self.pos_emb.data.normal_(mean=0.0, std=0.02)
else:
self.pos_emb = None
@@ -453,8 +465,9 @@ class HCXVisionCAbstractor(nn.Module):
else:
self.prenorm = None
self.build_net(num_queries, encoder_hidden_size, hidden_size,
output_hidden_size)
self.build_net(
num_queries, encoder_hidden_size, hidden_size, output_hidden_size
)
self.dtype = next(self.parameters()).dtype
def forward(
@@ -491,7 +504,8 @@ class HCXVisionCAbstractor(nn.Module):
if num_queries_vis_abstractors is not None:
assert num_grids is not None
return self._forward_adaptive_num_query(
x, num_queries_vis_abstractors, num_grids)
x, num_queries_vis_abstractors, num_grids
)
x = self.net(x)
x = rearrange(x, "b d h w -> b (h w) d")
@@ -512,7 +526,7 @@ class HCXVisionCAbstractor(nn.Module):
for i, num_queries in enumerate(num_queries_vis_abstractors):
hw = int(num_queries**0.5)
sampler = nn.AdaptiveAvgPool2d((hw, hw))
out = sampler(x[num_grids[i]:num_grids[i + 1], :])
out = sampler(x[num_grids[i] : num_grids[i + 1], :])
out = self.net[2](out) # s2
out = rearrange(out, "b d h w -> b (h w) d")
@@ -530,8 +544,9 @@ class HCXVisionCAbstractor(nn.Module):
depth: int = 3,
mlp_depth: int = 2,
):
assert (n_queries**0.5).is_integer(
), f"n_queries must be square number. n_queries: {n_queries}"
assert (n_queries**0.5).is_integer(), (
f"n_queries must be square number. n_queries: {n_queries}"
)
hw = int(n_queries**0.5)
# RegBlock = ResBlock + SE
@@ -556,8 +571,7 @@ class HCXVisionCAbstractor(nn.Module):
)
self.net = nn.Sequential(s1, sampler, s2)
self.readout = self.build_mlp(mlp_depth, hidden_size,
output_hidden_size)
self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
def build_mlp(
self,
@@ -575,13 +589,14 @@ class HCXVisionCAbstractor(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(
_build_hcxvision_hf_processor,
info=_build_hcxvision_hf_info,
dummy_inputs=HCXVisionDummyInputsBuilder)
dummy_inputs=HCXVisionDummyInputsBuilder,
)
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(
@@ -611,7 +626,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
## possible_resolution should be matched with preprocessor_config.json
config.possible_resolutions = self._init_possible_resolutions(
config, vision_config)
config, vision_config
)
# init models & parameters
with no_init_weights(): # weight will be loaded in from_pretrained
@@ -622,11 +638,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mm_projector = self._init_mm_projector(config, text_config,
vision_config)
self.mm_projector = self._init_mm_projector(config, text_config, vision_config)
self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size",
text_config.vocab_size)
self.lm_head_vocab_size = getattr(
text_config, "padded_vocab_size", text_config.vocab_size
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=text_config,
@@ -635,7 +651,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if config.anyres:
self.image_newline = nn.Parameter(
torch.empty(text_config.hidden_size, dtype=self.dtype))
torch.empty(text_config.hidden_size, dtype=self.dtype)
)
self.config = config
self.vision_config = vision_config
@@ -679,7 +696,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None
return HCXVisionVideoPixelInputs(
pixel_values_videos=pixel_values_videos, )
pixel_values_videos=pixel_values_videos,
)
def _process_image_input(
self,
@@ -695,7 +713,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
video_input: HCXVisionVideoInputs,
) -> tuple[torch.Tensor, ...]:
return self.forward_videos(
pixel_values_videos=video_input["pixel_values_videos"], )
pixel_values_videos=video_input["pixel_values_videos"],
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
@@ -703,14 +722,10 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if (input_key == "pixel_values_images"
and "images" not in modalities):
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if (input_key == "pixel_values_videos"
and "videos" not in modalities):
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
if input_key == "pixel_values_images" and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key == "pixel_values_videos" and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
return modalities
@@ -754,10 +769,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def forward_images(
@@ -768,24 +782,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
image_forward_outs = self.vision_model(
pixel_values_image_flat)[:, visual_token_idx:]
image_forward_outs = self.vision_model(pixel_values_image_flat)[
:, visual_token_idx:
]
image_forward_outs = image_forward_outs.to(
dtype=self.mm_projector.dtype)
image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
split_sizes = [len(item) for item in pixel_values_images]
image_forward_outs = torch.split(image_forward_outs,
split_sizes,
dim=0)
image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0)
# newline for anyres postprocessing
image_features = anyres_postprocessing(
image_forward_outs=image_forward_outs,
image_sizes=image_sizes_images.tolist(),
num_queries_vis_abstractor=self.config.
num_queries_vis_abstractor_image,
num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image,
unpad=self.config.unpad,
patch_size=self.vision_config.patch_size,
grid_size=self.vision_config.image_size,
@@ -805,11 +816,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
video_forward_outs = self.vision_model(
pixel_values_videos_flat)[:, visual_token_idx:]
video_forward_outs = self.vision_model(pixel_values_videos_flat)[
:, visual_token_idx:
]
video_forward_outs = video_forward_outs.to(
dtype=self.mm_projector.dtype)
video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype)
# Run MM-Projector
# len(num_grids) == len(num_queries_vis_abstractors) + 1
@@ -817,8 +828,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
num_grids = [
grid_idx
] # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56]
num_queries_vis_abstractors = [
] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
num_queries_vis_abstractors = [] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
len_total_frames = video_forward_outs.shape[0]
if self.config.first_last_frames_slow:
@@ -826,22 +836,26 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
assert len_total_frames != 0
if len_total_frames <= 2:
num_queries_vis_abstractors.append(
self.config.num_queries_vis_abstractor_video_slow)
self.config.num_queries_vis_abstractor_video_slow
)
grid_idx += len_total_frames
num_grids.append(grid_idx)
else:
num_queries_vis_abstractors.append(
self.config.num_queries_vis_abstractor_video_slow)
self.config.num_queries_vis_abstractor_video_slow
)
grid_idx += 1
num_grids.append(grid_idx)
num_queries_vis_abstractors.append(
self.config.num_queries_vis_abstractor_video_fast)
self.config.num_queries_vis_abstractor_video_fast
)
grid_idx += len_total_frames - 2
num_grids.append(grid_idx)
num_queries_vis_abstractors.append(
self.config.num_queries_vis_abstractor_video_slow)
self.config.num_queries_vis_abstractor_video_slow
)
grid_idx += 1
num_grids.append(grid_idx)
else:
@@ -850,17 +864,19 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
for pixel_values_frame in pixel_values_frames:
if len(pixel_values_frame) > 0:
num_queries_vis_abstractors.append(
self.config.num_queries_vis_abstractor_video_slow)
self.config.num_queries_vis_abstractor_video_slow
)
grid_idx += 1
num_grids.append(grid_idx)
num_queries_vis_abstractors.append(
self.config.num_queries_vis_abstractor_video_fast)
self.config.num_queries_vis_abstractor_video_fast
)
grid_idx = grid_idx + len(pixel_values_frame) - 1
num_grids.append(grid_idx)
video_forward_outs = self.mm_projector(video_forward_outs,
num_queries_vis_abstractors,
num_grids)
video_forward_outs = self.mm_projector(
video_forward_outs, num_queries_vis_abstractors, num_grids
)
video_features = [] # what we want to return
target_features = []
@@ -882,18 +898,19 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
target_group_size = 0
elif video_group_size < target_group_size:
raise RuntimeError(
f"{video_group_size=} < {target_group_size=}")
raise RuntimeError(f"{video_group_size=} < {target_group_size=}")
assert len(target_features
) == 0, f"target_features is not empty!! {target_features}"
assert len(target_features) == 0, (
f"target_features is not empty!! {target_features}"
)
assert len(video_groups) == len(video_features)
feats_per_video = [len(video) for video in pixel_values_videos]
idxs_per_video = [0, *accumulate(feats_per_video)]
return tuple(
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
for i in range(len(feats_per_video)))
torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]])
for i in range(len(feats_per_video))
)
def _prepare_multimodal_kwargs(self, **kwargs: object):
output = defaultdict(list)
@@ -902,7 +919,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
continue # if empty batch of empty sample
new_k, is_video = k, False
if (not k.endswith("_images") and not k.endswith("_videos")):
if not k.endswith("_images") and not k.endswith("_videos"):
pass
else:
new_k, is_video = k.split("_")[:-1], k.split("_")[-1]
@@ -955,10 +972,10 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if i * j <= config.max_num_grids:
possible_resolutions.append([i, j])
possible_resolutions = [[
ys * vision_config.image_size,
xs * vision_config.image_size
] for ys, xs in possible_resolutions]
possible_resolutions = [
[ys * vision_config.image_size, xs * vision_config.image_size]
for ys, xs in possible_resolutions
]
return possible_resolutions
else:
return config.possible_resolutions
@@ -971,14 +988,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
):
input_hidden_size = vision_config.hidden_size
if config.mm_projector_type == "linear":
mm_projector = nn.Linear(input_hidden_size,
text_config.hidden_size)
mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
mm_projector.dtype = next(mm_projector.parameters()).dtype
elif config.mm_projector_type == "cabstractor":
mm_projector = HCXVisionCAbstractor(
num_queries=config.num_queries_vis_abstractor_image,
num_input_tokens=(vision_config.image_size //
vision_config.patch_size)**2,
num_input_tokens=(vision_config.image_size // vision_config.patch_size)
** 2,
encoder_hidden_size=input_hidden_size,
hidden_size=input_hidden_size,
output_hidden_size=text_config.hidden_size,
@@ -995,8 +1011,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return mm_projector
def unpad_image(tensor: torch.Tensor,
original_size: tuple[int, int]) -> torch.Tensor:
def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor:
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
@@ -1007,18 +1022,17 @@ def unpad_image(tensor: torch.Tensor,
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
def select_best_resolution(original_size: tuple,
possible_resolutions: list) -> tuple:
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
original_height, original_width = original_size
best_fit = None
max_effective_resolution = 0
@@ -1026,15 +1040,19 @@ def select_best_resolution(original_size: tuple,
for height, width in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(
original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height,
original_width * original_height)
downscaled_width, downscaled_height = (
int(original_width * scale),
int(original_height * scale),
)
effective_resolution = min(
downscaled_width * downscaled_height, original_width * original_height
)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution):
effective_resolution == max_effective_resolution
and wasted_resolution < min_wasted_resolution
):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (height, width)
@@ -1047,12 +1065,16 @@ def get_anyres_image_grid_shape(
grid_pinpoints: Union[str, list[tuple[int, int]]],
patch_size: int,
) -> tuple[int, int]:
possible_resolutions = grid_pinpoints if isinstance(
grid_pinpoints, list) else ast.literal_eval(grid_pinpoints)
possible_resolutions = (
grid_pinpoints
if isinstance(grid_pinpoints, list)
else ast.literal_eval(grid_pinpoints)
)
original_width, original_height = image_size
height, width = select_best_resolution((original_height, original_width),
possible_resolutions)
height, width = select_best_resolution(
(original_height, original_width), possible_resolutions
)
return width // patch_size, height // patch_size
@@ -1070,12 +1092,15 @@ def reshape_and_unpad_image_features(
image_feature = image_feature[1:]
assert height * width == base_image_feature.shape[0], (
f"{height=} * {width=} != {base_image_feature.shape[0]=}")
f"{height=} * {width=} != {base_image_feature.shape[0]=}"
)
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size, possible_resolutions, grid_size)
image_feature = image_feature.view(num_patch_height, num_patch_width,
height, width, -1)
image_size, possible_resolutions, grid_size
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
if unpad:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
@@ -1084,8 +1109,9 @@ def reshape_and_unpad_image_features(
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1).to(image_feature.device),
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
@@ -1111,8 +1137,9 @@ def anyres_postprocessing(
height = width = grid_size // patch_size
if num_queries_vis_abstractor > 0:
assert (num_queries_vis_abstractor**0.5
).is_integer(), "n_queries must be square number"
assert (num_queries_vis_abstractor**0.5).is_integer(), (
"n_queries must be square number"
)
height = width = int(num_queries_vis_abstractor**0.5)
# post-processing (unpad, add newline)
@@ -1132,8 +1159,8 @@ def anyres_postprocessing(
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, image_newline[None].to(image_feature.device)),
dim=0)
(image_feature, image_newline[None].to(image_feature.device)), dim=0
)
new_image_features.append(image_feature)
return new_image_features