[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
|
||||
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
|
||||
|
||||
TokenPoolingType = Literal["ALL", "STEP"]
|
||||
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
|
||||
|
||||
|
||||
@config
|
||||
@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: PoolingTypeStr | None = None
|
||||
pooling_type: SequencePoolingType | TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method of the pooling model.
|
||||
The pooling method used for pooling.
|
||||
|
||||
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
|
||||
with this field. Alternatively, users can set `seq_pooling_type` and
|
||||
`tok_pooling_type` explicitly.
|
||||
|
||||
This field is mainly for user convenience. Internal code should always use
|
||||
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
|
||||
"""
|
||||
|
||||
seq_pooling_type: SequencePoolingType | None = None
|
||||
"""
|
||||
The pooling method used for sequence pooling.
|
||||
"""
|
||||
|
||||
tok_pooling_type: TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method used for tokenwise pooling.
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
@@ -88,9 +109,40 @@ class PoolerConfig:
|
||||
# raise deprecated warning for softmax and activation
|
||||
self.use_activation = get_use_activation(self)
|
||||
|
||||
def get_pooling_type(self) -> PoolingTypeStr:
|
||||
assert self.pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.pooling_type
|
||||
if pooling_type := self.pooling_type:
|
||||
if self.seq_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `seq_pooling_type`"
|
||||
)
|
||||
if self.tok_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `tok_pooling_type`"
|
||||
)
|
||||
|
||||
if pooling_type in SEQ_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.seq_pooling_type = pooling_type
|
||||
elif pooling_type in TOK_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.tok_pooling_type = pooling_type
|
||||
else:
|
||||
raise NotImplementedError(pooling_type)
|
||||
|
||||
def get_seq_pooling_type(self) -> SequencePoolingType:
|
||||
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.seq_pooling_type
|
||||
|
||||
def get_tok_pooling_type(self) -> TokenPoolingType:
|
||||
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.tok_pooling_type
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user