Support temporal compression for Nemotron-3-VL videos (#36808)

Signed-off-by: Collin McCarthy <cmccarthy@nvidia.com>
This commit is contained in:
Collin McCarthy
2026-03-19 01:02:19 -07:00
committed by GitHub
parent d3cc379567
commit 0b6d52629f
5 changed files with 552 additions and 129 deletions

View File

@@ -8,6 +8,7 @@
# --------------------------------------------------------
import copy
import math
import warnings
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
@@ -77,6 +78,7 @@ from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.processors.internvl import get_internvl_target_ratios
from vllm.transformers_utils.processors.nano_nemotron_vl import (
AUDIO_CONTEXT,
IMG_CONTEXT,
@@ -85,7 +87,7 @@ from vllm.transformers_utils.processors.nano_nemotron_vl import (
BaseNanoNemotronVLProcessor,
DynamicResolutionImageTiler,
NanoNemotronVLProcessor,
get_internvl_target_ratios,
get_video_target_size_and_feature_size,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -295,10 +297,13 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
max_videos = mm_counts.get("video", 0)
processor = self.get_hf_processor() # we get the CustomProcessor here
T = processor.video_temporal_patch_size
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
max_frames_per_video = max_total_frames // max(max_videos, 1)
tokens_per_tubelet = processor.num_video_token
max_total_tubelets = (seq_len - max_image_tokens) // tokens_per_tubelet
max_tubelets_per_video = max_total_tubelets // max(max_videos, 1)
max_frames_per_video = max_tubelets_per_video * T
return max(max_frames_per_video, 1)
def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
@@ -589,28 +594,49 @@ class NanoNemotronVLMultiModalProcessor(
video_num_patches = []
def get_video_replacement_internvl(item_idx: int):
feature_size = hf_processor.num_image_token
video, metadata = mm_items["video"][item_idx]
patch_size = hf_processor.config.patch_size
downsample_ratio = hf_processor.config.downsample_ratio
target_patches = hf_processor.video_target_num_patches
if target_patches is not None and video is not None and video.shape[0] > 0:
orig_h, orig_w = video.shape[1], video.shape[2]
_, _, feature_size = get_video_target_size_and_feature_size(
orig_w=orig_w,
orig_h=orig_h,
target_patches=target_patches,
maintain_aspect_ratio=hf_processor.video_maintain_aspect_ratio,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
else:
feature_size = hf_processor.num_image_token
num_patches = video_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
T = hf_processor.video_temporal_patch_size
if T > 1 and num_patches is not None:
num_tubelets = math.ceil(num_patches / T)
else:
num_tubelets = num_patches
video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate
if video_pruning_rate is not None and video_pruning_rate > 0.0:
# Start of EVS-specific code
num_tokens = compute_retained_tokens_count(
tokens_per_frame=feature_size,
num_frames=num_patches,
num_frames=num_tubelets,
q=video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame
tokens_per_frame = [num_tokens] + [0] * (num_patches - 1)
tokens_per_frame = [num_tokens] + [0] * (num_tubelets - 1)
# End of EVS-specific code
else:
tokens_per_frame = [feature_size] * num_patches
tokens_per_frame = [feature_size] * num_tubelets
frame_duration_ms = int(1000 / metadata["fps"])
return hf_processor.get_video_repl(
@@ -621,6 +647,7 @@ class NanoNemotronVLMultiModalProcessor(
img_start_token_ids=hf_processor._img_start_token_ids,
img_end_token_ids=hf_processor._img_end_token_ids,
img_context_token_ids=hf_processor._img_context_token_ids,
video_temporal_patch_size=T,
)
if self.info.supports_video:
@@ -745,15 +772,39 @@ class NanoNemotronVLDummyInputsBuilder(
if self.info.supports_video:
config = self.info.get_hf_config()
image_size: int = config.force_image_size
processor = self.info.get_hf_processor()
# When video_target_num_patches is set the per-frame pixel
# resolution can exceed image_size. Use the actual target
# dimensions so that profiling sees the correct upper bound.
if processor.video_target_num_patches is not None:
target_w, target_h, _ = get_video_target_size_and_feature_size(
orig_w=image_size,
orig_h=image_size,
target_patches=processor.video_target_num_patches,
maintain_aspect_ratio=processor.video_maintain_aspect_ratio,
patch_size=config.patch_size,
downsample_ratio=config.downsample_ratio,
)
video_width, video_height = target_w, target_h
else:
video_width, video_height = image_size, image_size
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
mm_config = self.info.ctx.get_mm_config()
if num_frames := mm_config.media_io_kwargs.get("video", {}).get(
"num_frames"
):
assert num_frames > 0
target_num_frames = num_frames
num_videos = mm_counts.get("video", 0)
video_overrides = mm_options.get("video")
dummy_video = {
"video": self._get_dummy_videos(
width=image_size,
height=image_size,
width=video_width,
height=video_height,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides,
@@ -790,6 +841,9 @@ class NanoNemotronVLDummyInputsBuilder(
class NemotronH_Nano_VL_V2(
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
):
requires_sequential_video_encoding = True
"""Temporarily needed for dynamic res video w/ conv3d, doesn't support bs>1 yet"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@@ -817,6 +871,11 @@ class NemotronH_Nano_VL_V2(
self.image_tag_type = config.image_tag_type
self.video_pruning_rate = multimodal_config.video_pruning_rate
vision_config = getattr(config, "vision_config", config)
self.video_temporal_patch_size: int = getattr(
vision_config, "video_temporal_patch_size", 1
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@@ -838,11 +897,12 @@ class NemotronH_Nano_VL_V2(
mlp1 = nn.Sequential(
RMSNorm(
hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
hidden_size=vit_hidden_size
* int(round(1 / self.downsample_ratio)) ** 2,
eps=1e-5,
),
nn.Linear(
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
vit_hidden_size * int(round(1 / self.downsample_ratio)) ** 2,
vision_projection_hidden_size,
bias=False,
),
@@ -958,19 +1018,37 @@ class NemotronH_Nano_VL_V2(
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def extract_feature(self, pixel_values: torch.Tensor):
def extract_feature(
self,
pixel_values: torch.Tensor,
num_frames: int | None = None,
) -> torch.Tensor:
# Process images in a micro-batch of at most 128 frames per call
# This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems
# as we don't support chunked prefill for video media
micro_batch_size = 128
n = pixel_values.shape[0]
# This is done on purpose to ensure peak GPU ram usage of huge batch
# (namely for really long videos with EVS ON) won't cause any problems
# as we don't support chunked prefill for video media
# When num_frames is provided and temporal_patch_size > 1, consecutive
# frames are grouped into tubelets — the batch size must be a multiple
# of T so chunk boundaries don't split a tubelet.
N, _C, H, W = pixel_values.shape
T = self.video_temporal_patch_size if num_frames is not None else 1
micro_batch_size = 128 - (128 % T)
patch_size = self.patch_size
H_patches = H // patch_size
W_patches = W // patch_size
vit_embeds_list = []
for i in range(0, n, micro_batch_size):
_, vit_embeds = self.vision_model(pixel_values[i : i + micro_batch_size])
for i in range(0, N, micro_batch_size):
chunk = pixel_values[i : i + micro_batch_size]
if num_frames is not None and T > 1:
_, vit_embeds = self.vision_model(chunk, num_frames=chunk.shape[0])
else:
_, vit_embeds = self.vision_model(chunk)
vit_embeds = vit_embeds.to(dtype=torch.bfloat16)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = vit_embeds.reshape(
vit_embeds.shape[0], H_patches, W_patches, -1
)
vit_embeds = self.pixel_shuffle(
vit_embeds, scale_factor=self.downsample_ratio
)
@@ -1042,16 +1120,21 @@ class NemotronH_Nano_VL_V2(
) -> tuple[torch.Tensor, ...]:
"""Process video input and create final embeddings with video content
and indicator tokens."""
# Get video embeddings using the same processing as images
video_embeddings = self._process_image_input(video_input)
T = self.video_temporal_patch_size
if T > 1:
video_embeddings = self._extract_video_embeddings_temporal(video_input)
else:
video_embeddings = self._process_image_input(video_input)
final_video_embeddings: tuple[torch.Tensor, ...] = ()
image_rows = image_cols = self.config.force_image_size
downsample_ratio = self.config.downsample_ratio
patch_size = self.config.patch_size
rows = int(image_rows * downsample_ratio // patch_size)
cols = int(image_cols * downsample_ratio // patch_size)
pixel_values = video_input["pixel_values_flat"]
frame_h, frame_w = pixel_values.shape[-2], pixel_values.shape[-1]
rows = int(frame_h * downsample_ratio // patch_size)
cols = int(frame_w * downsample_ratio // patch_size)
video_pruning_rate = self.video_pruning_rate
video_num_frames = video_input["num_patches"].tolist()
video_frames_indices = video_input["frames_indices"].split(video_num_frames)
@@ -1062,13 +1145,14 @@ class NemotronH_Nano_VL_V2(
num_frames = video_num_frames[i]
frames_indices = video_frames_indices[i].tolist()
frame_duration_ms = video_input["frame_duration_ms"][i].item()
assert single_video_embeddings.shape[0] % num_frames == 0
num_tubelets = math.ceil(num_frames / T) if T > 1 else num_frames
assert single_video_embeddings.shape[0] % num_tubelets == 0
if video_pruning_rate is not None and video_pruning_rate > 0.0:
# Start of EVS-specific code
retention_mask = compute_retention_mask(
single_video_embeddings,
video_size_thw=(num_frames, rows, cols),
video_size_thw=(num_tubelets, rows, cols),
spatial_merge_size=1,
q=video_pruning_rate,
)
@@ -1077,14 +1161,14 @@ class NemotronH_Nano_VL_V2(
single_video_embeddings = single_video_embeddings[retention_mask]
# calculate the actual number of retained tokens per frame
retention_mask_thw = retention_mask.reshape(num_frames, rows, cols)
retention_mask_thw = retention_mask.reshape(num_tubelets, rows, cols)
num_tokens_per_frame = (
retention_mask_thw.sum(dim=(1, 2)).long().tolist()
)
# End of EVS-specific code
else:
feature_size = single_video_embeddings.shape[0] // num_frames
num_tokens_per_frame = [feature_size] * num_frames
feature_size = single_video_embeddings.shape[0] // num_tubelets
num_tokens_per_frame = [feature_size] * num_tubelets
final_video_embeddings += (
self._create_final_video_embeddings(
@@ -1092,11 +1176,36 @@ class NemotronH_Nano_VL_V2(
num_tokens_per_frame,
frames_indices,
frame_duration_ms,
video_temporal_patch_size=T,
),
)
return final_video_embeddings
def _extract_video_embeddings_temporal(
self, video_input: NanoNemotronVLVideoPixelInputs
) -> tuple[torch.Tensor, ...]:
"""Extract per-video embeddings with temporal compression.
Each video is processed separately through extract_feature with
num_frames, which uses the fixed-resolution temporal path in RADIO
(no attention mask, flash attention).
"""
pixel_values = video_input["pixel_values_flat"]
num_frames_per_video = video_input["num_patches"].tolist()
hidden_size = self.config.text_config.hidden_size
results: list[torch.Tensor] = []
frame_offset = 0
for nf in num_frames_per_video:
video_frames = pixel_values[frame_offset : frame_offset + nf]
frame_offset += nf
vit_embeds = self.extract_feature(video_frames, num_frames=nf)
results.append(vit_embeds.view(-1, hidden_size))
return tuple(results)
def _process_audio_input(
self, audio_input: NanoNemotronVLAudioFeatureInputs
) -> tuple[torch.Tensor, ...]:
@@ -1134,6 +1243,7 @@ class NemotronH_Nano_VL_V2(
num_tokens_per_frame: list[int],
frames_indices: list[int],
frame_duration_ms: int,
video_temporal_patch_size: int = 1,
) -> torch.Tensor:
"""Create final embeddings that combine video embeddings with
text embeddings of indicator tokens.
@@ -1161,6 +1271,7 @@ class NemotronH_Nano_VL_V2(
img_start_token_ids=self._img_start_token_ids,
img_end_token_ids=self._img_end_token_ids,
img_context_token_ids=self._img_context_token_ids,
video_temporal_patch_size=video_temporal_patch_size,
)
# video_repl.full is a list of token IDs
@@ -1207,8 +1318,27 @@ class NemotronH_Nano_VL_V2(
else:
frames_indices = torch.cat([f.flatten() for f in frames_indices], dim=0)
frame_duration_ms = frame_duration_ms.flatten()
expected_h = expected_w = self.config.force_image_size
if torch.is_tensor(frame_duration_ms):
frame_duration_ms = frame_duration_ms.flatten()
else:
frame_duration_ms = torch.cat(
[f.flatten() for f in frame_duration_ms], dim=0
)
if (
torch.is_tensor(pixel_values_flat_video)
and pixel_values_flat_video.ndim == 5
):
# batched._reduce_data stacked same-shape videos into
# [num_videos, nf, 3, H, W]; unstack back to a list so the
# same-H,W cat path below handles it uniformly.
pixel_values_flat_video = list(pixel_values_flat_video)
if not torch.is_tensor(pixel_values_flat_video):
pixel_values_flat_video = torch.cat(pixel_values_flat_video, dim=0)
expected_h = pixel_values_flat_video.shape[-2]
expected_w = pixel_values_flat_video.shape[-1]
num_frames = video_num_patches[0].item()
resolve_bindings = {"h": expected_h, "w": expected_w, "f": num_frames}
@@ -1361,8 +1491,7 @@ class NemotronH_Nano_VL_V2(
self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights)
if self.sound_encoder is not None:
assert len(sound_weights) > 0
if self.sound_encoder is not None and len(sound_weights) > 0:
self.sound_encoder.load_weights(sound_weights)
def get_vit_model_from_radio_config(self, hf_config):
@@ -1375,12 +1504,23 @@ class NemotronH_Nano_VL_V2(
image_size = preferred_resolution[0] if preferred_resolution else 224
patch_size = getattr(hf_config_vision, "patch_size", 16)
# video_temporal_patch_size and separate_video_embedder are
# top-level vision_config attributes, not inside args.
video_temporal_patch_size = getattr(
hf_config_vision, "video_temporal_patch_size", 1
)
separate_video_embedder = getattr(
hf_config_vision, "separate_video_embedder", True
)
radio_config = RadioConfig(
model_name=model_name,
image_size=image_size,
patch_size=patch_size,
norm_mean=hf_config.norm_mean,
norm_std=hf_config.norm_std,
video_temporal_patch_size=video_temporal_patch_size,
separate_video_embedder=separate_video_embedder,
**hf_config_vision.args,
)

View File

@@ -123,6 +123,8 @@ class ViTPatchGenerator(nn.Module):
register_multiple: int | None = None,
num_registers: int | None = None,
patch_bias: bool = False,
temporal_patch_size: int = 1,
separate_video_embedder: bool = True,
device=None,
dtype=None,
):
@@ -148,6 +150,7 @@ class ViTPatchGenerator(nn.Module):
self.patch_size = patch_size
self.abs_pos = abs_pos
self.embed_dim = embed_dim
self.temporal_patch_size = temporal_patch_size
self.num_rows = max_input_dims[0] // patch_size
self.num_cols = max_input_dims[1] // patch_size
@@ -160,6 +163,21 @@ class ViTPatchGenerator(nn.Module):
patch_size, embed_dim, bias=patch_bias, **factory
)
if temporal_patch_size > 1:
if not separate_video_embedder:
raise NotImplementedError(
"Only separate_video_embedder=True is supported for"
" temporal compression (temporal_patch_size > 1)"
)
self.video_embedder = ViTPatchLinear(
patch_size,
embed_dim,
bias=patch_bias,
temporal_patch_size=temporal_patch_size,
**factory,
)
self._video_embedder_loaded = False
if abs_pos:
scale = embed_dim**-0.5
self.pos_embed = nn.Parameter(
@@ -196,6 +214,60 @@ class ViTPatchGenerator(nn.Module):
return patches, pos_enc
return patches
def forward_video(self, x: torch.Tensor) -> torch.Tensor:
"""Process video frames with temporal compression.
Groups T consecutive frames into tubelets before embedding.
Args:
x: [num_frames, 3, H, W] tensor of video frames
Returns:
Embedded patches with temporal compression applied.
"""
if not self._video_embedder_loaded:
raise ValueError(
"Temporal compression (video_temporal_patch_size > 1) requires "
"video_embedder weights, but they were never loaded. "
"Ensure the checkpoint was trained with temporal compression."
)
T = self.temporal_patch_size
input_size = x.shape[2:]
patches = self.im_to_patches(x) # [N, num_patches, 3*P*P]
num_frames, num_spatial, feat_dim = patches.shape
# Pad to a multiple of T by repeating the last frame so that
# all tubelets have exactly T frames.
num_pad_frames = (-num_frames) % T
if num_pad_frames > 0:
last_frame_dup = patches[-1:].expand(num_pad_frames, -1, -1)
patches = torch.cat([patches, last_frame_dup], dim=0)
# Group T frames per tubelet: for each spatial position, concatenate
# features across T consecutive frames; order follows Megatron training
num_frames_padded = patches.shape[0]
num_tublets = num_frames_padded // T
patches = rearrange(
patches,
"(tubelets frames) spatial feat -> tubelets spatial (frames feat)",
tubelets=num_tublets,
frames=T,
spatial=num_spatial,
feat=feat_dim,
)
patches = self.video_embedder(patches)
patches, pos_enc = self.apply_pos_enc(patches, input_size=input_size)
patches = self.cls_token(patches)
patches = self.patch_normalizer(patches)
if self.return_pos_enc:
return patches, pos_enc
return patches
def apply_pos_enc_dynamic(
self, patches: torch.Tensor, imgs_sizes: list[tuple[int, int]]
) -> tuple[torch.Tensor, torch.Tensor | None]:
@@ -381,66 +453,21 @@ class ViTPatchGenerator(nn.Module):
return pos_embed
if self.cpe_mode:
if self.training:
min_scale = math.sqrt(0.1)
scale = (
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (1 - min_scale)
+ min_scale
)
aspect_min = math.log(3 / 4)
aspect_max = -aspect_min
aspect = torch.exp(
torch.rand(batch_size, 1, 1, device=pos_embed.device)
* (aspect_max - aspect_min)
+ aspect_min
)
max_dim = max(input_dims)
pos_embed = F.interpolate(
pos_embed.float(),
size=(max_dim, max_dim),
align_corners=False,
mode="bilinear",
).to(pos_embed.dtype)
scale_x = scale * aspect
scale_y = scale * (1 / aspect)
scale_xy = torch.stack([scale_x, scale_y], dim=-1).clamp_(0, 1)
pos_xy = torch.rand(batch_size, 1, 1, 2, device=pos_embed.device) * (
1 - scale_xy
)
lin_x = torch.linspace(
0, 1, steps=input_dims[1], device=pos_embed.device
)[None, None].expand(batch_size, input_dims[0], -1)
lin_y = torch.linspace(
0, 1, steps=input_dims[0], device=pos_embed.device
)[None, :, None].expand(batch_size, -1, input_dims[1])
lin_xy = torch.stack([lin_x, lin_y], dim=-1)
grid_xy = lin_xy * scale_xy + pos_xy
# Convert to [-1, 1] range
grid_xy.mul_(2).sub_(1)
pos_embed = F.grid_sample(
pos_embed.float().expand(batch_size, -1, -1, -1),
grid=grid_xy,
mode="bilinear",
padding_mode="zeros",
align_corners=True,
).to(pos_embed.dtype)
else:
max_dim = max(input_dims)
pos_embed = F.interpolate(
pos_embed.float(),
size=(max_dim, max_dim),
align_corners=True,
mode="bilinear",
).to(pos_embed.dtype)
pos_embed = window_select(pos_embed)
pos_embed = window_select(pos_embed)
else:
pos_embed = window_select(pos_embed)
if pos_embed.shape[-2:] != input_dims:
pos_embed = F.interpolate(
pos_embed.float(), size=input_dims, align_corners=True, mode="bilinear"
pos_embed.float(), size=input_dims, align_corners=False, mode="bilinear"
).to(pos_embed.dtype)
pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
@@ -473,9 +500,19 @@ class Im2Patches(nn.Module):
class ViTPatchLinear(nn.Linear):
def __init__(self, patch_size: int, embed_dim: int, bias: bool = False, **factory):
super().__init__(3 * (patch_size**2), embed_dim, bias=bias, **factory)
def __init__(
self,
patch_size: int,
embed_dim: int,
bias: bool = False,
temporal_patch_size: int = 1,
**factory,
):
super().__init__(
3 * temporal_patch_size * (patch_size**2), embed_dim, bias=bias, **factory
)
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
@dataclass(frozen=True, kw_only=True)
@@ -560,6 +597,7 @@ class RadioInternVisionModel(nn.Module):
max_img_size = int(
round(config.cpe_max_size / config.patch_size) * config.patch_size
)
self.temporal_patch_size = config.video_temporal_patch_size
unique_teachers = set(t["name"] for t in config.teachers)
self.patch_generator = ViTPatchGenerator(
config.patch_size,
@@ -569,6 +607,8 @@ class RadioInternVisionModel(nn.Module):
cls_token=True,
num_cls_tokens=len(unique_teachers) if config.cls_token_per_teacher else 1,
register_multiple=config.register_multiple,
temporal_patch_size=self.temporal_patch_size,
separate_video_embedder=config.separate_video_embedder,
)
self.encoder = RadioVisionEncoder(
@@ -593,33 +633,68 @@ class RadioInternVisionModel(nn.Module):
def inter_image_mask_metadata(
self, imgs_sizes: list[tuple[int, int]], device: torch.device
) -> MaskMetadata:
"""Build mask metadata from image pixel sizes. Adds num_skip to each
sequence length (cls/register tokens) to match patch generator output."""
patch_size = self.patch_generator.patch_size
num_skip = self.patch_generator.num_skip
seq_lens = calc_seq_lens(imgs_sizes, patch_size)
adjusted = [s + num_skip for s in seq_lens]
return self._inter_image_mask_metadata_from_seq_lens(adjusted, device=device)
def _inter_image_mask_metadata_from_seq_lens(
self, seq_lens: list[int], device: torch.device
) -> MaskMetadata:
"""Build mask metadata from actual sequence lengths (already including
cls/register tokens, i.e. patch_count + num_skip per item).
Use inter_image_mask_metadata() when you only have imgs_sizes."""
assert len(seq_lens) > 0
cu_seqlens = torch.tensor(
list(accumulate(adjusted, initial=0)), dtype=torch.int32, device=device
list(accumulate(seq_lens, initial=0)), dtype=torch.int32, device=device
)
# Keep max_seqlen on CPU to avoid .item() sync
# See: https://github.com/vllm-project/vllm/blob/20b6b01/vllm/v1/attention/ops/vit_attn_wrappers.py#L48
max_seqlen = torch.tensor(max(adjusted), dtype=torch.int32)
max_seqlen = torch.tensor(max(seq_lens), dtype=torch.int32)
return MaskMetadata(cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
def forward(
self,
x: torch.Tensor,
imgs_sizes: list[tuple[int, int]] | None = None,
num_frames: int | None = None,
) -> torch.FloatTensor:
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
T = self.temporal_patch_size
# Build packed-sequence metadata for MMEncoderAttention when needed.
mask_meta = None
if imgs_sizes is not None:
assert len(imgs_sizes) > 0
# Dynamic resolution: process each image as an independent sequence.
mask_meta = self.inter_image_mask_metadata(
imgs_sizes, device=hidden_states.device
packed_batch_size = None # Original batch size before packing
if num_frames is not None and T > 1:
# Conv3d video: all tubelets have the same sequence length.
# Pack [num_tubelets, seq_per_tubelet, hidden] → [1, total, hidden]
hidden_states = self.patch_generator.forward_video(x)
packed_batch_size, seq_per_tubelet, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(1, -1, hidden_dim)
mask_meta = self._inter_image_mask_metadata_from_seq_lens(
[seq_per_tubelet] * packed_batch_size, device=hidden_states.device
)
else:
# Images for any model, or video for non-conv3d model
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
if imgs_sizes is not None and len(imgs_sizes) > 1:
# Dynamic resolution w/ > 1 image, create attn mask
mask_meta = self.inter_image_mask_metadata(
imgs_sizes, device=hidden_states.device
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, mask_meta=mask_meta)
# Unpack back to original batch shape if we packed for video
if packed_batch_size is not None:
encoder_outputs = encoder_outputs.reshape(
packed_batch_size, seq_per_tubelet, -1
)
return encoder_outputs
@@ -663,8 +738,13 @@ class RadioModel(nn.Module):
pixel_embeds: torch.Tensor | None = None,
*,
imgs_sizes: list[tuple[int, int]] | None = None,
num_frames: int | None = None,
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
y = self.model(pixel_values, imgs_sizes=imgs_sizes)
y = self.model(
pixel_values,
imgs_sizes=imgs_sizes,
num_frames=num_frames,
)
return self._extract_final(y, imgs_sizes=imgs_sizes)
def load_weights(self, weights) -> set[str]:
@@ -714,6 +794,9 @@ class RadioModel(nn.Module):
weight_loader(param, weight)
loaded_params.add(vllm_key)
if "model.patch_generator.video_embedder.weight" in loaded_params:
self.model.patch_generator._video_embedder_loaded = True
return loaded_params
def _extract_final(

View File

@@ -47,6 +47,14 @@ class RadioConfig(PretrainedConfig):
teachers: A list of teacher model configurations. Each teacher configuration is
a dict with keys like "name" and some may have "use_summary".
cls_token_per_teacher: Whether to use a separate CLS token for each teacher.
video_temporal_patch_size: Number of consecutive video frames grouped into
a single tubelet for temporal compression. Default 1 (no compression).
When > 1, a dedicated video_embedder (3*T*P*P -> hidden) is created
alongside the image embedder (3*P*P -> hidden).
separate_video_embedder: When True and video_temporal_patch_size > 1, use a
dedicated video patch embedder (3*T*P*P -> hidden) separate from the
image embedder (3*P*P -> hidden). When False, a single embedder with
input size 3*T*P*P is used for both (images are duplicated T times).
"""
model_type = "radio"
@@ -68,6 +76,8 @@ class RadioConfig(PretrainedConfig):
register_multiple: int | None = None,
teachers: list[dict[str, Any]] | None = None,
cls_token_per_teacher: bool = False,
video_temporal_patch_size: int = 1,
separate_video_embedder: bool = True,
**kwargs,
):
self.model_name = model_name
@@ -95,4 +105,6 @@ class RadioConfig(PretrainedConfig):
self.register_multiple = register_multiple
self.teachers = teachers if teachers is not None else []
self.cls_token_per_teacher = cls_token_per_teacher
self.video_temporal_patch_size = video_temporal_patch_size
self.separate_video_embedder = separate_video_embedder
super().__init__(**kwargs)

View File

@@ -11,6 +11,7 @@ import math
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any, TypeVar
import einops
@@ -43,6 +44,12 @@ AUDIO_CONTEXT = "<so_embedding>"
# MAX_FRAMES = 16
DEFAULT_NUM_TILES = 12
# Configure PIL to handle large images without warnings
# This prevents DecompressionBombWarning for legitimate large images
Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
def calculate_timestamps(
indices: list[int] | torch.Tensor,
@@ -138,19 +145,110 @@ def image_to_pixel_values(
return pixel_values
def _compute_aspect_preserving_size(
orig_w: int,
orig_h: int,
target_num_patches: int,
patch_size: int,
downsample_ratio: float,
) -> tuple[int, int]:
"""Compute target pixel dimensions that preserve aspect ratio.
Mirrors Megatron-LM image_processing.py video frame resizing:
target area in patch-grid space is *target_num_patches*, distributed
according to the source aspect ratio, then snapped to a multiple of
the required divisor (2 for pixel-shuffle).
"""
aspect_wh = orig_w / max(orig_h, 1)
ph = round(math.sqrt(target_num_patches / aspect_wh))
pw = round(math.sqrt(target_num_patches * aspect_wh))
ph = max(ph, 1)
pw = max(pw, 1)
reduction_factor = int(round(1 / downsample_ratio))
required_divisor = reduction_factor # 2 for pixel-shuffle
if required_divisor > 1:
rem_h = ph % required_divisor
rem_w = pw % required_divisor
ph_up = ph + (required_divisor - rem_h if rem_h else 0)
ph_down = ph - rem_h
pw_up = pw + (required_divisor - rem_w if rem_w else 0)
pw_down = pw - rem_w
if ph_up * pw_up <= target_num_patches:
ph, pw = ph_up, pw_up
else:
ph = max(required_divisor, ph_down)
pw = max(required_divisor, pw_down)
return pw * patch_size, ph * patch_size # (width, height) in pixels
def get_video_target_size_and_feature_size(
orig_w: int,
orig_h: int,
target_patches: int,
maintain_aspect_ratio: bool,
patch_size: int,
downsample_ratio: float,
) -> tuple[int, int, int]:
"""Compute target (width, height) and feature_size for video resize and token count.
Used by video_to_pixel_values (resize) and get_video_replacement_internvl
(seq length calc) so both use the same dimensions.
"""
if maintain_aspect_ratio:
target_w, target_h = _compute_aspect_preserving_size(
orig_w=orig_w,
orig_h=orig_h,
target_num_patches=target_patches,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
else:
reduction_factor = int(round(1 / downsample_ratio))
side = int(math.sqrt(target_patches))
side = max(reduction_factor, (side // reduction_factor) * reduction_factor)
target_w = side * patch_size
target_h = side * patch_size
feature_size = int((target_h // patch_size) * downsample_ratio) * int(
(target_w // patch_size) * downsample_ratio
)
return target_w, target_h, feature_size
def video_to_pixel_values(
video: npt.NDArray,
*,
input_size: int,
max_num_tiles: int = 1,
use_thumbnail: bool,
video_target_num_patches: int | None = None,
video_maintain_aspect_ratio: bool = False,
patch_size: int = 16,
downsample_ratio: float = 0.5,
) -> torch.Tensor:
assert max_num_tiles == 1, "Video modality always uses one tile"
# (num_frames, H, W, C) -> (num_frames, C, H, W)
video_tensor = torch.from_numpy(video).permute(0, 3, 1, 2)
if video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
if video_target_num_patches is not None:
# Resize to target patch count (aspect-preserving or square).
orig_h, orig_w = video_tensor.shape[2], video_tensor.shape[3]
target_w, target_h, _ = get_video_target_size_and_feature_size(
orig_w=orig_w,
orig_h=orig_h,
target_patches=video_target_num_patches,
maintain_aspect_ratio=video_maintain_aspect_ratio,
patch_size=patch_size,
downsample_ratio=downsample_ratio,
)
if video_tensor.shape[2] != target_h or video_tensor.shape[3] != target_w:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(target_h, target_w),
mode="bicubic",
align_corners=False,
antialias=True,
)
elif video_tensor.shape[2] != input_size or video_tensor.shape[3] != input_size:
video_tensor = torch.nn.functional.interpolate(
video_tensor,
size=(input_size, input_size),
@@ -645,9 +743,9 @@ class BaseNanoNemotronVLProcessor(ABC):
"which should be a single string"
)
parts = [x for x in re.split(r"(<image>)", text[0]) if x]
assert parts.count("<image>") == len(pixel_values_lst), (
"the number of <image> tokens in the text should be the "
"same as the number of images"
assert parts.count("<image>") == len(num_tokens_per_image), (
f"Expected {len(num_tokens_per_image)} <image> tokens in text "
f"but found {parts.count('<image>')}"
)
for i, (feature_size, num_patches) in enumerate(
@@ -706,6 +804,33 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
# Video params live exclusively in vision_config
vision_config = getattr(config, "vision_config", config)
self.video_temporal_patch_size: int = getattr(
vision_config, "video_temporal_patch_size", 1
)
self.video_maintain_aspect_ratio: bool = getattr(
vision_config, "video_maintain_aspect_ratio", False
)
# Resolve video frame target size: exactly one of video_target_num_patches
# or video_target_img_size may be set (mirrors Megatron's
# DynamicResolutionImageTilingStrategy validation).
target_num_patches = getattr(vision_config, "video_target_num_patches", None)
target_img_size = getattr(vision_config, "video_target_img_size", None)
if target_num_patches is not None and target_img_size is not None:
raise ValueError(
"Exactly one of video_target_num_patches or "
"video_target_img_size must be set, got both"
)
if target_num_patches is not None:
self.video_target_num_patches: int | None = target_num_patches
elif target_img_size is not None:
base_patches = math.ceil(target_img_size / config.patch_size)
self.video_target_num_patches = base_patches * base_patches
else:
self.video_target_num_patches = None
self.audio_extractor: ParakeetExtractor | None = None
raw_sound_config = getattr(config, "sound_config", None)
if raw_sound_config is not None:
@@ -721,6 +846,27 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
IMG_CONTEXT, add_special_tokens=False
)
@cached_property
def num_video_token(self) -> int:
"""Token count per video frame, accounting for video_target_num_patches.
When video_target_num_patches is set the per-frame feature count
differs from the image-based num_image_token. We use a square
dummy (1:1) to compute the feature_size because the dummy video is
square and the user confirmed that is acceptable.
"""
if self.video_target_num_patches is not None:
_, _, feature_size = get_video_target_size_and_feature_size(
orig_w=self.image_size,
orig_h=self.image_size,
target_patches=self.video_target_num_patches,
maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio,
)
return feature_size
return self.num_image_token
@property
def supports_video(self) -> bool:
return self.video_token_id is not None
@@ -738,14 +884,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def _videos_to_pixel_values_lst(
self,
videos: list[npt.NDArray],
max_num_tiles: int,
) -> list[torch.Tensor]:
return [
video_to_pixel_values(
video,
input_size=self.image_size,
max_num_tiles=max_num_tiles,
use_thumbnail=self.use_thumbnail,
video_target_num_patches=self.video_target_num_patches,
video_maintain_aspect_ratio=self.video_maintain_aspect_ratio,
patch_size=self.config.patch_size,
downsample_ratio=self.config.downsample_ratio,
)
for video in videos
]
@@ -754,7 +901,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self,
text: list[str],
videos: list[tuple[npt.NDArray, dict[str, Any]]],
max_num_tiles: int,
) -> tuple[list[str], dict[str, Any]]:
if len(videos) == 0 or not self.supports_video:
return text, {}
@@ -763,7 +909,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
video_metadata_lst = [v[1] for v in videos]
pixel_values_lst_video = self._videos_to_pixel_values_lst(
videos_lst,
max_num_tiles=max_num_tiles,
)
# We use frame duration in milliseconds (as integer) to ensure
@@ -788,12 +933,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
"frame_duration_ms": torch.tensor(frame_duration_ms_lst),
}
image_size: int = self.config.force_image_size
patch_size: int = self.config.patch_size
downsample_ratio = self.config.downsample_ratio
tokens_in_single_frame = int(
(image_size * image_size // patch_size**2) * (downsample_ratio**2)
)
T = self.video_temporal_patch_size
for pixel_values, video_metadata, frames_indices, frame_duration_ms in zip(
pixel_values_lst_video,
@@ -802,23 +945,28 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
frame_duration_ms_lst,
):
num_frames = pixel_values.shape[0]
frame_h, frame_w = pixel_values.shape[-2], pixel_values.shape[-1]
tokens_in_single_frame = int(
(frame_h * frame_w // patch_size**2) * (downsample_ratio**2)
)
num_tubelets = math.ceil(num_frames / T) if T > 1 else num_frames
if self.video_pruning_rate is not None and self.video_pruning_rate > 0.0:
# Start of EVS-specific code
num_tokens = compute_retained_tokens_count(
tokens_per_frame=tokens_in_single_frame,
num_frames=num_frames,
num_frames=num_tubelets,
q=self.video_pruning_rate,
)
# Here we just need placeholders that won't actually be replaced -
# we just need to make sure the total number of tokens is correct
# assign all tokens to the first frame
tokens_per_frame = [num_tokens] + [0] * (num_frames - 1)
tokens_per_frame = [num_tokens] + [0] * (num_tubelets - 1)
# End of EVS-specific code
else:
tokens_per_frame = [tokens_in_single_frame] * num_frames
tokens_per_frame = [tokens_in_single_frame] * num_tubelets
video_repl = self.get_video_repl(
tokens_per_frame=tokens_per_frame,
@@ -828,6 +976,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
img_start_token_ids=self._img_start_token_ids,
img_end_token_ids=self._img_end_token_ids,
img_context_token_ids=self._img_context_token_ids,
video_temporal_patch_size=T,
)
# video_repl.full is a list of token IDs
@@ -908,7 +1057,6 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text, video_inputs = self._preprocess_video(
text=text,
videos=videos,
max_num_tiles=1,
)
text, audio_inputs = self._preprocess_audio(
@@ -962,6 +1110,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
img_start_token_ids: list[int],
img_end_token_ids: list[int],
img_context_token_ids: list[int],
video_temporal_patch_size: int = 1,
) -> PromptUpdateDetails[list[int]]:
"""
Build prompt replacement for a video.
@@ -981,31 +1130,60 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
- EVS real (called from get_real_video_repl_for_evs) - different value per frame
Args:
tokens_per_frame (list[int]): number of tokens per frame
frames_indices (list[int]): frame indices
(one per tubelet when T > 1)
frames_indices (list[int]): orig. frame indices
(one per frame, before tubelet subsampling)
frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (HfTokenizer): tokenizer to use for tokenizing frame separators
tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
video_temporal_patch_size (int): temporal patch size for videos
"""
# TODO: Add support of frame_duration_ms to be None
# At preprocessing step we should allow absent / metadata without
# frames_indices field.
timestamps_enabled = frame_duration_ms is not None
T = video_temporal_patch_size
num_frames = len(frames_indices)
if timestamps_enabled:
if T > 1 and timestamps_enabled:
all_timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
frame_separators = []
for group_idx, i in enumerate(range(0, num_frames, T)):
group_frames = []
for j in range(T): # Every frame in the group
frame_idx = i + j
if frame_idx < num_frames:
# Valid idx (haven't padded to mult. of T yet)
ts = all_timestamps[frame_idx]
frame_str = "Frame" if j == 0 else "frame"
group_frames.append(
f"{frame_str} {frame_idx + 1} sampled at {ts:.2f} seconds"
)
if group_frames:
# Join by `and` if there are >1 frame, otherwise no `and`
# Prepend \n to match training format (except first group)
sep = " and ".join(group_frames) + ": "
if group_idx > 0:
sep = "\n" + sep
frame_separators.append(sep)
elif timestamps_enabled:
timestamps = calculate_timestamps(frames_indices, frame_duration_ms)
assert len(timestamps) == len(tokens_per_frame), (
"timestamps and tokens_per_frame must have the same length"
)
frame_separators = [
f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
("\n" if i > 0 else "")
+ f"Frame {i + 1} sampled at {timestamp:.2f} seconds: "
for i, timestamp in enumerate(timestamps)
]
else:
frame_separators = [
f"Frame {i + 1}: " for i, _ in enumerate(tokens_per_frame)
("\n" if i > 0 else "") + f"Frame {i + 1}: "
for i, _ in enumerate(tokens_per_frame)
]
# Tokenize frame separator independently

View File

@@ -420,8 +420,9 @@ class GPUModelRunner(
self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model
)
# This will be overridden in load_model()
# These will be overridden in load_model()
self.is_multimodal_pruning_enabled = False
self.requires_sequential_video_encoding = False
# Set to True after init_routed_experts_capturer() completes.
# Prevents routed experts code from running during profiling/dummy run.
self.routed_experts_initialized = False
@@ -2625,17 +2626,23 @@ class GPUModelRunner(
):
batch_outputs: MultiModalEmbeddings
# EVS-related change.
# EVS and dynamic res video related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
# processing multimodal data. This solves the issue with scheduler
# putting too many video samples into a single batch. Scheduler
# uses pruned vision tokens count to compare it versus compute
# budget which is incorrect (Either input media size or non-pruned
# output vision tokens count should be considered)
# dynamic res video for nemotron temporarily uses this hack via
# requires_sequential_video_encoding
# because it doesn't yet support video batching.
# TODO(ywang96): Fix memory profiling to take EVS into account and
# remove this hack.
if (
self.is_multimodal_pruning_enabled
(
self.is_multimodal_pruning_enabled
or self.requires_sequential_video_encoding
)
and modality == "video"
and num_items > 1
):
@@ -4609,6 +4616,9 @@ class GPUModelRunner(
and mm_config is not None
and mm_config.is_multimodal_pruning_enabled()
)
self.requires_sequential_video_encoding = hasattr(
self.get_model(), "requires_sequential_video_encoding"
) # Temporary hack for dynamic res video w/o support for bs>1 yet
if (
is_mixture_of_experts(self.model)