[Refactor] Clean up pooler modules (#31897)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 00:07:43 +08:00
committed by GitHub
parent cc6dafaef2
commit b7036c87a1
7 changed files with 167 additions and 120 deletions

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import TypeVar
from typing import TypeAlias, TypeVar
import torch
import torch.nn as nn
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__)
@@ -30,6 +30,15 @@ PoolingFn = Callable[
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
return PoolingParamsUpdate()
@abstractmethod
def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
) -> PoolingMethodOutput:
raise NotImplementedError
class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu]
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward_all(
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> PoolerOutput:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
) -> TokensPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = []
output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, is_finished):
if finished:
hidden_states_cache = p.hidden_states_cache
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states
class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
self.activation = activation
class TokenPoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output one token."""
@abstractmethod
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self.activation(pooled_data)
) -> TokenPoolerHeadOutput:
raise NotImplementedError
class EmbeddingPoolerHead(PoolerHead):
class EmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerNormalize())
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = (
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
"""
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
super().__init__()
self.pooling = pooling
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerHeadOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
class TokensPoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens."""
@abstractmethod
def forward(
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> PoolerOutput:
self,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
raise NotImplementedError
class TokenEmbeddingPoolerHead(TokensPoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
return pooled_data
class TokenClassifierPoolerHead(nn.Module):
class TokenClassifierPoolerHead(TokensPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
self.activation = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
def forward(
self,
hidden_states: torch.Tensor | None,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> PoolerOutput:
) -> TokensPoolerHeadOutput:
# for unfinished chunked prefill
if hidden_states is None:
if pooled_data is None:
return None
hidden_states = hidden_states.to(self.head_dtype)
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(hidden_states)
scores = self.classifier(pooled_data)
else:
scores = hidden_states
scores = pooled_data
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.use_activation:
scores = self.act_fn(scores)
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
def __init__(self, head: TokensPoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokensPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
def __init__(self, head: TokensPoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> list[torch.Tensor | None]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooling_params = pooling_metadata.pooling_params
pooled_data: PoolerOutput = []
pooled_data = list[torch.Tensor | None]()
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokensPoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
class DispatchPooler(Pooler):
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[torch.Tensor]()
outputs = list[torch.Tensor | None]()
offset = 0
for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)):