Files
vllm/vllm/model_executor/layers/pooler/common.py
Cyrus Leung 8863c2b25c [Model] Standardize pooling heads (#32148)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2026-01-12 17:01:49 +00:00

33 lines
1.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
import torch
from vllm.pooling_params import PoolingParams
_T = TypeVar("_T", bound=torch.Tensor | list[torch.Tensor])
ProjectorFn = Callable[[torch.Tensor], torch.Tensor]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
return PoolingParamsUpdate(
requires_token_ids=self.requires_token_ids or other.requires_token_ids,
)
def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids
__all__ = ["ActivationFn", "ClassifierFn", "ProjectorFn", "PoolingParamsUpdate"]