Support temporal compression for Nemotron-3-VL videos (#36808)
Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
# --------------------------------------------------------
|
||||
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
@@ -77,6 +78,7 @@ from vllm.renderers import TokenizeParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.transformers_utils.configs.radio import RadioConfig
|
||||
from vllm.transformers_utils.processors.internvl import get_internvl_target_ratios
|
||||
from vllm.transformers_utils.processors.nano_nemotron_vl import (
|
||||
AUDIO_CONTEXT,
|
||||
IMG_CONTEXT,
|
||||
@@ -85,7 +87,7 @@ from vllm.transformers_utils.processors.nano_nemotron_vl import (
|
||||
BaseNanoNemotronVLProcessor,
|
||||
DynamicResolutionImageTiler,
|
||||
NanoNemotronVLProcessor,
|
||||
get_internvl_target_ratios,
|
||||
get_video_target_size_and_feature_size,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@@ -295,10 +297,13 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
processor = self.get_hf_processor() # we get the CustomProcessor here
|
||||
T = processor.video_temporal_patch_size
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
|
||||
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
||||
tokens_per_tubelet = processor.num_video_token
|
||||
max_total_tubelets = (seq_len - max_image_tokens) // tokens_per_tubelet
|
||||
max_tubelets_per_video = max_total_tubelets // max(max_videos, 1)
|
||||
max_frames_per_video = max_tubelets_per_video * T
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
|
||||
@@ -589,28 +594,49 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
video_num_patches = []
|
||||
|
||||
def get_video_replacement_internvl(item_idx: int):
|
||||
feature_size = hf_processor.num_image_token
|
||||
video, metadata = mm_items["video"][item_idx]
|
||||
patch_size = hf_processor.config.patch_size
|
||||
downsample_ratio = hf_processor.config.downsample_ratio
|
||||
target_patches = hf_processor.video_target_num_patches
|
||||
|
||||
if target_patches is not None and video is not None and video.shape[0] > 0:
|
||||
orig_h, orig_w = video.shape[1], video.shape[2]
|
||||
_, _, feature_size = get_video_target_size_and_feature_size(
|
||||
orig_w=orig_w,
|
||||
orig_h=orig_h,
|
||||
target_patches=target_patches,
|
||||
maintain_aspect_ratio=hf_processor.video_maintain_aspect_ratio,
|
||||
patch_size=patch_size,
|
||||
downsample_ratio=downsample_ratio,
|
||||
)
|
||||
else:
|
||||
feature_size = hf_processor.num_image_token
|
||||
num_patches = video_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
T = hf_processor.video_temporal_patch_size
|
||||
if T > 1 and num_patches is not None:
|
||||
num_tubelets = math.ceil(num_patches / T)
|
||||
else:
|
||||
num_tubelets = num_patches
|
||||
|
||||
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
|
||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||
# Start of EVS-specific code
|
||||
num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=feature_size,
|
||||
num_frames=num_patches,
|
||||
num_frames=num_tubelets,
|
||||
q=video_pruning_rate,
|
||||
)
|
||||
# Here we just need placeholders that won't actually be replaced -
|
||||
# we just need to make sure the total number of tokens is correct
|
||||
# assign all tokens to the first frame
|
||||
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1)
|
||||
tokens_per_frame = [num_tokens] + [0] * (num_tubelets - 1)
|
||||
|
||||
# End of EVS-specific code
|
||||
else:
|
||||
tokens_per_frame = [feature_size] * num_patches
|
||||
tokens_per_frame = [feature_size] * num_tubelets
|
||||
|
||||
frame_duration_ms = int(1000 / metadata["fps"])
|
||||
return hf_processor.get_video_repl(
|
||||
@@ -621,6 +647,7 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
img_start_token_ids=hf_processor._img_start_token_ids,
|
||||
img_end_token_ids=hf_processor._img_end_token_ids,
|
||||
img_context_token_ids=hf_processor._img_context_token_ids,
|
||||
video_temporal_patch_size=T,
|
||||
)
|
||||
|
||||
if self.info.supports_video:
|
||||
@@ -745,15 +772,39 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
if self.info.supports_video:
|
||||
config = self.info.get_hf_config()
|
||||
image_size: int = config.force_image_size
|
||||
processor = self.info.get_hf_processor()
|
||||
|
||||
# When video_target_num_patches is set the per-frame pixel
|
||||
# resolution can exceed image_size. Use the actual target
|
||||
# dimensions so that profiling sees the correct upper bound.
|
||||
if processor.video_target_num_patches is not None:
|
||||
target_w, target_h, _ = get_video_target_size_and_feature_size(
|
||||
orig_w=image_size,
|
||||
orig_h=image_size,
|
||||
target_patches=processor.video_target_num_patches,
|
||||
maintain_aspect_ratio=processor.video_maintain_aspect_ratio,
|
||||
patch_size=config.patch_size,
|
||||
downsample_ratio=config.downsample_ratio,
|
||||
)
|
||||
video_width, video_height = target_w, target_h
|
||||
else:
|
||||
video_width, video_height = image_size, image_size
|
||||
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts
|
||||
)
|
||||
mm_config = self.info.ctx.get_mm_config()
|
||||
if num_frames := mm_config.media_io_kwargs.get("video", {}).get(
|
||||
"num_frames"
|
||||
):
|
||||
assert num_frames > 0
|
||||
target_num_frames = num_frames
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
video_overrides = mm_options.get("video")
|
||||
dummy_video = {
|
||||
"video": self._get_dummy_videos(
|
||||
width=image_size,
|
||||
height=image_size,
|
||||
width=video_width,
|
||||
height=video_height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
overrides=video_overrides,
|
||||
@@ -790,6 +841,9 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
class NemotronH_Nano_VL_V2(
|
||||
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
|
||||
):
|
||||
requires_sequential_video_encoding = True
|
||||
"""Temporarily needed for dynamic res video w/ conv3d, doesn't support bs>1 yet"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
@@ -817,6 +871,11 @@ class NemotronH_Nano_VL_V2(
|
||||
self.image_tag_type = config.image_tag_type
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
|
||||
vision_config = getattr(config, "vision_config", config)
|
||||
self.video_temporal_patch_size: int = getattr(
|
||||
vision_config, "video_temporal_patch_size", 1
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
@@ -838,11 +897,12 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
mlp1 = nn.Sequential(
|
||||
RMSNorm(
|
||||
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
hidden_size=vit_hidden_size
|
||||
* int(round(1 / self.downsample_ratio)) ** 2,
|
||||
eps=1e-5,
|
||||
),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
|
||||
vit_hidden_size * int(round(1 / self.downsample_ratio)) ** 2,
|
||||
vision_projection_hidden_size,
|
||||
bias=False,
|
||||
),
|
||||
@@ -958,19 +1018,37 @@ class NemotronH_Nano_VL_V2(
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def extract_feature(self, pixel_values: torch.Tensor):
|
||||
def extract_feature(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
num_frames: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Process images in a micro-batch of at most 128 frames per call
|
||||
# This is done on purpose to ensure peak GPU ram usage of huge batch
|
||||
# (namely for really long videos with EVS ON) won't cause any problems
|
||||
# as we don't support chunked prefill for video media
|
||||
micro_batch_size = 128
|
||||
n = pixel_values.shape[0]
|
||||
# This is done on purpose to ensure peak GPU ram usage of huge batch
|
||||
# (namely for really long videos with EVS ON) won't cause any problems
|
||||
# as we don't support chunked prefill for video media
|
||||
# When num_frames is provided and temporal_patch_size > 1, consecutive
|
||||
# frames are grouped into tubelets — the batch size must be a multiple
|
||||
# of T so chunk boundaries don't split a tubelet.
|
||||
N, _C, H, W = pixel_values.shape
|
||||
|
||||
T = self.video_temporal_patch_size if num_frames is not None else 1
|
||||
micro_batch_size = 128 - (128 % T)
|
||||
patch_size = self.patch_size
|
||||
H_patches = H // patch_size
|
||||
W_patches = W // patch_size
|
||||
|
||||
vit_embeds_list = []
|
||||
for i in range(0, n, micro_batch_size):
|
||||
_, vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
|
||||
for i in range(0, N, micro_batch_size):
|
||||
chunk = pixel_values[i : i + micro_batch_size]
|
||||
if num_frames is not None and T > 1:
|
||||
_, vit_embeds = self.vision_model(chunk, num_frames=chunk.shape[0])
|
||||
else:
|
||||
_, vit_embeds = self.vision_model(chunk)
|
||||
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = vit_embeds.reshape(
|
||||
vit_embeds.shape[0], H_patches, W_patches, -1
|
||||
)
|
||||
vit_embeds = self.pixel_shuffle(
|
||||
vit_embeds, scale_factor=self.downsample_ratio
|
||||
)
|
||||
@@ -1042,16 +1120,21 @@ class NemotronH_Nano_VL_V2(
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Process video input and create final embeddings with video content
|
||||
and indicator tokens."""
|
||||
# Get video embeddings using the same processing as images
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
T = self.video_temporal_patch_size
|
||||
|
||||
if T > 1:
|
||||
video_embeddings = self._extract_video_embeddings_temporal(video_input)
|
||||
else:
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
|
||||
final_video_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
image_rows = image_cols = self.config.force_image_size
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
patch_size = self.config.patch_size
|
||||
rows = int(image_rows * downsample_ratio // patch_size)
|
||||
cols = int(image_cols * downsample_ratio // patch_size)
|
||||
pixel_values = video_input["pixel_values_flat"]
|
||||
frame_h, frame_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||
rows = int(frame_h * downsample_ratio // patch_size)
|
||||
cols = int(frame_w * downsample_ratio // patch_size)
|
||||
video_pruning_rate = self.video_pruning_rate
|
||||
video_num_frames = video_input["num_patches"].tolist()
|
||||
video_frames_indices = video_input["frames_indices"].split(video_num_frames)
|
||||
@@ -1062,13 +1145,14 @@ class NemotronH_Nano_VL_V2(
|
||||
num_frames = video_num_frames[i]
|
||||
frames_indices = video_frames_indices[i].tolist()
|
||||
frame_duration_ms = video_input["frame_duration_ms"][i].item()
|
||||
assert single_video_embeddings.shape[0] % num_frames == 0
|
||||
num_tubelets = math.ceil(num_frames / T) if T > 1 else num_frames
|
||||
assert single_video_embeddings.shape[0] % num_tubelets == 0
|
||||
|
||||
if video_pruning_rate is not None and video_pruning_rate > 0.0:
|
||||
# Start of EVS-specific code
|
||||
retention_mask = compute_retention_mask(
|
||||
single_video_embeddings,
|
||||
video_size_thw=(num_frames, rows, cols),
|
||||
video_size_thw=(num_tubelets, rows, cols),
|
||||
spatial_merge_size=1,
|
||||
q=video_pruning_rate,
|
||||
)
|
||||
@@ -1077,14 +1161,14 @@ class NemotronH_Nano_VL_V2(
|
||||
single_video_embeddings = single_video_embeddings[retention_mask]
|
||||
|
||||
# calculate the actual number of retained tokens per frame
|
||||
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
|
||||
retention_mask_thw = retention_mask.reshape(num_tubelets, rows, cols)
|
||||
num_tokens_per_frame = (
|
||||
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
|
||||
)
|
||||
# End of EVS-specific code
|
||||
else:
|
||||
feature_size = single_video_embeddings.shape[0] // num_frames
|
||||
num_tokens_per_frame = [feature_size] * num_frames
|
||||
feature_size = single_video_embeddings.shape[0] // num_tubelets
|
||||
num_tokens_per_frame = [feature_size] * num_tubelets
|
||||
|
||||
final_video_embeddings += (
|
||||
self._create_final_video_embeddings(
|
||||
@@ -1092,11 +1176,36 @@ class NemotronH_Nano_VL_V2(
|
||||
num_tokens_per_frame,
|
||||
frames_indices,
|
||||
frame_duration_ms,
|
||||
video_temporal_patch_size=T,
|
||||
),
|
||||
)
|
||||
|
||||
return final_video_embeddings
|
||||
|
||||
def _extract_video_embeddings_temporal(
|
||||
self, video_input: NanoNemotronVLVideoPixelInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Extract per-video embeddings with temporal compression.
|
||||
|
||||
Each video is processed separately through extract_feature with
|
||||
num_frames, which uses the fixed-resolution temporal path in RADIO
|
||||
(no attention mask, flash attention).
|
||||
"""
|
||||
pixel_values = video_input["pixel_values_flat"]
|
||||
num_frames_per_video = video_input["num_patches"].tolist()
|
||||
hidden_size = self.config.text_config.hidden_size
|
||||
|
||||
results: list[torch.Tensor] = []
|
||||
frame_offset = 0
|
||||
for nf in num_frames_per_video:
|
||||
video_frames = pixel_values[frame_offset : frame_offset + nf]
|
||||
frame_offset += nf
|
||||
|
||||
vit_embeds = self.extract_feature(video_frames, num_frames=nf)
|
||||
results.append(vit_embeds.view(-1, hidden_size))
|
||||
|
||||
return tuple(results)
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: NanoNemotronVLAudioFeatureInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
@@ -1134,6 +1243,7 @@ class NemotronH_Nano_VL_V2(
|
||||
num_tokens_per_frame: list[int],
|
||||
frames_indices: list[int],
|
||||
frame_duration_ms: int,
|
||||
video_temporal_patch_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Create final embeddings that combine video embeddings with
|
||||
text embeddings of indicator tokens.
|
||||
@@ -1161,6 +1271,7 @@ class NemotronH_Nano_VL_V2(
|
||||
img_start_token_ids=self._img_start_token_ids,
|
||||
img_end_token_ids=self._img_end_token_ids,
|
||||
img_context_token_ids=self._img_context_token_ids,
|
||||
video_temporal_patch_size=video_temporal_patch_size,
|
||||
)
|
||||
|
||||
# video_repl.full is a list of token IDs
|
||||
@@ -1207,8 +1318,27 @@ class NemotronH_Nano_VL_V2(
|
||||
else:
|
||||
frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0)
|
||||
|
||||
frame_duration_ms = frame_duration_ms.flatten()
|
||||
expected_h = expected_w = self.config.force_image_size
|
||||
if torch.is_tensor(frame_duration_ms):
|
||||
frame_duration_ms = frame_duration_ms.flatten()
|
||||
else:
|
||||
frame_duration_ms = torch.cat(
|
||||
[f.flatten() for f in frame_duration_ms], dim=0
|
||||
)
|
||||
|
||||
if (
|
||||
torch.is_tensor(pixel_values_flat_video)
|
||||
and pixel_values_flat_video.ndim == 5
|
||||
):
|
||||
# batched._reduce_data stacked same-shape videos into
|
||||
# [num_videos, nf, 3, H, W]; unstack back to a list so the
|
||||
# same-H,W cat path below handles it uniformly.
|
||||
pixel_values_flat_video = list(pixel_values_flat_video)
|
||||
|
||||
if not torch.is_tensor(pixel_values_flat_video):
|
||||
pixel_values_flat_video = torch.cat(pixel_values_flat_video, dim=0)
|
||||
|
||||
expected_h = pixel_values_flat_video.shape[-2]
|
||||
expected_w = pixel_values_flat_video.shape[-1]
|
||||
num_frames = video_num_patches[0].item()
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames}
|
||||
|
||||
@@ -1361,8 +1491,7 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.vision_model.load_weights(vision_weights)
|
||||
if self.sound_encoder is not None:
|
||||
assert len(sound_weights) > 0
|
||||
if self.sound_encoder is not None and len(sound_weights) > 0:
|
||||
self.sound_encoder.load_weights(sound_weights)
|
||||
|
||||
def get_vit_model_from_radio_config(self, hf_config):
|
||||
@@ -1375,12 +1504,23 @@ class NemotronH_Nano_VL_V2(
|
||||
image_size = preferred_resolution[0] if preferred_resolution else 224
|
||||
patch_size = getattr(hf_config_vision, "patch_size", 16)
|
||||
|
||||
# video_temporal_patch_size and separate_video_embedder are
|
||||
# top-level vision_config attributes, not inside args.
|
||||
video_temporal_patch_size = getattr(
|
||||
hf_config_vision, "video_temporal_patch_size", 1
|
||||
)
|
||||
separate_video_embedder = getattr(
|
||||
hf_config_vision, "separate_video_embedder", True
|
||||
)
|
||||
|
||||
radio_config = RadioConfig(
|
||||
model_name=model_name,
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
norm_mean=hf_config.norm_mean,
|
||||
norm_std=hf_config.norm_std,
|
||||
video_temporal_patch_size=video_temporal_patch_size,
|
||||
separate_video_embedder=separate_video_embedder,
|
||||
**hf_config_vision.args,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,6 +123,8 @@ class ViTPatchGenerator(nn.Module):
|
||||
register_multiple: int | None = None,
|
||||
num_registers: int | None = None,
|
||||
patch_bias: bool = False,
|
||||
temporal_patch_size: int = 1,
|
||||
separate_video_embedder: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
@@ -148,6 +150,7 @@ class ViTPatchGenerator(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.abs_pos = abs_pos
|
||||
self.embed_dim = embed_dim
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
self.num_rows = max_input_dims[0] // patch_size
|
||||
self.num_cols = max_input_dims[1] // patch_size
|
||||
@@ -160,6 +163,21 @@ class ViTPatchGenerator(nn.Module):
|
||||
patch_size, embed_dim, bias=patch_bias, **factory
|
||||
)
|
||||
|
||||
if temporal_patch_size > 1:
|
||||
if not separate_video_embedder:
|
||||
raise NotImplementedError(
|
||||
"Only separate_video_embedder=True is supported for"
|
||||
" temporal compression (temporal_patch_size > 1)"
|
||||
)
|
||||
self.video_embedder = ViTPatchLinear(
|
||||
patch_size,
|
||||
embed_dim,
|
||||
bias=patch_bias,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
**factory,
|
||||
)
|
||||
self._video_embedder_loaded = False
|
||||
|
||||
if abs_pos:
|
||||
scale = embed_dim**-0.5
|
||||
self.pos_embed = nn.Parameter(
|
||||
@@ -196,6 +214,60 @@ class ViTPatchGenerator(nn.Module):
|
||||
return patches, pos_enc
|
||||
return patches
|
||||
|
||||
def forward_video(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process video frames with temporal compression.
|
||||
|
||||
Groups T consecutive frames into tubelets before embedding.
|
||||
|
||||
Args:
|
||||
x: [num_frames, 3, H, W] tensor of video frames
|
||||
|
||||
Returns:
|
||||
Embedded patches with temporal compression applied.
|
||||
"""
|
||||
if not self._video_embedder_loaded:
|
||||
raise ValueError(
|
||||
"Temporal compression (video_temporal_patch_size > 1) requires "
|
||||
"video_embedder weights, but they were never loaded. "
|
||||
"Ensure the checkpoint was trained with temporal compression."
|
||||
)
|
||||
T = self.temporal_patch_size
|
||||
input_size = x.shape[2:]
|
||||
|
||||
patches = self.im_to_patches(x) # [N, num_patches, 3*P*P]
|
||||
num_frames, num_spatial, feat_dim = patches.shape
|
||||
|
||||
# Pad to a multiple of T by repeating the last frame so that
|
||||
# all tubelets have exactly T frames.
|
||||
num_pad_frames = (-num_frames) % T
|
||||
if num_pad_frames > 0:
|
||||
last_frame_dup = patches[-1:].expand(num_pad_frames, -1, -1)
|
||||
patches = torch.cat([patches, last_frame_dup], dim=0)
|
||||
|
||||
# Group T frames per tubelet: for each spatial position, concatenate
|
||||
# features across T consecutive frames; order follows Megatron training
|
||||
num_frames_padded = patches.shape[0]
|
||||
num_tublets = num_frames_padded // T
|
||||
patches = rearrange(
|
||||
patches,
|
||||
"(tubelets frames) spatial feat -> tubelets spatial (frames feat)",
|
||||
tubelets=num_tublets,
|
||||
frames=T,
|
||||
spatial=num_spatial,
|
||||
feat=feat_dim,
|
||||
)
|
||||
|
||||
patches = self.video_embedder(patches)
|
||||
|
||||
patches, pos_enc = self.apply_pos_enc(patches, input_size=input_size)
|
||||
|
||||
patches = self.cls_token(patches)
|
||||
|
||||
patches = self.patch_normalizer(patches)
|
||||
if self.return_pos_enc:
|
||||
return patches, pos_enc
|
||||
return patches
|
||||
|
||||
def apply_pos_enc_dynamic(
|
||||
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
@@ -381,66 +453,21 @@ class ViTPatchGenerator(nn.Module):
|
||||
return pos_embed
|
||||
|
||||
if self.cpe_mode:
|
||||
if self.training:
|
||||
min_scale = math.sqrt(0.1)
|
||||
scale = (
|
||||
torch.rand(batch_size, 1, 1, device=pos_embed.device)
|
||||
* (1 - min_scale)
|
||||
+ min_scale
|
||||
)
|
||||
aspect_min = math.log(3 / 4)
|
||||
aspect_max = -aspect_min
|
||||
aspect = torch.exp(
|
||||
torch.rand(batch_size, 1, 1, device=pos_embed.device)
|
||||
* (aspect_max - aspect_min)
|
||||
+ aspect_min
|
||||
)
|
||||
max_dim = max(input_dims)
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed.float(),
|
||||
size=(max_dim, max_dim),
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
).to(pos_embed.dtype)
|
||||
|
||||
scale_x = scale * aspect
|
||||
scale_y = scale * (1 / aspect)
|
||||
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
|
||||
|
||||
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
|
||||
1 - scale_xy
|
||||
)
|
||||
|
||||
lin_x = torch.linspace(
|
||||
0, 1, steps=input_dims[1], device=pos_embed.device
|
||||
)[None, None].expand(batch_size, input_dims[0], -1)
|
||||
lin_y = torch.linspace(
|
||||
0, 1, steps=input_dims[0], device=pos_embed.device
|
||||
)[None, :, None].expand(batch_size, -1, input_dims[1])
|
||||
|
||||
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
|
||||
|
||||
grid_xy = lin_xy * scale_xy + pos_xy
|
||||
|
||||
# Convert to [-1, 1] range
|
||||
grid_xy.mul_(2).sub_(1)
|
||||
|
||||
pos_embed = F.grid_sample(
|
||||
pos_embed.float().expand(batch_size, -1, -1, -1),
|
||||
grid=grid_xy,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=True,
|
||||
).to(pos_embed.dtype)
|
||||
else:
|
||||
max_dim = max(input_dims)
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed.float(),
|
||||
size=(max_dim, max_dim),
|
||||
align_corners=True,
|
||||
mode="bilinear",
|
||||
).to(pos_embed.dtype)
|
||||
|
||||
pos_embed = window_select(pos_embed)
|
||||
pos_embed = window_select(pos_embed)
|
||||
else:
|
||||
pos_embed = window_select(pos_embed)
|
||||
|
||||
if pos_embed.shape[-2:] != input_dims:
|
||||
pos_embed = F.interpolate(
|
||||
pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
|
||||
pos_embed.float(), size=input_dims, align_corners=False, mode="bilinear"
|
||||
).to(pos_embed.dtype)
|
||||
|
||||
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
|
||||
@@ -473,9 +500,19 @@ class Im2Patches(nn.Module):
|
||||
|
||||
|
||||
class ViTPatchLinear(nn.Linear):
|
||||
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
|
||||
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
embed_dim: int,
|
||||
bias: bool = False,
|
||||
temporal_patch_size: int = 1,
|
||||
**factory,
|
||||
):
|
||||
super().__init__(
|
||||
3 * temporal_patch_size * (patch_size**2), embed_dim, bias=bias, **factory
|
||||
)
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
@@ -560,6 +597,7 @@ class RadioInternVisionModel(nn.Module):
|
||||
max_img_size = int(
|
||||
round(config.cpe_max_size / config.patch_size) * config.patch_size
|
||||
)
|
||||
self.temporal_patch_size = config.video_temporal_patch_size
|
||||
unique_teachers = set(t["name"] for t in config.teachers)
|
||||
self.patch_generator = ViTPatchGenerator(
|
||||
config.patch_size,
|
||||
@@ -569,6 +607,8 @@ class RadioInternVisionModel(nn.Module):
|
||||
cls_token=True,
|
||||
num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
|
||||
register_multiple=config.register_multiple,
|
||||
temporal_patch_size=self.temporal_patch_size,
|
||||
separate_video_embedder=config.separate_video_embedder,
|
||||
)
|
||||
|
||||
self.encoder = RadioVisionEncoder(
|
||||
@@ -593,33 +633,68 @@ class RadioInternVisionModel(nn.Module):
|
||||
def inter_image_mask_metadata(
|
||||
self, imgs_sizes: list[tuple[int, int]], device: torch.device
|
||||
) -> MaskMetadata:
|
||||
"""Build mask metadata from image pixel sizes. Adds num_skip to each
|
||||
sequence length (cls/register tokens) to match patch generator output."""
|
||||
patch_size = self.patch_generator.patch_size
|
||||
num_skip = self.patch_generator.num_skip
|
||||
|
||||
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
|
||||
adjusted = [s + num_skip for s in seq_lens]
|
||||
return self._inter_image_mask_metadata_from_seq_lens(adjusted, device=device)
|
||||
|
||||
def _inter_image_mask_metadata_from_seq_lens(
|
||||
self, seq_lens: list[int], device: torch.device
|
||||
) -> MaskMetadata:
|
||||
"""Build mask metadata from actual sequence lengths (already including
|
||||
cls/register tokens, i.e. patch_count + num_skip per item).
|
||||
Use inter_image_mask_metadata() when you only have imgs_sizes."""
|
||||
assert len(seq_lens) > 0
|
||||
cu_seqlens = torch.tensor(
|
||||
list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device
|
||||
list(accumulate(seq_lens, initial=0)), dtype=torch.int32, device=device
|
||||
)
|
||||
# Keep max_seqlen on CPU to avoid .item() sync
|
||||
# See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
|
||||
max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32)
|
||||
max_seqlen = torch.tensor(max(seq_lens), dtype=torch.int32)
|
||||
return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
imgs_sizes: list[tuple[int, int]] | None = None,
|
||||
num_frames: int | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
|
||||
T = self.temporal_patch_size
|
||||
|
||||
# Build packed-sequence metadata for MMEncoderAttention when needed.
|
||||
mask_meta = None
|
||||
if imgs_sizes is not None:
|
||||
assert len(imgs_sizes) > 0
|
||||
# Dynamic resolution: process each image as an independent sequence.
|
||||
mask_meta = self.inter_image_mask_metadata(
|
||||
imgs_sizes, device=hidden_states.device
|
||||
packed_batch_size = None # Original batch size before packing
|
||||
|
||||
if num_frames is not None and T > 1:
|
||||
# Conv3d video: all tubelets have the same sequence length.
|
||||
# Pack [num_tubelets, seq_per_tubelet, hidden] → [1, total, hidden]
|
||||
hidden_states = self.patch_generator.forward_video(x)
|
||||
packed_batch_size, seq_per_tubelet, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(1, -1, hidden_dim)
|
||||
mask_meta = self._inter_image_mask_metadata_from_seq_lens(
|
||||
[seq_per_tubelet] * packed_batch_size, device=hidden_states.device
|
||||
)
|
||||
else:
|
||||
# Images for any model, or video for non-conv3d model
|
||||
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
|
||||
if imgs_sizes is not None and len(imgs_sizes) > 1:
|
||||
# Dynamic resolution w/ > 1 image, create attn mask
|
||||
mask_meta = self.inter_image_mask_metadata(
|
||||
imgs_sizes, device=hidden_states.device
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
|
||||
|
||||
# Unpack back to original batch shape if we packed for video
|
||||
if packed_batch_size is not None:
|
||||
encoder_outputs = encoder_outputs.reshape(
|
||||
packed_batch_size, seq_per_tubelet, -1
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
@@ -663,8 +738,13 @@ class RadioModel(nn.Module):
|
||||
pixel_embeds: torch.Tensor | None = None,
|
||||
*,
|
||||
imgs_sizes: list[tuple[int, int]] | None = None,
|
||||
num_frames: int | None = None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
|
||||
y = self.model(
|
||||
pixel_values,
|
||||
imgs_sizes=imgs_sizes,
|
||||
num_frames=num_frames,
|
||||
)
|
||||
return self._extract_final(y, imgs_sizes=imgs_sizes)
|
||||
|
||||
def load_weights(self, weights) -> set[str]:
|
||||
@@ -714,6 +794,9 @@ class RadioModel(nn.Module):
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(vllm_key)
|
||||
|
||||
if "model.patch_generator.video_embedder.weight" in loaded_params:
|
||||
self.model.patch_generator._video_embedder_loaded = True
|
||||
|
||||
return loaded_params
|
||||
|
||||
def _extract_final(
|
||||
|
||||
@@ -47,6 +47,14 @@ class RadioConfig(PretrainedConfig):
|
||||
teachers: A list of teacher model configurations. Each teacher configuration is
|
||||
a dict with keys like "name" and some may have "use_summary".
|
||||
cls_token_per_teacher: Whether to use a separate CLS token for each teacher.
|
||||
video_temporal_patch_size: Number of consecutive video frames grouped into
|
||||
a single tubelet for temporal compression. Default 1 (no compression).
|
||||
When > 1, a dedicated video_embedder (3*T*P*P -> hidden) is created
|
||||
alongside the image embedder (3*P*P -> hidden).
|
||||
separate_video_embedder: When True and video_temporal_patch_size > 1, use a
|
||||
dedicated video patch embedder (3*T*P*P -> hidden) separate from the
|
||||
image embedder (3*P*P -> hidden). When False, a single embedder with
|
||||
input size 3*T*P*P is used for both (images are duplicated T times).
|
||||
"""
|
||||
|
||||
model_type = "radio"
|
||||
@@ -68,6 +76,8 @@ class RadioConfig(PretrainedConfig):
|
||||
register_multiple: int | None = None,
|
||||
teachers: list[dict[str, Any]] | None = None,
|
||||
cls_token_per_teacher: bool = False,
|
||||
video_temporal_patch_size: int = 1,
|
||||
separate_video_embedder: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.model_name = model_name
|
||||
@@ -95,4 +105,6 @@ class RadioConfig(PretrainedConfig):
|
||||
self.register_multiple = register_multiple
|
||||
self.teachers = teachers if teachers is not None else []
|
||||
self.cls_token_per_teacher = cls_token_per_teacher
|
||||
self.video_temporal_patch_size = video_temporal_patch_size
|
||||
self.separate_video_embedder = separate_video_embedder
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -11,6 +11,7 @@ import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import einops
|
||||
@@ -43,6 +44,12 @@ AUDIO_CONTEXT = "<so_embedding>"
|
||||
# MAX_FRAMES = 16
|
||||
DEFAULT_NUM_TILES = 12
|
||||
|
||||
# Configure PIL to handle large images without warnings
|
||||
# This prevents DecompressionBombWarning for legitimate large images
|
||||
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
# Alternative: Set a specific higher limit
|
||||
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
|
||||
|
||||
|
||||
def calculate_timestamps(
|
||||
indices: list[int] | torch.Tensor,
|
||||
@@ -138,19 +145,110 @@ def image_to_pixel_values(
|
||||
return pixel_values
|
||||
|
||||
|
||||
def _compute_aspect_preserving_size(
|
||||
orig_w: int,
|
||||
orig_h: int,
|
||||
target_num_patches: int,
|
||||
patch_size: int,
|
||||
downsample_ratio: float,
|
||||
) -> tuple[int, int]:
|
||||
"""Compute target pixel dimensions that preserve aspect ratio.
|
||||
|
||||
Mirrors Megatron-LM image_processing.py video frame resizing:
|
||||
target area in patch-grid space is *target_num_patches*, distributed
|
||||
according to the source aspect ratio, then snapped to a multiple of
|
||||
the required divisor (2 for pixel-shuffle).
|
||||
"""
|
||||
aspect_wh = orig_w / max(orig_h, 1)
|
||||
ph = round(math.sqrt(target_num_patches / aspect_wh))
|
||||
pw = round(math.sqrt(target_num_patches * aspect_wh))
|
||||
ph = max(ph, 1)
|
||||
pw = max(pw, 1)
|
||||
|
||||
reduction_factor = int(round(1 / downsample_ratio))
|
||||
required_divisor = reduction_factor # 2 for pixel-shuffle
|
||||
if required_divisor > 1:
|
||||
rem_h = ph % required_divisor
|
||||
rem_w = pw % required_divisor
|
||||
ph_up = ph + (required_divisor - rem_h if rem_h else 0)
|
||||
ph_down = ph - rem_h
|
||||
pw_up = pw + (required_divisor - rem_w if rem_w else 0)
|
||||
pw_down = pw - rem_w
|
||||
if ph_up * pw_up <= target_num_patches:
|
||||
ph, pw = ph_up, pw_up
|
||||
else:
|
||||
ph = max(required_divisor, ph_down)
|
||||
pw = max(required_divisor, pw_down)
|
||||
|
||||
return pw * patch_size, ph * patch_size # (width, height) in pixels
|
||||
|
||||
|
||||
def get_video_target_size_and_feature_size(
|
||||
orig_w: int,
|
||||
orig_h: int,
|
||||
target_patches: int,
|
||||
maintain_aspect_ratio: bool,
|
||||
patch_size: int,
|
||||
downsample_ratio: float,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Compute target (width, height) and feature_size for video resize and token count.
|
||||
|
||||
Used by video_to_pixel_values (resize) and get_video_replacement_internvl
|
||||
(seq length calc) so both use the same dimensions.
|
||||
"""
|
||||
if maintain_aspect_ratio:
|
||||
target_w, target_h = _compute_aspect_preserving_size(
|
||||
orig_w=orig_w,
|
||||
orig_h=orig_h,
|
||||
target_num_patches=target_patches,
|
||||
patch_size=patch_size,
|
||||
downsample_ratio=downsample_ratio,
|
||||
)
|
||||
else:
|
||||
reduction_factor = int(round(1 / downsample_ratio))
|
||||
side = int(math.sqrt(target_patches))
|
||||
side = max(reduction_factor, (side // reduction_factor) * reduction_factor)
|
||||
target_w = side * patch_size
|
||||
target_h = side * patch_size
|
||||
|
||||
feature_size = int((target_h // patch_size) * downsample_ratio) * int(
|
||||
(target_w // patch_size) * downsample_ratio
|
||||
)
|
||||
return target_w, target_h, feature_size
|
||||
|
||||
|
||||
def video_to_pixel_values(
|
||||
video: npt.NDArray,
|
||||
*,
|
||||
input_size: int,
|
||||
max_num_tiles: int = 1,
|
||||
use_thumbnail: bool,
|
||||
video_target_num_patches: int | None = None,
|
||||
video_maintain_aspect_ratio: bool = False,
|
||||
patch_size: int = 16,
|
||||
downsample_ratio: float = 0.5,
|
||||
) -> torch.Tensor:
|
||||
assert max_num_tiles == 1, "Video modality always uses one tile"
|
||||
|
||||
# (num_frames, H, W, C) -> (num_frames, C, H, W)
|
||||
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
|
||||
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
|
||||
if video_target_num_patches is not None:
|
||||
# Resize to target patch count (aspect-preserving or square).
|
||||
orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3]
|
||||
target_w, target_h, _ = get_video_target_size_and_feature_size(
|
||||
orig_w=orig_w,
|
||||
orig_h=orig_h,
|
||||
target_patches=video_target_num_patches,
|
||||
maintain_aspect_ratio=video_maintain_aspect_ratio,
|
||||
patch_size=patch_size,
|
||||
downsample_ratio=downsample_ratio,
|
||||
)
|
||||
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
|
||||
video_tensor = torch.nn.functional.interpolate(
|
||||
video_tensor,
|
||||
size=(target_h, target_w),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
|
||||
video_tensor = torch.nn.functional.interpolate(
|
||||
video_tensor,
|
||||
size=(input_size, input_size),
|
||||
@@ -645,9 +743,9 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
"which should be a single string"
|
||||
)
|
||||
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
|
||||
assert parts.count("<image>") == len(pixel_values_lst), (
|
||||
"the number of <image> tokens in the text should be the "
|
||||
"same as the number of images"
|
||||
assert parts.count("<image>") == len(num_tokens_per_image), (
|
||||
f"Expected {len(num_tokens_per_image)} <image> tokens in text "
|
||||
f"but found {parts.count('<image>')}"
|
||||
)
|
||||
|
||||
for i, (feature_size, num_patches) in enumerate(
|
||||
@@ -706,6 +804,33 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self.video_token = video_token
|
||||
self.video_pruning_rate = video_pruning_rate
|
||||
|
||||
# Video params live exclusively in vision_config
|
||||
vision_config = getattr(config, "vision_config", config)
|
||||
self.video_temporal_patch_size: int = getattr(
|
||||
vision_config, "video_temporal_patch_size", 1
|
||||
)
|
||||
self.video_maintain_aspect_ratio: bool = getattr(
|
||||
vision_config, "video_maintain_aspect_ratio", False
|
||||
)
|
||||
|
||||
# Resolve video frame target size: exactly one of video_target_num_patches
|
||||
# or video_target_img_size may be set (mirrors Megatron's
|
||||
# DynamicResolutionImageTilingStrategy validation).
|
||||
target_num_patches = getattr(vision_config, "video_target_num_patches", None)
|
||||
target_img_size = getattr(vision_config, "video_target_img_size", None)
|
||||
if target_num_patches is not None and target_img_size is not None:
|
||||
raise ValueError(
|
||||
"Exactly one of video_target_num_patches or "
|
||||
"video_target_img_size must be set, got both"
|
||||
)
|
||||
if target_num_patches is not None:
|
||||
self.video_target_num_patches: int | None = target_num_patches
|
||||
elif target_img_size is not None:
|
||||
base_patches = math.ceil(target_img_size / config.patch_size)
|
||||
self.video_target_num_patches = base_patches * base_patches
|
||||
else:
|
||||
self.video_target_num_patches = None
|
||||
|
||||
self.audio_extractor: ParakeetExtractor | None = None
|
||||
raw_sound_config = getattr(config, "sound_config", None)
|
||||
if raw_sound_config is not None:
|
||||
@@ -721,6 +846,27 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
IMG_CONTEXT, add_special_tokens=False
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def num_video_token(self) -> int:
|
||||
"""Token count per video frame, accounting for video_target_num_patches.
|
||||
|
||||
When video_target_num_patches is set the per-frame feature count
|
||||
differs from the image-based num_image_token. We use a square
|
||||
dummy (1:1) to compute the feature_size because the dummy video is
|
||||
square and the user confirmed that is acceptable.
|
||||
"""
|
||||
if self.video_target_num_patches is not None:
|
||||
_, _, feature_size = get_video_target_size_and_feature_size(
|
||||
orig_w=self.image_size,
|
||||
orig_h=self.image_size,
|
||||
target_patches=self.video_target_num_patches,
|
||||
maintain_aspect_ratio=self.video_maintain_aspect_ratio,
|
||||
patch_size=self.config.patch_size,
|
||||
downsample_ratio=self.config.downsample_ratio,
|
||||
)
|
||||
return feature_size
|
||||
return self.num_image_token
|
||||
|
||||
@property
|
||||
def supports_video(self) -> bool:
|
||||
return self.video_token_id is not None
|
||||
@@ -738,14 +884,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
def _videos_to_pixel_values_lst(
|
||||
self,
|
||||
videos: list[npt.NDArray],
|
||||
max_num_tiles: int,
|
||||
) -> list[torch.Tensor]:
|
||||
return [
|
||||
video_to_pixel_values(
|
||||
video,
|
||||
input_size=self.image_size,
|
||||
max_num_tiles=max_num_tiles,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
video_target_num_patches=self.video_target_num_patches,
|
||||
video_maintain_aspect_ratio=self.video_maintain_aspect_ratio,
|
||||
patch_size=self.config.patch_size,
|
||||
downsample_ratio=self.config.downsample_ratio,
|
||||
)
|
||||
for video in videos
|
||||
]
|
||||
@@ -754,7 +901,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self,
|
||||
text: list[str],
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]],
|
||||
max_num_tiles: int,
|
||||
) -> tuple[list[str], dict[str, Any]]:
|
||||
if len(videos) == 0 or not self.supports_video:
|
||||
return text, {}
|
||||
@@ -763,7 +909,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
video_metadata_lst = [v[1] for v in videos]
|
||||
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
||||
videos_lst,
|
||||
max_num_tiles=max_num_tiles,
|
||||
)
|
||||
|
||||
# We use frame duration in milliseconds (as integer) to ensure
|
||||
@@ -788,12 +933,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
|
||||
}
|
||||
|
||||
image_size: int = self.config.force_image_size
|
||||
patch_size: int = self.config.patch_size
|
||||
downsample_ratio = self.config.downsample_ratio
|
||||
tokens_in_single_frame = int(
|
||||
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
|
||||
)
|
||||
|
||||
T = self.video_temporal_patch_size
|
||||
|
||||
for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip(
|
||||
pixel_values_lst_video,
|
||||
@@ -802,23 +945,28 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
frame_duration_ms_lst,
|
||||
):
|
||||
num_frames = pixel_values.shape[0]
|
||||
frame_h, frame_w = pixel_values.shape[-2], pixel_values.shape[-1]
|
||||
tokens_in_single_frame = int(
|
||||
(frame_h * frame_w // patch_size**2) * (downsample_ratio**2)
|
||||
)
|
||||
num_tubelets = math.ceil(num_frames / T) if T > 1 else num_frames
|
||||
|
||||
if self.video_pruning_rate is not None and self.video_pruning_rate > 0.0:
|
||||
# Start of EVS-specific code
|
||||
num_tokens = compute_retained_tokens_count(
|
||||
tokens_per_frame=tokens_in_single_frame,
|
||||
num_frames=num_frames,
|
||||
num_frames=num_tubelets,
|
||||
q=self.video_pruning_rate,
|
||||
)
|
||||
|
||||
# Here we just need placeholders that won't actually be replaced -
|
||||
# we just need to make sure the total number of tokens is correct
|
||||
# assign all tokens to the first frame
|
||||
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
|
||||
tokens_per_frame = [num_tokens] + [0] * (num_tubelets - 1)
|
||||
|
||||
# End of EVS-specific code
|
||||
else:
|
||||
tokens_per_frame = [tokens_in_single_frame] * num_frames
|
||||
tokens_per_frame = [tokens_in_single_frame] * num_tubelets
|
||||
|
||||
video_repl = self.get_video_repl(
|
||||
tokens_per_frame=tokens_per_frame,
|
||||
@@ -828,6 +976,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
img_start_token_ids=self._img_start_token_ids,
|
||||
img_end_token_ids=self._img_end_token_ids,
|
||||
img_context_token_ids=self._img_context_token_ids,
|
||||
video_temporal_patch_size=T,
|
||||
)
|
||||
|
||||
# video_repl.full is a list of token IDs
|
||||
@@ -908,7 +1057,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
text, video_inputs = self._preprocess_video(
|
||||
text=text,
|
||||
videos=videos,
|
||||
max_num_tiles=1,
|
||||
)
|
||||
|
||||
text, audio_inputs = self._preprocess_audio(
|
||||
@@ -962,6 +1110,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
img_start_token_ids: list[int],
|
||||
img_end_token_ids: list[int],
|
||||
img_context_token_ids: list[int],
|
||||
video_temporal_patch_size: int = 1,
|
||||
) -> PromptUpdateDetails[list[int]]:
|
||||
"""
|
||||
Build prompt replacement for a video.
|
||||
@@ -981,31 +1130,60 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
|
||||
Args:
|
||||
tokens_per_frame (list[int]): number of tokens per frame
|
||||
frames_indices (list[int]): frame indices
|
||||
(one per tubelet when T > 1)
|
||||
frames_indices (list[int]): orig. frame indices
|
||||
(one per frame, before tubelet subsampling)
|
||||
frame_duration_ms (int): duration of each frame in milliseconds
|
||||
tokenizer (HfTokenizer): tokenizer to use for tokenizing frame separators
|
||||
tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
|
||||
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
|
||||
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
|
||||
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
|
||||
video_temporal_patch_size (int): temporal patch size for videos
|
||||
"""
|
||||
# TODO: Add support of frame_duration_ms to be None
|
||||
# At preprocessing step we should allow absent / metadata without
|
||||
# frames_indices field.
|
||||
timestamps_enabled = frame_duration_ms is not None
|
||||
T = video_temporal_patch_size
|
||||
num_frames = len(frames_indices)
|
||||
|
||||
if timestamps_enabled:
|
||||
if T > 1 and timestamps_enabled:
|
||||
all_timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
|
||||
|
||||
frame_separators = []
|
||||
for group_idx, i in enumerate(range(0, num_frames, T)):
|
||||
group_frames = []
|
||||
for j in range(T): # Every frame in the group
|
||||
frame_idx = i + j
|
||||
if frame_idx < num_frames:
|
||||
# Valid idx (haven't padded to mult. of T yet)
|
||||
ts = all_timestamps[frame_idx]
|
||||
frame_str = "Frame" if j == 0 else "frame"
|
||||
group_frames.append(
|
||||
f"{frame_str} {frame_idx + 1} sampled at {ts:.2f} seconds"
|
||||
)
|
||||
if group_frames:
|
||||
# Join by `and` if there are >1 frame, otherwise no `and`
|
||||
# Prepend \n to match training format (except first group)
|
||||
sep = " and ".join(group_frames) + ": "
|
||||
if group_idx > 0:
|
||||
sep = "\n" + sep
|
||||
frame_separators.append(sep)
|
||||
elif timestamps_enabled:
|
||||
timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
|
||||
|
||||
assert len(timestamps) == len(tokens_per_frame), (
|
||||
"timestamps and tokens_per_frame must have the same length"
|
||||
)
|
||||
frame_separators = [
|
||||
f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
|
||||
("\n" if i > 0 else "")
|
||||
+ f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
|
||||
for i, timestamp in enumerate(timestamps)
|
||||
]
|
||||
else:
|
||||
frame_separators = [
|
||||
f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame)
|
||||
("\n" if i > 0 else "") + f"Frame {i + 1}: "
|
||||
for i, _ in enumerate(tokens_per_frame)
|
||||
]
|
||||
|
||||
# Tokenize frame separator independently
|
||||
|
||||
@@ -420,8 +420,9 @@ class GPUModelRunner(
|
||||
self.is_multimodal_raw_input_only_model = (
|
||||
model_config.is_multimodal_raw_input_only_model
|
||||
)
|
||||
# This will be overridden in load_model()
|
||||
# These will be overridden in load_model()
|
||||
self.is_multimodal_pruning_enabled = False
|
||||
self.requires_sequential_video_encoding = False
|
||||
# Set to True after init_routed_experts_capturer() completes.
|
||||
# Prevents routed experts code from running during profiling/dummy run.
|
||||
self.routed_experts_initialized = False
|
||||
@@ -2625,17 +2626,23 @@ class GPUModelRunner(
|
||||
):
|
||||
batch_outputs: MultiModalEmbeddings
|
||||
|
||||
# EVS-related change.
|
||||
# EVS and dynamic res video related change.
|
||||
# (ekhvedchenia): Temporary hack to limit peak memory usage when
|
||||
# processing multimodal data. This solves the issue with scheduler
|
||||
# putting too many video samples into a single batch. Scheduler
|
||||
# uses pruned vision tokens count to compare it versus compute
|
||||
# budget which is incorrect (Either input media size or non-pruned
|
||||
# output vision tokens count should be considered)
|
||||
# dynamic res video for nemotron temporarily uses this hack via
|
||||
# requires_sequential_video_encoding
|
||||
# because it doesn't yet support video batching.
|
||||
# TODO(ywang96): Fix memory profiling to take EVS into account and
|
||||
# remove this hack.
|
||||
if (
|
||||
self.is_multimodal_pruning_enabled
|
||||
(
|
||||
self.is_multimodal_pruning_enabled
|
||||
or self.requires_sequential_video_encoding
|
||||
)
|
||||
and modality == "video"
|
||||
and num_items > 1
|
||||
):
|
||||
@@ -4609,6 +4616,9 @@ class GPUModelRunner(
|
||||
and mm_config is not None
|
||||
and mm_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
self.requires_sequential_video_encoding = hasattr(
|
||||
self.get_model(), "requires_sequential_video_encoding"
|
||||
) # Temporary hack for dynamic res video w/o support for bs>1 yet
|
||||
|
||||
if (
|
||||
is_mixture_of_experts(self.model)
|
||||
|
||||
Reference in New Issue
Block a user