support dynamic resolution image encoding for Nemotron Nano VL (#32121)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
@@ -282,12 +282,14 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_cls: type[InternParallelAttention] = InternParallelAttention,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
self.attn_cls = attn_cls
|
||||
|
||||
self.attn = self._init_attn(
|
||||
config,
|
||||
@@ -327,7 +329,7 @@ class InternVisionEncoderLayer(nn.Module):
|
||||
use_data_parallel = (
|
||||
use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
|
||||
)
|
||||
return InternParallelAttention(
|
||||
return self.attn_cls(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
@@ -356,10 +358,12 @@ class InternVisionEncoder(nn.Module):
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.layer_cls = layer_cls
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
@@ -368,7 +372,7 @@ class InternVisionEncoder(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
InternVisionEncoderLayer(
|
||||
self.layer_cls(
|
||||
config,
|
||||
quant_config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
|
||||
@@ -8,11 +8,15 @@
|
||||
# --------------------------------------------------------
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import einops
|
||||
import numpy.typing as npt
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -23,6 +27,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -39,7 +44,7 @@ from vllm.model_executor.models.internvl import (
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.radio import RadioModel
|
||||
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
|
||||
from vllm.model_executor.models.utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
@@ -78,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .utils import _merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
# Configure PIL to handle large images without warnings
|
||||
# This prevents DecompressionBombWarning for legitimate large images
|
||||
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
@@ -103,11 +109,25 @@ class NanoNemotronVLImagePixelInputs(TensorSchema):
|
||||
- w: Width of each image patch
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
class NanoNemotronVLImagePixelInputsDynamic(TensorSchema):
|
||||
"""
|
||||
Dynamic-resolution image inputs.
|
||||
|
||||
imgs_sizes: per-image (height, width) in pixels.
|
||||
num_tokens_per_image: per-image number of embedding tokens (post downsample).
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values_dynamic"] = "pixel_values_dynamic"
|
||||
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bn", "h", "w")]
|
||||
imgs_sizes: list[tuple[int, int]]
|
||||
num_tokens_per_image: list[int]
|
||||
|
||||
|
||||
class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
@@ -121,7 +141,9 @@ class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
|
||||
|
||||
|
||||
NanoNemotronVLImageInputs: TypeAlias = (
|
||||
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs
|
||||
NanoNemotronVLImagePixelInputs
|
||||
| NanoNemotronVLImagePixelInputsDynamic
|
||||
| NanoNemotronVLImageEmbeddingInputs
|
||||
)
|
||||
|
||||
|
||||
@@ -267,6 +289,329 @@ def calculate_timestamps(
|
||||
return timestamps
|
||||
|
||||
|
||||
class DynamicResolutionImageTiler:
|
||||
CONV_MERGING = False
|
||||
PIXEL_SHUFFLE = True
|
||||
USE_THUMBNAIL = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_model_len: int,
|
||||
patch_size: int,
|
||||
min_num_patches: int,
|
||||
max_num_patches: int,
|
||||
downsample_ratio: int,
|
||||
norm_mean: Sequence[float],
|
||||
norm_std: Sequence[float],
|
||||
factor_max: float = 1.0,
|
||||
use_thumbnail: bool = False,
|
||||
) -> None:
|
||||
assert use_thumbnail is False, "use_thumbnail is not supported"
|
||||
self._patch_size: int = patch_size
|
||||
self._max_model_len = max_model_len
|
||||
self._min_num_patches = min_num_patches
|
||||
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
|
||||
self._factor_max = factor_max
|
||||
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
|
||||
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
|
||||
self._transform = T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
T.ToTensor(),
|
||||
]
|
||||
)
|
||||
assert downsample_ratio < 1
|
||||
reduction_factor = 1 / downsample_ratio
|
||||
assert reduction_factor == 2.0
|
||||
self._downsample_ratio = int(reduction_factor) ** (
|
||||
self.PIXEL_SHUFFLE + self.CONV_MERGING
|
||||
)
|
||||
assert self._downsample_ratio == 2
|
||||
|
||||
def _get_num_embeddings(self, width: int, height: int) -> int:
|
||||
num_patches = (width // self._patch_size) * (height // self._patch_size)
|
||||
num_tokens = num_patches // (self._downsample_ratio**2)
|
||||
return num_tokens
|
||||
|
||||
def width_and_height_for_max_num_tokens_available(
|
||||
self,
|
||||
target_num_tokens_post_shuffle: int,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
TODO: optimize this so it squeezes closer to target number of tokens.
|
||||
Calculate image dimensions that produce approximately `target` tokens after
|
||||
pixel_shuffle.
|
||||
|
||||
With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we
|
||||
need 4*B patches to get B tokens.
|
||||
|
||||
Examples:
|
||||
>>> PATCH_SIZE = 16
|
||||
>>> DOWNSAMPLE_RATIO = 0.5
|
||||
>>> tiler = DynamicResolutionImageTiler(
|
||||
... max_model_len=16384,
|
||||
... patch_size=PATCH_SIZE,
|
||||
... downsample_ratio=DOWNSAMPLE_RATIO,
|
||||
... min_num_patches=4,
|
||||
... max_num_patches=0,
|
||||
... )
|
||||
>>> width, height = tiler.width_and_height_for_max_num_tokens_available(
|
||||
... target_num_tokens_post_shuffle=8192,
|
||||
... )
|
||||
>>> assert width, height == (2880, 2880)
|
||||
>>> assert (width // PATCH_SIZE) * (
|
||||
... height // PATCH_SIZE
|
||||
... ) // 2**2 == 8100 # tokens post-shuffle
|
||||
>>> assert tiler._get_num_embeddings(width=width, height=height) == 8100
|
||||
"""
|
||||
side_pixels = (
|
||||
math.isqrt(target_num_tokens_post_shuffle)
|
||||
* self._downsample_ratio
|
||||
* self._patch_size
|
||||
)
|
||||
assert isinstance(side_pixels, int) and side_pixels % self._patch_size == 0
|
||||
return side_pixels, side_pixels
|
||||
|
||||
def max_num_tokens_available(self, text_prompt_length: int) -> int:
|
||||
return self._max_model_len - text_prompt_length - 4
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
text_prompt_length: int,
|
||||
images: list[Image.Image],
|
||||
) -> tuple[list[torch.Tensor], list[int]]:
|
||||
num_tokens_available = self.max_num_tokens_available(text_prompt_length)
|
||||
params_per_image = self.compute_params(images, num_tokens_available)
|
||||
|
||||
feature_sizes = []
|
||||
images = []
|
||||
for param in params_per_image:
|
||||
for t in self.apply_params(param):
|
||||
assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor"
|
||||
images.append(t)
|
||||
feature_sizes.append(param.num_embeddings)
|
||||
return images, feature_sizes
|
||||
|
||||
feature_size_cache: dict[Image.Image, int] = {}
|
||||
|
||||
@classmethod
|
||||
def get_cached_feature_size(cls, image: Image.Image) -> int:
|
||||
feature_size = cls.feature_size_cache[id(image)]
|
||||
# hard assert that we only use the feature size once
|
||||
del cls.feature_size_cache[id(image)]
|
||||
return feature_size
|
||||
|
||||
@dataclass
|
||||
class DynamicResolutionParams:
|
||||
media: Image.Image
|
||||
num_tiles: int
|
||||
num_embeddings: int
|
||||
patch_size: tuple[int, int]
|
||||
|
||||
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
|
||||
resized_img = params.media.resize(
|
||||
(
|
||||
params.patch_size[0] * self._patch_size,
|
||||
params.patch_size[1] * self._patch_size,
|
||||
)
|
||||
)
|
||||
processed_images = [resized_img]
|
||||
|
||||
return [self._transform(img) for img in processed_images]
|
||||
|
||||
def process_media(
|
||||
self,
|
||||
media: Image.Image,
|
||||
num_tokens_available: int,
|
||||
) -> tuple[DynamicResolutionParams, int]:
|
||||
"""Process a single media item and return its parameters.
|
||||
|
||||
Args:
|
||||
media: The media item to process
|
||||
num_tokens_available: Number of tokens available for this media
|
||||
Returns:
|
||||
DynamicResolutionParams for the media
|
||||
"""
|
||||
current_num_tokens_available = num_tokens_available
|
||||
assert isinstance(media, Image.Image), (
|
||||
"Dynamic resolution is only supported for image media"
|
||||
)
|
||||
orig_width, orig_height = media.width, media.height
|
||||
closest_patch_height = round(orig_height / self._patch_size + 0.5)
|
||||
closest_patch_width = round(orig_width / self._patch_size + 0.5)
|
||||
patches = closest_patch_height * closest_patch_width
|
||||
|
||||
factor = min(
|
||||
math.sqrt(current_num_tokens_available / patches), self._factor_max
|
||||
)
|
||||
target_patch_height = math.floor(factor * closest_patch_height)
|
||||
target_patch_width = math.floor(factor * closest_patch_width)
|
||||
|
||||
# Consider self._min_num_patches if > current_num_tokens_available.
|
||||
if (
|
||||
current_num_tokens_available > self._min_num_patches
|
||||
and target_patch_height * target_patch_width < self._min_num_patches
|
||||
):
|
||||
up_factor = math.sqrt(
|
||||
self._min_num_patches / (target_patch_height * target_patch_width)
|
||||
)
|
||||
target_patch_height = math.ceil(up_factor * target_patch_height)
|
||||
target_patch_width = math.ceil(up_factor * target_patch_width)
|
||||
|
||||
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
|
||||
# or by 4 when BOTH are enabled (two successive 2x reductions)
|
||||
if self.PIXEL_SHUFFLE or self.CONV_MERGING:
|
||||
required_divisor = 4 if (self.PIXEL_SHUFFLE and self.CONV_MERGING) else 2
|
||||
|
||||
rem_h = target_patch_height % required_divisor
|
||||
if rem_h != 0:
|
||||
inc_h = required_divisor - rem_h
|
||||
if (
|
||||
target_patch_height + inc_h
|
||||
) * target_patch_width <= current_num_tokens_available:
|
||||
target_patch_height += inc_h
|
||||
else:
|
||||
target_patch_height = max(
|
||||
required_divisor, target_patch_height - rem_h
|
||||
)
|
||||
|
||||
rem_w = target_patch_width % required_divisor
|
||||
if rem_w != 0:
|
||||
inc_w = required_divisor - rem_w
|
||||
if (
|
||||
target_patch_height * (target_patch_width + inc_w)
|
||||
<= current_num_tokens_available
|
||||
):
|
||||
target_patch_width += inc_w
|
||||
else:
|
||||
target_patch_width = max(
|
||||
required_divisor, target_patch_width - rem_w
|
||||
)
|
||||
|
||||
# Calculate embeddings for the main dynamic resolution image
|
||||
num_embeddings = self._get_num_embeddings(
|
||||
target_patch_width * self._patch_size,
|
||||
target_patch_height * self._patch_size,
|
||||
)
|
||||
|
||||
token_count = target_patch_width * target_patch_height
|
||||
|
||||
# Add thumbnail embeddings if enabled and image area is below threshold
|
||||
num_tiles = 1 # Base dynamic resolution image
|
||||
|
||||
return self.DynamicResolutionParams(
|
||||
media=media,
|
||||
num_tiles=num_tiles,
|
||||
num_embeddings=num_embeddings,
|
||||
patch_size=(target_patch_width, target_patch_height),
|
||||
), token_count
|
||||
|
||||
def compute_params(
|
||||
self,
|
||||
media_list: list[Image.Image],
|
||||
num_tokens_available: int | None = None,
|
||||
) -> list[DynamicResolutionParams]:
|
||||
"""Compute parameters for all media with iterative token budgeting.
|
||||
|
||||
Args:
|
||||
media_list: List of media items to process
|
||||
num_tokens_available: Total number of tokens available across all media
|
||||
Returns:
|
||||
List of ImageTilingParams for each media item
|
||||
"""
|
||||
num_tokens_available = (
|
||||
num_tokens_available
|
||||
* (4 if self.PIXEL_SHUFFLE else 1)
|
||||
* (4 if self.CONV_MERGING else 1)
|
||||
)
|
||||
# When the number of available token is too small,
|
||||
# allow self._min_num_patches per media and let the sample be truncated.
|
||||
num_tokens_available = max(
|
||||
num_tokens_available, self._min_num_patches * len(media_list)
|
||||
)
|
||||
|
||||
# Clip the number of tokens available per media to >min and <max patches.
|
||||
num_tokens_available_per_media = [
|
||||
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
|
||||
for _ in range(len(media_list))
|
||||
]
|
||||
|
||||
# prevent infinite loop in any case
|
||||
for _ in range(10):
|
||||
# Step 1: Process each media with current token budget
|
||||
params = []
|
||||
token_counts = []
|
||||
|
||||
for media, tokens_for_media in zip(
|
||||
media_list, num_tokens_available_per_media
|
||||
):
|
||||
param, token_count = self.process_media(media, tokens_for_media)
|
||||
params.append(param)
|
||||
token_counts.append(token_count)
|
||||
self.feature_size_cache[id(param.media)] = param.num_embeddings
|
||||
|
||||
# Step 2: Check if total tokens is within budget
|
||||
total_tokens = sum(token_counts)
|
||||
|
||||
if total_tokens <= num_tokens_available:
|
||||
# We're within budget, return the params
|
||||
return params
|
||||
|
||||
# Step 3: We're over budget, need to scale down
|
||||
# Calculate scaling factor to get under budget
|
||||
scaling_factor = num_tokens_available / total_tokens
|
||||
|
||||
# Recalculate token budgets for each media based on scaling
|
||||
# Each media gets a proportional share of the total budget
|
||||
scaled_down_num_tokens_available_per_media = [
|
||||
max(self._min_num_patches, int(token_count * scaling_factor))
|
||||
for token_count in token_counts
|
||||
]
|
||||
scaled_down = any(
|
||||
[
|
||||
scaled_down_num_tokens_available_per_media[i]
|
||||
< num_tokens_available_per_media[i]
|
||||
for i in range(len(num_tokens_available_per_media))
|
||||
]
|
||||
)
|
||||
# If there wasn't scaling down, we're stuck with min_num_patches per media,
|
||||
# else try with the scaled down num_tokens_available_per_media.
|
||||
if not scaled_down:
|
||||
num_tokens_available_per_media = [self._min_num_patches] * len(
|
||||
media_list
|
||||
)
|
||||
else:
|
||||
num_tokens_available_per_media = (
|
||||
scaled_down_num_tokens_available_per_media
|
||||
)
|
||||
ctx = f"{params=} {total_tokens=} {num_tokens_available=}"
|
||||
raise ValueError(
|
||||
f"Should be unreachable - `return params` above must be reached: {ctx}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def stack(images: list[torch.Tensor], patch_size: int) -> torch.Tensor:
|
||||
assert len(images) > 0, "No images to stack"
|
||||
|
||||
def rearrange_img(x):
|
||||
py = x.shape[-2] // patch_size
|
||||
px = x.shape[-1] // patch_size
|
||||
x = einops.rearrange(
|
||||
x,
|
||||
"c (py yy) (px xx) -> (py px) (c yy xx)",
|
||||
py=py,
|
||||
yy=patch_size,
|
||||
px=px,
|
||||
xx=patch_size,
|
||||
)
|
||||
return x
|
||||
|
||||
imgs = [rearrange_img(img) for img in images]
|
||||
pixel_values_flat = torch.cat(imgs, dim=0).unsqueeze(0)
|
||||
return pixel_values_flat
|
||||
|
||||
|
||||
class BaseNanoNemotronVLProcessor(ABC):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
@@ -281,6 +626,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
*args,
|
||||
max_model_len: int,
|
||||
max_num_tiles: int | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -292,15 +638,32 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
|
||||
image_size: int = config.force_image_size
|
||||
patch_size: int = config.patch_size
|
||||
downsample_ratio: int = config.downsample_ratio
|
||||
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
(image_size // patch_size) ** 2 * (downsample_ratio**2)
|
||||
)
|
||||
self.image_size = image_size
|
||||
self.use_thumbnail: bool = config.use_thumbnail
|
||||
self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
|
||||
self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
|
||||
|
||||
self.dynamic_tiler: DynamicResolutionImageTiler | None = None
|
||||
if self.use_dynamic_resolution(config):
|
||||
self.dynamic_tiler = DynamicResolutionImageTiler(
|
||||
max_model_len=max_model_len,
|
||||
patch_size=patch_size,
|
||||
downsample_ratio=downsample_ratio,
|
||||
min_num_patches=config.vision_config.args["min_num_patches"],
|
||||
max_num_patches=config.vision_config.args["max_num_patches"],
|
||||
norm_mean=config.norm_mean,
|
||||
norm_std=config.norm_std,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def use_dynamic_resolution(config: PretrainedConfig) -> bool:
|
||||
return "min_num_patches" in config.vision_config.args
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def image_token_id(self) -> int:
|
||||
@@ -354,36 +717,61 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
text: list[str],
|
||||
images: list[Image.Image],
|
||||
max_num_tiles: int,
|
||||
) -> tuple[list[str], dict[str, torch.Tensor]]:
|
||||
) -> tuple[list[str], dict[str, Any]]:
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
return text, image_inputs
|
||||
|
||||
if tiler := self.dynamic_tiler:
|
||||
sans_images = text[0].replace("<image>", "")
|
||||
text_prompt_length = len(
|
||||
self.tokenizer(sans_images, add_special_tokens=False).input_ids
|
||||
)
|
||||
pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst(
|
||||
text_prompt_length=text_prompt_length,
|
||||
images=images,
|
||||
)
|
||||
imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst]
|
||||
normalized = [
|
||||
input_conditioner(img, tiler.norm_mean, tiler.norm_std)
|
||||
for img in pixel_values_lst
|
||||
]
|
||||
image_num_patches = torch.tensor([1] * len(num_tokens_per_image))
|
||||
image_inputs = {
|
||||
"pixel_values_flat": normalized,
|
||||
"imgs_sizes": imgs_sizes,
|
||||
"num_tokens_per_image": num_tokens_per_image,
|
||||
}
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
|
||||
image_num_patches = torch.tensor([len(item) for item in pixel_values_lst])
|
||||
pixel_values_flat = input_conditioner(
|
||||
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
|
||||
)
|
||||
image_inputs = {
|
||||
"pixel_values_flat": input_conditioner(
|
||||
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
|
||||
),
|
||||
"image_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst]
|
||||
),
|
||||
"pixel_values_flat": pixel_values_flat,
|
||||
"image_num_patches": image_num_patches,
|
||||
}
|
||||
num_tokens_per_image = [
|
||||
self.num_image_token * len(item) for item in pixel_values_lst
|
||||
]
|
||||
|
||||
assert len(text) == 1, (
|
||||
"hf_processor is called on the output of get_dummy_text, "
|
||||
"which should be a single string"
|
||||
)
|
||||
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
|
||||
assert parts.count("<image>") == len(pixel_values_lst), (
|
||||
"the number of <image> tokens in the text should be the "
|
||||
"same as the number of images"
|
||||
)
|
||||
assert len(text) == 1, (
|
||||
"hf_processor is called on the output of get_dummy_text, "
|
||||
"which should be a single string"
|
||||
)
|
||||
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
|
||||
assert parts.count("<image>") == len(pixel_values_lst), (
|
||||
"the number of <image> tokens in the text should be the "
|
||||
"same as the number of images"
|
||||
)
|
||||
|
||||
for i, pixel_values in enumerate(pixel_values_lst):
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
parts[i] = parts[i].replace("<image>", image_repl.full)
|
||||
text = ["".join(parts)]
|
||||
for i, (feature_size, num_patches) in enumerate(
|
||||
zip(num_tokens_per_image, image_num_patches, strict=True)
|
||||
):
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
parts[i] = parts[i].replace("<image>", image_repl.full)
|
||||
text = ["".join(parts)]
|
||||
return text, image_inputs
|
||||
|
||||
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
||||
@@ -393,6 +781,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
input_item = [input_item]
|
||||
return input_item
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | None = None,
|
||||
@@ -400,23 +789,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
return_tensors: str | TensorType | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
) -> BatchFeature:
|
||||
# Use default if not provided
|
||||
if max_num_tiles is None:
|
||||
max_num_tiles = self.max_num_tiles
|
||||
|
||||
text, images = [self._make_batch_input(x) for x in (text, images)]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
text=text,
|
||||
images=images,
|
||||
max_num_tiles=max_num_tiles,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
combined_outputs = {**text_inputs, **image_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
@@ -431,20 +804,16 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
max_model_len: int,
|
||||
max_num_tiles: int | None = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
video_token: str | None = None,
|
||||
video_pruning_rate: float | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
max_model_len=max_model_len,
|
||||
max_num_tiles=max_num_tiles,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
# add extra video token for video processing
|
||||
self.video_token = video_token
|
||||
@@ -478,7 +847,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self,
|
||||
videos: list[npt.NDArray],
|
||||
max_num_tiles: int,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
return [
|
||||
video_to_pixel_values(
|
||||
@@ -495,7 +863,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
text: list[str],
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]],
|
||||
max_num_tiles: int,
|
||||
dynamic_image_size: bool | None = None,
|
||||
):
|
||||
if len(videos) == 0 or not self.supports_video:
|
||||
video_inputs = {}
|
||||
@@ -505,7 +872,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
||||
videos_lst,
|
||||
max_num_tiles=max_num_tiles,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
# We use frame duration in milliseconds (as integer) to ensure
|
||||
@@ -592,7 +958,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> BatchFeature:
|
||||
# Use default if not provided
|
||||
if max_num_tiles is None:
|
||||
@@ -612,14 +977,23 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
text=text,
|
||||
videos=videos,
|
||||
max_num_tiles=1,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
if self.dynamic_tiler is None:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs, **image_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
else:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs}, tensor_type=return_tensors
|
||||
)
|
||||
# allow images to be exempt from the BatchFeature validation:
|
||||
# We will .stack() them in _parse_and_validate_image_input
|
||||
batch.update(image_inputs)
|
||||
return batch
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
@@ -722,23 +1096,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
max_num_tiles: int,
|
||||
processor: BaseNanoNemotronVLProcessor | None,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
return processor.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
max_num_tiles=max_num_tiles,
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
@@ -749,11 +1106,8 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
for wr, hr in target_ratios:
|
||||
width, height = base_size * wr, base_size * hr
|
||||
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
max_num_tiles=max_num_tiles,
|
||||
processor=processor,
|
||||
feat_size = processor.get_num_image_tokens(
|
||||
image_width=width, image_height=height, max_num_tiles=max_num_tiles
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
@@ -772,11 +1126,10 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
|
||||
max_num_tiles
|
||||
)
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
return processor.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
max_num_tiles=max_num_tiles,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
|
||||
@@ -822,6 +1175,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
tokenizer=self.get_tokenizer(),
|
||||
video_token=self.get_video_token(),
|
||||
video_pruning_rate=self.get_video_pruning_rate(),
|
||||
max_model_len=self.ctx.model_config.max_model_len,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -829,19 +1183,29 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"""Basic image-only MultiModalProcessor for InternVL-style models."""
|
||||
|
||||
@cached_property
|
||||
def is_dynamic_tiler(self) -> bool:
|
||||
return self.info.get_hf_processor().dynamic_tiler is not None
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
if self.is_dynamic_tiler:
|
||||
pixel_values_flat = MultiModalFieldConfig.batched("image")
|
||||
else:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
pixel_values_flat = MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches
|
||||
)
|
||||
|
||||
return dict(
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches
|
||||
),
|
||||
pixel_values_flat=pixel_values_flat,
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
num_tokens_per_image=MultiModalFieldConfig.batched("image"),
|
||||
imgs_sizes=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -870,17 +1234,19 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
elif tiler := hf_processor.dynamic_tiler:
|
||||
image = images.get(item_idx)
|
||||
feature_size = tiler.get_cached_feature_size(image)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
# Extract max_num_tiles from kwargs, default to 12
|
||||
max_num_tiles = hf_processor_mm_kwargs.get(
|
||||
"max_num_tiles", hf_processor.max_num_tiles
|
||||
)
|
||||
feature_size = self.info.get_num_image_tokens(
|
||||
feature_size = hf_processor.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
max_num_tiles=max_num_tiles,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
num_patches = None
|
||||
@@ -1017,12 +1383,18 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
# Use default max_num_tiles for dummy data generation
|
||||
max_num_tiles = 12
|
||||
target_width, target_height = self.info.get_image_size_with_most_features(
|
||||
max_num_tiles
|
||||
)
|
||||
num_images = mm_counts.get("image", 0)
|
||||
processor = self.info.get_hf_processor()
|
||||
if tiler := processor.dynamic_tiler:
|
||||
budget = tiler.max_num_tokens_available(text_prompt_length=num_images)
|
||||
target_width, target_height = (
|
||||
tiler.width_and_height_for_max_num_tokens_available(budget)
|
||||
)
|
||||
else:
|
||||
max_num_tiles = 12
|
||||
target_width, target_height = self.info.get_image_size_with_most_features(
|
||||
max_num_tiles
|
||||
)
|
||||
|
||||
image_overrides = mm_options.get("image") if mm_options else None
|
||||
|
||||
@@ -1181,6 +1553,11 @@ class NemotronH_Nano_VL_V2(
|
||||
self._img_context_token_ids = tokenizer.encode(
|
||||
IMG_CONTEXT, add_special_tokens=False
|
||||
)
|
||||
self.dynamic_resolution = BaseNanoNemotronVLProcessor.use_dynamic_resolution(
|
||||
config
|
||||
)
|
||||
if self.dynamic_resolution:
|
||||
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
@@ -1211,7 +1588,51 @@ class NemotronH_Nano_VL_V2(
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
def pixel_shuffle_dynamic_res(
|
||||
self, x: torch.Tensor, *, imgs_sizes: list[tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
scale_factor = self.downsample_ratio
|
||||
patch_dim = self.patch_size
|
||||
seq_lens = calc_seq_lens(imgs_sizes, patch_dim)
|
||||
splits = torch.split(x, seq_lens, dim=-2)
|
||||
out = []
|
||||
for i, sv in enumerate(splits):
|
||||
h = imgs_sizes[i][0] // patch_dim
|
||||
w = imgs_sizes[i][1] // patch_dim
|
||||
sv = sv.reshape(sv.shape[0], h, w, -1)
|
||||
|
||||
n, h, w, c = sv.size()
|
||||
|
||||
sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
|
||||
sv = sv.permute(0, 2, 1, 3).contiguous()
|
||||
sv = sv.view(
|
||||
n,
|
||||
int(w * scale_factor),
|
||||
int(h * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)),
|
||||
)
|
||||
|
||||
if self.ps_version == "v2":
|
||||
sv = sv.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
|
||||
out.append(sv)
|
||||
|
||||
x = torch.cat(out, dim=-2)
|
||||
|
||||
return x
|
||||
|
||||
def extract_feature_dynamic(
|
||||
self, pixel_values: torch.Tensor, imgs_sizes: list[tuple[int, int]]
|
||||
):
|
||||
"""Dynamic resolution extract_feature for images."""
|
||||
_, vit_embeds = self.vision_model(pixel_values, imgs_sizes=imgs_sizes)
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
vit_embeds = self.pixel_shuffle_dynamic_res(vit_embeds, imgs_sizes=imgs_sizes)
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def extract_feature(self, pixel_values: torch.Tensor):
|
||||
# Process images in a micro-batch of at most 128 frames per call
|
||||
# This is done on purpose to ensure peak GPU ram usage of huge batch
|
||||
# (namely for really long videos with EVS ON) won't cause any problems
|
||||
@@ -1239,36 +1660,39 @@ class NemotronH_Nano_VL_V2(
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> NanoNemotronVLImageInputs | None:
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if image_embeds := kwargs.pop("image_embeds", None):
|
||||
return NanoNemotronVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
return NanoNemotronVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=pixel_values_flat,
|
||||
num_patches=image_num_patches,
|
||||
if self.dynamic_resolution:
|
||||
pixel_values_flat = DynamicResolutionImageTiler.stack(
|
||||
kwargs.pop("pixel_values_flat"), self.patch_size
|
||||
)
|
||||
return NanoNemotronVLImagePixelInputsDynamic(
|
||||
pixel_values_flat=pixel_values_flat, **kwargs
|
||||
)
|
||||
else:
|
||||
return NanoNemotronVLImagePixelInputs(**kwargs)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
def _process_image_input_dynamic(
|
||||
self, image_input: NanoNemotronVLImagePixelInputsDynamic
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
image_embeds = self.extract_feature_dynamic(
|
||||
image_input.pixel_values_flat, image_input.imgs_sizes
|
||||
)
|
||||
num_tokens_per_image = image_input.num_tokens_per_image
|
||||
|
||||
if len(num_tokens_per_image) == 1:
|
||||
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
|
||||
|
||||
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
|
||||
return image_embeds.split(num_tokens_per_image)
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: NanoNemotronVLImageInputs
|
||||
self, image_input: NanoNemotronVLImagePixelInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
|
||||
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
||||
num_patches = image_input["num_patches"]
|
||||
|
||||
@@ -1470,7 +1894,13 @@ class NemotronH_Nano_VL_V2(
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
image_embeddings = self._process_image_input(image_input)
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_embeddings = image_input["data"]
|
||||
elif self.dynamic_resolution:
|
||||
assert image_input["type"] == "pixel_values_dynamic"
|
||||
image_embeddings = self._process_image_input_dynamic(image_input)
|
||||
else:
|
||||
image_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += tuple(image_embeddings)
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
@@ -1652,33 +2082,6 @@ class NemotronH_Nano_VL_V2(
|
||||
if save_to_file and sys.stdout != original_stdout:
|
||||
sys.stdout = original_stdout
|
||||
|
||||
def get_model_info(self):
|
||||
"""
|
||||
Get basic model information as a dictionary.
|
||||
"""
|
||||
total_params = sum(p.numel() for p in self.parameters())
|
||||
|
||||
component_info = {}
|
||||
for name, param in self.named_parameters():
|
||||
component = name.split(".")[0]
|
||||
if component not in component_info:
|
||||
component_info[component] = {"params": 0, "size": 0}
|
||||
component_info[component]["params"] += 1
|
||||
component_info[component]["size"] += param.numel()
|
||||
|
||||
return {
|
||||
"model_name": "NemotronH_Nano_VL_V2",
|
||||
"total_parameters": total_params,
|
||||
"memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16
|
||||
"components": component_info,
|
||||
"config": {
|
||||
"image_size": getattr(self.config, "force_image_size", None),
|
||||
"patch_size": getattr(self.config, "patch_size", None),
|
||||
"num_image_token": self.num_image_token,
|
||||
"downsample_ratio": self.downsample_ratio,
|
||||
},
|
||||
}
|
||||
|
||||
def get_vit_model_from_radio_config(self, hf_config):
|
||||
hf_config_vision = hf_config.vision_config
|
||||
model_name = hf_config_vision.args.get("model")
|
||||
|
||||
@@ -21,7 +21,11 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.intern_vit import InternVisionEncoder
|
||||
from vllm.model_executor.models.intern_vit import (
|
||||
InternParallelAttention,
|
||||
InternVisionEncoder,
|
||||
InternVisionEncoderLayer,
|
||||
)
|
||||
|
||||
input_dim_t: TypeAlias = int | tuple[int, int]
|
||||
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
|
||||
@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
|
||||
h, w = size
|
||||
return (h // patch_size) * (w // patch_size)
|
||||
|
||||
|
||||
def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
|
||||
return [calc_seq_len(size, patch_size) for size in sizes]
|
||||
|
||||
|
||||
class ClsToken(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module):
|
||||
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
patches = self.embed_patches(x)
|
||||
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
|
||||
patches = self.cls_token(patches)
|
||||
def forward(
|
||||
self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
|
||||
) -> torch.Tensor:
|
||||
if imgs_sizes is not None:
|
||||
patches = self.embedder(x)
|
||||
patches, pos_enc = self.apply_pos_enc_dynamic(
|
||||
patches, imgs_sizes=imgs_sizes
|
||||
)
|
||||
patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
|
||||
else:
|
||||
patches = self.embed_patches(x)
|
||||
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
|
||||
patches = self.cls_token(patches)
|
||||
patches = self.patch_normalizer(patches)
|
||||
if self.return_pos_enc:
|
||||
return patches, pos_enc
|
||||
return patches
|
||||
|
||||
def apply_pos_enc_dynamic(
|
||||
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
if not self.abs_pos:
|
||||
return patches, None
|
||||
|
||||
current_length = 0
|
||||
pos_enc_list = []
|
||||
|
||||
for size in imgs_sizes:
|
||||
seq_length = calc_seq_len(size, self.patch_size)
|
||||
|
||||
img_patches = patches[:, current_length : current_length + seq_length, :]
|
||||
pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
|
||||
img_patches_with_pos = img_patches + pos_enc
|
||||
|
||||
patches = torch.cat(
|
||||
[
|
||||
patches[:, :current_length, :],
|
||||
img_patches_with_pos,
|
||||
patches[:, current_length + seq_length :, :],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
pos_enc_list.append(pos_enc)
|
||||
current_length += seq_length
|
||||
|
||||
full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
|
||||
return patches, full_pos_enc
|
||||
|
||||
def cls_token_dynamic(
|
||||
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
|
||||
) -> torch.Tensor:
|
||||
if not self.cls_token.enabled:
|
||||
return patches
|
||||
|
||||
out = []
|
||||
current_length = 0
|
||||
|
||||
for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
|
||||
class_token = self.cls_token.token.unsqueeze(0).expand(
|
||||
patches.shape[0], -1, -1
|
||||
)
|
||||
out.append(class_token)
|
||||
out.append(patches[:, current_length : current_length + seq_len, :])
|
||||
current_length += seq_len
|
||||
|
||||
return torch.cat(out, dim=1)
|
||||
|
||||
@property
|
||||
def apply_cls_token(self):
|
||||
return self.cls_token.enabled
|
||||
@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear):
|
||||
self.patch_size = patch_size
|
||||
|
||||
|
||||
class RadioParallelAttention(InternParallelAttention):
|
||||
def forward(
|
||||
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is None:
|
||||
return super().forward(x)
|
||||
|
||||
B, N, _ = x.shape
|
||||
qkv, _ = self.qkv(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
|
||||
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
|
||||
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
|
||||
out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, scale=self.scale
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(B, N, -1)
|
||||
out, _ = self.proj(out)
|
||||
return out
|
||||
|
||||
|
||||
class RadioVisionEncoderLayer(InternVisionEncoderLayer):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
hidden_states = (
|
||||
hidden_states
|
||||
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RadioVisionEncoder(InternVisionEncoder):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RadioInternVisionModel(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"qkv": ["qkv"],
|
||||
@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module):
|
||||
register_multiple=config.register_multiple,
|
||||
)
|
||||
|
||||
self.encoder = InternVisionEncoder(
|
||||
self.encoder = RadioVisionEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module):
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.FloatTensor:
|
||||
def create_inter_image_attention_mask(
|
||||
self, imgs_sizes: list[tuple[int, int]], device: torch.device
|
||||
) -> torch.Tensor:
|
||||
patch_size = self.patch_generator.patch_size
|
||||
num_skip = self.patch_generator.num_skip
|
||||
|
||||
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
|
||||
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
|
||||
total_patches = sum(patch_counts)
|
||||
|
||||
# Create attention mask - default to False (mask out)
|
||||
mask = torch.zeros(
|
||||
total_patches, total_patches, dtype=torch.bool, device=device
|
||||
)
|
||||
|
||||
# Each image's patches can only attend to patches from the same image
|
||||
start_idx = 0
|
||||
for patch_count in patch_counts:
|
||||
end_idx = start_idx + patch_count
|
||||
# Allow attention within this image's patches
|
||||
mask[start_idx:end_idx, start_idx:end_idx] = True
|
||||
start_idx = end_idx
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
assert self.patch_generator is not None
|
||||
hidden_states = self.patch_generator(x)
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
|
||||
attn_mask = None
|
||||
if imgs_sizes is not None and len(imgs_sizes) > 1:
|
||||
# Dynamic Resolution
|
||||
attn_mask = self.create_inter_image_attention_mask(
|
||||
imgs_sizes, device=x.device
|
||||
)
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@@ -504,9 +670,11 @@ class RadioModel(nn.Module):
|
||||
self,
|
||||
pixel_values: torch.Tensor | None = None,
|
||||
pixel_embeds: torch.Tensor | None = None,
|
||||
*,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
y = self.model(pixel_values)
|
||||
return self._extract_final(y)
|
||||
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
|
||||
return self._extract_final(y, imgs_sizes=imgs_sizes)
|
||||
|
||||
def load_weights(self, weights) -> set[str]:
|
||||
loaded_params: set[str] = set()
|
||||
@@ -558,16 +726,32 @@ class RadioModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
def _extract_final(
|
||||
self, y: torch.Tensor
|
||||
self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# Remove CLS + REGISTERS tokens
|
||||
patch_gen = getattr(self.model, "patch_generator", None)
|
||||
if patch_gen is not None:
|
||||
all_summary = y[:, : patch_gen.num_cls_tokens]
|
||||
if self.summary_idxs is not None:
|
||||
bb_summary = all_summary[:, self.summary_idxs]
|
||||
else:
|
||||
bb_summary = all_summary
|
||||
all_feat = y[:, patch_gen.num_skip :]
|
||||
num_skip = self.model.patch_generator.num_skip
|
||||
patch_size = self.model.patch_generator.patch_size
|
||||
num_cls_tokens = self.model.patch_generator.num_cls_tokens
|
||||
if imgs_sizes is None:
|
||||
all_summary = y[:, :num_cls_tokens]
|
||||
all_feat = y[:, num_skip:]
|
||||
else:
|
||||
all_patches = []
|
||||
summaries = []
|
||||
current_pos = 0
|
||||
for num_patches in calc_seq_lens(imgs_sizes, patch_size):
|
||||
patches = y[
|
||||
:, current_pos + num_skip : current_pos + num_skip + num_patches, :
|
||||
]
|
||||
all_patches.append(patches)
|
||||
summary = y[:, current_pos : current_pos + num_cls_tokens, :]
|
||||
summaries.append(summary)
|
||||
current_pos += num_skip + num_patches
|
||||
all_summary = torch.cat(summaries, dim=1)
|
||||
all_feat = torch.cat(all_patches, dim=1)
|
||||
|
||||
if self.summary_idxs is not None:
|
||||
bb_summary = all_summary[:, self.summary_idxs]
|
||||
else:
|
||||
bb_summary = all_summary
|
||||
return bb_summary.flatten(1), all_feat
|
||||
|
||||
Reference in New Issue
Block a user