chunk parakeet into 30s clips to prevent OOMs on long audios (#36671)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
Netanel Haber
2026-03-18 23:22:24 +02:00
committed by GitHub
parent a913b612d8
commit 6ae4c8d6fc
3 changed files with 72 additions and 64 deletions

View File

@@ -99,15 +99,16 @@ MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- b: Number of audio clips
- c: Number of audio clips (possibly flattened across audio items)
- b: Number of original audio items
- t: Audio feature length
- f: Feature size (mel bins)
"""
type: Literal["audio_features"] = "audio_features"
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
input_audio_features: Annotated[torch.Tensor, TensorShape("c", "t", "f")]
feature_attention_mask: Annotated[torch.Tensor, TensorShape("c", "t")]
audio_num_clips: list[int]
class NanoNemotronVLImagePixelInputs(TensorSchema):
@@ -548,10 +549,17 @@ class NanoNemotronVLMultiModalProcessor(
video_fields = {}
if self.info.audio_extractor is not None:
audio_num_clips = torch.as_tensor(hf_inputs["audio_num_clips"])
audio_fields = dict(
input_audio_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_clips
),
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_clips
),
audio_num_clips=MultiModalFieldConfig.batched(
"audio", keep_on_cpu=True
),
)
else:
audio_fields = {}
@@ -1095,28 +1103,9 @@ class NemotronH_Nano_VL_V2(
assert self.sound_encoder is not None
input_audio_features = audio_input.input_audio_features
feature_attention_mask = audio_input.feature_attention_mask
audio_num_clips = audio_input.audio_num_clips
target_device = next(self.sound_encoder.parameters()).device
# When cross-request batching combines audio clips with different
# time dimensions, _reduce_data returns a list instead of a stacked
# tensor. Pad to the max time dim and stack; the attention mask
# already marks valid positions so zero-padding is safe.
if isinstance(input_audio_features, list):
feature_sizes = [f.shape[-2] for f in input_audio_features]
max_t = max(feature_sizes)
padded_feats = [
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
for feat, feat_size in zip(
input_audio_features, feature_sizes, strict=True
)
]
padded_masks = [
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
for mask in feature_attention_mask
]
input_audio_features = torch.stack(padded_feats)
feature_attention_mask = torch.stack(padded_masks)
input_audio_features = input_audio_features.to(
dtype=self.llm_dtype, device=target_device
)
@@ -1126,13 +1115,18 @@ class NemotronH_Nano_VL_V2(
valid_input_lens = feature_attention_mask.sum(dim=1)
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
valid_input_lens
)
truncated_embeds = []
for i in range(sound_embeds.shape[0]):
valid_len = valid_output_lens[i].item()
truncated_embeds.append(sound_embeds[i, :valid_len])
).tolist()
grouped_embeds = []
clip_offset = 0
for num_clips in audio_num_clips:
embeds = []
for clip_idx in range(clip_offset, clip_offset + num_clips):
valid_len = valid_output_lens[clip_idx]
embeds.append(sound_embeds[clip_idx, :valid_len])
grouped_embeds.append(torch.cat(embeds, dim=0))
clip_offset += num_clips
return tuple(truncated_embeds)
return tuple(grouped_embeds)
def _create_final_video_embeddings(
self,
@@ -1246,7 +1240,7 @@ class NemotronH_Nano_VL_V2(
in (
"input_audio_features",
"feature_attention_mask",
"audio_feature_lengths",
"audio_num_clips",
)
and "audios" not in modalities
):

View File

@@ -114,33 +114,50 @@ class ParakeetExtractor(ParakeetFeatureExtractor):
round(self.config.clip_min_duration_s * self.sampling_rate)
)
def _normalize_audio_length(self, audio_len: int) -> int:
# Match mcore's compute_params() logic for clip/minduration handling.
target_len = max(audio_len, self._tail_min_samples)
tail_remainder = target_len % self._clip_target_samples
if 0 < tail_remainder < self._tail_min_samples:
padding = self._tail_min_samples - tail_remainder
target_len += padding
assert isinstance(target_len, int)
return target_len
def _clip_sizes(self, audio_len: int) -> list[int]:
audio_len = max(audio_len, self._tail_min_samples)
num_full_clips, remainder = divmod(audio_len, self._clip_target_samples)
clip_sizes = [self._clip_target_samples] * num_full_clips
if remainder > 0:
clip_sizes.append(max(remainder, self._tail_min_samples))
return clip_sizes
def audio_token_count(self, audio_len: int) -> int:
audio_len = self._normalize_audio_length(audio_len)
num_frames = audio_len // self.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
return max(1, n_tokens.item())
total_tokens = 0
for clip_size in self._clip_sizes(audio_len):
num_frames = clip_size // self.hop_length
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
self, torch.tensor([num_frames], dtype=torch.float)
)
total_tokens += int(n_tokens.item())
return max(1, total_tokens)
def split_audio_into_clips(self, audio: np.ndarray) -> list[np.ndarray]:
assert audio.ndim == 1
audio_len = int(audio.shape[0])
clip_sizes = self._clip_sizes(audio_len)
target_len = sum(clip_sizes)
if audio_len < target_len:
audio = np.pad(audio, (0, target_len - audio_len))
clips = list[np.ndarray]()
offset = 0
for clip_size in clip_sizes:
clips.append(audio[offset : offset + clip_size])
offset += clip_size
return clips
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
padded = []
for p in raw_speech:
assert p.ndim == 1
audio_len = int(p.shape[0])
target_len = self._normalize_audio_length(audio_len)
p = np.pad(p, (0, target_len - audio_len))
padded.append(p)
return super().__call__(padded, *args, **kwargs)
audio_clips = list[np.ndarray]()
audio_num_clips = list[int]()
for audio in raw_speech:
clips = self.split_audio_into_clips(audio)
audio_clips.extend(clips)
audio_num_clips.append(len(clips))
outputs = super().__call__(audio_clips, *args, **kwargs)
outputs["audio_num_clips"] = audio_num_clips
return outputs
def audio_length(self, audio_tokens: int) -> int:
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)

View File

@@ -845,7 +845,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
audios: list[npt.NDArray],
) -> tuple[list[str], dict[str, Any]]:
if len(audios) == 0:
return text, {}
return text, {"audio_num_clips": []}
assert self.audio_extractor is not None
extractor = self.audio_extractor
@@ -869,13 +869,10 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
sampling_rate=extractor.sampling_rate,
return_tensors="pt",
)
input_audio_features = audio_inputs.input_features
feature_attention_mask = audio_inputs.attention_mask
audio_feature_lengths = feature_attention_mask.sum(dim=1)
audio_inputs = {
"input_audio_features": input_audio_features,
"feature_attention_mask": feature_attention_mask,
"audio_feature_lengths": audio_feature_lengths,
"input_audio_features": audio_inputs.input_features,
"feature_attention_mask": audio_inputs.attention_mask,
"audio_num_clips": audio_inputs.audio_num_clips,
}
return text, audio_inputs