[Model] Nano Nemotron VL - fast media preprocessing (#35657)
Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>
This commit is contained in:
@@ -17,11 +17,11 @@ from functools import cached_property
|
|||||||
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.transforms as T
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||||
|
|
||||||
@@ -214,7 +214,12 @@ NanoNemotronVLVideoInputs: TypeAlias = (
|
|||||||
|
|
||||||
|
|
||||||
def dynamic_preprocess(
|
def dynamic_preprocess(
|
||||||
image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0
|
image,
|
||||||
|
*,
|
||||||
|
image_size=512,
|
||||||
|
max_num_tiles=12,
|
||||||
|
use_thumbnail=True,
|
||||||
|
idx=0,
|
||||||
):
|
):
|
||||||
orig_width, orig_height = image.size
|
orig_width, orig_height = image.size
|
||||||
|
|
||||||
@@ -227,35 +232,44 @@ def dynamic_preprocess(
|
|||||||
image_size=image_size,
|
image_size=image_size,
|
||||||
use_thumbnail=False,
|
use_thumbnail=False,
|
||||||
)
|
)
|
||||||
# resize the image
|
|
||||||
resized_img = image.resize((target_width, target_height))
|
|
||||||
processed_images = []
|
|
||||||
for i in range(blocks):
|
|
||||||
box = (
|
|
||||||
(i % (target_width // image_size)) * image_size,
|
|
||||||
(i // (target_width // image_size)) * image_size,
|
|
||||||
((i % (target_width // image_size)) + 1) * image_size,
|
|
||||||
((i // (target_width // image_size)) + 1) * image_size,
|
|
||||||
)
|
|
||||||
# split the image
|
|
||||||
split_img = resized_img.crop(box)
|
|
||||||
processed_images.append(split_img)
|
|
||||||
assert len(processed_images) == blocks
|
|
||||||
if use_thumbnail and len(processed_images) != 1:
|
|
||||||
thumbnail_img = image.resize((image_size, image_size))
|
|
||||||
processed_images.append(thumbnail_img)
|
|
||||||
|
|
||||||
processed_images = [
|
image = np.asarray(
|
||||||
img.convert("RGB") if img.mode != "RGB" else img for img in processed_images
|
image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8
|
||||||
]
|
)
|
||||||
processed_images = [
|
|
||||||
T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(
|
image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3)
|
||||||
img
|
image = image.permute(0, 3, 1, 2) # (1, 3, H, W)
|
||||||
|
|
||||||
|
resized_img = torch.nn.functional.interpolate(
|
||||||
|
image,
|
||||||
|
size=(target_height, target_width),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
B, C, H, W = resized_img.shape
|
||||||
|
hp, wp = H // image_size, W // image_size
|
||||||
|
patches = (
|
||||||
|
resized_img.reshape(B, C, hp, image_size, wp, image_size)
|
||||||
|
.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
.reshape(B * hp * wp, C, image_size, image_size)
|
||||||
|
/ 255.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_thumbnail and patches.shape[0] > 1:
|
||||||
|
thumb = (
|
||||||
|
torch.nn.functional.interpolate(
|
||||||
|
image,
|
||||||
|
size=(image_size, image_size),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
/ 255.0
|
||||||
)
|
)
|
||||||
for img in processed_images
|
patches = torch.cat([patches, thumb], dim=0)
|
||||||
]
|
|
||||||
processed_images = [T.ToTensor()(img) for img in processed_images]
|
return list(patches)
|
||||||
return processed_images
|
|
||||||
|
|
||||||
|
|
||||||
def image_to_pixel_values(
|
def image_to_pixel_values(
|
||||||
@@ -287,22 +301,21 @@ def video_to_pixel_values(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert max_num_tiles == 1, "Video modality always uses one tile"
|
assert max_num_tiles == 1, "Video modality always uses one tile"
|
||||||
|
|
||||||
# Convert each frame to a single resized tile tensor consistent
|
# (num_frames, H, W, C) -> (num_frames, C, H, W)
|
||||||
# with image path
|
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||||
frames_tensors: list[torch.Tensor] = []
|
|
||||||
for frame in video:
|
|
||||||
pil_frame = dynamic_preprocess(
|
|
||||||
Image.fromarray(frame, mode="RGB"),
|
|
||||||
image_size=input_size,
|
|
||||||
max_num_tiles=max_num_tiles,
|
|
||||||
use_thumbnail=use_thumbnail,
|
|
||||||
idx=0,
|
|
||||||
)
|
|
||||||
# dynamic_preprocess returns tensors already; take the single tile
|
|
||||||
assert len(pil_frame) >= 1
|
|
||||||
frames_tensors.append(pil_frame[-1])
|
|
||||||
|
|
||||||
return torch.stack(frames_tensors)
|
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
|
||||||
|
video_tensor = torch.nn.functional.interpolate(
|
||||||
|
video_tensor,
|
||||||
|
size=(input_size, input_size),
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_tensor = video_tensor / 255.0
|
||||||
|
|
||||||
|
return video_tensor
|
||||||
|
|
||||||
|
|
||||||
def input_conditioner(x, norm_mean, norm_std):
|
def input_conditioner(x, norm_mean, norm_std):
|
||||||
@@ -346,12 +359,6 @@ class DynamicResolutionImageTiler:
|
|||||||
self._factor_max = factor_max
|
self._factor_max = factor_max
|
||||||
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
|
self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1)
|
||||||
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
|
self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1)
|
||||||
self._transform = T.Compose(
|
|
||||||
[
|
|
||||||
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
|
||||||
T.ToTensor(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert downsample_ratio < 1
|
assert downsample_ratio < 1
|
||||||
reduction_factor = 1 / downsample_ratio
|
reduction_factor = 1 / downsample_ratio
|
||||||
assert reduction_factor == 2.0
|
assert reduction_factor == 2.0
|
||||||
@@ -441,15 +448,25 @@ class DynamicResolutionImageTiler:
|
|||||||
patch_size: tuple[int, int]
|
patch_size: tuple[int, int]
|
||||||
|
|
||||||
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
|
def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]:
|
||||||
resized_img = params.media.resize(
|
target_size = (
|
||||||
(
|
params.patch_size[1] * self._patch_size,
|
||||||
params.patch_size[0] * self._patch_size,
|
params.patch_size[0] * self._patch_size,
|
||||||
params.patch_size[1] * self._patch_size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
processed_images = [resized_img]
|
image = np.asarray(
|
||||||
|
params.media.convert("RGB") if params.media.mode != "RGB" else params.media,
|
||||||
return [self._transform(img) for img in processed_images]
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
resized_img = (
|
||||||
|
torch.nn.functional.interpolate(
|
||||||
|
torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2),
|
||||||
|
size=target_size,
|
||||||
|
mode="bicubic",
|
||||||
|
align_corners=False,
|
||||||
|
antialias=True,
|
||||||
|
)
|
||||||
|
/ 255.0
|
||||||
|
)
|
||||||
|
return list(resized_img)
|
||||||
|
|
||||||
def process_media(
|
def process_media(
|
||||||
self,
|
self,
|
||||||
@@ -803,6 +820,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
|||||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||||
parts[i] = parts[i].replace("<image>", image_repl.full)
|
parts[i] = parts[i].replace("<image>", image_repl.full)
|
||||||
text = ["".join(parts)]
|
text = ["".join(parts)]
|
||||||
|
|
||||||
return text, image_inputs
|
return text, image_inputs
|
||||||
|
|
||||||
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
def _make_batch_input(self, input_item: Any | list[Any] | None = None):
|
||||||
@@ -922,14 +940,14 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
frames_indices_lst = [
|
frames_indices_lst = [
|
||||||
metadata["frames_indices"] for metadata in video_metadata_lst
|
metadata["frames_indices"] for metadata in video_metadata_lst
|
||||||
]
|
]
|
||||||
|
video_num_patches = torch.tensor(
|
||||||
|
[len(item) for item in pixel_values_lst_video]
|
||||||
|
)
|
||||||
video_inputs = {
|
video_inputs = {
|
||||||
"pixel_values_flat_video": input_conditioner(
|
"pixel_values_flat_video": input_conditioner(
|
||||||
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
|
torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std
|
||||||
),
|
),
|
||||||
"video_num_patches": torch.tensor(
|
"video_num_patches": video_num_patches,
|
||||||
[len(item) for item in pixel_values_lst_video]
|
|
||||||
),
|
|
||||||
"frames_indices": frames_indices_lst,
|
"frames_indices": frames_indices_lst,
|
||||||
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
|
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
|
||||||
}
|
}
|
||||||
@@ -985,6 +1003,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
|||||||
video_repl.full, skip_special_tokens=False
|
video_repl.full, skip_special_tokens=False
|
||||||
)
|
)
|
||||||
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||||
|
|
||||||
return text, video_inputs
|
return text, video_inputs
|
||||||
|
|
||||||
def _preprocess_audio(
|
def _preprocess_audio(
|
||||||
|
|||||||
Reference in New Issue
Block a user