support dynamic resolution image encoding for Nemotron Nano VL (#32121)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Netanel Haber
2026-01-19 20:15:58 +02:00
committed by GitHub
parent 2636d76257
commit cd3ac5b797
3 changed files with 754 additions and 163 deletions

View File

@@ -282,12 +282,14 @@ class InternVisionEncoderLayer(nn.Module):
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
attn_cls: type[InternParallelAttention] = InternParallelAttention,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn_cls = attn_cls
self.attn = self._init_attn(
config,
@@ -327,7 +329,7 @@ class InternVisionEncoderLayer(nn.Module):
use_data_parallel = (
use_data_parallel or (num_heads + num_dummy_heads) % tp_size != 0
)
return InternParallelAttention(
return self.attn_cls(
config,
quant_config=quant_config,
num_dummy_heads=num_dummy_heads,
@@ -356,10 +358,12 @@ class InternVisionEncoder(nn.Module):
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
layer_cls: type[InternVisionEncoderLayer] = InternVisionEncoderLayer,
):
super().__init__()
self.config = config
self.layer_cls = layer_cls
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
@@ -368,7 +372,7 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList(
[
InternVisionEncoderLayer(
self.layer_cls(
config,
quant_config,
num_dummy_heads=num_dummy_heads,

View File

@@ -8,11 +8,15 @@
# --------------------------------------------------------
import copy
import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import einops
import numpy.typing as npt
import regex as re
import torch
@@ -23,6 +27,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -39,7 +44,7 @@ from vllm.model_executor.models.internvl import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
maybe_prefix,
@@ -78,6 +83,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import _merge_multimodal_embeddings
logger = init_logger(__name__)
# Configure PIL to handle large images without warnings
# This prevents DecompressionBombWarning for legitimate large images
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
@@ -103,11 +109,25 @@ class NanoNemotronVLImagePixelInputs(TensorSchema):
- w: Width of each image patch
"""
type: Literal["pixel_values"]
type: Literal["pixel_values"] = "pixel_values"
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class NanoNemotronVLImagePixelInputsDynamic(TensorSchema):
"""
Dynamic-resolution image inputs.
imgs_sizes: per-image (height, width) in pixels.
num_tokens_per_image: per-image number of embedding tokens (post downsample).
"""
type: Literal["pixel_values_dynamic"] = "pixel_values_dynamic"
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bn", "h", "w")]
imgs_sizes: list[tuple[int, int]]
num_tokens_per_image: list[int]
class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
@@ -121,7 +141,9 @@ class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
NanoNemotronVLImageInputs: TypeAlias = (
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs
NanoNemotronVLImagePixelInputs
| NanoNemotronVLImagePixelInputsDynamic
| NanoNemotronVLImageEmbeddingInputs
)
@@ -267,6 +289,329 @@ def calculate_timestamps(
return timestamps
class DynamicResolutionImageTiler:
CONV_MERGING = False
PIXEL_SHUFFLE = True
USE_THUMBNAIL = False
def __init__(
self,
*,
max_model_len: int,
patch_size: int,
min_num_patches: int,
max_num_patches: int,
downsample_ratio: int,
norm_mean: Sequence[float],
norm_std: Sequence[float],
factor_max: float = 1.0,
use_thumbnail: bool = False,
) -> None:
assert use_thumbnail is False, "use_thumbnail is not supported"
self._patch_size: int = patch_size
self._max_model_len = max_model_len
self._min_num_patches = min_num_patches
self._max_num_patches = max_num_patches if max_num_patches > 0 else float("inf")
self._factor_max = factor_max
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
]
)
assert downsample_ratio < 1
reduction_factor = 1 / downsample_ratio
assert reduction_factor == 2.0
self._downsample_ratio = int(reduction_factor) ** (
self.PIXEL_SHUFFLE + self.CONV_MERGING
)
assert self._downsample_ratio == 2
def _get_num_embeddings(self, width: int, height: int) -> int:
num_patches = (width // self._patch_size) * (height // self._patch_size)
num_tokens = num_patches // (self._downsample_ratio**2)
return num_tokens
def width_and_height_for_max_num_tokens_available(
self,
target_num_tokens_post_shuffle: int,
) -> tuple[int, int]:
"""
TODO: optimize this so it squeezes closer to target number of tokens.
Calculate image dimensions that produce approximately `target` tokens after
pixel_shuffle.
With pixel_shuffle enabled, each 2x2 patch grid becomes 1 token, so we
need 4*B patches to get B tokens.
Examples:
>>> PATCH_SIZE = 16
>>> DOWNSAMPLE_RATIO = 0.5
>>> tiler = DynamicResolutionImageTiler(
... max_model_len=16384,
... patch_size=PATCH_SIZE,
... downsample_ratio=DOWNSAMPLE_RATIO,
... min_num_patches=4,
... max_num_patches=0,
... )
>>> width, height = tiler.width_and_height_for_max_num_tokens_available(
... target_num_tokens_post_shuffle=8192,
... )
>>> assert width, height == (2880, 2880)
>>> assert (width // PATCH_SIZE) * (
... height // PATCH_SIZE
... ) // 2**2 == 8100 # tokens post-shuffle
>>> assert tiler._get_num_embeddings(width=width, height=height) == 8100
"""
side_pixels = (
math.isqrt(target_num_tokens_post_shuffle)
* self._downsample_ratio
* self._patch_size
)
assert isinstance(side_pixels, int) and side_pixels % self._patch_size == 0
return side_pixels, side_pixels
def max_num_tokens_available(self, text_prompt_length: int) -> int:
return self._max_model_len - text_prompt_length - 4
def _images_to_pixel_values_lst(
self,
text_prompt_length: int,
images: list[Image.Image],
) -> tuple[list[torch.Tensor], list[int]]:
num_tokens_available = self.max_num_tokens_available(text_prompt_length)
params_per_image = self.compute_params(images, num_tokens_available)
feature_sizes = []
images = []
for param in params_per_image:
for t in self.apply_params(param):
assert t.ndim == 3, f"{t.ndim=}: expected 3 dim tensor"
images.append(t)
feature_sizes.append(param.num_embeddings)
return images, feature_sizes
feature_size_cache: dict[Image.Image, int] = {}
@classmethod
def get_cached_feature_size(cls, image: Image.Image) -> int:
feature_size = cls.feature_size_cache[id(image)]
# hard assert that we only use the feature size once
del cls.feature_size_cache[id(image)]
return feature_size
@dataclass
class DynamicResolutionParams:
media: Image.Image
num_tiles: int
num_embeddings: int
patch_size: tuple[int, int]
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
resized_img = params.media.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
)
)
processed_images = [resized_img]
return [self._transform(img) for img in processed_images]
def process_media(
self,
media: Image.Image,
num_tokens_available: int,
) -> tuple[DynamicResolutionParams, int]:
"""Process a single media item and return its parameters.
Args:
media: The media item to process
num_tokens_available: Number of tokens available for this media
Returns:
DynamicResolutionParams for the media
"""
current_num_tokens_available = num_tokens_available
assert isinstance(media, Image.Image), (
"Dynamic resolution is only supported for image media"
)
orig_width, orig_height = media.width, media.height
closest_patch_height = round(orig_height / self._patch_size + 0.5)
closest_patch_width = round(orig_width / self._patch_size + 0.5)
patches = closest_patch_height * closest_patch_width
factor = min(
math.sqrt(current_num_tokens_available / patches), self._factor_max
)
target_patch_height = math.floor(factor * closest_patch_height)
target_patch_width = math.floor(factor * closest_patch_width)
# Consider self._min_num_patches if > current_num_tokens_available.
if (
current_num_tokens_available > self._min_num_patches
and target_patch_height * target_patch_width < self._min_num_patches
):
up_factor = math.sqrt(
self._min_num_patches / (target_patch_height * target_patch_width)
)
target_patch_height = math.ceil(up_factor * target_patch_height)
target_patch_width = math.ceil(up_factor * target_patch_width)
# Round patch grid to be divisible by 2 (pixel-shuffle OR conv-merging)
# or by 4 when BOTH are enabled (two successive 2x reductions)
if self.PIXEL_SHUFFLE or self.CONV_MERGING:
required_divisor = 4 if (self.PIXEL_SHUFFLE and self.CONV_MERGING) else 2
rem_h = target_patch_height % required_divisor
if rem_h != 0:
inc_h = required_divisor - rem_h
if (
target_patch_height + inc_h
) * target_patch_width <= current_num_tokens_available:
target_patch_height += inc_h
else:
target_patch_height = max(
required_divisor, target_patch_height - rem_h
)
rem_w = target_patch_width % required_divisor
if rem_w != 0:
inc_w = required_divisor - rem_w
if (
target_patch_height * (target_patch_width + inc_w)
<= current_num_tokens_available
):
target_patch_width += inc_w
else:
target_patch_width = max(
required_divisor, target_patch_width - rem_w
)
# Calculate embeddings for the main dynamic resolution image
num_embeddings = self._get_num_embeddings(
target_patch_width * self._patch_size,
target_patch_height * self._patch_size,
)
token_count = target_patch_width * target_patch_height
# Add thumbnail embeddings if enabled and image area is below threshold
num_tiles = 1 # Base dynamic resolution image
return self.DynamicResolutionParams(
media=media,
num_tiles=num_tiles,
num_embeddings=num_embeddings,
patch_size=(target_patch_width, target_patch_height),
), token_count
def compute_params(
self,
media_list: list[Image.Image],
num_tokens_available: int | None = None,
) -> list[DynamicResolutionParams]:
"""Compute parameters for all media with iterative token budgeting.
Args:
media_list: List of media items to process
num_tokens_available: Total number of tokens available across all media
Returns:
List of ImageTilingParams for each media item
"""
num_tokens_available = (
num_tokens_available
* (4 if self.PIXEL_SHUFFLE else 1)
* (4 if self.CONV_MERGING else 1)
)
# When the number of available token is too small,
# allow self._min_num_patches per media and let the sample be truncated.
num_tokens_available = max(
num_tokens_available, self._min_num_patches * len(media_list)
)
# Clip the number of tokens available per media to >min and <max patches.
num_tokens_available_per_media = [
max(min(num_tokens_available, self._max_num_patches), self._min_num_patches)
for _ in range(len(media_list))
]
# prevent infinite loop in any case
for _ in range(10):
# Step 1: Process each media with current token budget
params = []
token_counts = []
for media, tokens_for_media in zip(
media_list, num_tokens_available_per_media
):
param, token_count = self.process_media(media, tokens_for_media)
params.append(param)
token_counts.append(token_count)
self.feature_size_cache[id(param.media)] = param.num_embeddings
# Step 2: Check if total tokens is within budget
total_tokens = sum(token_counts)
if total_tokens <= num_tokens_available:
# We're within budget, return the params
return params
# Step 3: We're over budget, need to scale down
# Calculate scaling factor to get under budget
scaling_factor = num_tokens_available / total_tokens
# Recalculate token budgets for each media based on scaling
# Each media gets a proportional share of the total budget
scaled_down_num_tokens_available_per_media = [
max(self._min_num_patches, int(token_count * scaling_factor))
for token_count in token_counts
]
scaled_down = any(
[
scaled_down_num_tokens_available_per_media[i]
< num_tokens_available_per_media[i]
for i in range(len(num_tokens_available_per_media))
]
)
# If there wasn't scaling down, we're stuck with min_num_patches per media,
# else try with the scaled down num_tokens_available_per_media.
if not scaled_down:
num_tokens_available_per_media = [self._min_num_patches] * len(
media_list
)
else:
num_tokens_available_per_media = (
scaled_down_num_tokens_available_per_media
)
ctx = f"{params=} {total_tokens=} {num_tokens_available=}"
raise ValueError(
f"Should be unreachable - `return params` above must be reached: {ctx}"
)
@staticmethod
def stack(images: list[torch.Tensor], patch_size: int) -> torch.Tensor:
assert len(images) > 0, "No images to stack"
def rearrange_img(x):
py = x.shape[-2] // patch_size
px = x.shape[-1] // patch_size
x = einops.rearrange(
x,
"c (py yy) (px xx) -> (py px) (c yy xx)",
py=py,
yy=patch_size,
px=px,
xx=patch_size,
)
return x
imgs = [rearrange_img(img) for img in images]
pixel_values_flat = torch.cat(imgs, dim=0).unsqueeze(0)
return pixel_values_flat
class BaseNanoNemotronVLProcessor(ABC):
"""
This model doesn't define its own HF processor,
@@ -281,6 +626,7 @@ class BaseNanoNemotronVLProcessor(ABC):
config: PretrainedConfig,
tokenizer: TokenizerLike,
*args,
max_model_len: int,
max_num_tiles: int | None = None,
**kwargs,
) -> None:
@@ -292,15 +638,32 @@ class BaseNanoNemotronVLProcessor(ABC):
self.max_num_tiles = max_num_tiles or DEFAULT_NUM_TILES
image_size: int = config.force_image_size
patch_size: int = config.patch_size
downsample_ratio: int = config.downsample_ratio
self.num_image_token = int(
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
(image_size // patch_size) ** 2 * (downsample_ratio**2)
)
self.image_size = image_size
self.use_thumbnail: bool = config.use_thumbnail
self.norm_mean = torch.Tensor(config.norm_mean).reshape(1, 3, 1, 1)
self.norm_std = torch.Tensor(config.norm_std).reshape(1, 3, 1, 1)
self.dynamic_tiler: DynamicResolutionImageTiler | None = None
if self.use_dynamic_resolution(config):
self.dynamic_tiler = DynamicResolutionImageTiler(
max_model_len=max_model_len,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
min_num_patches=config.vision_config.args["min_num_patches"],
max_num_patches=config.vision_config.args["max_num_patches"],
norm_mean=config.norm_mean,
norm_std=config.norm_std,
)
@staticmethod
def use_dynamic_resolution(config: PretrainedConfig) -> bool:
return "min_num_patches" in config.vision_config.args
@property
@abstractmethod
def image_token_id(self) -> int:
@@ -354,36 +717,61 @@ class BaseNanoNemotronVLProcessor(ABC):
text: list[str],
images: list[Image.Image],
max_num_tiles: int,
) -> tuple[list[str], dict[str, torch.Tensor]]:
) -> tuple[list[str], dict[str, Any]]:
if len(images) == 0:
image_inputs = {}
return text, image_inputs
if tiler := self.dynamic_tiler:
sans_images = text[0].replace("<image>", "")
text_prompt_length = len(
self.tokenizer(sans_images, add_special_tokens=False).input_ids
)
pixel_values_lst, num_tokens_per_image = tiler._images_to_pixel_values_lst(
text_prompt_length=text_prompt_length,
images=images,
)
imgs_sizes = [(pv.shape[-2], pv.shape[-1]) for pv in pixel_values_lst]
normalized = [
input_conditioner(img, tiler.norm_mean, tiler.norm_std)
for img in pixel_values_lst
]
image_num_patches = torch.tensor([1] * len(num_tokens_per_image))
image_inputs = {
"pixel_values_flat": normalized,
"imgs_sizes": imgs_sizes,
"num_tokens_per_image": num_tokens_per_image,
}
else:
pixel_values_lst = self._images_to_pixel_values_lst(images, max_num_tiles)
image_num_patches = torch.tensor([len(item) for item in pixel_values_lst])
pixel_values_flat = input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
)
image_inputs = {
"pixel_values_flat": input_conditioner(
torch.cat(pixel_values_lst), self.norm_mean, self.norm_std
),
"image_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst]
),
"pixel_values_flat": pixel_values_flat,
"image_num_patches": image_num_patches,
}
num_tokens_per_image = [
self.num_image_token * len(item) for item in pixel_values_lst
]
assert len(text) == 1, (
"hf_processor is called on the output of get_dummy_text, "
"which should be a single string"
)
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
assert parts.count("<image>") == len(pixel_values_lst), (
"the number of <image> tokens in the text should be the "
"same as the number of images"
)
assert len(text) == 1, (
"hf_processor is called on the output of get_dummy_text, "
"which should be a single string"
)
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
assert parts.count("<image>") == len(pixel_values_lst), (
"the number of <image> tokens in the text should be the "
"same as the number of images"
)
for i, pixel_values in enumerate(pixel_values_lst):
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)]
for i, (feature_size, num_patches) in enumerate(
zip(num_tokens_per_image, image_num_patches, strict=True)
):
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)]
return text, image_inputs
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
@@ -393,6 +781,7 @@ class BaseNanoNemotronVLProcessor(ABC):
input_item = [input_item]
return input_item
@abstractmethod
def __call__(
self,
text: str | list[str] | None = None,
@@ -400,23 +789,7 @@ class BaseNanoNemotronVLProcessor(ABC):
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
max_num_tiles = self.max_num_tiles
text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image(
text=text,
images=images,
max_num_tiles=max_num_tiles,
)
text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_outputs = {**text_inputs, **image_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
raise NotImplementedError
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
@@ -431,20 +804,16 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
config: PretrainedConfig,
tokenizer: TokenizerLike,
*,
max_model_len: int,
max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
video_token: str | None = None,
video_pruning_rate: float | None = None,
) -> None:
super().__init__(
config=config,
tokenizer=tokenizer,
max_model_len=max_model_len,
max_num_tiles=max_num_tiles,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
# add extra video token for video processing
self.video_token = video_token
@@ -478,7 +847,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self,
videos: list[npt.NDArray],
max_num_tiles: int,
dynamic_image_size: bool | None = None,
) -> list[torch.Tensor]:
return [
video_to_pixel_values(
@@ -495,7 +863,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text: list[str],
videos: list[tuple[npt.NDArray, dict[str, Any]]],
max_num_tiles: int,
dynamic_image_size: bool | None = None,
):
if len(videos) == 0 or not self.supports_video:
video_inputs = {}
@@ -505,7 +872,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst,
max_num_tiles=max_num_tiles,
dynamic_image_size=dynamic_image_size,
)
# We use frame duration in milliseconds (as integer) to ensure
@@ -592,7 +958,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
dynamic_image_size: bool | None = None,
) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
@@ -612,14 +977,23 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text=text,
videos=videos,
max_num_tiles=1,
dynamic_image_size=dynamic_image_size,
)
text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
if self.dynamic_tiler is None:
batch = BatchFeature(
{**text_inputs, **video_inputs, **image_inputs},
tensor_type=return_tensors,
)
else:
batch = BatchFeature(
{**text_inputs, **video_inputs}, tensor_type=return_tensors
)
# allow images to be exempt from the BatchFeature validation:
# We will .stack() them in _parse_and_validate_image_input
batch.update(image_inputs)
return batch
def get_image_repl(
self,
@@ -722,23 +1096,6 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
max_num_tiles: int,
processor: BaseNanoNemotronVLProcessor | None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
max_num_tiles=max_num_tiles,
)
def get_image_size_with_most_features(self, max_num_tiles: int) -> ImageSize:
processor = self.get_hf_processor()
@@ -749,11 +1106,8 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
max_num_tiles=max_num_tiles,
processor=processor,
feat_size = processor.get_num_image_tokens(
image_width=width, image_height=height, max_num_tiles=max_num_tiles
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
@@ -772,11 +1126,10 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
max_num_tiles
)
return self.get_num_image_tokens(
return processor.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
max_num_tiles=max_num_tiles,
processor=processor,
)
@@ -822,6 +1175,7 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
tokenizer=self.get_tokenizer(),
video_token=self.get_video_token(),
video_pruning_rate=self.get_video_pruning_rate(),
max_model_len=self.ctx.model_config.max_model_len,
**kwargs,
)
@@ -829,19 +1183,29 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Basic image-only MultiModalProcessor for InternVL-style models."""
@cached_property
def is_dynamic_tiler(self) -> bool:
return self.info.get_hf_processor().dynamic_tiler is not None
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
if self.is_dynamic_tiler:
pixel_values_flat = MultiModalFieldConfig.batched("image")
else:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
pixel_values_flat = MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches
)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches
),
pixel_values_flat=pixel_values_flat,
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
num_tokens_per_image=MultiModalFieldConfig.batched("image"),
imgs_sizes=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -870,17 +1234,19 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
elif tiler := hf_processor.dynamic_tiler:
image = images.get(item_idx)
feature_size = tiler.get_cached_feature_size(image)
else:
image_size = images.get_image_size(item_idx)
# Extract max_num_tiles from kwargs, default to 12
max_num_tiles = hf_processor_mm_kwargs.get(
"max_num_tiles", hf_processor.max_num_tiles
)
feature_size = self.info.get_num_image_tokens(
feature_size = hf_processor.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
max_num_tiles=max_num_tiles,
processor=hf_processor,
)
num_patches = None
@@ -1017,12 +1383,18 @@ class NanoNemotronVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
# Use default max_num_tiles for dummy data generation
max_num_tiles = 12
target_width, target_height = self.info.get_image_size_with_most_features(
max_num_tiles
)
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if tiler := processor.dynamic_tiler:
budget = tiler.max_num_tokens_available(text_prompt_length=num_images)
target_width, target_height = (
tiler.width_and_height_for_max_num_tokens_available(budget)
)
else:
max_num_tiles = 12
target_width, target_height = self.info.get_image_size_with_most_features(
max_num_tiles
)
image_overrides = mm_options.get("image") if mm_options else None
@@ -1181,6 +1553,11 @@ class NemotronH_Nano_VL_V2(
self._img_context_token_ids = tokenizer.encode(
IMG_CONTEXT, add_special_tokens=False
)
self.dynamic_resolution = BaseNanoNemotronVLProcessor.use_dynamic_resolution(
config
)
if self.dynamic_resolution:
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
@@ -1211,7 +1588,51 @@ class NemotronH_Nano_VL_V2(
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
def pixel_shuffle_dynamic_res(
self, x: torch.Tensor, *, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
scale_factor = self.downsample_ratio
patch_dim = self.patch_size
seq_lens = calc_seq_lens(imgs_sizes, patch_dim)
splits = torch.split(x, seq_lens, dim=-2)
out = []
for i, sv in enumerate(splits):
h = imgs_sizes[i][0] // patch_dim
w = imgs_sizes[i][1] // patch_dim
sv = sv.reshape(sv.shape[0], h, w, -1)
n, h, w, c = sv.size()
sv = sv.view(n, h, int(w * scale_factor), int(c / scale_factor))
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.view(
n,
int(w * scale_factor),
int(h * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v2":
sv = sv.permute(0, 2, 1, 3).contiguous()
sv = sv.reshape(sv.shape[0], -1, sv.shape[-1])
out.append(sv)
x = torch.cat(out, dim=-2)
return x
def extract_feature_dynamic(
self, pixel_values: torch.Tensor, imgs_sizes: list[tuple[int, int]]
):
"""Dynamic resolution extract_feature for images."""
_, vit_embeds = self.vision_model(pixel_values, imgs_sizes=imgs_sizes)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
vit_embeds = self.pixel_shuffle_dynamic_res(vit_embeds, imgs_sizes=imgs_sizes)
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def extract_feature(self, pixel_values: torch.Tensor):
# Process images in a micro-batch of at most 128 frames per call
# This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems
@@ -1239,36 +1660,39 @@ class NemotronH_Nano_VL_V2(
def _parse_and_validate_image_input(
self, **kwargs: object
) -> NanoNemotronVLImageInputs | None:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
return None
if image_embeds is not None:
if image_embeds := kwargs.pop("image_embeds", None):
return NanoNemotronVLImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
if pixel_values_flat is not None:
return NanoNemotronVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,
num_patches=image_num_patches,
if self.dynamic_resolution:
pixel_values_flat = DynamicResolutionImageTiler.stack(
kwargs.pop("pixel_values_flat"), self.patch_size
)
return NanoNemotronVLImagePixelInputsDynamic(
pixel_values_flat=pixel_values_flat, **kwargs
)
else:
return NanoNemotronVLImagePixelInputs(**kwargs)
raise AssertionError("This line should be unreachable.")
def _process_image_input_dynamic(
self, image_input: NanoNemotronVLImagePixelInputsDynamic
) -> tuple[torch.Tensor, ...]:
image_embeds = self.extract_feature_dynamic(
image_input.pixel_values_flat, image_input.imgs_sizes
)
num_tokens_per_image = image_input.num_tokens_per_image
if len(num_tokens_per_image) == 1:
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
return image_embeds.split(num_tokens_per_image)
def _process_image_input(
self, image_input: NanoNemotronVLImageInputs
self, image_input: NanoNemotronVLImagePixelInputs
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
@@ -1470,7 +1894,13 @@ class NemotronH_Nano_VL_V2(
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
image_embeddings = self._process_image_input(image_input)
if image_input["type"] == "image_embeds":
image_embeddings = image_input["data"]
elif self.dynamic_resolution:
assert image_input["type"] == "pixel_values_dynamic"
image_embeddings = self._process_image_input_dynamic(image_input)
else:
image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(image_embeddings)
if modality == "videos":
video_input = modalities["videos"]
@@ -1652,33 +2082,6 @@ class NemotronH_Nano_VL_V2(
if save_to_file and sys.stdout != original_stdout:
sys.stdout = original_stdout
def get_model_info(self):
"""
Get basic model information as a dictionary.
"""
total_params = sum(p.numel() for p in self.parameters())
component_info = {}
for name, param in self.named_parameters():
component = name.split(".")[0]
if component not in component_info:
component_info[component] = {"params": 0, "size": 0}
component_info[component]["params"] += 1
component_info[component]["size"] += param.numel()
return {
"model_name": "NemotronH_Nano_VL_V2",
"total_parameters": total_params,
"memory_estimate_mb": total_params * 2 / (1024**2), # bfloat16
"components": component_info,
"config": {
"image_size": getattr(self.config, "force_image_size", None),
"patch_size": getattr(self.config, "patch_size", None),
"num_image_token": self.num_image_token,
"downsample_ratio": self.downsample_ratio,
},
}
def get_vit_model_from_radio_config(self, hf_config):
hf_config_vision = hf_config.vision_config
model_name = hf_config_vision.args.get("model")

View File

@@ -21,7 +21,11 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionEncoder
from vllm.model_executor.models.intern_vit import (
InternParallelAttention,
InternVisionEncoder,
InternVisionEncoderLayer,
)
input_dim_t: TypeAlias = int | tuple[int, int]
norm_t: TypeAlias = tuple[float, float, float] | torch.Tensor
@@ -43,6 +47,15 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def calc_seq_len(size: tuple[int, int], patch_size: int) -> int:
h, w = size
return (h // patch_size) * (w // patch_size)
def calc_seq_lens(sizes: list[tuple[int, int]], patch_size: int) -> list[int]:
return [calc_seq_len(size, patch_size) for size in sizes]
class ClsToken(nn.Module):
def __init__(
self,
@@ -164,15 +177,73 @@ class ViTPatchGenerator(nn.Module):
nn.LayerNorm(embed_dim) if normalize_patches else nn.Identity()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
def forward(
self, x: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
) -> torch.Tensor:
if imgs_sizes is not None:
patches = self.embedder(x)
patches, pos_enc = self.apply_pos_enc_dynamic(
patches, imgs_sizes=imgs_sizes
)
patches = self.cls_token_dynamic(patches, imgs_sizes=imgs_sizes)
else:
patches = self.embed_patches(x)
patches, pos_enc = self.apply_pos_enc(patches, input_size=x.shape[2:])
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches)
if self.return_pos_enc:
return patches, pos_enc
return patches
def apply_pos_enc_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor | None]:
if not self.abs_pos:
return patches, None
current_length = 0
pos_enc_list = []
for size in imgs_sizes:
seq_length = calc_seq_len(size, self.patch_size)
img_patches = patches[:, current_length : current_length + seq_length, :]
pos_enc = self.get_pos_enc(patches.shape[0], input_size=size)
img_patches_with_pos = img_patches + pos_enc
patches = torch.cat(
[
patches[:, :current_length, :],
img_patches_with_pos,
patches[:, current_length + seq_length :, :],
],
dim=1,
)
pos_enc_list.append(pos_enc)
current_length += seq_length
full_pos_enc = torch.cat(pos_enc_list, dim=1) if pos_enc_list else None
return patches, full_pos_enc
def cls_token_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> torch.Tensor:
if not self.cls_token.enabled:
return patches
out = []
current_length = 0
for seq_len in calc_seq_lens(imgs_sizes, self.patch_size):
class_token = self.cls_token.token.unsqueeze(0).expand(
patches.shape[0], -1, -1
)
out.append(class_token)
out.append(patches[:, current_length : current_length + seq_len, :])
current_length += seq_len
return torch.cat(out, dim=1)
@property
def apply_cls_token(self):
return self.cls_token.enabled
@@ -406,6 +477,66 @@ class ViTPatchLinear(nn.Linear):
self.patch_size = patch_size
class RadioParallelAttention(InternParallelAttention):
def forward(
self, x: torch.Tensor, attn_mask: torch.Tensor | None = None
) -> torch.Tensor:
if attn_mask is None:
return super().forward(x)
B, N, _ = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)
q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, scale=self.scale
)
out = out.transpose(1, 2).reshape(B, N, -1)
out, _ = self.proj(out)
return out
class RadioVisionEncoderLayer(InternVisionEncoderLayer):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, attn_cls=RadioParallelAttention, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = (
hidden_states
+ self.attn(self.norm1(hidden_states), attn_mask=attn_mask) * self.ls1
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * self.ls2
return hidden_states
class RadioVisionEncoder(InternVisionEncoder):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, layer_cls=RadioVisionEncoderLayer, **kwargs)
def forward(
self,
inputs_embeds: torch.Tensor,
attn_mask: torch.Tensor | None = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states, attn_mask=attn_mask)
return hidden_states
class RadioInternVisionModel(nn.Module):
packed_modules_mapping = {
"qkv": ["qkv"],
@@ -440,7 +571,7 @@ class RadioInternVisionModel(nn.Module):
register_multiple=config.register_multiple,
)
self.encoder = InternVisionEncoder(
self.encoder = RadioVisionEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
@@ -459,10 +590,45 @@ class RadioInternVisionModel(nn.Module):
def get_input_embeddings(self):
return self.embeddings
def forward(self, x: torch.Tensor) -> torch.FloatTensor:
def create_inter_image_attention_mask(
self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> torch.Tensor:
patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
patch_counts = [seq_len + num_skip for seq_len in seq_lens]
total_patches = sum(patch_counts)
# Create attention mask - default to False (mask out)
mask = torch.zeros(
total_patches, total_patches, dtype=torch.bool, device=device
)
# Each image's patches can only attend to patches from the same image
start_idx = 0
for patch_count in patch_counts:
end_idx = start_idx + patch_count
# Allow attention within this image's patches
mask[start_idx:end_idx, start_idx:end_idx] = True
start_idx = end_idx
return mask
def forward(
self,
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor:
assert self.patch_generator is not None
hidden_states = self.patch_generator(x)
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic Resolution
attn_mask = self.create_inter_image_attention_mask(
imgs_sizes, device=x.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attn_mask=attn_mask)
return encoder_outputs
@@ -504,9 +670,11 @@ class RadioModel(nn.Module):
self,
pixel_values: torch.Tensor | None = None,
pixel_embeds: torch.Tensor | None = None,
*,
imgs_sizes: torch.Tensor | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values)
return self._extract_final(y)
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
return self._extract_final(y, imgs_sizes=imgs_sizes)
def load_weights(self, weights) -> set[str]:
loaded_params: set[str] = set()
@@ -558,16 +726,32 @@ class RadioModel(nn.Module):
return loaded_params
def _extract_final(
self, y: torch.Tensor
self, y: torch.Tensor, imgs_sizes: list[tuple[int, int]] | None = None
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# Remove CLS + REGISTERS tokens
patch_gen = getattr(self.model, "patch_generator", None)
if patch_gen is not None:
all_summary = y[:, : patch_gen.num_cls_tokens]
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
all_feat = y[:, patch_gen.num_skip :]
num_skip = self.model.patch_generator.num_skip
patch_size = self.model.patch_generator.patch_size
num_cls_tokens = self.model.patch_generator.num_cls_tokens
if imgs_sizes is None:
all_summary = y[:, :num_cls_tokens]
all_feat = y[:, num_skip:]
else:
all_patches = []
summaries = []
current_pos = 0
for num_patches in calc_seq_lens(imgs_sizes, patch_size):
patches = y[
:, current_pos + num_skip : current_pos + num_skip + num_patches, :
]
all_patches.append(patches)
summary = y[:, current_pos : current_pos + num_cls_tokens, :]
summaries.append(summary)
current_pos += num_skip + num_patches
all_summary = torch.cat(summaries, dim=1)
all_feat = torch.cat(all_patches, dim=1)
if self.summary_idxs is not None:
bb_summary = all_summary[:, self.summary_idxs]
else:
bb_summary = all_summary
return bb_summary.flatten(1), all_feat