Fix Nano Nemotron VL regressions (#38655)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Netanel Haber
2026-04-03 10:22:06 +03:00
committed by GitHub
parent 5506435419
commit fa9e68022d
7 changed files with 84 additions and 52 deletions

View File

@@ -7,7 +7,6 @@
# LICENSE is in root directory.
# --------------------------------------------------------
import copy
import math
import warnings
from collections.abc import Iterable, Mapping, Sequence
@@ -17,7 +16,7 @@ from typing import Annotated, Literal, TypeAlias
import torch
import torch.nn as nn
from transformers import BatchFeature
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
@@ -210,11 +209,15 @@ class NanoNemotronVLProcessingInfo(BaseProcessingInfo):
@cached_property
def is_dynamic_tiler(self) -> bool:
return self.get_hf_processor().dynamic_tiler is not None
return BaseNanoNemotronVLProcessor.use_dynamic_resolution(self.get_hf_config())
@cached_property
@property
def supports_video(self):
return self.get_hf_processor().supports_video
return True
@property
def supports_audio(self) -> bool:
return self.sound_config is not None
def get_video_token(self) -> str | None:
return IMG_CONTEXT
@@ -223,8 +226,8 @@ class NanoNemotronVLProcessingInfo(BaseProcessingInfo):
return self.ctx.get_mm_config().video_pruning_rate
@property
def audio_extractor(self) -> ParakeetExtractor | None:
return self.get_hf_processor().audio_extractor
def sound_config(self) -> PretrainedConfig | None:
return getattr(self.get_hf_config(), "sound_config", None)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
@@ -232,14 +235,14 @@ class NanoNemotronVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
image_limit = {"image": None}
video_limit = {"video": None} if self.supports_video else {}
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
audio_limit = {"audio": None} if self.supports_audio else {}
return {**image_limit, **video_limit, **audio_limit}
def get_data_parser(self):
target_sr = None
target_channels = None
if extractor := self.audio_extractor:
target_sr = extractor.sampling_rate
if self.sound_config:
target_sr = self.sound_config.sampling_rate
target_channels = 1
return MultiModalDataParser(
@@ -371,7 +374,7 @@ class NanoNemotronVLMultiModalProcessor(
fields = self._get_image_fields_config(hf_inputs)
if self.info.supports_video:
fields |= self._get_video_fields_config(hf_inputs)
if self.info.audio_extractor:
if self.info.supports_audio:
fields |= self._get_audio_fields_config(hf_inputs)
return fields
@@ -399,9 +402,8 @@ class NanoNemotronVLMultiModalProcessor(
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
elif tiler := hf_processor.dynamic_tiler:
image = images.get(item_idx)
feature_size = tiler.get_cached_feature_size(image)
elif self.info.is_dynamic_tiler:
feature_size = out_mm_data["num_tokens_per_image"][item_idx]
else:
image_size = images.get_image_size(item_idx)
max_num_tiles = hf_processor.max_num_tiles
@@ -536,7 +538,7 @@ class NanoNemotronVLMultiModalProcessor(
prompt_repls.append(
self._get_prompt_repl_video(mm_items, hf_processor, out_mm_data)
)
if self.info.audio_extractor:
if self.info.supports_audio:
prompt_repls.append(
self._get_prompt_repl_audio(mm_items, hf_processor, out_mm_data)
)
@@ -772,12 +774,14 @@ class NanoNemotronVLDummyInputsBuilder(
else:
dummy_video = {}
if extractor := self.info.audio_extractor:
if sound_config := self.info.sound_config:
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
max_audio_num_samples = MAX_AUDIO_LEN_S * sound_config.sampling_rate
calculated_max_audio_num_samples = ParakeetExtractor.audio_length(
sound_config, tokens_per_audio
)
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
dummy_audio = {
"audio": self._get_dummy_audios(
@@ -1029,9 +1033,13 @@ class NemotronH_Nano_VL_V2(
data=image_embeds,
)
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
if pixel_values_flat is None:
return None
if self.dynamic_resolution:
pixel_values_flat = DynamicResolutionImageTiler.stack(
kwargs.pop("pixel_values_flat"), self.patch_size
pixel_values_flat, self.patch_size
)
return NanoNemotronVLImagePixelInputsDynamic(
pixel_values_flat=pixel_values_flat, **kwargs
@@ -1497,15 +1505,13 @@ class NemotronH_Nano_VL_V2(
@classmethod
def get_mamba_state_shape_from_config(cls, vllm_config: "VllmConfig"):
text_config = vllm_config.model_config.hf_config.text_config
temp_vllm_config = copy.deepcopy(vllm_config)
temp_vllm_config.model_config.hf_config = text_config
temp_vllm_config = vllm_config.with_hf_config(text_config)
return NemotronHForCausalLM.get_mamba_state_shape_from_config(temp_vllm_config)
@classmethod
def get_mamba_state_dtype_from_config(cls, vllm_config: "VllmConfig"):
text_config = vllm_config.model_config.hf_config.text_config
temp_vllm_config = copy.deepcopy(vllm_config)
temp_vllm_config.model_config.hf_config = text_config
temp_vllm_config = vllm_config.with_hf_config(text_config)
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)
@classmethod

View File

@@ -159,5 +159,7 @@ class ParakeetExtractor(ParakeetFeatureExtractor):
outputs["audio_num_clips"] = audio_num_clips
return outputs
def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)
@staticmethod
def audio_length(raw_config: PretrainedConfig, audio_tokens: int) -> int:
config = ExtractorConfig.from_hf_config(raw_config)
return int(audio_tokens * config.subsampling_factor * config.hop_length)

View File

@@ -176,7 +176,6 @@ class ViTPatchGenerator(nn.Module):
temporal_patch_size=temporal_patch_size,
**factory,
)
self._video_embedder_loaded = False
if abs_pos:
scale = embed_dim**-0.5
@@ -225,12 +224,7 @@ class ViTPatchGenerator(nn.Module):
Returns:
Embedded patches with temporal compression applied.
"""
if not self._video_embedder_loaded:
raise ValueError(
"Temporal compression (video_temporal_patch_size > 1) requires "
"video_embedder weights, but they were never loaded. "
"Ensure the checkpoint was trained with temporal compression."
)
assert self.temporal_patch_size > 1
T = self.temporal_patch_size
input_size = x.shape[2:]
@@ -794,9 +788,6 @@ class RadioModel(nn.Module):
weight_loader(param, weight)
loaded_params.add(vllm_key)
if "model.patch_generator.video_embedder.weight" in loaded_params:
self.model.patch_generator._video_embedder_loaded = True
return loaded_params
def _extract_final(