Support parakeet as audio encoder for nemotron-nano-vl (#35100)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Netanel Haber
2026-02-27 20:07:38 +02:00
committed by GitHub
parent b602e4f299
commit c8aca0c9e1
3 changed files with 448 additions and 20 deletions

View File

@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
compute_retention_mask,
)
from vllm.multimodal.inputs import (
AudioItem,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- b: Number of audio clips
- t: Audio feature length
- f: Feature size (mel bins)
"""
type: Literal["audio_features"] = "audio_features"
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<image>"
AUDIO_START = "<so_start>"
AUDIO_END = "<so_end>"
AUDIO_CONTEXT = "<so_embedding>"
# Profiling
# MAX_FRAMES = 16
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
self.audio_extractor: ParakeetExtractor | None = None
raw_sound_config = getattr(config, "sound_config", None)
if raw_sound_config is not None:
self.audio_extractor = ParakeetExtractor(raw_sound_config)
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
self._img_start_token_ids = tokenizer.encode(
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text = [t.replace("<video>", video_repl_text, 1) for t in text]
return text, video_inputs
def _preprocess_audio(
self,
text: list[str],
audios: list[npt.NDArray],
):
if len(audios) == 0:
return text, {}
assert self.audio_extractor is not None
extractor = self.audio_extractor
parts = [x for x in re.split(f"({re.escape(AUDIO_CONTEXT)})", text[0]) if x]
token_count = parts.count(AUDIO_CONTEXT)
if token_count != len(audios):
raise ValueError(
"Number of audio tokens in text does not match the number "
f"of audios (tokens={token_count}, audios={len(audios)})."
)
audio_index = 0
for idx, part in enumerate(parts):
if part == AUDIO_CONTEXT:
audio_repl = self.get_audio_repl(audios[audio_index])
parts[idx] = audio_repl.full
audio_index += 1
text = ["".join(parts)]
audio_inputs = extractor(
audios,
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
input_audio_features = audio_inputs.input_features
feature_attention_mask = audio_inputs.attention_mask
audio_feature_lengths = feature_attention_mask.sum(dim=1)
audio_inputs = {
"input_audio_features": input_audio_features,
"feature_attention_mask": feature_attention_mask,
"audio_feature_lengths": audio_feature_lengths,
}
return text, audio_inputs
def __call__(
self,
text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None,
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
audios: AudioItem | list[AudioItem] | None = None,
return_tensors: str | TensorType | None = None,
max_num_tiles: int | None = None,
) -> BatchFeature:
@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
if max_num_tiles is None:
max_num_tiles = self.max_num_tiles
text, images, videos = [
self._make_batch_input(x) for x in (text, images, videos)
text, images, videos, audios = [
self._make_batch_input(x) for x in (text, images, videos, audios)
]
text, image_inputs = self._preprocess_image(
@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
max_num_tiles=1,
)
text, audio_inputs = self._preprocess_audio(
text=text,
audios=audios,
)
text_inputs = self.tokenizer(text, add_special_tokens=False)
combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
if self.dynamic_tiler is None:
batch = BatchFeature(
{**text_inputs, **video_inputs, **image_inputs},
{**combined_inputs, **image_inputs},
tensor_type=return_tensors,
)
else:
batch = BatchFeature(
{**text_inputs, **video_inputs}, tensor_type=return_tensors
)
batch = BatchFeature(combined_inputs, tensor_type=return_tensors)
# allow images to be exempt from the BatchFeature validation:
# We will .stack() them in _parse_and_validate_image_input
batch.update(image_inputs)
@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def get_audio_repl(
self,
audio: npt.NDArray,
) -> PromptUpdateDetails[str]:
assert self.audio_extractor is not None
num_tokens = self.audio_extractor.audio_token_count(len(audio))
repl_full = f"{AUDIO_START}{AUDIO_CONTEXT * num_tokens}{AUDIO_END}"
return PromptUpdateDetails.select_text(repl_full, AUDIO_CONTEXT)
@classmethod
def get_video_repl(
cls,
@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
def supports_video(self):
return self.get_hf_processor().supports_video
@property
def audio_extractor(self) -> ParakeetExtractor | None:
return self.get_hf_processor().audio_extractor
def get_data_parser(self):
target_sr = None
target_channels = None
if extractor := self.audio_extractor:
target_sr = extractor.sampling_rate
target_channels = 1
return MultiModalDataParser(
video_needs_metadata=True,
target_sr=target_sr,
target_channels=target_channels,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {}
return {**super().get_supported_mm_limits(), **video_limit}
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
def get_video_token(self) -> str | None:
return IMG_CONTEXT
@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor(
else:
video_fields = {}
return image_fields | video_fields
if self.info.audio_extractor is not None:
audio_fields = dict(
input_audio_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
)
else:
audio_fields = {}
return image_fields | video_fields | audio_fields
def _get_prompt_updates(
self,
@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor(
),
]
def get_audio_replacement(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
return hf_processor.get_audio_repl(audios.get(item_idx))
if self.info.audio_extractor is not None:
prompt_repl = [
*prompt_repl,
PromptReplacement(
modality="audio",
target=AUDIO_CONTEXT,
replacement=get_audio_replacement,
),
]
return prompt_repl
@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder(
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_videos = mm_counts.get("video", 0)
num_audios = mm_counts.get("audio", 0)
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
return (
super().get_dummy_text(mm_counts)
+ "<video>" * num_videos
+ AUDIO_CONTEXT * num_audios
)
def _get_dummy_videos(
self,
@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder(
}
else:
dummy_video = {}
return {**dummy_image, **dummy_video}
if extractor := self.info.audio_extractor:
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
dummy_audio = {
"audio": self._get_dummy_audios(
length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
)
}
else:
dummy_audio = {}
return {**dummy_image, **dummy_video, **dummy_audio}
@MULTIMODAL_REGISTRY.register_processor(
@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2(
return "<image>"
if modality.startswith("video"):
return "<video>"
if modality.startswith("audio"):
return AUDIO_CONTEXT
return None
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
model_config = vllm_config.model_config
config = model_config.hf_config
multimodal_config = model_config.multimodal_config
image_size = config.force_image_size
patch_size = config.patch_size
self.patch_size = patch_size
@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2(
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
llm_dtype = self.language_model.config.dtype
assert isinstance(llm_dtype, torch.dtype)
self.llm_dtype = llm_dtype
with self._mark_tower_model(vllm_config, {"image", "video", "audio"}):
self.vision_model = self.get_vit_model_from_radio_config(config).to(
self.language_model.config.dtype
llm_dtype
)
# Construct the vision projection.
@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2(
ReLUSquaredActivation(),
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
)
self.mlp1 = mlp1.to(self.language_model.config.dtype)
self.mlp1 = mlp1.to(llm_dtype)
self.sound_encoder: ProjectedParakeet | None = None
if getattr(config, "sound_config", None) is not None:
logger.info_once(
"Found sound config, initializing sound encoder for Nemotron AVLM",
scope="global",
)
self.sound_encoder = ProjectedParakeet(
config.sound_config,
dtype=llm_dtype,
llm_hidden_size=llm_hidden_size,
max_model_len=model_config.max_model_len,
)
self.config = config
self.model_config = vllm_config.model_config
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
tokenizer = cached_tokenizer_from_config(model_config)
self._img_start_token_ids = tokenizer.encode(
IMG_START, add_special_tokens=False
)
@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2(
config
)
if self.dynamic_resolution:
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
logger.info_once(
"Dynamic resolution is enabled for NanoNemotronVLProcessor",
scope="global",
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2(
return final_video_embeddings
def _process_audio_input(
self, audio_input: NanoNemotronVLAudioFeatureInputs
) -> tuple[torch.Tensor, ...]:
assert self.sound_encoder is not None
input_audio_features = audio_input.input_audio_features
feature_attention_mask = audio_input.feature_attention_mask
target_device = next(self.sound_encoder.parameters()).device
# When cross-request batching combines audio clips with different
# time dimensions, _reduce_data returns a list instead of a stacked
# tensor. Pad to the max time dim and stack; the attention mask
# already marks valid positions so zero-padding is safe.
if isinstance(input_audio_features, list):
feature_sizes = [f.shape[-2] for f in input_audio_features]
max_t = max(feature_sizes)
padded_feats = [
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
for feat, feat_size in zip(
input_audio_features, feature_sizes, strict=True
)
]
padded_masks = [
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
for mask in feature_attention_mask
]
input_audio_features = torch.stack(padded_feats)
feature_attention_mask = torch.stack(padded_masks)
input_audio_features = input_audio_features.to(
dtype=self.llm_dtype, device=target_device
)
feature_attention_mask = feature_attention_mask.to(device=target_device)
sound_embeds = self.sound_encoder(input_audio_features, feature_attention_mask)
valid_input_lens = feature_attention_mask.sum(dim=1)
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
valid_input_lens
)
truncated_embeds = []
for i in range(sound_embeds.shape[0]):
valid_len = valid_output_lens[i].item()
truncated_embeds.append(sound_embeds[i, :valid_len])
return tuple(truncated_embeds)
def _create_final_video_embeddings(
self,
video_embeddings: torch.Tensor,
@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2(
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
if (
input_key
in (
"input_audio_features",
"feature_attention_mask",
"audio_feature_lengths",
)
and "audios" not in modalities
):
modalities["audios"] = NanoNemotronVLAudioFeatureInputs(
**kwargs, validate=False
)
return modalities
@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2(
video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += tuple(video_embeddings)
if modality == "audios":
audio_input = modalities["audios"]
audio_embeddings = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings
@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2(
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="mlp1",
tower_model="vision_model",
connector=["mlp1", "sound_encoder.projection"],
tower_model=["vision_model", "sound_encoder.encoder"],
)
def compute_logits(
@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2(
def is_vision_weights(name: str) -> bool:
return name.startswith("vision_model.radio_model.")
def is_sound_weights(name: str) -> bool:
return name.startswith("sound")
# Separate weights by component
llm_weights = []
vision_weights = []
sound_weights = []
for name, w in weights:
if is_llm(name):
@@ -1987,9 +2215,15 @@ class NemotronH_Nano_VL_V2(
# Convert: vision_model.radio_model.* → radio_model.*
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
vision_weights.append((hf_key, w))
elif is_sound_weights(name):
assert self.sound_encoder is not None
sound_weights.append((name, w))
self.language_model.load_weights(llm_weights)
self.vision_model.load_weights(vision_weights)
if self.sound_encoder is not None:
assert len(sound_weights) > 0
self.sound_encoder.load_weights(sound_weights)
def print_architecture(self, detailed: bool = True, save_to_file: str = None):
"""

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Modules below used for the audio encoder component in: models/nano_nemotron_vl.py
"""
from collections.abc import Iterable
from dataclasses import asdict
import numpy as np
import torch
import torch.nn as nn
from transformers import ParakeetEncoder as HFParakeetEncoder
from transformers import ParakeetFeatureExtractor, PretrainedConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
class ParakeetProjection(nn.Module):
def __init__(self, config: ParakeetConfig) -> None:
super().__init__()
sound_hidden_size = config.hidden_size
proj_hidden_size = config.projection_hidden_size
llm_hidden_size = config.llm_hidden_size
bias = config.projection_bias
self.norm = nn.LayerNorm(sound_hidden_size, eps=config.projection_eps)
self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
self.activation = ReLUSquaredActivation()
self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
hidden_states = self.linear1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.linear2(hidden_states)
return hidden_states
class ProjectedParakeet(nn.Module):
def __init__(
self,
config: PretrainedConfig,
*,
dtype: torch.dtype,
llm_hidden_size: int,
max_model_len: int,
) -> None:
super().__init__()
self.config = ParakeetConfig.from_hf_config(
config, llm_hidden_size=llm_hidden_size, max_model_len=max_model_len
)
self.encoder = HFParakeetEncoder(self.config)
self.encoder = self.encoder.to(dtype)
self.projection = ParakeetProjection(self.config)
self.projection = self.projection.to(dtype)
def forward(
self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None
) -> torch.Tensor:
outputs = self.encoder(
input_features=input_features, attention_mask=attention_mask
)
outputs = outputs.last_hidden_state
outputs = outputs.to(dtype=torch.bfloat16)
outputs = self.projection(outputs)
return outputs
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_params: set[str] = set()
params_dict = dict(self.named_parameters())
buffers_dict = dict(self.named_buffers())
if isinstance(weights, dict):
weights_list = list(weights.items())
else:
weights_list = list(weights)
for name, weight in weights_list:
if name.startswith("sound_encoder.encoder.feature_extractor."):
# Feature extractor buffers are handled outside the encoder.
continue
if name.startswith("sound_encoder."):
target_name = name[len("sound_encoder.") :]
elif name.startswith("sound_projection."):
target_name = f"projection.{name[len('sound_projection.') :]}"
else:
continue
target = params_dict.get(target_name)
if target is None:
target = buffers_dict.get(target_name)
if target is None:
raise ValueError(f"Unknown weight: {name}")
weight_loader = getattr(target, "weight_loader", default_weight_loader)
with torch.no_grad():
weight_loader(target, weight)
loaded_params.add(target_name)
return loaded_params
class ParakeetExtractor(ParakeetFeatureExtractor):
def __init__(self, config: PretrainedConfig) -> None:
self.config = ExtractorConfig.from_hf_config(config)
super().__init__(**asdict(self.config))
self._clip_target_samples = int(
round(self.config.clip_duration_s * self.sampling_rate)
)
self._tail_min_samples = int(
round(self.config.clip_min_duration_s * self.sampling_rate)
)
def _normalize_audio_length(self, audio_len: int) -> int:
# Match mcore's compute_params() logic for clip/minduration handling.
target_len = max(audio_len, self._tail_min_samples)
tail_remainder = target_len % self._clip_target_samples
if 0 < tail_remainder < self._tail_min_samples:
padding = self._tail_min_samples - tail_remainder
target_len += padding
assert isinstance(target_len, int)
return target_len
def audio_token_count(self, audio_len: int) -> int:
audio_len = self._normalize_audio_length(audio_len)
num_frames = audio_len // self.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
return max(1, n_tokens.item())
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
padded = []
for p in raw_speech:
assert p.ndim == 1
audio_len = int(p.shape[0])
target_len = self._normalize_audio_length(audio_len)
p = np.pad(p, (0, target_len - audio_len))
padded.append(p)
return super().__call__(padded, *args, **kwargs)
def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from transformers import ParakeetEncoderConfig, PretrainedConfig
class ParakeetConfig(ParakeetEncoderConfig):
llm_hidden_size: int
projection_hidden_size: int
projection_bias: bool
projection_eps: float = 1e-5
sampling_rate: int
@staticmethod
def from_hf_config(
config: PretrainedConfig, *, llm_hidden_size: int, max_model_len: int
) -> "ParakeetConfig":
assert isinstance(config, PretrainedConfig)
return ParakeetConfig(
**config.to_dict(),
scale_input=False,
attention_bias=False,
llm_hidden_size=llm_hidden_size,
max_position_embeddings=max_model_len
+ 1, # + 1 because it seems like max_model_len+1 can be passed
)
@dataclass(kw_only=True, frozen=True)
class ExtractorConfig:
feature_size: int
sampling_rate: int
subsampling_factor: int
subsampling_conv_kernel_size: int
subsampling_conv_stride: int
clip_duration_s: int = 30
clip_min_duration_s: float = 0.1
@staticmethod
def from_hf_config(config: PretrainedConfig) -> "ExtractorConfig":
assert isinstance(config, PretrainedConfig)
return ExtractorConfig(
feature_size=config.num_mel_bins,
sampling_rate=config.sampling_rate,
subsampling_factor=config.subsampling_factor,
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
subsampling_conv_stride=config.subsampling_conv_stride,
)