[Model] Nano Nemotron VL - fast media preprocessing (#35657)

Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>
This commit is contained in:
nvnbagrov
2026-03-08 12:04:05 +02:00
committed by GitHub
parent 40077ea3de
commit b7332b058c

View File

@@ -17,11 +17,11 @@ from functools import cached_property
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import einops
import numpy as np
import numpy.typing as npt
import regex as re
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType
@@ -214,7 +214,12 @@ NanoNemotronVLVideoInputs: TypeAlias = (
def dynamic_preprocess(
image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0
image,
*,
image_size=512,
max_num_tiles=12,
use_thumbnail=True,
idx=0,
):
orig_width, orig_height = image.size
@@ -227,35 +232,44 @@ def dynamic_preprocess(
image_size=image_size,
use_thumbnail=False,
)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
processed_images = [
img.convert("RGB") if img.mode != "RGB" else img for img in processed_images
]
processed_images = [
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(
img
image = np.asarray(
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
)
image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3)
image = image.permute(0, 3, 1, 2) # (1, 3, H, W)
resized_img = torch.nn.functional.interpolate(
image,
size=(target_height, target_width),
mode="bicubic",
align_corners=False,
antialias=True,
)
B, C, H, W = resized_img.shape
hp, wp = H // image_size, W // image_size
patches = (
resized_img.reshape(B, C, hp, image_size, wp, image_size)
.permute(0, 2, 4, 1, 3, 5)
.reshape(B * hp * wp, C, image_size, image_size)
/ 255.0
)
if use_thumbnail and patches.shape[0] > 1:
thumb = (
torch.nn.functional.interpolate(
image,
size=(image_size, image_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
for img in processed_images
]
processed_images = [T.ToTensor()(img) for img in processed_images]
return processed_images
patches = torch.cat([patches, thumb], dim=0)
return list(patches)
def image_to_pixel_values(
@@ -287,22 +301,21 @@ def video_to_pixel_values(
) -> torch.Tensor:
assert max_num_tiles == 1, "Video modality always uses one tile"
# Convert each frame to a single resized tile tensor consistent
# with image path
frames_tensors: list[torch.Tensor] = []
for frame in video:
pil_frame = dynamic_preprocess(
Image.fromarray(frame, mode="RGB"),
image_size=input_size,
max_num_tiles=max_num_tiles,
use_thumbnail=use_thumbnail,
idx=0,
)
# dynamic_preprocess returns tensors already; take the single tile
assert len(pil_frame) >= 1
frames_tensors.append(pil_frame[-1])
# (num_frames, H, W, C) -> (num_frames, C, H, W)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
return torch.stack(frames_tensors)
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(input_size, input_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
video_tensor = video_tensor / 255.0
return video_tensor
def input_conditioner(x, norm_mean, norm_std):
@@ -346,12 +359,6 @@ class DynamicResolutionImageTiler:
self._factor_max = factor_max
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
self._transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
]
)
assert downsample_ratio < 1
reduction_factor = 1 / downsample_ratio
assert reduction_factor == 2.0
@@ -441,15 +448,25 @@ class DynamicResolutionImageTiler:
patch_size: tuple[int, int]
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
resized_img = params.media.resize(
(
params.patch_size[0] * self._patch_size,
params.patch_size[1] * self._patch_size,
)
target_size = (
params.patch_size[1] * self._patch_size,
params.patch_size[0] * self._patch_size,
)
processed_images = [resized_img]
return [self._transform(img) for img in processed_images]
image = np.asarray(
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
dtype=np.uint8,
)
resized_img = (
torch.nn.functional.interpolate(
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2),
size=target_size,
mode="bicubic",
align_corners=False,
antialias=True,
)
/ 255.0
)
return list(resized_img)
def process_media(
self,
@@ -803,6 +820,7 @@ class BaseNanoNemotronVLProcessor(ABC):
image_repl = self.get_image_repl(feature_size, num_patches)
parts[i] = parts[i].replace("<image>", image_repl.full)
text = ["".join(parts)]
return text, image_inputs
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
@@ -922,14 +940,14 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
frames_indices_lst = [
metadata["frames_indices"] for metadata in video_metadata_lst
]
video_num_patches = torch.tensor(
[len(item) for item in pixel_values_lst_video]
)
video_inputs = {
"pixel_values_flat_video": input_conditioner(
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
),
"video_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst_video]
),
"video_num_patches": video_num_patches,
"frames_indices": frames_indices_lst,
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
}
@@ -985,6 +1003,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_repl.full, skip_special_tokens=False
)
text = [t.replace("<video>", video_repl_text, 1) for t in text]
return text, video_inputs
def _preprocess_audio(