[Model] Ultravox Model: Support v0.5 Release (#12912)

Signed-off-by: Farzad Abdolhosseini <farzad@fixie.ai>
This commit is contained in:
Farzad Abdolhosseini
2025-02-10 14:02:48 -08:00
committed by GitHub
parent 2ae889052c
commit 08b2d845d6
12 changed files with 36 additions and 22 deletions

View File

@@ -258,27 +258,35 @@ class UltravoxProjector(nn.Module):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim)
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
dim = self.hidden_dim
dim_in = config.audio_config.hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim_in)
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
dim_mid = self.hidden_dim
if config.projector_act == "swiglu":
self.act = MulAndSilu()
dim = dim // 2
dim_mid = dim_mid // 2
else:
self.act = get_act_fn(config.projector_act)
self.linear_2 = nn.Linear(dim,
config.text_config.hidden_size,
bias=False)
self.ln_post = RMSNorm(config.text_config.hidden_size)
dim_out = config.text_config.hidden_size
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
# while v0.5.0 and above uses layer_norm after the first linear layer.
if config.projector_ln_mid:
self.ln_mid: nn.Module = RMSNorm(dim_mid)
self.ln_post = nn.Identity()
else:
self.ln_mid = nn.Identity()
self.ln_post = RMSNorm(dim_out)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.ln_mid(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states

View File

@@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
The LoRA configuration for finetuning the text model.
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
The LoRA configuration for finetuning the audio model.
projector_ln_mid (`bool`, *optional*, defaults to `False`):
Whether to apply layer normalization at the middle of the
projector or at the end. Versions v0.4.1 and below
use `False`, but v0.5 and above use `True`.
"""
model_type = "ultravox"
@@ -56,6 +60,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
projector_act: str = "swiglu",
text_model_lora_config: Optional[Dict[str, Any]] = None,
audio_model_lora_config: Optional[Dict[str, Any]] = None,
projector_ln_mid: bool = False,
**kwargs,
):
self.ignore_index = ignore_index
@@ -68,6 +73,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
self.stack_factor = stack_factor
self.norm_init = norm_init
self.projector_act = projector_act
self.projector_ln_mid = projector_ln_mid
if text_model_id is not None:
# Avoid circular import