[Refactor] Clean up pooler modules (#31897)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-08 00:07:43 +08:00
committed by GitHub
parent cc6dafaef2
commit b7036c87a1
7 changed files with 167 additions and 120 deletions

View File

@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import TypeVar
from typing import TypeAlias, TypeVar
import torch
import torch.nn as nn
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokensPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__)
@@ -30,6 +30,15 @@ PoolingFn = Callable[
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokensPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokensPoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerHeadOutput: TypeAlias = torch.Tensor | None
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
return PoolingParamsUpdate()
@abstractmethod
def forward_all(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
raise NotImplementedError
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
return self.forward_all(hidden_states, pooling_cursor)
) -> PoolingMethodOutput:
raise NotImplementedError
class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu]
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward_all(
self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor
) -> PoolerOutput:
raise NotImplementedError(
"forward_all is not implemented for AllPool. Use forward instead."
)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooling_cursor = pooling_metadata.pooling_cursor
) -> TokensPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
is_finished = pooling_cursor.is_finished()
hidden_states_lst = list(
hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist())
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list: PoolerOutput = []
output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, is_finished):
if finished:
hidden_states_cache = p.hidden_states_cache
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward_all(
def forward(
self,
hidden_states: torch.Tensor,
pooling_cursor: PoolingCursor,
) -> PoolerOutput:
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
@abstractmethod
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states
class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()
self.activation = activation
class TokenPoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output one token."""
@abstractmethod
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return self.activation(pooled_data)
) -> TokenPoolerHeadOutput:
raise NotImplementedError
class EmbeddingPoolerHead(PoolerHead):
class EmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerNormalize())
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector: nn.Module | None = (
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: list[torch.Tensor] | torch.Tensor,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
"""
def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None:
def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
super().__init__()
self.pooling = pooling
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerHeadOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
return scores
class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
class TokensPoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens."""
@abstractmethod
def forward(
self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams
) -> PoolerOutput:
self,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
raise NotImplementedError
class TokenEmbeddingPoolerHead(TokensPoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokensPoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
return pooled_data
class TokenClassifierPoolerHead(nn.Module):
class TokenClassifierPoolerHead(TokensPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.act_fn = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
self.activation = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
def forward(
self,
hidden_states: torch.Tensor | None,
pooled_data: TokensPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> PoolerOutput:
) -> TokensPoolerHeadOutput:
# for unfinished chunked prefill
if hidden_states is None:
if pooled_data is None:
return None
hidden_states = hidden_states.to(self.head_dtype)
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(hidden_states)
scores = self.classifier(pooled_data)
else:
scores = hidden_states
scores = pooled_data
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.use_activation:
scores = self.act_fn(scores)
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
def __init__(self, head: TokensPoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokensPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
class StepPooler(Pooler):
def __init__(self, head: nn.Module | PoolerHead) -> None:
def __init__(self, head: TokensPoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
def extract_states(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> list[torch.Tensor | None]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooling_params = pooling_metadata.pooling_params
pooled_data: PoolerOutput = []
pooled_data = list[torch.Tensor | None]()
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokensPoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
return pooled_data
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
class DispatchPooler(Pooler):
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[torch.Tensor]()
outputs = list[torch.Tensor | None]()
offset = 0
for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)):

View File

@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = self.dense(pooled_data)
pooled_data = self.activation(pooled_data)
return pooled_data
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class BertEncoder(nn.Module):

View File

@@ -4,21 +4,22 @@ from collections.abc import Set
import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
Pooler,
PoolerHead,
PoolerNormalize,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.outputs import PoolerOutput
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
logger = init_logger(__name__)
class GritLMMeanPool(nn.Module):
class GritLMMeanPool(PoolingMethod):
"""As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig):
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
return instruction_len
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"encode", "embed"}
return {"embed"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor] | torch.Tensor:
) -> TokenPoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor(
[
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
super().__init__()
self.pooling = GritLMMeanPool(model_config)
self.head = PoolerHead(PoolerNormalize())
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
return self.activation(pooled_data)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data

View File

@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod,
PoolingParamsUpdate,
PoolingType,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = pooled_output.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_output)))
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data)))
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor | list[torch.Tensor]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("CLS")

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, TypeAlias
import numpy as np
import torch
@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokensPoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokensPoolerOutput
@dataclass

View File

@@ -90,6 +90,12 @@ class PoolingMetadata:
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def get_pooling_cursor(self) -> PoolingCursor:
pooling_cursor = self.pooling_cursor
assert pooling_cursor is not None, "Should call `build_pooling_cursor` first"
return pooling_cursor
def build_pooling_cursor(
self,
num_scheduled_tokens_np: np.ndarray,

View File

@@ -4680,7 +4680,7 @@ class GPUModelRunner(
for task in supported_pooling_tasks:
# Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = sum(o.nbytes for o in output)
output_size[task] = sum(o.nbytes for o in output if o is not None)
del output # Allow GC
max_task = max(output_size.items(), key=lambda x: x[1])[0]