[Bugfix] Replace PoolingParams.normalize with use_activation (#32243)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -44,14 +44,14 @@ class TokenPoolerHead(nn.Module, ABC):
|
||||
class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
projector: ProjectorFn | None = None,
|
||||
head_dtype: torch.dtype | str | None = None,
|
||||
projector: ProjectorFn | None = None,
|
||||
activation: ActivationFn | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.projector = projector
|
||||
self.head_dtype = head_dtype
|
||||
self.projector = projector
|
||||
self.activation = activation
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
@@ -79,7 +79,7 @@ class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if self.activation is not None and pooling_param.normalize:
|
||||
if self.activation is not None and pooling_param.use_activation:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
|
||||
@@ -95,8 +95,8 @@ def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenEmbeddingPoolerHead(
|
||||
projector=_load_st_projector(model_config),
|
||||
head_dtype=model_config.head_dtype,
|
||||
projector=_load_st_projector(model_config),
|
||||
activation=PoolerNormalize(),
|
||||
)
|
||||
|
||||
@@ -116,9 +116,9 @@ def pooler_for_token_classify(
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
head = TokenClassifierPoolerHead(
|
||||
head_dtype=model_config.head_dtype,
|
||||
classifier=classifier,
|
||||
logit_bias=model_config.pooler_config.logit_bias,
|
||||
head_dtype=model_config.head_dtype,
|
||||
activation=resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=False, act_fn=act_fn
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user