Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -22,13 +22,19 @@ from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -36,8 +42,12 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@@ -48,8 +58,8 @@ VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
|
||||
# Based on combine_frames_into_images in
|
||||
# https://huggingface.co/naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B/blob/main/processing_hyperclovax.py
|
||||
def get_num_combined_frames(
|
||||
num_frames: int,
|
||||
max_grid_shape: tuple[int, int] = (3, 3),
|
||||
num_frames: int,
|
||||
max_grid_shape: tuple[int, int] = (3, 3),
|
||||
) -> int:
|
||||
max_num_grids = max_grid_shape[0] * max_grid_shape[1]
|
||||
|
||||
@@ -69,10 +79,11 @@ class HCXVisionImagePixelInputs(TensorSchema):
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values_images: Annotated[
|
||||
list[torch.Tensor],
|
||||
TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})]
|
||||
list[torch.Tensor], TensorShape("n", "g", 3, "h", "w", dynamic_dims={"g"})
|
||||
]
|
||||
image_sizes_images: Annotated[torch.Tensor, TensorShape("n", 2)]
|
||||
|
||||
|
||||
@@ -89,17 +100,18 @@ class HCXVisionVideoPixelInputs(TensorSchema):
|
||||
- h: Height
|
||||
- w: Width
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
pixel_values_videos: Annotated[
|
||||
list[list[torch.Tensor]],
|
||||
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"})]
|
||||
TensorShape("n", "f", "g", 3, "h", "w", dynamic_dims={"f", "g"}),
|
||||
]
|
||||
|
||||
|
||||
HCXVisionVideoInputs = HCXVisionVideoPixelInputs
|
||||
|
||||
|
||||
class HCXVisionProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self.get_hf_config())
|
||||
|
||||
@@ -140,15 +152,14 @@ class HCXVisionProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
|
||||
|
||||
class HCXVisionDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[HCXVisionProcessingInfo]):
|
||||
|
||||
class HCXVisionDummyInputsBuilder(BaseDummyInputsBuilder[HCXVisionProcessingInfo]):
|
||||
def get_dummy_text(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> str:
|
||||
dummy_text = IMAGE_TOKEN * mm_counts.get(
|
||||
"image", 0) + VIDEO_TOKEN * mm_counts.get("video", 0)
|
||||
"image", 0
|
||||
) + VIDEO_TOKEN * mm_counts.get("video", 0)
|
||||
return dummy_text
|
||||
|
||||
def get_dummy_mm_data(
|
||||
@@ -160,35 +171,30 @@ class HCXVisionDummyInputsBuilder(
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
target_num_frames = 32
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
video_overrides = mm_options.get("video") if mm_options else None
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(
|
||||
"image": self._get_dummy_images(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images,
|
||||
overrides=image_overrides,
|
||||
),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
"video": self._get_dummy_videos(
|
||||
width=target_width - 1,
|
||||
height=target_height - 1,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HCXVisionMultiModalProcessor(
|
||||
BaseMultiModalProcessor[HCXVisionProcessingInfo]):
|
||||
|
||||
class HCXVisionMultiModalProcessor(BaseMultiModalProcessor[HCXVisionProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -230,26 +236,31 @@ class HCXVisionMultiModalProcessor(
|
||||
|
||||
if images:
|
||||
_processed_outputs["image_sizes_images"] = torch.tensor(
|
||||
_processed_outputs["image_sizes_images"])
|
||||
_processed_outputs[
|
||||
"vision_query_lengths_images"] = torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_images"])
|
||||
_processed_outputs["image_sizes_images"]
|
||||
)
|
||||
_processed_outputs["vision_query_lengths_images"] = torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_images"]
|
||||
)
|
||||
|
||||
if videos:
|
||||
_idx_per_video = [
|
||||
0, *accumulate(
|
||||
get_num_combined_frames(len(video))
|
||||
for video in videos)
|
||||
0,
|
||||
*accumulate(
|
||||
get_num_combined_frames(len(video)) for video in videos
|
||||
),
|
||||
]
|
||||
_processed_outputs["pixel_values_videos"] = [
|
||||
_processed_outputs["pixel_values_videos"]
|
||||
[_idx_per_video[i]:_idx_per_video[i + 1]]
|
||||
_processed_outputs["pixel_values_videos"][
|
||||
_idx_per_video[i] : _idx_per_video[i + 1]
|
||||
]
|
||||
for i in range(len(videos))
|
||||
]
|
||||
_processed_outputs["vision_query_lengths_videos"] = [
|
||||
torch.tensor(
|
||||
_processed_outputs["vision_query_lengths_videos"]
|
||||
[_idx_per_video[i]:_idx_per_video[i + 1]])
|
||||
_processed_outputs["vision_query_lengths_videos"][
|
||||
_idx_per_video[i] : _idx_per_video[i + 1]
|
||||
]
|
||||
)
|
||||
for i in range(len(videos))
|
||||
]
|
||||
|
||||
@@ -287,12 +298,10 @@ class HCXVisionMultiModalProcessor(
|
||||
|
||||
if modality == "image":
|
||||
lens = out_item["vision_query_lengths_images"].data.tolist()
|
||||
num_tokens = self.info.get_num_image_tokens(
|
||||
vision_query_length=lens)
|
||||
num_tokens = self.info.get_num_image_tokens(vision_query_length=lens)
|
||||
elif modality == "video":
|
||||
lens = out_item["vision_query_lengths_videos"].data.tolist()
|
||||
num_tokens = self.info.get_num_video_tokens(
|
||||
vision_query_length=lens)
|
||||
num_tokens = self.info.get_num_video_tokens(vision_query_length=lens)
|
||||
else:
|
||||
raise NotImplementedError(modality)
|
||||
|
||||
@@ -309,7 +318,8 @@ class HCXVisionMultiModalProcessor(
|
||||
modality=modality,
|
||||
out_mm_kwargs=out_mm_kwargs,
|
||||
),
|
||||
) for modality in ("image", "video")
|
||||
)
|
||||
for modality in ("image", "video")
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@@ -327,7 +337,8 @@ class HCXVisionMultiModalProcessor(
|
||||
|
||||
|
||||
def _build_hcxvision_hf_info(
|
||||
ctx: InputProcessingContext, ) -> HCXVisionProcessingInfo:
|
||||
ctx: InputProcessingContext,
|
||||
) -> HCXVisionProcessingInfo:
|
||||
return HCXVisionProcessingInfo(ctx)
|
||||
|
||||
|
||||
@@ -385,7 +396,6 @@ def init_vision_tower_for_hcxvision(
|
||||
|
||||
|
||||
class HCXVisionMlp(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mm_projector_type,
|
||||
@@ -407,8 +417,9 @@ class HCXVisionMlp(nn.Module):
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(2 * hidden_features, out_features)
|
||||
else:
|
||||
raise NotImplementedError("{} is not implemented".format(
|
||||
self.mm_projector_type))
|
||||
raise NotImplementedError(
|
||||
"{} is not implemented".format(self.mm_projector_type)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
@@ -420,7 +431,7 @@ class HCXVisionMlp(nn.Module):
|
||||
class HCXVisionCAbstractor(nn.Module):
|
||||
"""
|
||||
This module is based on C-Abstractor, whose license is under apache-2.0.
|
||||
You can check the original code at
|
||||
You can check the original code at
|
||||
https://github.com/khanrc/honeybee/blob/main/honeybee/projectors/projectors.py
|
||||
and we made necessary modifications.
|
||||
"""
|
||||
@@ -442,7 +453,8 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
# Positional embedding
|
||||
if pos_emb:
|
||||
self.pos_emb = torch.nn.Parameter(
|
||||
torch.zeros(1, num_input_tokens, encoder_hidden_size))
|
||||
torch.zeros(1, num_input_tokens, encoder_hidden_size)
|
||||
)
|
||||
self.pos_emb.data.normal_(mean=0.0, std=0.02)
|
||||
else:
|
||||
self.pos_emb = None
|
||||
@@ -453,8 +465,9 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
else:
|
||||
self.prenorm = None
|
||||
|
||||
self.build_net(num_queries, encoder_hidden_size, hidden_size,
|
||||
output_hidden_size)
|
||||
self.build_net(
|
||||
num_queries, encoder_hidden_size, hidden_size, output_hidden_size
|
||||
)
|
||||
self.dtype = next(self.parameters()).dtype
|
||||
|
||||
def forward(
|
||||
@@ -491,7 +504,8 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
if num_queries_vis_abstractors is not None:
|
||||
assert num_grids is not None
|
||||
return self._forward_adaptive_num_query(
|
||||
x, num_queries_vis_abstractors, num_grids)
|
||||
x, num_queries_vis_abstractors, num_grids
|
||||
)
|
||||
|
||||
x = self.net(x)
|
||||
x = rearrange(x, "b d h w -> b (h w) d")
|
||||
@@ -512,7 +526,7 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
for i, num_queries in enumerate(num_queries_vis_abstractors):
|
||||
hw = int(num_queries**0.5)
|
||||
sampler = nn.AdaptiveAvgPool2d((hw, hw))
|
||||
out = sampler(x[num_grids[i]:num_grids[i + 1], :])
|
||||
out = sampler(x[num_grids[i] : num_grids[i + 1], :])
|
||||
out = self.net[2](out) # s2
|
||||
|
||||
out = rearrange(out, "b d h w -> b (h w) d")
|
||||
@@ -530,8 +544,9 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
depth: int = 3,
|
||||
mlp_depth: int = 2,
|
||||
):
|
||||
assert (n_queries**0.5).is_integer(
|
||||
), f"n_queries must be square number. n_queries: {n_queries}"
|
||||
assert (n_queries**0.5).is_integer(), (
|
||||
f"n_queries must be square number. n_queries: {n_queries}"
|
||||
)
|
||||
hw = int(n_queries**0.5)
|
||||
|
||||
# RegBlock = ResBlock + SE
|
||||
@@ -556,8 +571,7 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(s1, sampler, s2)
|
||||
self.readout = self.build_mlp(mlp_depth, hidden_size,
|
||||
output_hidden_size)
|
||||
self.readout = self.build_mlp(mlp_depth, hidden_size, output_hidden_size)
|
||||
|
||||
def build_mlp(
|
||||
self,
|
||||
@@ -575,13 +589,14 @@ class HCXVisionCAbstractor(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
_build_hcxvision_hf_processor,
|
||||
info=_build_hcxvision_hf_info,
|
||||
dummy_inputs=HCXVisionDummyInputsBuilder)
|
||||
dummy_inputs=HCXVisionDummyInputsBuilder,
|
||||
)
|
||||
class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -611,7 +626,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
## possible_resolution should be matched with preprocessor_config.json
|
||||
config.possible_resolutions = self._init_possible_resolutions(
|
||||
config, vision_config)
|
||||
config, vision_config
|
||||
)
|
||||
|
||||
# init models & parameters
|
||||
with no_init_weights(): # weight will be loaded in from_pretrained
|
||||
@@ -622,11 +638,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
self.mm_projector = self._init_mm_projector(config, text_config,
|
||||
vision_config)
|
||||
self.mm_projector = self._init_mm_projector(config, text_config, vision_config)
|
||||
|
||||
self.lm_head_vocab_size = getattr(text_config, "padded_vocab_size",
|
||||
text_config.vocab_size)
|
||||
self.lm_head_vocab_size = getattr(
|
||||
text_config, "padded_vocab_size", text_config.vocab_size
|
||||
)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=text_config,
|
||||
@@ -635,7 +651,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
if config.anyres:
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(text_config.hidden_size, dtype=self.dtype))
|
||||
torch.empty(text_config.hidden_size, dtype=self.dtype)
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.vision_config = vision_config
|
||||
@@ -679,7 +696,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
return HCXVisionVideoPixelInputs(
|
||||
pixel_values_videos=pixel_values_videos, )
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
@@ -695,7 +713,8 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
video_input: HCXVisionVideoInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
return self.forward_videos(
|
||||
pixel_values_videos=video_input["pixel_values_videos"], )
|
||||
pixel_values_videos=video_input["pixel_values_videos"],
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
@@ -703,14 +722,10 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if (input_key == "pixel_values_images"
|
||||
and "images" not in modalities):
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if (input_key == "pixel_values_videos"
|
||||
and "videos" not in modalities):
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
if input_key == "pixel_values_images" and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key == "pixel_values_videos" and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
@@ -754,10 +769,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward_images(
|
||||
@@ -768,24 +782,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
pixel_values_image_flat = flatten_bn(pixel_values_images, concat=True)
|
||||
|
||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||
image_forward_outs = self.vision_model(
|
||||
pixel_values_image_flat)[:, visual_token_idx:]
|
||||
image_forward_outs = self.vision_model(pixel_values_image_flat)[
|
||||
:, visual_token_idx:
|
||||
]
|
||||
|
||||
image_forward_outs = image_forward_outs.to(
|
||||
dtype=self.mm_projector.dtype)
|
||||
image_forward_outs = image_forward_outs.to(dtype=self.mm_projector.dtype)
|
||||
image_forward_outs = self.mm_projector(image_forward_outs) # b (h w) d
|
||||
|
||||
split_sizes = [len(item) for item in pixel_values_images]
|
||||
image_forward_outs = torch.split(image_forward_outs,
|
||||
split_sizes,
|
||||
dim=0)
|
||||
image_forward_outs = torch.split(image_forward_outs, split_sizes, dim=0)
|
||||
|
||||
# newline for anyres postprocessing
|
||||
image_features = anyres_postprocessing(
|
||||
image_forward_outs=image_forward_outs,
|
||||
image_sizes=image_sizes_images.tolist(),
|
||||
num_queries_vis_abstractor=self.config.
|
||||
num_queries_vis_abstractor_image,
|
||||
num_queries_vis_abstractor=self.config.num_queries_vis_abstractor_image,
|
||||
unpad=self.config.unpad,
|
||||
patch_size=self.vision_config.patch_size,
|
||||
grid_size=self.vision_config.image_size,
|
||||
@@ -805,11 +816,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
visual_token_idx = 0 if "siglip" in self.vision_config.model_type else 1
|
||||
video_forward_outs = self.vision_model(
|
||||
pixel_values_videos_flat)[:, visual_token_idx:]
|
||||
video_forward_outs = self.vision_model(pixel_values_videos_flat)[
|
||||
:, visual_token_idx:
|
||||
]
|
||||
|
||||
video_forward_outs = video_forward_outs.to(
|
||||
dtype=self.mm_projector.dtype)
|
||||
video_forward_outs = video_forward_outs.to(dtype=self.mm_projector.dtype)
|
||||
|
||||
# Run MM-Projector
|
||||
# len(num_grids) == len(num_queries_vis_abstractors) + 1
|
||||
@@ -817,8 +828,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
num_grids = [
|
||||
grid_idx
|
||||
] # e.g. [0, 9, 18, 19, 27, 28, 36, 37, 45, 46, 54, 55, 56]
|
||||
num_queries_vis_abstractors = [
|
||||
] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
|
||||
num_queries_vis_abstractors = [] # e.g. [81, 81, 81, 9, 81, 9, 81, 9, 81, 9, 81, 9]
|
||||
len_total_frames = video_forward_outs.shape[0]
|
||||
|
||||
if self.config.first_last_frames_slow:
|
||||
@@ -826,22 +836,26 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
assert len_total_frames != 0
|
||||
if len_total_frames <= 2:
|
||||
num_queries_vis_abstractors.append(
|
||||
self.config.num_queries_vis_abstractor_video_slow)
|
||||
self.config.num_queries_vis_abstractor_video_slow
|
||||
)
|
||||
grid_idx += len_total_frames
|
||||
num_grids.append(grid_idx)
|
||||
else:
|
||||
num_queries_vis_abstractors.append(
|
||||
self.config.num_queries_vis_abstractor_video_slow)
|
||||
self.config.num_queries_vis_abstractor_video_slow
|
||||
)
|
||||
grid_idx += 1
|
||||
num_grids.append(grid_idx)
|
||||
|
||||
num_queries_vis_abstractors.append(
|
||||
self.config.num_queries_vis_abstractor_video_fast)
|
||||
self.config.num_queries_vis_abstractor_video_fast
|
||||
)
|
||||
grid_idx += len_total_frames - 2
|
||||
num_grids.append(grid_idx)
|
||||
|
||||
num_queries_vis_abstractors.append(
|
||||
self.config.num_queries_vis_abstractor_video_slow)
|
||||
self.config.num_queries_vis_abstractor_video_slow
|
||||
)
|
||||
grid_idx += 1
|
||||
num_grids.append(grid_idx)
|
||||
else:
|
||||
@@ -850,17 +864,19 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
for pixel_values_frame in pixel_values_frames:
|
||||
if len(pixel_values_frame) > 0:
|
||||
num_queries_vis_abstractors.append(
|
||||
self.config.num_queries_vis_abstractor_video_slow)
|
||||
self.config.num_queries_vis_abstractor_video_slow
|
||||
)
|
||||
grid_idx += 1
|
||||
num_grids.append(grid_idx)
|
||||
num_queries_vis_abstractors.append(
|
||||
self.config.num_queries_vis_abstractor_video_fast)
|
||||
self.config.num_queries_vis_abstractor_video_fast
|
||||
)
|
||||
grid_idx = grid_idx + len(pixel_values_frame) - 1
|
||||
num_grids.append(grid_idx)
|
||||
|
||||
video_forward_outs = self.mm_projector(video_forward_outs,
|
||||
num_queries_vis_abstractors,
|
||||
num_grids)
|
||||
video_forward_outs = self.mm_projector(
|
||||
video_forward_outs, num_queries_vis_abstractors, num_grids
|
||||
)
|
||||
|
||||
video_features = [] # what we want to return
|
||||
target_features = []
|
||||
@@ -882,18 +898,19 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
target_group_size = 0
|
||||
|
||||
elif video_group_size < target_group_size:
|
||||
raise RuntimeError(
|
||||
f"{video_group_size=} < {target_group_size=}")
|
||||
raise RuntimeError(f"{video_group_size=} < {target_group_size=}")
|
||||
|
||||
assert len(target_features
|
||||
) == 0, f"target_features is not empty!! {target_features}"
|
||||
assert len(target_features) == 0, (
|
||||
f"target_features is not empty!! {target_features}"
|
||||
)
|
||||
assert len(video_groups) == len(video_features)
|
||||
|
||||
feats_per_video = [len(video) for video in pixel_values_videos]
|
||||
idxs_per_video = [0, *accumulate(feats_per_video)]
|
||||
return tuple(
|
||||
torch.cat(video_features[idxs_per_video[i]:idxs_per_video[i + 1]])
|
||||
for i in range(len(feats_per_video)))
|
||||
torch.cat(video_features[idxs_per_video[i] : idxs_per_video[i + 1]])
|
||||
for i in range(len(feats_per_video))
|
||||
)
|
||||
|
||||
def _prepare_multimodal_kwargs(self, **kwargs: object):
|
||||
output = defaultdict(list)
|
||||
@@ -902,7 +919,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
continue # if empty batch of empty sample
|
||||
|
||||
new_k, is_video = k, False
|
||||
if (not k.endswith("_images") and not k.endswith("_videos")):
|
||||
if not k.endswith("_images") and not k.endswith("_videos"):
|
||||
pass
|
||||
else:
|
||||
new_k, is_video = k.split("_")[:-1], k.split("_")[-1]
|
||||
@@ -955,10 +972,10 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if i * j <= config.max_num_grids:
|
||||
possible_resolutions.append([i, j])
|
||||
|
||||
possible_resolutions = [[
|
||||
ys * vision_config.image_size,
|
||||
xs * vision_config.image_size
|
||||
] for ys, xs in possible_resolutions]
|
||||
possible_resolutions = [
|
||||
[ys * vision_config.image_size, xs * vision_config.image_size]
|
||||
for ys, xs in possible_resolutions
|
||||
]
|
||||
return possible_resolutions
|
||||
else:
|
||||
return config.possible_resolutions
|
||||
@@ -971,14 +988,13 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
):
|
||||
input_hidden_size = vision_config.hidden_size
|
||||
if config.mm_projector_type == "linear":
|
||||
mm_projector = nn.Linear(input_hidden_size,
|
||||
text_config.hidden_size)
|
||||
mm_projector = nn.Linear(input_hidden_size, text_config.hidden_size)
|
||||
mm_projector.dtype = next(mm_projector.parameters()).dtype
|
||||
elif config.mm_projector_type == "cabstractor":
|
||||
mm_projector = HCXVisionCAbstractor(
|
||||
num_queries=config.num_queries_vis_abstractor_image,
|
||||
num_input_tokens=(vision_config.image_size //
|
||||
vision_config.patch_size)**2,
|
||||
num_input_tokens=(vision_config.image_size // vision_config.patch_size)
|
||||
** 2,
|
||||
encoder_hidden_size=input_hidden_size,
|
||||
hidden_size=input_hidden_size,
|
||||
output_hidden_size=text_config.hidden_size,
|
||||
@@ -995,8 +1011,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return mm_projector
|
||||
|
||||
|
||||
def unpad_image(tensor: torch.Tensor,
|
||||
original_size: tuple[int, int]) -> torch.Tensor:
|
||||
def unpad_image(tensor: torch.Tensor, original_size: tuple[int, int]) -> torch.Tensor:
|
||||
original_width, original_height = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
@@ -1007,18 +1022,17 @@ def unpad_image(tensor: torch.Tensor,
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
def select_best_resolution(original_size: tuple,
|
||||
possible_resolutions: list) -> tuple:
|
||||
def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
|
||||
original_height, original_width = original_size
|
||||
best_fit = None
|
||||
max_effective_resolution = 0
|
||||
@@ -1026,15 +1040,19 @@ def select_best_resolution(original_size: tuple,
|
||||
|
||||
for height, width in possible_resolutions:
|
||||
scale = min(width / original_width, height / original_height)
|
||||
downscaled_width, downscaled_height = int(original_width * scale), int(
|
||||
original_height * scale)
|
||||
effective_resolution = min(downscaled_width * downscaled_height,
|
||||
original_width * original_height)
|
||||
downscaled_width, downscaled_height = (
|
||||
int(original_width * scale),
|
||||
int(original_height * scale),
|
||||
)
|
||||
effective_resolution = min(
|
||||
downscaled_width * downscaled_height, original_width * original_height
|
||||
)
|
||||
wasted_resolution = (width * height) - effective_resolution
|
||||
|
||||
if effective_resolution > max_effective_resolution or (
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution):
|
||||
effective_resolution == max_effective_resolution
|
||||
and wasted_resolution < min_wasted_resolution
|
||||
):
|
||||
max_effective_resolution = effective_resolution
|
||||
min_wasted_resolution = wasted_resolution
|
||||
best_fit = (height, width)
|
||||
@@ -1047,12 +1065,16 @@ def get_anyres_image_grid_shape(
|
||||
grid_pinpoints: Union[str, list[tuple[int, int]]],
|
||||
patch_size: int,
|
||||
) -> tuple[int, int]:
|
||||
possible_resolutions = grid_pinpoints if isinstance(
|
||||
grid_pinpoints, list) else ast.literal_eval(grid_pinpoints)
|
||||
possible_resolutions = (
|
||||
grid_pinpoints
|
||||
if isinstance(grid_pinpoints, list)
|
||||
else ast.literal_eval(grid_pinpoints)
|
||||
)
|
||||
|
||||
original_width, original_height = image_size
|
||||
height, width = select_best_resolution((original_height, original_width),
|
||||
possible_resolutions)
|
||||
height, width = select_best_resolution(
|
||||
(original_height, original_width), possible_resolutions
|
||||
)
|
||||
return width // patch_size, height // patch_size
|
||||
|
||||
|
||||
@@ -1070,12 +1092,15 @@ def reshape_and_unpad_image_features(
|
||||
image_feature = image_feature[1:]
|
||||
|
||||
assert height * width == base_image_feature.shape[0], (
|
||||
f"{height=} * {width=} != {base_image_feature.shape[0]=}")
|
||||
f"{height=} * {width=} != {base_image_feature.shape[0]=}"
|
||||
)
|
||||
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size, possible_resolutions, grid_size)
|
||||
image_feature = image_feature.view(num_patch_height, num_patch_width,
|
||||
height, width, -1)
|
||||
image_size, possible_resolutions, grid_size
|
||||
)
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
|
||||
if unpad:
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
@@ -1084,8 +1109,9 @@ def reshape_and_unpad_image_features(
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
image_newline[:, None, None].expand(
|
||||
*image_feature.shape[:-1], 1).to(image_feature.device),
|
||||
image_newline[:, None, None]
|
||||
.expand(*image_feature.shape[:-1], 1)
|
||||
.to(image_feature.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
@@ -1111,8 +1137,9 @@ def anyres_postprocessing(
|
||||
height = width = grid_size // patch_size
|
||||
|
||||
if num_queries_vis_abstractor > 0:
|
||||
assert (num_queries_vis_abstractor**0.5
|
||||
).is_integer(), "n_queries must be square number"
|
||||
assert (num_queries_vis_abstractor**0.5).is_integer(), (
|
||||
"n_queries must be square number"
|
||||
)
|
||||
height = width = int(num_queries_vis_abstractor**0.5)
|
||||
|
||||
# post-processing (unpad, add newline)
|
||||
@@ -1132,8 +1159,8 @@ def anyres_postprocessing(
|
||||
else:
|
||||
image_feature = image_feature[0]
|
||||
image_feature = torch.cat(
|
||||
(image_feature, image_newline[None].to(image_feature.device)),
|
||||
dim=0)
|
||||
(image_feature, image_newline[None].to(image_feature.device)), dim=0
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
|
||||
return new_image_features
|
||||
|
||||
Reference in New Issue
Block a user