[Chore] Further cleanup pooler (#31951)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 18:16:21 +08:00
committed by GitHub
parent 04a49669d1
commit d1b6fe007f
7 changed files with 47 additions and 62 deletions

View File

@@ -5,12 +5,7 @@ import os
import pytest
from vllm.model_executor.layers.pooler import (
CLSPool,
DispatchPooler,
MeanPool,
PoolingType,
)
from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform
@@ -50,7 +45,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
assert model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
assert model_config.pooler_config.pooling_type == "CLS"
assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded
@@ -94,7 +89,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
assert not model_config.encoder_config["do_lower_case"]
# asserts on the pooling config files
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_type == "MEAN"
assert model_config.pooler_config.normalize
# asserts on the tokenizer loaded

View File

@@ -25,7 +25,6 @@ from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel,
)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@@ -162,7 +161,7 @@ def test_get_pooling_config():
assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config.pooling_type == "MEAN"
@pytest.mark.skipif(

View File

@@ -21,8 +21,7 @@ class PoolerConfig:
pooling_type: PoolingTypeStr | None = None
"""
The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][].
The pooling method of the pooling model.
"""
## for embeddings models

View File

@@ -3,7 +3,6 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import TypeAlias, TypeVar
@@ -12,13 +11,14 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__)
@@ -31,27 +31,17 @@ ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
ALL = 1
CLS = 2
STEP = 3
MEAN = 4
TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
@dataclass(frozen=True)
class ResolvedPoolingConfig:
pooling_type: PoolingType
pooling_type: PoolingTypeStr
task: PoolingTask
@classmethod
@@ -61,7 +51,7 @@ class ResolvedPoolingConfig:
pooler_config: PoolerConfig,
) -> "ResolvedPoolingConfig":
assert pooler_config.pooling_type is not None
return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
return cls(task=task, pooling_type=pooler_config.pooling_type)
@dataclass(frozen=True)
@@ -112,17 +102,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
class PoolingMethod(nn.Module, ABC):
@staticmethod
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
if pooling_type == PoolingType.LAST:
def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
if pooling_type == "LAST":
return LastPool()
if pooling_type == PoolingType.ALL:
if pooling_type == "ALL":
return AllPool()
if pooling_type == PoolingType.CLS:
if pooling_type == "CLS":
return CLSPool()
if pooling_type == PoolingType.MEAN:
if pooling_type == "MEAN":
return MeanPool()
if pooling_type == "STEP":
raise ValueError(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise NotImplementedError(f"Unsupported method: {pooling_type}")
raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
@@ -186,13 +181,12 @@ class AllPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokensPoolingMethodOutput:
) -> TokenwisePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
hidden_states_all = hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()
)
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
@@ -206,7 +200,7 @@ class AllPool(PoolingMethod):
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, is_finished):
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
@@ -620,19 +614,19 @@ class ClassifierPooler(Pooler):
return scores
class TokensPoolerHead(nn.Module, ABC):
class TokenwisePoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens."""
@abstractmethod
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
) -> TokenwisePoolerHeadOutput:
raise NotImplementedError
class TokenEmbeddingPoolerHead(TokensPoolerHead):
class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
def __init__(self) -> None:
super().__init__()
@@ -647,9 +641,9 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
@@ -673,7 +667,7 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
return pooled_data
class TokenClassifierPoolerHead(TokensPoolerHead):
class TokenClassifierPoolerHead(TokenwisePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None,
@@ -695,9 +689,9 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
@@ -722,7 +716,7 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
class AllPooler(Pooler):
def __init__(self, head: TokensPoolerHead) -> None:
def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -735,7 +729,7 @@ class AllPooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokensPoolerOutput:
) -> TokenwisePoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
@@ -744,7 +738,7 @@ class AllPooler(Pooler):
class StepPooler(Pooler):
def __init__(self, head: TokensPoolerHead) -> None:
def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -790,7 +784,7 @@ class StepPooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokensPoolerOutput:
) -> TokenwisePoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)

View File

@@ -23,7 +23,6 @@ from vllm.model_executor.layers.pooler import (
Pooler,
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
@@ -90,7 +89,7 @@ class BertPooler(Pooler):
def __init__(self, config: BertConfig):
super().__init__()
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
self.pooling = PoolingMethod.from_pooling_type("CLS")
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

View File

@@ -18,7 +18,6 @@ from vllm.model_executor.layers.pooler import (
Pooler,
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
@@ -287,7 +286,7 @@ class ModernBertPooler(Pooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()]
pooling_type = config.classifier_pooling.upper()
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias

View File

@@ -92,8 +92,8 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
@dataclass