[Model] Reorganize pooling layers (#31973)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-09 19:02:14 +08:00
committed by GitHub
parent 020732800c
commit c8ed39b9dd
34 changed files with 1290 additions and 1143 deletions

2
.github/CODEOWNERS vendored
View File

@@ -153,7 +153,7 @@ mkdocs.yaml @hmellor
/vllm/entrypoints/pooling @noooop
/vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler.py @noooop
/vllm/model_executor/layers/pooler @noooop
# Security guide and policies
/docs/usage/security.md @russellb

View File

@@ -5,7 +5,8 @@ import os
import pytest
from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool, MeanPool
from vllm.model_executor.models.bert import BertEmbeddingModel
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
from vllm.platforms import current_platform

View File

@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.models.gemma2 import Gemma2Model
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.sequence import IntermediateTensors
@@ -28,12 +28,7 @@ class MyGemma2Embedding(nn.Module):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors

View File

@@ -88,6 +88,10 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation
self.use_activation = get_use_activation(self)
def get_pooling_type(self) -> PoolingTypeStr:
assert self.pooling_type is not None, "Should be resolved by ModelConfig"
return self.pooling_type
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@@ -1,845 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Set
from dataclasses import dataclass
from itertools import groupby
from typing import TypeAlias, TypeVar
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
logger = init_logger(__name__)
PoolingFn = Callable[
[torch.Tensor | list[torch.Tensor], PoolingMetadata],
torch.Tensor | list[torch.Tensor],
]
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
@dataclass(frozen=True)
class ResolvedPoolingConfig:
pooling_type: PoolingTypeStr
task: PoolingTask
@classmethod
def from_config(
cls,
task: PoolingTask,
pooler_config: PoolerConfig,
) -> "ResolvedPoolingConfig":
assert pooler_config.pooling_type is not None
return cls(task=task, pooling_type=pooler_config.pooling_type)
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids
def get_classification_activation_function(config: PretrainedConfig):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type = getattr(config, "problem_type", "")
if problem_type == "regression":
return PoolerIdentity()
if problem_type == "single_label_classification":
return PoolerClassify()
if problem_type == "multi_label_classification":
return PoolerMultiLabelClassify()
return PoolerClassify()
def get_cross_encoder_activation_function(config: PretrainedConfig):
function_name: str | None = None
if (
hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers
):
function_name = config.sentence_transformers["activation_fn"]
elif (
hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None
):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn = resolve_obj_by_qualname(function_name)()
return PoolerActivation.wraps(fn)
return PoolerClassify()
class PoolingMethod(nn.Module, ABC):
@staticmethod
def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
if pooling_type == "LAST":
return LastPool()
if pooling_type == "ALL":
return AllPool()
if pooling_type == "CLS":
return CLSPool()
if pooling_type == "MEAN":
return MeanPool()
if pooling_type == "STEP":
raise ValueError(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolingMethodOutput:
raise NotImplementedError
class CLSPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
return hidden_states[pooling_cursor.first_token_indices_gpu]
class LastPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu]
class AllPool(PoolingMethod):
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
self.enable_chunked_prefill = (
vllm_config.scheduler_config.enable_chunked_prefill
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenwisePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
hidden_states_all = hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()
)
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
pooling_states = pooling_metadata.pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = list[torch.Tensor | None]()
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)
return output_list
class MeanPool(PoolingMethod):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
hidden_states.device, non_blocking=True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
return (
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
) / prompt_lens.unsqueeze(1)
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
class BasePoolerActivation(nn.Module, ABC):
@abstractmethod
def forward(self, pooled_data: _T) -> _T:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
raise NotImplementedError
class PoolerActivation(BasePoolerActivation):
@staticmethod
def wraps(module: nn.Module):
if isinstance(module, nn.Identity):
return PoolerIdentity()
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
return PoolerClassify()
return LambdaPoolerActivation(module)
@abstractmethod
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def forward(self, pooled_data: _T) -> _T:
if isinstance(pooled_data, list):
return [self.forward_chunk(data) for data in pooled_data]
return self.forward_chunk(pooled_data)
class PoolerIdentity(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return pooled_data
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.normalize(pooled_data, p=2, dim=-1)
class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data)
class PoolerClassify(PoolerActivation):
def __init__(self, *, static_num_labels: bool = True) -> None:
super().__init__()
if static_num_labels:
vllm_config = get_current_vllm_config()
self.num_labels = getattr(
vllm_config.model_config.hf_config, "num_labels", 0
)
if self.num_labels == 0:
logger.warning(
"num_labels should be > 0 for classification"
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
else:
self.num_labels = None
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = (
self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
)
if num_labels < 2:
return F.sigmoid(pooled_data)
return F.softmax(pooled_data, dim=-1)
class LambdaPoolerActivation(PoolerActivation):
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
super().__init__()
self.fn = fn
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return self.fn(pooled_data)
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@staticmethod
def for_token_embed(pooler_config: PoolerConfig):
head = TokenEmbeddingPoolerHead()
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_token_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
if pooler_config.pooling_type == "STEP":
return StepPooler(head=head)
return AllPooler(head=head)
@staticmethod
def for_embed(pooler_config: PoolerConfig):
resolved_config = ResolvedPoolingConfig.from_config(
task="embed",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
head = EmbeddingPoolerHead()
return SimplePooler(pooling=pooling, head=head)
@staticmethod
def for_classify(
pooler_config: PoolerConfig,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
):
resolved_config = ResolvedPoolingConfig.from_config(
task="classify",
pooler_config=pooler_config,
)
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
return ClassifierPooler(
pooling=pooling,
classifier=classifier,
act_fn=act_fn,
)
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
class DummyPooler(Pooler):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"plugin", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states
class TokenPoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output one token."""
@abstractmethod
def forward(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
raise NotImplementedError
class EmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: TokenPoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params = pooling_metadata.pooling_params
# for matryoshka representation
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
# if all dimensions are the same
d = dimensions_list[0]
pooled_data = pooled_data[..., :d]
else:
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
# for normalize
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
class SimplePooler(Pooler):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
"""
def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class ClassifierPooler(Pooler):
"""A pooling layer for classification tasks.
This layer does the following:
1. Applies a classification layer to the hidden states.
2. Optionally applies a pooler layer.
3. Applies an activation function to the output.
"""
@staticmethod
def act_fn_for_seq_cls(model_config: ModelConfig):
return get_classification_activation_function(model_config.hf_config)
@staticmethod
def act_fn_for_cross_encoder(model_config: ModelConfig):
return get_cross_encoder_activation_function(model_config.hf_config)
@staticmethod
def resolve_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: PoolerActivation | str | None = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return ClassifierPooler.act_fn_for_seq_cls(model_config)
elif act_fn == "score":
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
else:
raise ValueError(f"act_fn [{act_fn=}] not supported.")
elif act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
else:
assert callable(act_fn)
return act_fn
def __init__(
self,
pooling: PoolingFn,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.pooling = pooling
self.classifier = classifier
self.act_fn = self.resolve_act_fn(
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
)
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
pooling_params = pooling_metadata.pooling_params
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores
class TokenwisePoolerHead(nn.Module, ABC):
"""Applicable to pooling strategies that output multiple tokens."""
@abstractmethod
def forward(
self,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenwisePoolerHeadOutput:
raise NotImplementedError
class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
self.projector = (
_load_st_projector(vllm_config.model_config) if vllm_config else None
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = PoolerNormalize()
def forward(
self,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(TokenwisePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
self.classifier = classifier
self.logit_bias: float | None = (
vllm_config.model_config.pooler_config.logit_bias
)
self.head_dtype = vllm_config.model_config.head_dtype
self.activation = ClassifierPooler.resolve_act_fn(
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
)
def forward(
self,
pooled_data: TokenwisePoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenwisePoolerHeadOutput:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(pooled_data)
else:
scores = pooled_data
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.use_activation:
scores = self.activation(scores)
# scores shape: [n_token, num_labels]
return scores
class AllPooler(Pooler):
def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenwisePoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
class StepPooler(Pooler):
def __init__(self, head: TokenwisePoolerHead) -> None:
super().__init__()
self.pooling = AllPool()
self.head = head
def extract_states(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[torch.Tensor | None]:
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooling_params = pooling_metadata.pooling_params
pooled_data = list[torch.Tensor | None]()
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pooled_data.append(data)
continue
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenwisePoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
super().__init__()
for task, pooler in poolers_by_task.items():
if task not in pooler.get_supported_tasks():
raise ValueError(
f"{pooler=} does not support {task=}. "
f"Supported tasks: {pooler.get_supported_tasks()}"
)
self.poolers_by_task = poolers_by_task
def get_supported_tasks(self) -> Set[PoolingTask]:
return set(self.poolers_by_task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.poolers_by_task[task].get_pooling_updates(task)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[torch.Tensor | None]()
offset = 0
for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(
f"Unsupported task: {task} "
f"Supported tasks: {self.get_supported_tasks()}"
)
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states,
pooling_metadata[offset : offset + num_items],
)
outputs.extend(group_output)
offset += num_items
return outputs
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s

View File

@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .abstract import *
from .common import *
from .special import *

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
import torch
import torch.nn as nn
from vllm.tasks import PoolingTask
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .common import PoolingParamsUpdate
class Pooler(nn.Module, ABC):
"""The interface required for all poolers used in pooling models in vLLM."""
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
"""Determine which pooling tasks are supported."""
raise NotImplementedError
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
raise NotImplementedError
__all__ = ["Pooler"]

View File

@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TypeVar
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config import ModelConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
def get_classification_act_fn(
config: PretrainedConfig,
) -> "PoolerActivation":
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type = getattr(config, "problem_type", "")
if problem_type == "regression":
return PoolerIdentity()
if problem_type == "single_label_classification":
return PoolerClassify()
if problem_type == "multi_label_classification":
return PoolerMultiLabelClassify()
return PoolerClassify()
def get_cross_encoder_act_fn(
config: PretrainedConfig,
) -> "PoolerActivation":
function_name: str | None = None
if (
hasattr(config, "sentence_transformers")
and "activation_fn" in config.sentence_transformers
):
function_name = config.sentence_transformers["activation_fn"]
elif (
hasattr(config, "sbert_ce_default_activation_function")
and config.sbert_ce_default_activation_function is not None
):
function_name = config.sbert_ce_default_activation_function
if function_name is not None:
assert function_name.startswith("torch.nn.modules."), (
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn = resolve_obj_by_qualname(function_name)()
return PoolerActivation.wraps(fn)
return PoolerClassify()
def resolve_classifier_act_fn(
model_config: ModelConfig,
static_num_labels: bool = True,
act_fn: "PoolerActivation | str | None" = None,
):
if isinstance(act_fn, str):
if act_fn == "classify":
return get_classification_act_fn(model_config.hf_config)
if act_fn == "score":
return get_cross_encoder_act_fn(model_config.hf_config)
raise ValueError(f"act_fn [{act_fn=}] not supported.")
if act_fn is None:
return PoolerClassify(static_num_labels=static_num_labels)
assert callable(act_fn)
return act_fn
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
class PoolerActivation(nn.Module, ABC):
@staticmethod
def wraps(module: nn.Module):
if isinstance(module, nn.Identity):
return PoolerIdentity()
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
return PoolerClassify()
return LambdaPoolerActivation(module)
@abstractmethod
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def forward(self, pooled_data: _T) -> _T:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
if isinstance(pooled_data, list):
return [self.forward_chunk(data) for data in pooled_data]
return self.forward_chunk(pooled_data)
class PoolerIdentity(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return pooled_data
class PoolerNormalize(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.normalize(pooled_data, p=2, dim=-1)
class PoolerMultiLabelClassify(PoolerActivation):
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return F.sigmoid(pooled_data)
class PoolerClassify(PoolerActivation):
def __init__(self, *, static_num_labels: bool = True) -> None:
super().__init__()
if static_num_labels:
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
num_labels = getattr(model_config.hf_config, "num_labels", 0)
else:
num_labels = None
if num_labels == 0:
logger.warning(
"num_labels should be > 0 for classification "
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
self.num_labels = num_labels
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
num_labels = self.num_labels
if num_labels is None:
num_labels = pooled_data.shape[-1]
if num_labels < 2:
return F.sigmoid(pooled_data)
return F.softmax(pooled_data, dim=-1)
class LambdaPoolerActivation(PoolerActivation):
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
super().__init__()
self.fn = fn
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
return self.fn(pooled_data)

View File

@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from dataclasses import dataclass
import torch
from vllm.pooling_params import PoolingParams
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
return PoolingParamsUpdate(
requires_token_ids=self.requires_token_ids or other.requires_token_ids,
)
def apply(self, params: PoolingParams) -> None:
params.requires_token_ids = self.requires_token_ids
__all__ = ["ClassifierFn", "PoolingParamsUpdate"]

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output aggregating all tokens in the sequence."""
from .heads import (
ClassifierPoolerHead,
EmbeddingPoolerHead,
SequencePoolerHead,
SequencePoolerHeadOutput,
)
from .methods import (
CLSPool,
LastPool,
MeanPool,
SequencePoolingMethod,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from .poolers import (
SequencePooler,
SequencePoolerOutput,
SequencePoolingFn,
SequencePoolingHeadFn,
pooler_for_classify,
pooler_for_embed,
)
__all__ = [
"SequencePoolerHead",
"SequencePoolerHeadOutput",
"ClassifierPoolerHead",
"EmbeddingPoolerHead",
"SequencePoolingMethod",
"SequencePoolingMethodOutput",
"CLSPool",
"LastPool",
"MeanPool",
"get_seq_pooling_method",
"SequencePooler",
"SequencePoolingFn",
"SequencePoolingHeadFn",
"SequencePoolerOutput",
"pooler_for_classify",
"pooler_for_embed",
]

View File

@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .methods import SequencePoolingMethodOutput
SequencePoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePoolerHead(nn.Module, ABC):
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
@abstractmethod
def forward(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
raise NotImplementedError
class EmbeddingPoolerHead(SequencePoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"embed"}
def forward(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data = pooled_data.to(self.head_dtype)
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [batchsize, embedding_dimension]
# for matryoshka representation
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
if any(d is not None for d in dimensions_list):
# change the output dimension
assert len(pooled_data) == len(dimensions_list)
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
# if all dimensions are the same
d = dimensions_list[0]
pooled_data = pooled_data[..., :d]
else:
pooled_data = [
vecs if d is None else vecs[..., :d]
for vecs, d in zip(pooled_data, dimensions_list)
]
# for normalize
flags = [p.normalize for p in pooling_params]
if len(set(flags)) == 1:
if flags[0]:
pooled_data = self.activation(pooled_data)
else:
pooled_data = [
self.activation(vecs) if f else vecs
for vecs, f in zip(pooled_data, flags)
]
# pooled_data shape: [batchsize, embedding_dimension]
return pooled_data
class ClassifierPoolerHead(SequencePoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=True, act_fn=act_fn
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"classify", "score"}
def forward(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
# pooled_data shape: [batchsize, hidden_size]
pooled_data = pooled_data.to(self.head_dtype)
if self.classifier is not None:
pooled_data = self.classifier(pooled_data)
# pooled_data shape: [batchsize, num_labels]
if self.logit_bias is not None:
pooled_data -= self.logit_bias
flags = [p.use_activation for p in pooling_params]
if len(set(flags)) == 1:
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
else:
scores = [
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
]
# scores shape: [batchsize, num_labels]
return scores

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config.pooler import PoolingTypeStr
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePoolingMethod(nn.Module, ABC):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify", "embed", "classify", "score"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
raise NotImplementedError
class CLSPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with CLS pooling"
)
return hidden_states[pooling_cursor.first_token_indices_gpu]
class LastPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
return hidden_states[pooling_cursor.last_token_indices_gpu]
class MeanPool(SequencePoolingMethod):
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
pooling_cursor = pooling_metadata.get_pooling_cursor()
assert not pooling_cursor.is_partial_prefill(), (
"partial prefill not supported with MEAN pooling"
)
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
hidden_states.device, non_blocking=True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
return (
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
) / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str):
if pooling_type == "LAST":
return LastPool()
if pooling_type == "CLS":
return CLSPool()
if pooling_type == "MEAN":
return MeanPool()
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Set
from typing import TypeAlias
import torch
from vllm.config import PoolerConfig
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .heads import (
ClassifierPoolerHead,
EmbeddingPoolerHead,
SequencePoolerHead,
SequencePoolerHeadOutput,
)
from .methods import (
SequencePoolingMethod,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
SequencePoolingFn: TypeAlias = Callable[
[torch.Tensor, PoolingMetadata],
SequencePoolingMethodOutput,
]
SequencePoolingHeadFn: TypeAlias = Callable[
[SequencePoolingMethodOutput, PoolingMetadata],
SequencePoolerHeadOutput,
]
SequencePoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
class SequencePooler(Pooler):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def __init__(
self,
pooling: SequencePoolingMethod | SequencePoolingFn,
head: SequencePoolerHead | SequencePoolingHeadFn,
) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
tasks = set(POOLING_TASKS)
if isinstance(self.pooling, SequencePoolingMethod):
tasks &= self.pooling.get_supported_tasks()
if isinstance(self.head, SequencePoolerHead):
tasks &= self.head.get_supported_tasks()
return tasks
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
updates = PoolingParamsUpdate()
if isinstance(self.pooling, SequencePoolingMethod):
updates |= self.pooling.get_pooling_updates(task)
return updates
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
def pooler_for_embed(pooler_config: PoolerConfig):
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
head = EmbeddingPoolerHead()
return SequencePooler(pooling=pooling, head=head)
def pooler_for_classify(
pooler_config: PoolerConfig,
*,
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
return SequencePooler(pooling=pooling, head=head)

View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping, Set
from itertools import groupby
import torch
from vllm.config import PoolerConfig
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .abstract import Pooler, PoolerOutput
from .common import ClassifierFn
from .seqwise import (
SequencePoolingFn,
SequencePoolingMethod,
pooler_for_classify,
pooler_for_embed,
)
from .tokwise import AllPool, pooler_for_token_classify, pooler_for_token_embed
class DispatchPooler(Pooler):
"""Dispatches calls to a sub-pooler based on the pooling task."""
@classmethod
def for_embedding(cls, pooler_config: PoolerConfig):
return cls(
{
"token_embed": pooler_for_token_embed(pooler_config),
"embed": pooler_for_embed(pooler_config),
},
)
@classmethod
def for_seq_cls(
cls,
pooler_config: PoolerConfig,
*,
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
classifier: ClassifierFn | None = None,
):
return cls(
{
"token_classify": pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=classifier,
),
"classify": pooler_for_classify(
pooler_config,
pooling=pooling,
classifier=classifier,
act_fn="classify",
),
"score": pooler_for_classify(
pooler_config,
pooling=pooling,
classifier=classifier,
act_fn="score",
),
}
)
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
super().__init__()
for task, pooler in poolers_by_task.items():
if task not in pooler.get_supported_tasks():
raise ValueError(
f"{pooler=} does not support {task=}. "
f"Supported tasks: {pooler.get_supported_tasks()}"
)
self.poolers_by_task = poolers_by_task
def get_supported_tasks(self) -> Set[PoolingTask]:
return set(self.poolers_by_task)
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.poolers_by_task[task].get_pooling_updates(task)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task
outputs = list[torch.Tensor | None]()
offset = 0
for task, group in groupby(pooling_metadata.tasks):
if not (pooler := poolers_by_task.get(task)):
raise ValueError(
f"Unsupported task: {task!r} "
f"Supported tasks: {self.get_supported_tasks()}"
)
num_items = len(list(group))
group_output: PoolerOutput = pooler(
hidden_states,
pooling_metadata[offset : offset + num_items],
)
outputs.extend(group_output)
offset += num_items
return outputs
def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
return s
class IdentityPooler(Pooler):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"plugin", "score"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states
__all__ = ["DispatchPooler", "IdentityPooler"]

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output for each token in the sequence."""
from .heads import (
TokenClassifierPoolerHead,
TokenEmbeddingPoolerHead,
TokenPoolerHead,
TokenPoolerHeadOutputItem,
)
from .methods import (
AllPool,
StepPool,
TokenPoolingMethod,
TokenPoolingMethodOutputItem,
get_tok_pooling_method,
)
from .poolers import (
TokenPooler,
TokenPoolerOutput,
pooler_for_token_classify,
pooler_for_token_embed,
)
__all__ = [
"TokenPoolerHead",
"TokenPoolerHeadOutputItem",
"TokenClassifierPoolerHead",
"TokenEmbeddingPoolerHead",
"TokenPoolingMethod",
"TokenPoolingMethodOutputItem",
"AllPool",
"StepPool",
"get_tok_pooling_method",
"TokenPooler",
"TokenPoolerOutput",
"pooler_for_token_classify",
"pooler_for_token_embed",
]

View File

@@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn
from vllm.model_executor.layers.pooler.activations import (
PoolerActivation,
PoolerNormalize,
resolve_classifier_act_fn,
)
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .methods import TokenPoolingMethodOutputItem
TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None
class TokenPoolerHead(nn.Module, ABC):
@abstractmethod
def get_supported_tasks(self) -> Set[PoolingTask]:
raise NotImplementedError
@abstractmethod
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
raise NotImplementedError
def forward(
self,
pooled_data: list[TokenPoolingMethodOutputItem],
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolerHeadOutputItem]:
pooling_params = pooling_metadata.pooling_params
assert len(pooled_data) == len(pooling_params)
return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]
class TokenEmbeddingPoolerHead(TokenPoolerHead):
def __init__(self) -> None:
super().__init__()
# Load ST projector if available
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.projector = _load_st_projector(model_config)
self.head_dtype = model_config.head_dtype
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed"}
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if self.projector is not None:
pooled_data = self.projector(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data = pooled_data[..., : pooling_param.dimensions]
# for normalize
if pooling_param.normalize:
pooled_data = self.activation(pooled_data)
# pooled_data shape: [n_tokens, embedding_dimension]
return pooled_data
class TokenClassifierPoolerHead(TokenPoolerHead):
def __init__(
self,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
) -> None:
super().__init__()
vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config
self.classifier = classifier
self.logit_bias: float | None = model_config.pooler_config.logit_bias
self.head_dtype = model_config.head_dtype
self.act_fn = resolve_classifier_act_fn(
model_config, static_num_labels=False, act_fn=act_fn
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_classify"}
def forward_chunk(
self,
pooled_data: TokenPoolingMethodOutputItem,
pooling_param: PoolingParams,
) -> TokenPoolerHeadOutputItem:
# for unfinished chunked prefill
if pooled_data is None:
return None
pooled_data = pooled_data.to(self.head_dtype)
# hidden_states shape: [n_token, hidden_size]
if self.classifier is not None:
scores = self.classifier(pooled_data)
else:
scores = pooled_data
# scores shape: [n_token, num_labels]
if self.logit_bias is not None:
scores -= self.logit_bias
if pooling_param.use_activation:
scores = self.act_fn(scores)
# scores shape: [n_token, num_labels]
return scores

View File

@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from typing import TypeAlias
import torch
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.config.pooler import PoolingTypeStr
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
TokenPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
class TokenPoolingMethod(nn.Module, ABC):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"token_embed", "token_classify"}
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate()
@abstractmethod
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
raise NotImplementedError
class AllPool(TokenPoolingMethod):
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
scheduler_config = vllm_config.scheduler_config
self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
pooling_cursor = pooling_metadata.get_pooling_cursor()
hidden_states_all = hidden_states.split(
pooling_cursor.num_scheduled_tokens_cpu.tolist()
)
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
if not self.enable_chunked_prefill:
return hidden_states_lst
pooling_states = pooling_metadata.pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
p.hidden_states_cache.append(hs_chunk)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list = list[TokenPoolingMethodOutputItem]()
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
if finished:
hidden_states_cache = p.hidden_states_cache
if len(hidden_states_cache) == 1:
output_list.append(hidden_states_cache[0])
else:
output_list.append(torch.concat(hidden_states_cache, dim=0))
p.clean()
else:
output_list.append(None)
return output_list
class StepPool(AllPool):
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
pooled_data_lst = super().forward(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
pooling_params = pooling_metadata.pooling_params
pooled_data = list[torch.Tensor | None]()
for data, token_id, pooling_param in zip(
pooled_data_lst, prompt_token_ids, pooling_params
):
# for unfinished chunked prefill
if data is None:
pass
else:
step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
data = data[:, returned_token_ids]
if step_tag_id is not None:
data = data[token_id == step_tag_id]
pooled_data.append(data)
return pooled_data
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str):
if pooling_type == "ALL":
return AllPool()
if pooling_type == "STEP":
return StepPool()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return AllPool()
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Set
from typing import TypeAlias
import torch
from vllm.config import PoolerConfig
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import PoolerActivation
from vllm.tasks import POOLING_TASKS, PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from .heads import (
TokenClassifierPoolerHead,
TokenEmbeddingPoolerHead,
TokenPoolerHead,
TokenPoolerHeadOutputItem,
)
from .methods import (
TokenPoolingMethod,
TokenPoolingMethodOutputItem,
get_tok_pooling_method,
)
TokenPoolingFn: TypeAlias = Callable[
[torch.Tensor, PoolingMetadata],
list[TokenPoolingMethodOutputItem],
]
TokenPoolingHeadFn: TypeAlias = Callable[
[list[TokenPoolingMethodOutputItem], PoolingMetadata],
list[TokenPoolerHeadOutputItem],
]
TokenPoolerOutput: TypeAlias = list[torch.Tensor | None]
class TokenPooler(Pooler):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def __init__(
self,
pooling: TokenPoolingMethod | TokenPoolingFn,
head: TokenPoolerHead | TokenPoolingHeadFn,
) -> None:
super().__init__()
self.pooling = pooling
self.head = head
def get_supported_tasks(self) -> Set[PoolingTask]:
tasks = set(POOLING_TASKS)
if isinstance(self.pooling, TokenPoolingMethod):
tasks &= self.pooling.get_supported_tasks()
if isinstance(self.head, TokenPoolerHead):
tasks &= self.head.get_supported_tasks()
return tasks
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
updates = PoolingParamsUpdate()
if isinstance(self.pooling, TokenPoolingMethod):
updates |= self.pooling.get_pooling_updates(task)
return updates
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
def pooler_for_token_embed(pooler_config: PoolerConfig):
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
head = TokenEmbeddingPoolerHead()
return TokenPooler(pooling=pooling, head=head)
def pooler_for_token_classify(
pooler_config: PoolerConfig,
*,
pooling: TokenPoolingMethod | TokenPoolingFn | None = None,
classifier: ClassifierFn | None = None,
act_fn: PoolerActivation | str | None = None,
):
if pooling is None:
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
return TokenPooler(pooling=pooling, head=head)

View File

@@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T:
return cls
# Lazy import
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
},
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
@@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import (
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from .utils import maybe_prefix
@@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
self.pooler = DispatchPooler.for_seq_cls(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler,
Pooler,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.seqwise import (
CLSPool,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput,
SequencePoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.tokwise import (
pooler_for_token_classify,
pooler_for_token_embed,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding, SupportsQuant
@@ -85,25 +91,21 @@ class BertEmbedding(nn.Module):
return embeddings
class BertPooler(Pooler):
class BertPooler(SequencePooler):
def __init__(self, config: BertConfig):
super().__init__()
super().__init__(
pooling=CLSPool(),
head=self.head,
)
self.pooling = PoolingMethod.from_pooling_type("CLS")
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
@@ -111,15 +113,6 @@ class BertPooler(Pooler):
pooled_data = self.activation(pooled_data)
return pooled_data
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
class BertEncoder(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
@@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
return DispatchPooler.for_embedding(pooler_config)
# Here we encode the token type ids together with the input ids.
@@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler):
remove_cls_sep: bool = True,
):
super().__init__()
assert pooling in ("max", "sum")
self.mlm_head = mlm_head
self.cls_token_id = cls_token_id
@@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
) -> SequencePoolerOutput:
lens_tensor = pooling_metadata.prompt_lens
lens: list[int] = lens_tensor.tolist()
B: int = len(lens)
@@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"token_embed": pooler_for_token_embed(pooler_config),
"embed": SPLADESparsePooler(
mlm_head=self.mlm_head,
cls_token_id=cls_id,
@@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
),
}
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
self.pooler = pooler_for_token_classify(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.bert.embed_input_ids(input_ids)

View File

@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
from .bert import BertPooler
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import default_pooling_type
@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.new.pooler,
classifier=self.classifier,
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
),
}
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert pooler_config is not None
self.pooler_config = pooler_config
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
# Assumes that self.forward is called after self.embed_input_ids
self._is_text_input = True

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from ..layers.pooler import DispatchPooler, Pooler
from .interfaces import SupportsCrossEncoding, SupportsPP
from .utils import (
AutoWeightsLoader,
@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)

View File

@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import (
DispatchPooler,
Pooler,
PoolerNormalize,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
from vllm.model_executor.layers.pooler.seqwise import (
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethod,
SequencePoolingMethodOutput,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.tasks import PoolingTask
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces_base import default_pooling_type
@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
logger = init_logger(__name__)
class GritLMMeanPool(PoolingMethod):
class GritLMMeanPool(SequencePoolingMethod):
"""As `MeanPool`, but only includes non-instruction tokens."""
def __init__(self, model_config: ModelConfig):
@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolingMethodOutput:
) -> SequencePoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens
instr_lens = torch.tensor(
[
@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
return pooled_data
class GritLMPooler(Pooler):
class GritLMPooler(SequencePooler):
def __init__(self, model_config: ModelConfig):
super().__init__()
super().__init__(
pooling=GritLMMeanPool(model_config),
head=self.head,
)
self.pooling = GritLMMeanPool(model_config)
self.activation = PoolerNormalize()
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
) -> SequencePoolerHeadOutput:
return self.activation(pooled_data)
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("MEAN")
class GritLM(LlamaForCausalLM):
@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
if pooler_config is not None:
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"token_embed": pooler_for_token_embed(pooler_config),
"embed": GritLMPooler(vllm_config.model_config),
}
)

View File

@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = pooler_for_token_classify(pooler_config)
def forward(
self,

View File

@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)

View File

@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
self.score = JinaVLScorer(
vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
)
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.score
),
"classify": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="classify"
),
"score": Pooler.for_classify(
pooler_config, classifier=self.score, act_fn="score"
),
}
)
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Set
from collections.abc import Iterable
import torch
from torch import nn
@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
DispatchPooler,
Pooler,
PoolingMethod,
PoolingParamsUpdate,
TokenPoolerHeadOutput,
TokenPoolingMethodOutput,
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import (
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.outputs import TokenPoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata
from .interfaces import SupportsCrossEncoding
@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
return norm_outputs
class ModernBertPooler(Pooler):
class ModernBertPooler(SequencePooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
super().__init__(
pooling=get_seq_pooling_method(config.classifier_pooling.upper()),
head=self.head,
)
pooling_type = config.classifier_pooling.upper()
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(
config.hidden_size, config.hidden_size, config.classifier_bias
)
@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def head(
self,
pooled_data: TokenPoolingMethodOutput,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerHeadOutput:
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = pooled_data.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_data)))
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data
@default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=self.pooling, classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.pooling,
classifier=self.classifier,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)
self.pooler = pooler_for_token_classify(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -14,7 +14,8 @@ from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = pooler_for_token_classify(pooler_config)
@default_pooling_type("STEP")
@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = pooler_for_token_classify(pooler_config)

View File

@@ -8,12 +8,8 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT,
@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=CLSPool(),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import (
@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
assert pooler_config is not None
self.pooler_config = pooler_config
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
self._is_text_input = True

View File

@@ -34,7 +34,7 @@ from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
from vllm.model_executor.layers.pooler import IdentityPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({"plugin": DummyPooler()})
self.pooler = IdentityPooler()
def embed_input_ids(
self,

View File

@@ -22,12 +22,8 @@ import torch
from transformers import AutoModelForSequenceClassification
from vllm.config.utils import getattr_iter
from vllm.model_executor.layers.pooler import (
ClassifierPooler,
CLSPool,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.seqwise import CLSPool
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
self.pooler = DispatchPooler.for_embedding(pooler_config)
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
self.classifier.__class__ = ClassifierWithReshape
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
),
"score": ClassifierPooler(
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
),
}
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=CLSPool(),
classifier=self.classifier,
)

View File

@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None]
@dataclass