[Chore] Further cleanup pooler (#31951)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:16:21 +08:00
committed by GitHub
parent 04a49669d1
commit d1b6fe007f
7 changed files with 47 additions and 62 deletions

View File

@@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import TypeAlias, TypeVar
@@ -12,13 +11,14 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__)
@@ -31,27 +31,17 @@ ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
CLS = 2
STEP = 3
MEAN = 4
TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
@dataclass(frozen=True)
class ResolvedPoolingConfig:
pooling_type: PoolingType
pooling_type: PoolingTypeStr
task: PoolingTask
@classmethod
@@ -61,7 +51,7 @@ class ResolvedPoolingConfig:
pooler_config: PoolerConfig,
) -> "ResolvedPoolingConfig":
assert pooler_config.pooling_type is not None
return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
return cls(task=task, pooling_type=pooler_config.pooling_type)
@dataclass(frozen=True)
@@ -112,17 +102,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
class PoolingMethod(nn.Module, ABC):
@staticmethod
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
if pooling_type == PoolingType.LAST:
def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
if pooling_type == "LAST":
return LastPool()
if pooling_type == PoolingType.ALL:
if pooling_type == "ALL":
return AllPool()
if pooling_type == PoolingType.CLS:
if pooling_type == "CLS":
return CLSPool()
if pooling_type == PoolingType.MEAN:
if pooling_type == "MEAN":
return MeanPool()
if pooling_type == "STEP":
raise ValueError(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise NotImplementedError(f"Unsupported method: {pooling_type}")
raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
@@ -186,13 +181,12 @@ class AllPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokensPoolingMethodOutput:
) -> TokenwisePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
hidden_states_all = hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()
)
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
@@ -206,7 +200,7 @@ class AllPool(PoolingMethod):
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, is_finished):
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
@@ -620,19 +614,19 @@ class ClassifierPooler(Pooler):
return scores
class TokensPoolerHead(nn.Module, ABC):
class TokenwisePoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens."""
@abstractmethod
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
) -> TokenwisePoolerHeadOutput:
raise NotImplementedError
class TokenEmbeddingPoolerHead(TokensPoolerHead):
class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
def __init__(self) -> None:
super().__init__()
@@ -647,9 +641,9 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
@@ -673,7 +667,7 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
return pooled_data
class TokenClassifierPoolerHead(TokensPoolerHead):
class TokenClassifierPoolerHead(TokenwisePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None,
@@ -695,9 +689,9 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
@@ -722,7 +716,7 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
class AllPooler(Pooler):
def __init__(self, head: TokensPoolerHead) -> None:
def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -735,7 +729,7 @@ class AllPooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokensPoolerOutput:
) -> TokenwisePoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
@@ -744,7 +738,7 @@ class AllPooler(Pooler):
class StepPooler(Pooler):
def __init__(self, head: TokensPoolerHead) -> None:
def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -790,7 +784,7 @@ class StepPooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokensPoolerOutput:
) -> TokenwisePoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)