ParakeetProjection.norm = RMSNorm instead of nn.LayerNorm (#36133)
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from transformers import ParakeetEncoder as HFParakeetEncoder
|
||||
from transformers import ParakeetFeatureExtractor, PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.transformers_utils.configs.parakeet import ExtractorConfig, ParakeetConfig
|
||||
|
||||
@@ -26,7 +27,7 @@ class ParakeetProjection(nn.Module):
|
||||
llm_hidden_size = config.llm_hidden_size
|
||||
bias = config.projection_bias
|
||||
|
||||
self.norm = nn.LayerNorm(sound_hidden_size, eps=config.projection_eps)
|
||||
self.norm = RMSNorm(sound_hidden_size, eps=config.projection_eps)
|
||||
self.linear1 = nn.Linear(sound_hidden_size, proj_hidden_size, bias=bias)
|
||||
self.activation = ReLUSquaredActivation()
|
||||
self.linear2 = nn.Linear(proj_hidden_size, llm_hidden_size, bias=bias)
|
||||
|
||||
Reference in New Issue
Block a user