Support parakeet as audio encoder for nemotron-nano-vl (#35100)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
|
||||
from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
|
||||
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
|
||||
from vllm.model_executor.models.utils import (
|
||||
init_vllm_registered_model,
|
||||
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
|
||||
compute_retention_mask,
|
||||
)
|
||||
from vllm.multimodal.inputs import (
|
||||
AudioItem,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
VideoItem,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
|
||||
# Alternative: Set a specific higher limit
|
||||
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
|
||||
|
||||
|
||||
class NanoNemotronVLAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Number of audio clips
|
||||
- t: Audio feature length
|
||||
- f: Feature size (mel bins)
|
||||
"""
|
||||
|
||||
type: Literal["audio_features"] = "audio_features"
|
||||
input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
|
||||
feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
|
||||
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
|
||||
|
||||
|
||||
MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
|
||||
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<image>"
|
||||
AUDIO_START = "<so_start>"
|
||||
AUDIO_END = "<so_end>"
|
||||
AUDIO_CONTEXT = "<so_embedding>"
|
||||
|
||||
# Profiling
|
||||
# MAX_FRAMES = 16
|
||||
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
self.video_token = video_token
|
||||
self.video_pruning_rate = video_pruning_rate
|
||||
|
||||
self.audio_extractor: ParakeetExtractor | None = None
|
||||
raw_sound_config = getattr(config, "sound_config", None)
|
||||
if raw_sound_config is not None:
|
||||
self.audio_extractor = ParakeetExtractor(raw_sound_config)
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
self._img_start_token_ids = tokenizer.encode(
|
||||
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
text = [t.replace("<video>", video_repl_text, 1) for t in text]
|
||||
return text, video_inputs
|
||||
|
||||
def _preprocess_audio(
|
||||
self,
|
||||
text: list[str],
|
||||
audios: list[npt.NDArray],
|
||||
):
|
||||
if len(audios) == 0:
|
||||
return text, {}
|
||||
assert self.audio_extractor is not None
|
||||
|
||||
extractor = self.audio_extractor
|
||||
|
||||
parts = [x for x in re.split(f"({re.escape(AUDIO_CONTEXT)})", text[0]) if x]
|
||||
token_count = parts.count(AUDIO_CONTEXT)
|
||||
if token_count != len(audios):
|
||||
raise ValueError(
|
||||
"Number of audio tokens in text does not match the number "
|
||||
f"of audios (tokens={token_count}, audios={len(audios)})."
|
||||
)
|
||||
audio_index = 0
|
||||
for idx, part in enumerate(parts):
|
||||
if part == AUDIO_CONTEXT:
|
||||
audio_repl = self.get_audio_repl(audios[audio_index])
|
||||
parts[idx] = audio_repl.full
|
||||
audio_index += 1
|
||||
text = ["".join(parts)]
|
||||
audio_inputs = extractor(
|
||||
audios,
|
||||
sampling_rate=extractor.sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_audio_features = audio_inputs.input_features
|
||||
feature_attention_mask = audio_inputs.attention_mask
|
||||
audio_feature_lengths = feature_attention_mask.sum(dim=1)
|
||||
audio_inputs = {
|
||||
"input_audio_features": input_audio_features,
|
||||
"feature_attention_mask": feature_attention_mask,
|
||||
"audio_feature_lengths": audio_feature_lengths,
|
||||
}
|
||||
|
||||
return text, audio_inputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
videos: list[tuple[npt.NDArray, dict[str, Any]]] | None = None,
|
||||
audios: AudioItem | list[AudioItem] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
max_num_tiles: int | None = None,
|
||||
) -> BatchFeature:
|
||||
@@ -964,8 +1034,8 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
if max_num_tiles is None:
|
||||
max_num_tiles = self.max_num_tiles
|
||||
|
||||
text, images, videos = [
|
||||
self._make_batch_input(x) for x in (text, images, videos)
|
||||
text, images, videos, audios = [
|
||||
self._make_batch_input(x) for x in (text, images, videos, audios)
|
||||
]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
@@ -980,17 +1050,22 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
max_num_tiles=1,
|
||||
)
|
||||
|
||||
text, audio_inputs = self._preprocess_audio(
|
||||
text=text,
|
||||
audios=audios,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
combined_inputs = {**text_inputs, **video_inputs, **audio_inputs}
|
||||
|
||||
if self.dynamic_tiler is None:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs, **image_inputs},
|
||||
{**combined_inputs, **image_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
else:
|
||||
batch = BatchFeature(
|
||||
{**text_inputs, **video_inputs}, tensor_type=return_tensors
|
||||
)
|
||||
batch = BatchFeature(combined_inputs, tensor_type=return_tensors)
|
||||
# allow images to be exempt from the BatchFeature validation:
|
||||
# We will .stack() them in _parse_and_validate_image_input
|
||||
batch.update(image_inputs)
|
||||
@@ -1006,6 +1081,15 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def get_audio_repl(
|
||||
self,
|
||||
audio: npt.NDArray,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
assert self.audio_extractor is not None
|
||||
num_tokens = self.audio_extractor.audio_token_count(len(audio))
|
||||
repl_full = f"{AUDIO_START}{AUDIO_CONTEXT * num_tokens}{AUDIO_END}"
|
||||
return PromptUpdateDetails.select_text(repl_full, AUDIO_CONTEXT)
|
||||
|
||||
@classmethod
|
||||
def get_video_repl(
|
||||
cls,
|
||||
@@ -1147,15 +1231,28 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
|
||||
def supports_video(self):
|
||||
return self.get_hf_processor().supports_video
|
||||
|
||||
@property
|
||||
def audio_extractor(self) -> ParakeetExtractor | None:
|
||||
return self.get_hf_processor().audio_extractor
|
||||
|
||||
def get_data_parser(self):
|
||||
target_sr = None
|
||||
target_channels = None
|
||||
if extractor := self.audio_extractor:
|
||||
target_sr = extractor.sampling_rate
|
||||
target_channels = 1
|
||||
|
||||
return MultiModalDataParser(
|
||||
video_needs_metadata=True,
|
||||
target_sr=target_sr,
|
||||
target_channels=target_channels,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self):
|
||||
video_limit = {"video": None} if self.supports_video else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit}
|
||||
audio_limit = {"audio": None} if self.audio_extractor is not None else {}
|
||||
return {**super().get_supported_mm_limits(), **video_limit, **audio_limit}
|
||||
|
||||
def get_video_token(self) -> str | None:
|
||||
return IMG_CONTEXT
|
||||
@@ -1304,7 +1401,16 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
else:
|
||||
video_fields = {}
|
||||
|
||||
return image_fields | video_fields
|
||||
if self.info.audio_extractor is not None:
|
||||
audio_fields = dict(
|
||||
input_audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
else:
|
||||
audio_fields = {}
|
||||
|
||||
return image_fields | video_fields | audio_fields
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -1373,6 +1479,20 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
),
|
||||
]
|
||||
|
||||
def get_audio_replacement(item_idx: int):
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
return hf_processor.get_audio_repl(audios.get(item_idx))
|
||||
|
||||
if self.info.audio_extractor is not None:
|
||||
prompt_repl = [
|
||||
*prompt_repl,
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=AUDIO_CONTEXT,
|
||||
replacement=get_audio_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
return prompt_repl
|
||||
|
||||
|
||||
@@ -1422,8 +1542,13 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
return super().get_dummy_text(mm_counts) + "<video>" * num_videos
|
||||
return (
|
||||
super().get_dummy_text(mm_counts)
|
||||
+ "<video>" * num_videos
|
||||
+ AUDIO_CONTEXT * num_audios
|
||||
)
|
||||
|
||||
def _get_dummy_videos(
|
||||
self,
|
||||
@@ -1482,7 +1607,25 @@ class NanoNemotronVLDummyInputsBuilder(
|
||||
}
|
||||
else:
|
||||
dummy_video = {}
|
||||
return {**dummy_image, **dummy_video}
|
||||
|
||||
if extractor := self.info.audio_extractor:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
tokens_per_audio = max(1, seq_len // max(num_audios, 1))
|
||||
max_audio_num_samples = MAX_AUDIO_LEN_S * extractor.sampling_rate
|
||||
calculated_max_audio_num_samples = extractor.audio_length(tokens_per_audio)
|
||||
audio_len = min(max_audio_num_samples, calculated_max_audio_num_samples)
|
||||
dummy_audio = {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
else:
|
||||
dummy_audio = {}
|
||||
|
||||
return {**dummy_image, **dummy_video, **dummy_audio}
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
@@ -1499,12 +1642,15 @@ class NemotronH_Nano_VL_V2(
|
||||
return "<image>"
|
||||
if modality.startswith("video"):
|
||||
return "<video>"
|
||||
if modality.startswith("audio"):
|
||||
return AUDIO_CONTEXT
|
||||
return None
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
model_config = vllm_config.model_config
|
||||
config = model_config.hf_config
|
||||
multimodal_config = model_config.multimodal_config
|
||||
image_size = config.force_image_size
|
||||
patch_size = config.patch_size
|
||||
self.patch_size = patch_size
|
||||
@@ -1523,10 +1669,12 @@ class NemotronH_Nano_VL_V2(
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
llm_dtype = self.language_model.config.dtype
|
||||
assert isinstance(llm_dtype, torch.dtype)
|
||||
self.llm_dtype = llm_dtype
|
||||
with self._mark_tower_model(vllm_config, {"image", "video", "audio"}):
|
||||
self.vision_model = self.get_vit_model_from_radio_config(config).to(
|
||||
self.language_model.config.dtype
|
||||
llm_dtype
|
||||
)
|
||||
|
||||
# Construct the vision projection.
|
||||
@@ -1547,14 +1695,26 @@ class NemotronH_Nano_VL_V2(
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = mlp1.to(self.language_model.config.dtype)
|
||||
self.mlp1 = mlp1.to(llm_dtype)
|
||||
self.sound_encoder: ProjectedParakeet | None = None
|
||||
if getattr(config, "sound_config", None) is not None:
|
||||
logger.info_once(
|
||||
"Found sound config, initializing sound encoder for Nemotron AVLM",
|
||||
scope="global",
|
||||
)
|
||||
self.sound_encoder = ProjectedParakeet(
|
||||
config.sound_config,
|
||||
dtype=llm_dtype,
|
||||
llm_hidden_size=llm_hidden_size,
|
||||
max_model_len=model_config.max_model_len,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
# Pre-tokenize special tokens for video processing
|
||||
# to avoid repeated tokenization
|
||||
tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
self._img_start_token_ids = tokenizer.encode(
|
||||
IMG_START, add_special_tokens=False
|
||||
)
|
||||
@@ -1566,7 +1726,10 @@ class NemotronH_Nano_VL_V2(
|
||||
config
|
||||
)
|
||||
if self.dynamic_resolution:
|
||||
logger.info("Dynamic resolution is enabled for NanoNemotronVLProcessor")
|
||||
logger.info_once(
|
||||
"Dynamic resolution is enabled for NanoNemotronVLProcessor",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
@@ -1780,6 +1943,51 @@ class NemotronH_Nano_VL_V2(
|
||||
|
||||
return final_video_embeddings
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: NanoNemotronVLAudioFeatureInputs
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
assert self.sound_encoder is not None
|
||||
input_audio_features = audio_input.input_audio_features
|
||||
feature_attention_mask = audio_input.feature_attention_mask
|
||||
target_device = next(self.sound_encoder.parameters()).device
|
||||
|
||||
# When cross-request batching combines audio clips with different
|
||||
# time dimensions, _reduce_data returns a list instead of a stacked
|
||||
# tensor. Pad to the max time dim and stack; the attention mask
|
||||
# already marks valid positions so zero-padding is safe.
|
||||
if isinstance(input_audio_features, list):
|
||||
feature_sizes = [f.shape[-2] for f in input_audio_features]
|
||||
max_t = max(feature_sizes)
|
||||
padded_feats = [
|
||||
torch.nn.functional.pad(feat, (0, 0, 0, max_t - feat_size))
|
||||
for feat, feat_size in zip(
|
||||
input_audio_features, feature_sizes, strict=True
|
||||
)
|
||||
]
|
||||
padded_masks = [
|
||||
torch.nn.functional.pad(mask, (0, max_t - mask.shape[-1]))
|
||||
for mask in feature_attention_mask
|
||||
]
|
||||
input_audio_features = torch.stack(padded_feats)
|
||||
feature_attention_mask = torch.stack(padded_masks)
|
||||
|
||||
input_audio_features = input_audio_features.to(
|
||||
dtype=self.llm_dtype, device=target_device
|
||||
)
|
||||
feature_attention_mask = feature_attention_mask.to(device=target_device)
|
||||
sound_embeds = self.sound_encoder(input_audio_features, feature_attention_mask)
|
||||
|
||||
valid_input_lens = feature_attention_mask.sum(dim=1)
|
||||
valid_output_lens = self.sound_encoder.encoder._get_subsampling_output_length(
|
||||
valid_input_lens
|
||||
)
|
||||
truncated_embeds = []
|
||||
for i in range(sound_embeds.shape[0]):
|
||||
valid_len = valid_output_lens[i].item()
|
||||
truncated_embeds.append(sound_embeds[i, :valid_len])
|
||||
|
||||
return tuple(truncated_embeds)
|
||||
|
||||
def _create_final_video_embeddings(
|
||||
self,
|
||||
video_embeddings: torch.Tensor,
|
||||
@@ -1887,6 +2095,18 @@ class NemotronH_Nano_VL_V2(
|
||||
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
|
||||
if (
|
||||
input_key
|
||||
in (
|
||||
"input_audio_features",
|
||||
"feature_attention_mask",
|
||||
"audio_feature_lengths",
|
||||
)
|
||||
and "audios" not in modalities
|
||||
):
|
||||
modalities["audios"] = NanoNemotronVLAudioFeatureInputs(
|
||||
**kwargs, validate=False
|
||||
)
|
||||
|
||||
return modalities
|
||||
|
||||
@@ -1917,6 +2137,10 @@ class NemotronH_Nano_VL_V2(
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_input(video_input)
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
if modality == "audios":
|
||||
audio_input = modalities["audios"]
|
||||
audio_embeddings = self._process_audio_input(audio_input)
|
||||
multimodal_embeddings += tuple(audio_embeddings)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@@ -1947,8 +2171,8 @@ class NemotronH_Nano_VL_V2(
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="mlp1",
|
||||
tower_model="vision_model",
|
||||
connector=["mlp1", "sound_encoder.projection"],
|
||||
tower_model=["vision_model", "sound_encoder.encoder"],
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
@@ -1969,9 +2193,13 @@ class NemotronH_Nano_VL_V2(
|
||||
def is_vision_weights(name: str) -> bool:
|
||||
return name.startswith("vision_model.radio_model.")
|
||||
|
||||
def is_sound_weights(name: str) -> bool:
|
||||
return name.startswith("sound")
|
||||
|
||||
# Separate weights by component
|
||||
llm_weights = []
|
||||
vision_weights = []
|
||||
sound_weights = []
|
||||
|
||||
for name, w in weights:
|
||||
if is_llm(name):
|
||||
@@ -1987,9 +2215,15 @@ class NemotronH_Nano_VL_V2(
|
||||
# Convert: vision_model.radio_model.* → radio_model.*
|
||||
hf_key = name[len("vision_model.") :] # Remove "vision_model." prefix
|
||||
vision_weights.append((hf_key, w))
|
||||
elif is_sound_weights(name):
|
||||
assert self.sound_encoder is not None
|
||||
sound_weights.append((name, w))
|
||||
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.vision_model.load_weights(vision_weights)
|
||||
if self.sound_encoder is not None:
|
||||
assert len(sound_weights) > 0
|
||||
self.sound_encoder.load_weights(sound_weights)
|
||||
|
||||
def print_architecture(self, detailed: bool = True, save_to_file: str = None):
|
||||
"""
|
||||
|
||||
145
vllm/model_executor/models/parakeet.py
Normal file
145
vllm/model_executor/models/parakeet.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Modules below used for the audio encoder component in: models/nano_nemotron_vl.py
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import ParakeetEncoder as HFParakeetEncoder
|
||||
from transformers import ParakeetFeatureExtractor, PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
|
||||
|
||||
|
||||
class ParakeetProjection(nn.Module):
|
||||
def __init__(self, config: ParakeetConfig) -> None:
|
||||
super().__init__()
|
||||
sound_hidden_size = config.hidden_size
|
||||
proj_hidden_size = config.projection_hidden_size
|
||||
llm_hidden_size = config.llm_hidden_size
|
||||
bias = config.projection_bias
|
||||
|
||||
self.norm = nn.LayerNorm(sound_hidden_size, eps=config.projection_eps)
|
||||
self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
|
||||
self.activation = ReLUSquaredActivation()
|
||||
self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.linear1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.linear2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ProjectedParakeet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
llm_hidden_size: int,
|
||||
max_model_len: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = ParakeetConfig.from_hf_config(
|
||||
config, llm_hidden_size=llm_hidden_size, max_model_len=max_model_len
|
||||
)
|
||||
self.encoder = HFParakeetEncoder(self.config)
|
||||
self.encoder = self.encoder.to(dtype)
|
||||
self.projection = ParakeetProjection(self.config)
|
||||
self.projection = self.projection.to(dtype)
|
||||
|
||||
def forward(
|
||||
self, input_features: torch.Tensor, attention_mask: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
outputs = self.encoder(
|
||||
input_features=input_features, attention_mask=attention_mask
|
||||
)
|
||||
outputs = outputs.last_hidden_state
|
||||
outputs = outputs.to(dtype=torch.bfloat16)
|
||||
outputs = self.projection(outputs)
|
||||
return outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loaded_params: set[str] = set()
|
||||
params_dict = dict(self.named_parameters())
|
||||
buffers_dict = dict(self.named_buffers())
|
||||
|
||||
if isinstance(weights, dict):
|
||||
weights_list = list(weights.items())
|
||||
else:
|
||||
weights_list = list(weights)
|
||||
|
||||
for name, weight in weights_list:
|
||||
if name.startswith("sound_encoder.encoder.feature_extractor."):
|
||||
# Feature extractor buffers are handled outside the encoder.
|
||||
continue
|
||||
if name.startswith("sound_encoder."):
|
||||
target_name = name[len("sound_encoder.") :]
|
||||
elif name.startswith("sound_projection."):
|
||||
target_name = f"projection.{name[len('sound_projection.') :]}"
|
||||
else:
|
||||
continue
|
||||
|
||||
target = params_dict.get(target_name)
|
||||
if target is None:
|
||||
target = buffers_dict.get(target_name)
|
||||
if target is None:
|
||||
raise ValueError(f"Unknown weight: {name}")
|
||||
weight_loader = getattr(target, "weight_loader", default_weight_loader)
|
||||
with torch.no_grad():
|
||||
weight_loader(target, weight)
|
||||
loaded_params.add(target_name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ParakeetExtractor(ParakeetFeatureExtractor):
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
self.config = ExtractorConfig.from_hf_config(config)
|
||||
super().__init__(**asdict(self.config))
|
||||
self._clip_target_samples = int(
|
||||
round(self.config.clip_duration_s * self.sampling_rate)
|
||||
)
|
||||
self._tail_min_samples = int(
|
||||
round(self.config.clip_min_duration_s * self.sampling_rate)
|
||||
)
|
||||
|
||||
def _normalize_audio_length(self, audio_len: int) -> int:
|
||||
# Match mcore's compute_params() logic for clip/minduration handling.
|
||||
target_len = max(audio_len, self._tail_min_samples)
|
||||
tail_remainder = target_len % self._clip_target_samples
|
||||
if 0 < tail_remainder < self._tail_min_samples:
|
||||
padding = self._tail_min_samples - tail_remainder
|
||||
target_len += padding
|
||||
assert isinstance(target_len, int)
|
||||
return target_len
|
||||
|
||||
def audio_token_count(self, audio_len: int) -> int:
|
||||
audio_len = self._normalize_audio_length(audio_len)
|
||||
num_frames = audio_len // self.hop_length
|
||||
n_tokens = HFParakeetEncoder._get_subsampling_output_length(
|
||||
self, torch.tensor([num_frames], dtype=torch.float)
|
||||
)
|
||||
return max(1, n_tokens.item())
|
||||
|
||||
def __call__(self, raw_speech: list[np.ndarray], *args, **kwargs):
|
||||
padded = []
|
||||
for p in raw_speech:
|
||||
assert p.ndim == 1
|
||||
audio_len = int(p.shape[0])
|
||||
target_len = self._normalize_audio_length(audio_len)
|
||||
p = np.pad(p, (0, target_len - audio_len))
|
||||
padded.append(p)
|
||||
return super().__call__(padded, *args, **kwargs)
|
||||
|
||||
def audio_length(self, audio_tokens: int) -> int:
|
||||
return int(audio_tokens * self.config.subsampling_factor * self.hop_length)
|
||||
49
vllm/transformers_utils/configs/parakeet.py
Normal file
49
vllm/transformers_utils/configs/parakeet.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transformers import ParakeetEncoderConfig, PretrainedConfig
|
||||
|
||||
|
||||
class ParakeetConfig(ParakeetEncoderConfig):
|
||||
llm_hidden_size: int
|
||||
projection_hidden_size: int
|
||||
projection_bias: bool
|
||||
projection_eps: float = 1e-5
|
||||
sampling_rate: int
|
||||
|
||||
@staticmethod
|
||||
def from_hf_config(
|
||||
config: PretrainedConfig, *, llm_hidden_size: int, max_model_len: int
|
||||
) -> "ParakeetConfig":
|
||||
assert isinstance(config, PretrainedConfig)
|
||||
return ParakeetConfig(
|
||||
**config.to_dict(),
|
||||
scale_input=False,
|
||||
attention_bias=False,
|
||||
llm_hidden_size=llm_hidden_size,
|
||||
max_position_embeddings=max_model_len
|
||||
+ 1, # + 1 because it seems like max_model_len+1 can be passed
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ExtractorConfig:
|
||||
feature_size: int
|
||||
sampling_rate: int
|
||||
subsampling_factor: int
|
||||
subsampling_conv_kernel_size: int
|
||||
subsampling_conv_stride: int
|
||||
clip_duration_s: int = 30
|
||||
clip_min_duration_s: float = 0.1
|
||||
|
||||
@staticmethod
|
||||
def from_hf_config(config: PretrainedConfig) -> "ExtractorConfig":
|
||||
assert isinstance(config, PretrainedConfig)
|
||||
return ExtractorConfig(
|
||||
feature_size=config.num_mel_bins,
|
||||
sampling_rate=config.sampling_rate,
|
||||
subsampling_factor=config.subsampling_factor,
|
||||
subsampling_conv_kernel_size=config.subsampling_conv_kernel_size,
|
||||
subsampling_conv_stride=config.subsampling_conv_stride,
|
||||
)
|
||||
Reference in New Issue
Block a user