diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 9b9beadc0..b32067557 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -17,11 +17,11 @@ from functools import cached_property from typing import Annotated, Any, Literal, TypeAlias, TypeVar import einops +import numpy as np import numpy.typing as npt import regex as re import torch import torch.nn as nn -import torchvision.transforms as T from PIL import Image from transformers import BatchFeature, PretrainedConfig, TensorType @@ -214,7 +214,12 @@ NanoNemotronVLVideoInputs: TypeAlias = ( def dynamic_preprocess( - image, *, image_size=512, max_num_tiles=12, use_thumbnail=True, idx=0 + image, + *, + image_size=512, + max_num_tiles=12, + use_thumbnail=True, + idx=0, ): orig_width, orig_height = image.size @@ -227,35 +232,44 @@ def dynamic_preprocess( image_size=image_size, use_thumbnail=False, ) - # resize the image - resized_img = image.resize((target_width, target_height)) - processed_images = [] - for i in range(blocks): - box = ( - (i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size, - ) - # split the image - split_img = resized_img.crop(box) - processed_images.append(split_img) - assert len(processed_images) == blocks - if use_thumbnail and len(processed_images) != 1: - thumbnail_img = image.resize((image_size, image_size)) - processed_images.append(thumbnail_img) - processed_images = [ - img.convert("RGB") if img.mode != "RGB" else img for img in processed_images - ] - processed_images = [ - T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)( - img + image = np.asarray( + image.convert("RGB") if image.mode != "RGB" else image, dtype=np.uint8 + ) + + image = torch.from_numpy(image).unsqueeze(0) # (1, H, W, 3) + image = image.permute(0, 3, 1, 2) # (1, 3, H, W) + + resized_img = torch.nn.functional.interpolate( + image, + size=(target_height, target_width), + mode="bicubic", + align_corners=False, + antialias=True, + ) + B, C, H, W = resized_img.shape + hp, wp = H // image_size, W // image_size + patches = ( + resized_img.reshape(B, C, hp, image_size, wp, image_size) + .permute(0, 2, 4, 1, 3, 5) + .reshape(B * hp * wp, C, image_size, image_size) + / 255.0 + ) + + if use_thumbnail and patches.shape[0] > 1: + thumb = ( + torch.nn.functional.interpolate( + image, + size=(image_size, image_size), + mode="bicubic", + align_corners=False, + antialias=True, + ) + / 255.0 ) - for img in processed_images - ] - processed_images = [T.ToTensor()(img) for img in processed_images] - return processed_images + patches = torch.cat([patches, thumb], dim=0) + + return list(patches) def image_to_pixel_values( @@ -287,22 +301,21 @@ def video_to_pixel_values( ) -> torch.Tensor: assert max_num_tiles == 1, "Video modality always uses one tile" - # Convert each frame to a single resized tile tensor consistent - # with image path - frames_tensors: list[torch.Tensor] = [] - for frame in video: - pil_frame = dynamic_preprocess( - Image.fromarray(frame, mode="RGB"), - image_size=input_size, - max_num_tiles=max_num_tiles, - use_thumbnail=use_thumbnail, - idx=0, - ) - # dynamic_preprocess returns tensors already; take the single tile - assert len(pil_frame) >= 1 - frames_tensors.append(pil_frame[-1]) + # (num_frames, H, W, C) -> (num_frames, C, H, W) + video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2) - return torch.stack(frames_tensors) + if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size: + video_tensor = torch.nn.functional.interpolate( + video_tensor, + size=(input_size, input_size), + mode="bicubic", + align_corners=False, + antialias=True, + ) + + video_tensor = video_tensor / 255.0 + + return video_tensor def input_conditioner(x, norm_mean, norm_std): @@ -346,12 +359,6 @@ class DynamicResolutionImageTiler: self._factor_max = factor_max self.norm_mean = torch.tensor(norm_mean).reshape(3, 1, 1) self.norm_std = torch.tensor(norm_std).reshape(3, 1, 1) - self._transform = T.Compose( - [ - T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), - T.ToTensor(), - ] - ) assert downsample_ratio < 1 reduction_factor = 1 / downsample_ratio assert reduction_factor == 2.0 @@ -441,15 +448,25 @@ class DynamicResolutionImageTiler: patch_size: tuple[int, int] def apply_params(self, params: DynamicResolutionParams) -> list[torch.Tensor]: - resized_img = params.media.resize( - ( - params.patch_size[0] * self._patch_size, - params.patch_size[1] * self._patch_size, - ) + target_size = ( + params.patch_size[1] * self._patch_size, + params.patch_size[0] * self._patch_size, ) - processed_images = [resized_img] - - return [self._transform(img) for img in processed_images] + image = np.asarray( + params.media.convert("RGB") if params.media.mode != "RGB" else params.media, + dtype=np.uint8, + ) + resized_img = ( + torch.nn.functional.interpolate( + torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2), + size=target_size, + mode="bicubic", + align_corners=False, + antialias=True, + ) + / 255.0 + ) + return list(resized_img) def process_media( self, @@ -803,6 +820,7 @@ class BaseNanoNemotronVLProcessor(ABC): image_repl = self.get_image_repl(feature_size, num_patches) parts[i] = parts[i].replace("", image_repl.full) text = ["".join(parts)] + return text, image_inputs def _make_batch_input(self, input_item: Any | list[Any] | None = None): @@ -922,14 +940,14 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): frames_indices_lst = [ metadata["frames_indices"] for metadata in video_metadata_lst ] - + video_num_patches = torch.tensor( + [len(item) for item in pixel_values_lst_video] + ) video_inputs = { "pixel_values_flat_video": input_conditioner( torch.cat(pixel_values_lst_video), self.norm_mean, self.norm_std ), - "video_num_patches": torch.tensor( - [len(item) for item in pixel_values_lst_video] - ), + "video_num_patches": video_num_patches, "frames_indices": frames_indices_lst, "frame_duration_ms": torch.tensor(frame_duration_ms_lst), } @@ -985,6 +1003,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): video_repl.full, skip_special_tokens=False ) text = [t.replace("