[Chore] Further cleanup pooler (#31951)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,12 +5,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
CLSPool,
|
||||
DispatchPooler,
|
||||
MeanPool,
|
||||
PoolingType,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
||||
from vllm.platforms import current_platform
|
||||
@@ -50,7 +45,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
assert model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
|
||||
assert model_config.pooler_config.pooling_type == "CLS"
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
@@ -94,7 +89,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
assert not model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config.pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
|
||||
@@ -25,7 +25,6 @@ from vllm.config.vllm import (
|
||||
OPTIMIZATION_LEVEL_TO_CONFIG,
|
||||
OptimizationLevel,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -162,7 +161,7 @@ def test_get_pooling_config():
|
||||
|
||||
assert model_config.pooler_config is not None
|
||||
assert model_config.pooler_config.normalize
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config.pooling_type == "MEAN"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@@ -21,8 +21,7 @@ class PoolerConfig:
|
||||
|
||||
pooling_type: PoolingTypeStr | None = None
|
||||
"""
|
||||
The pooling method of the pooling model. This should be a key in
|
||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||
The pooling method of the pooling model.
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from itertools import groupby
|
||||
from typing import TypeAlias, TypeVar
|
||||
|
||||
@@ -12,13 +11,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
|
||||
from vllm.config import ModelConfig, get_current_vllm_config
|
||||
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
|
||||
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -31,27 +31,17 @@ ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
||||
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
|
||||
TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
||||
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
|
||||
|
||||
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
"""Enumeration for different types of pooling methods."""
|
||||
|
||||
LAST = 0
|
||||
ALL = 1
|
||||
CLS = 2
|
||||
STEP = 3
|
||||
MEAN = 4
|
||||
TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedPoolingConfig:
|
||||
pooling_type: PoolingType
|
||||
pooling_type: PoolingTypeStr
|
||||
task: PoolingTask
|
||||
|
||||
@classmethod
|
||||
@@ -61,7 +51,7 @@ class ResolvedPoolingConfig:
|
||||
pooler_config: PoolerConfig,
|
||||
) -> "ResolvedPoolingConfig":
|
||||
assert pooler_config.pooling_type is not None
|
||||
return cls(task=task, pooling_type=PoolingType[pooler_config.pooling_type])
|
||||
return cls(task=task, pooling_type=pooler_config.pooling_type)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -112,17 +102,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
|
||||
class PoolingMethod(nn.Module, ABC):
|
||||
@staticmethod
|
||||
def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod":
|
||||
if pooling_type == PoolingType.LAST:
|
||||
def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
if pooling_type == PoolingType.ALL:
|
||||
if pooling_type == "ALL":
|
||||
return AllPool()
|
||||
if pooling_type == PoolingType.CLS:
|
||||
if pooling_type == "CLS":
|
||||
return CLSPool()
|
||||
if pooling_type == PoolingType.MEAN:
|
||||
if pooling_type == "MEAN":
|
||||
return MeanPool()
|
||||
if pooling_type == "STEP":
|
||||
raise ValueError(
|
||||
"'STEP' pooling is handled by StepPooler "
|
||||
"and is not a standalone PoolingMethod."
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type}")
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
@@ -186,13 +181,12 @@ class AllPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokensPoolingMethodOutput:
|
||||
) -> TokenwisePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
is_finished = pooling_cursor.is_finished()
|
||||
hidden_states_lst = list(
|
||||
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
|
||||
hidden_states_all = hidden_states.split(
|
||||
pooling_cursor.num_scheduled_tokens_cpu.tolist()
|
||||
)
|
||||
hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index]
|
||||
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
|
||||
|
||||
if not self.enable_chunked_prefill:
|
||||
return hidden_states_lst
|
||||
@@ -206,7 +200,7 @@ class AllPool(PoolingMethod):
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list = list[torch.Tensor | None]()
|
||||
for p, finished in zip(pooling_states, is_finished):
|
||||
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
if len(hidden_states_cache) == 1:
|
||||
@@ -620,19 +614,19 @@ class ClassifierPooler(Pooler):
|
||||
return scores
|
||||
|
||||
|
||||
class TokensPoolerHead(nn.Module, ABC):
|
||||
class TokenwisePoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output multiple tokens."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooled_data: TokenwisePoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
) -> TokenwisePoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokensPoolerHead):
|
||||
class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -647,9 +641,9 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooled_data: TokenwisePoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
) -> TokenwisePoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
@@ -673,7 +667,7 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(TokensPoolerHead):
|
||||
class TokenClassifierPoolerHead(TokenwisePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None,
|
||||
@@ -695,9 +689,9 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokensPoolingMethodOutputItem,
|
||||
pooled_data: TokenwisePoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokensPoolerHeadOutput:
|
||||
) -> TokenwisePoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
@@ -722,7 +716,7 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
|
||||
|
||||
|
||||
class AllPooler(Pooler):
|
||||
def __init__(self, head: TokensPoolerHead) -> None:
|
||||
def __init__(self, head: TokenwisePoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
@@ -735,7 +729,7 @@ class AllPooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokensPoolerOutput:
|
||||
) -> TokenwisePoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
@@ -744,7 +738,7 @@ class AllPooler(Pooler):
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(self, head: TokensPoolerHead) -> None:
|
||||
def __init__(self, head: TokenwisePoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
@@ -790,7 +784,7 @@ class StepPooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokensPoolerOutput:
|
||||
) -> TokenwisePoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.pooler import (
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
@@ -90,7 +89,7 @@ class BertPooler(Pooler):
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
|
||||
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
|
||||
self.pooling = PoolingMethod.from_pooling_type("CLS")
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.pooler import (
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
@@ -287,7 +286,7 @@ class ModernBertPooler(Pooler):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__()
|
||||
|
||||
pooling_type = PoolingType[config.classifier_pooling.upper()]
|
||||
pooling_type = config.classifier_pooling.upper()
|
||||
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size, config.hidden_size, config.classifier_bias
|
||||
|
||||
@@ -92,8 +92,8 @@ class LogprobsTensors(NamedTuple):
|
||||
# [num_reqs, <dynamic>]
|
||||
# The shape of each element depends on the pooler used
|
||||
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput
|
||||
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user