[Model] Nano Nemotron VL - fast media preprocessing (#35657)
Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>
This commit is contained in:
@@ -17,11 +17,11 @@ from functools import cached_property
|
||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
@@ -214,7 +214,12 @@ NanoNemotronVLVideoInputs: TypeAlias = (
|
||||
|
||||
|
||||
def dynamic_preprocess(
|
||||
image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0
|
||||
image,
|
||||
*,
|
||||
image_size=512,
|
||||
max_num_tiles=12,
|
||||
use_thumbnail=True,
|
||||
idx=0,
|
||||
):
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
@@ -227,35 +232,44 @@ def dynamic_preprocess(
|
||||
image_size=image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
processed_images = [
|
||||
img.convert("RGB") if img.mode != "RGB" else img for img in processed_images
|
||||
]
|
||||
processed_images = [
|
||||
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(
|
||||
img
|
||||
image = np.asarray(
|
||||
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
|
||||
)
|
||||
|
||||
image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3)
|
||||
image = image.permute(0, 3, 1, 2) # (1, 3, H, W)
|
||||
|
||||
resized_img = torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(target_height, target_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
B, C, H, W = resized_img.shape
|
||||
hp, wp = H // image_size, W // image_size
|
||||
patches = (
|
||||
resized_img.reshape(B, C, hp, image_size, wp, image_size)
|
||||
.permute(0, 2, 4, 1, 3, 5)
|
||||
.reshape(B * hp * wp, C, image_size, image_size)
|
||||
/ 255.0
|
||||
)
|
||||
|
||||
if use_thumbnail and patches.shape[0] > 1:
|
||||
thumb = (
|
||||
torch.nn.functional.interpolate(
|
||||
image,
|
||||
size=(image_size, image_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
/ 255.0
|
||||
)
|
||||
for img in processed_images
|
||||
]
|
||||
processed_images = [T.ToTensor()(img) for img in processed_images]
|
||||
return processed_images
|
||||
patches = torch.cat([patches, thumb], dim=0)
|
||||
|
||||
return list(patches)
|
||||
|
||||
|
||||
def image_to_pixel_values(
|
||||
@@ -287,22 +301,21 @@ def video_to_pixel_values(
|
||||
) -> torch.Tensor:
|
||||
assert max_num_tiles == 1, "Video modality always uses one tile"
|
||||
|
||||
# Convert each frame to a single resized tile tensor consistent
|
||||
# with image path
|
||||
frames_tensors: list[torch.Tensor] = []
|
||||
for frame in video:
|
||||
pil_frame = dynamic_preprocess(
|
||||
Image.fromarray(frame, mode="RGB"),
|
||||
image_size=input_size,
|
||||
max_num_tiles=max_num_tiles,
|
||||
use_thumbnail=use_thumbnail,
|
||||
idx=0,
|
||||
)
|
||||
# dynamic_preprocess returns tensors already; take the single tile
|
||||
assert len(pil_frame) >= 1
|
||||
frames_tensors.append(pil_frame[-1])
|
||||
# (num_frames, H, W, C) -> (num_frames, C, H, W)
|
||||
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
|
||||
return torch.stack(frames_tensors)
|
||||
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
|
||||
video_tensor = torch.nn.functional.interpolate(
|
||||
video_tensor,
|
||||
size=(input_size, input_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
video_tensor = video_tensor / 255.0
|
||||
|
||||
return video_tensor
|
||||
|
||||
|
||||
def input_conditioner(x, norm_mean, norm_std):
|
||||
@@ -346,12 +359,6 @@ class DynamicResolutionImageTiler:
|
||||
self._factor_max = factor_max
|
||||
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
|
||||
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
|
||||
self._transform = T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
T.ToTensor(),
|
||||
]
|
||||
)
|
||||
assert downsample_ratio < 1
|
||||
reduction_factor = 1 / downsample_ratio
|
||||
assert reduction_factor == 2.0
|
||||
@@ -441,15 +448,25 @@ class DynamicResolutionImageTiler:
|
||||
patch_size: tuple[int, int]
|
||||
|
||||
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
|
||||
resized_img = params.media.resize(
|
||||
(
|
||||
params.patch_size[0] * self._patch_size,
|
||||
params.patch_size[1] * self._patch_size,
|
||||
)
|
||||
target_size = (
|
||||
params.patch_size[1] * self._patch_size,
|
||||
params.patch_size[0] * self._patch_size,
|
||||
)
|
||||
processed_images = [resized_img]
|
||||
|
||||
return [self._transform(img) for img in processed_images]
|
||||
image = np.asarray(
|
||||
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
resized_img = (
|
||||
torch.nn.functional.interpolate(
|
||||
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2),
|
||||
size=target_size,
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
/ 255.0
|
||||
)
|
||||
return list(resized_img)
|
||||
|
||||
def process_media(
|
||||
self,
|
||||
@@ -803,6 +820,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
parts[i] = parts[i].replace("<image>", image_repl.full)
|
||||
text = ["".join(parts)]
|
||||
|
||||
return text, image_inputs
|
||||
|
||||
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
||||
@@ -922,14 +940,14 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
frames_indices_lst = [
|
||||
metadata["frames_indices"] for metadata in video_metadata_lst
|
||||
]
|
||||
|
||||
video_num_patches = torch.tensor(
|
||||
[len(item) for item in pixel_values_lst_video]
|
||||
)
|
||||
video_inputs = {
|
||||
"pixel_values_flat_video": input_conditioner(
|
||||
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
|
||||
),
|
||||
"video_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst_video]
|
||||
),
|
||||
"video_num_patches": video_num_patches,
|
||||
"frames_indices": frames_indices_lst,
|
||||
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
|
||||
}
|
||||
@@ -985,6 +1003,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
video_repl.full, skip_special_tokens=False
|
||||
)
|
||||
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||
|
||||
return text, video_inputs
|
||||
|
||||
def _preprocess_audio(
|
||||
|
||||
Reference in New Issue
Block a user