diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 46cf7fe97..51b36b1ca 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
+from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
compute_retention_mask,
)
from vllm.multimodal.inputs import (
+ AudioItem,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
+ AudioProcessorItems,
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
+
+class NanoNemotronVLAudioFeatureInputs(TensorSchema):
+ """
+ Dimensions:
+ - b: Number of audio clips
+ - t: Audio feature length
+ - f: Feature size (mel bins)
+ """
+
+ type: Literal["audio_features"] = "audio_features"
+ input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
+ feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
+ audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
+
+
+MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
+
IMG_START = "
"
IMG_END = ""
IMG_CONTEXT = ""
+AUDIO_START = ""
+AUDIO_END = ""
+AUDIO_CONTEXT = ""
# Profiling
# MAX_FRAMES = 16
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
+ self.audio_extractor: ParakeetExtractor | None = None
+ raw_sound_config = getattr(config, "sound_config", None)
+ if raw_sound_config is not None:
+ self.audio_extractor = ParakeetExtractor(raw_sound_config)
+
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
self._img_start_token_ids = tokenizer.encode(
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text = [t.replace("