[Refactor] Separate sequence and token pooling types (#32026)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-10 12:53:24 +08:00
committed by GitHub
parent 52d428295d
commit 583a90e005
42 changed files with 324 additions and 204 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal
from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
TokenPoolingType = Literal["ALL", "STEP"]
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
@config
@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
pooling_type: PoolingTypeStr | None = None
pooling_type: SequencePoolingType | TokenPoolingType | None = None
"""
The pooling method of the pooling model.
The pooling method used for pooling.
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
with this field. Alternatively, users can set `seq_pooling_type` and
`tok_pooling_type` explicitly.
This field is mainly for user convenience. Internal code should always use
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
"""
seq_pooling_type: SequencePoolingType | None = None
"""
The pooling method used for sequence pooling.
"""
tok_pooling_type: TokenPoolingType | None = None
"""
The pooling method used for tokenwise pooling.
"""
## for embeddings models
@@ -88,9 +109,40 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation
self.use_activation = get_use_activation(self)
def get_pooling_type(self) -> PoolingTypeStr:
assert self.pooling_type is not None, "Should be resolved by ModelConfig"
return self.pooling_type
if pooling_type := self.pooling_type:
if self.seq_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `seq_pooling_type`"
)
if self.tok_pooling_type is not None:
raise ValueError(
"Cannot set both `pooling_type` and `tok_pooling_type`"
)
if pooling_type in SEQ_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.seq_pooling_type = pooling_type
elif pooling_type in TOK_POOLING_TYPES:
logger.debug(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
pooling_type,
pooling_type,
)
self.tok_pooling_type = pooling_type
else:
raise NotImplementedError(pooling_type)
def get_seq_pooling_type(self) -> SequencePoolingType:
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
return self.seq_pooling_type
def get_tok_pooling_type(self) -> TokenPoolingType:
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
return self.tok_pooling_type
def compute_hash(self) -> str:
"""