[Model] Ultravox Model: Support v0.5 Release (#12912)
Signed-off-by: Farzad Abdolhosseini <farzad@fixie.ai>
This commit is contained in:
committed by
GitHub
parent
2ae889052c
commit
08b2d845d6
@@ -258,27 +258,35 @@ class UltravoxProjector(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
||||
dim = config.audio_config.hidden_size * config.stack_factor
|
||||
self.ln_pre = RMSNorm(dim)
|
||||
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
|
||||
dim = self.hidden_dim
|
||||
dim_in = config.audio_config.hidden_size * config.stack_factor
|
||||
self.ln_pre = RMSNorm(dim_in)
|
||||
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
|
||||
dim_mid = self.hidden_dim
|
||||
|
||||
if config.projector_act == "swiglu":
|
||||
self.act = MulAndSilu()
|
||||
dim = dim // 2
|
||||
dim_mid = dim_mid // 2
|
||||
else:
|
||||
self.act = get_act_fn(config.projector_act)
|
||||
|
||||
self.linear_2 = nn.Linear(dim,
|
||||
config.text_config.hidden_size,
|
||||
bias=False)
|
||||
self.ln_post = RMSNorm(config.text_config.hidden_size)
|
||||
dim_out = config.text_config.hidden_size
|
||||
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
|
||||
|
||||
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
|
||||
# while v0.5.0 and above uses layer_norm after the first linear layer.
|
||||
if config.projector_ln_mid:
|
||||
self.ln_mid: nn.Module = RMSNorm(dim_mid)
|
||||
self.ln_post = nn.Identity()
|
||||
else:
|
||||
self.ln_mid = nn.Identity()
|
||||
self.ln_post = RMSNorm(dim_out)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
audio_features = self._pad_and_stack(audio_features)
|
||||
audio_features = self.ln_pre(audio_features)
|
||||
hidden_states = self.linear_1(audio_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.ln_mid(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.ln_post(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
The LoRA configuration for finetuning the text model.
|
||||
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
|
||||
The LoRA configuration for finetuning the audio model.
|
||||
projector_ln_mid (`bool`, *optional*, defaults to `False`):
|
||||
Whether to apply layer normalization at the middle of the
|
||||
projector or at the end. Versions v0.4.1 and below
|
||||
use `False`, but v0.5 and above use `True`.
|
||||
"""
|
||||
|
||||
model_type = "ultravox"
|
||||
@@ -56,6 +60,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
projector_act: str = "swiglu",
|
||||
text_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||
audio_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||
projector_ln_mid: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@@ -68,6 +73,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
self.stack_factor = stack_factor
|
||||
self.norm_init = norm_init
|
||||
self.projector_act = projector_act
|
||||
self.projector_ln_mid = projector_ln_mid
|
||||
|
||||
if text_model_id is not None:
|
||||
# Avoid circular import
|
||||
|
||||
Reference in New Issue
Block a user