[Bugfix] Replace PoolingParams.normalize with use_activation (#32243)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -116,8 +116,8 @@ class BertPooler(SequencePooler):
|
||||
|
||||
# Use lambdas so that weights are not registered under `self.head`
|
||||
self.head = EmbeddingPoolerHead(
|
||||
projector=lambda x: self.dense(x),
|
||||
head_dtype=head_dtype,
|
||||
projector=lambda x: self.dense(x),
|
||||
activation=LambdaPoolerActivation(self.act_fn),
|
||||
)
|
||||
|
||||
|
||||
@@ -309,12 +309,13 @@ class ModernBertPooler(SequencePooler):
|
||||
config.hidden_size,
|
||||
eps=config.norm_eps,
|
||||
bias=config.norm_bias,
|
||||
dtype=head_dtype,
|
||||
)
|
||||
|
||||
# Use lambdas so that weights are not registered under `self.head`
|
||||
self.head = EmbeddingPoolerHead(
|
||||
projector=lambda x: self.dense(x),
|
||||
head_dtype=head_dtype,
|
||||
projector=lambda x: self.dense(x),
|
||||
activation=LambdaPoolerActivation(lambda x: self.norm(self.act(x))),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user