diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 4d7a366f0..8122c525f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -153,7 +153,7 @@ mkdocs.yaml @hmellor /vllm/entrypoints/pooling @noooop /vllm/config/pooler.py @noooop /vllm/pooling_params.py @noooop -/vllm/model_executor/layers/pooler.py @noooop +/vllm/model_executor/layers/pooler @noooop # Security guide and policies /docs/usage/security.md @russellb diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 85ef9d9be..4aeae8e36 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -5,7 +5,8 @@ import os import pytest -from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool +from vllm.model_executor.layers.pooler import DispatchPooler +from vllm.model_executor.layers.pooler.seqwise import CLSPool, MeanPool from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.platforms import current_platform diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 98245cdf0..b99c9629a 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix from vllm.sequence import IntermediateTensors @@ -28,12 +28,7 @@ class MyGemma2Embedding(nn.Module): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) + self.pooler = DispatchPooler.for_embedding(pooler_config) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 0c1569c5a..008fefadf 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -88,6 +88,10 @@ class PoolerConfig: # raise deprecated warning for softmax and activation self.use_activation = get_use_activation(self) + def get_pooling_type(self) -> PoolingTypeStr: + assert self.pooling_type is not None, "Should be resolved by ModelConfig" + return self.pooling_type + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py deleted file mode 100644 index 7bb7d8865..000000000 --- a/vllm/model_executor/layers/pooler.py +++ /dev/null @@ -1,845 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping, Set -from dataclasses import dataclass -from itertools import groupby -from typing import TypeAlias, TypeVar - -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import PretrainedConfig - -from vllm.config import ModelConfig, get_current_vllm_config -from vllm.config.pooler import PoolerConfig, PoolingTypeStr -from vllm.logger import init_logger -from vllm.model_executor.models.adapters import _load_st_projector -from vllm.pooling_params import PoolingParams -from vllm.tasks import PoolingTask -from vllm.utils.import_utils import resolve_obj_by_qualname -from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput -from vllm.v1.pool.metadata import PoolingMetadata - -logger = init_logger(__name__) - -PoolingFn = Callable[ - [torch.Tensor | list[torch.Tensor], PoolingMetadata], - torch.Tensor | list[torch.Tensor], -] -ClassifierFn = Callable[[torch.Tensor], torch.Tensor] - - -TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor] -TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None] -TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None -PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput - -TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor] -TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None - - -@dataclass(frozen=True) -class ResolvedPoolingConfig: - pooling_type: PoolingTypeStr - task: PoolingTask - - @classmethod - def from_config( - cls, - task: PoolingTask, - pooler_config: PoolerConfig, - ) -> "ResolvedPoolingConfig": - assert pooler_config.pooling_type is not None - return cls(task=task, pooling_type=pooler_config.pooling_type) - - -@dataclass(frozen=True) -class PoolingParamsUpdate: - requires_token_ids: bool = False - """Set this flag to enable `get_prompt_token_ids` for your pooler.""" - - def apply(self, params: PoolingParams) -> None: - params.requires_token_ids = self.requires_token_ids - - -def get_classification_activation_function(config: PretrainedConfig): - # Implement alignment with transformers ForSequenceClassificationLoss - # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 - problem_type = getattr(config, "problem_type", "") - if problem_type == "regression": - return PoolerIdentity() - if problem_type == "single_label_classification": - return PoolerClassify() - if problem_type == "multi_label_classification": - return PoolerMultiLabelClassify() - return PoolerClassify() - - -def get_cross_encoder_activation_function(config: PretrainedConfig): - function_name: str | None = None - if ( - hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers - ): - function_name = config.sentence_transformers["activation_fn"] - elif ( - hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None - ): - function_name = config.sbert_ce_default_activation_function - - if function_name is not None: - assert function_name.startswith("torch.nn.modules."), ( - "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons" - ) - fn = resolve_obj_by_qualname(function_name)() - return PoolerActivation.wraps(fn) - - return PoolerClassify() - - -class PoolingMethod(nn.Module, ABC): - @staticmethod - def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod": - if pooling_type == "LAST": - return LastPool() - if pooling_type == "ALL": - return AllPool() - if pooling_type == "CLS": - return CLSPool() - if pooling_type == "MEAN": - return MeanPool() - if pooling_type == "STEP": - raise ValueError( - "'STEP' pooling is handled by StepPooler " - "and is not a standalone PoolingMethod." - ) - - raise NotImplementedError(f"Unsupported method: {pooling_type!r}") - - @abstractmethod - def get_supported_tasks(self) -> Set[PoolingTask]: - raise NotImplementedError - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return PoolingParamsUpdate() - - @abstractmethod - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolingMethodOutput: - raise NotImplementedError - - -class CLSPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_embed", "token_classify", "embed", "classify", "score"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolingMethodOutput: - pooling_cursor = pooling_metadata.get_pooling_cursor() - assert not pooling_cursor.is_partial_prefill(), ( - "partial prefill not supported with CLS pooling" - ) - - return hidden_states[pooling_cursor.first_token_indices_gpu] - - -class LastPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_embed", "token_classify", "embed", "classify", "score"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolingMethodOutput: - pooling_cursor = pooling_metadata.get_pooling_cursor() - return hidden_states[pooling_cursor.last_token_indices_gpu] - - -class AllPool(PoolingMethod): - def __init__(self): - super().__init__() - - vllm_config = get_current_vllm_config() - self.enable_chunked_prefill = ( - vllm_config.scheduler_config.enable_chunked_prefill - ) - - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_embed", "token_classify"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenwisePoolingMethodOutput: - pooling_cursor = pooling_metadata.get_pooling_cursor() - hidden_states_all = hidden_states.split( - pooling_cursor.num_scheduled_tokens_cpu.tolist() - ) - hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index] - - if not self.enable_chunked_prefill: - return hidden_states_lst - - pooling_states = pooling_metadata.pooling_states - - # If chunked_prefill is enabled - # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache - for p, hs_chunk in zip(pooling_states, hidden_states_lst): - p.hidden_states_cache.append(hs_chunk) - - # 2. Once prefill is finished, send hidden_states_cache to PoolerHead - output_list = list[torch.Tensor | None]() - for p, finished in zip(pooling_states, pooling_cursor.is_finished()): - if finished: - hidden_states_cache = p.hidden_states_cache - if len(hidden_states_cache) == 1: - output_list.append(hidden_states_cache[0]) - else: - output_list.append(torch.concat(hidden_states_cache, dim=0)) - p.clean() - else: - output_list.append(None) - - return output_list - - -class MeanPool(PoolingMethod): - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_embed", "token_classify", "embed", "classify", "score"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolingMethodOutput: - pooling_cursor = pooling_metadata.get_pooling_cursor() - assert not pooling_cursor.is_partial_prefill(), ( - "partial prefill not supported with MEAN pooling" - ) - - prompt_lens = pooling_cursor.prompt_lens_cpu.to( - hidden_states.device, non_blocking=True - ) - - # Use float32 for torch.cumsum in MeanPool, - # otherwise precision will be lost significantly. - cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) - - start_indices = pooling_cursor.first_token_indices_gpu - end_indices = pooling_cursor.last_token_indices_gpu - return ( - cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices] - ) / prompt_lens.unsqueeze(1) - - -_T = TypeVar("_T", torch.Tensor, list[torch.Tensor]) - - -class BasePoolerActivation(nn.Module, ABC): - @abstractmethod - def forward(self, pooled_data: _T) -> _T: - # shape: - # classify (& score) -> (batch_size, num_classes) - # embed -> (batch_size, embedding_dim) or list(embedding_dim) - # (batch_size, dimensions) or list(dimensions) if using MRL - raise NotImplementedError - - -class PoolerActivation(BasePoolerActivation): - @staticmethod - def wraps(module: nn.Module): - if isinstance(module, nn.Identity): - return PoolerIdentity() - if isinstance(module, (nn.Sigmoid, nn.Softmax)): - return PoolerClassify() - - return LambdaPoolerActivation(module) - - @abstractmethod - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - def forward(self, pooled_data: _T) -> _T: - if isinstance(pooled_data, list): - return [self.forward_chunk(data) for data in pooled_data] - - return self.forward_chunk(pooled_data) - - -class PoolerIdentity(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return pooled_data - - -class PoolerNormalize(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return F.normalize(pooled_data, p=2, dim=-1) - - -class PoolerMultiLabelClassify(PoolerActivation): - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return F.sigmoid(pooled_data) - - -class PoolerClassify(PoolerActivation): - def __init__(self, *, static_num_labels: bool = True) -> None: - super().__init__() - - if static_num_labels: - vllm_config = get_current_vllm_config() - self.num_labels = getattr( - vllm_config.model_config.hf_config, "num_labels", 0 - ) - if self.num_labels == 0: - logger.warning( - "num_labels should be > 0 for classification" - "models, falling back to softmax. " - "Please check if the configuration is correct." - ) - else: - self.num_labels = None - - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - num_labels = ( - self.num_labels if self.num_labels is not None else pooled_data.shape[-1] - ) - - if num_labels < 2: - return F.sigmoid(pooled_data) - - return F.softmax(pooled_data, dim=-1) - - -class LambdaPoolerActivation(PoolerActivation): - def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): - super().__init__() - - self.fn = fn - - def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: - return self.fn(pooled_data) - - -class Pooler(nn.Module, ABC): - """The interface required for all poolers used in pooling models in vLLM.""" - - @staticmethod - def for_token_embed(pooler_config: PoolerConfig): - head = TokenEmbeddingPoolerHead() - - if pooler_config.pooling_type == "STEP": - return StepPooler(head=head) - - return AllPooler(head=head) - - @staticmethod - def for_token_classify( - pooler_config: PoolerConfig, - classifier: ClassifierFn | None = None, - act_fn: PoolerActivation | str | None = None, - ): - head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) - - if pooler_config.pooling_type == "STEP": - return StepPooler(head=head) - - return AllPooler(head=head) - - @staticmethod - def for_embed(pooler_config: PoolerConfig): - resolved_config = ResolvedPoolingConfig.from_config( - task="embed", - pooler_config=pooler_config, - ) - - pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) - head = EmbeddingPoolerHead() - - return SimplePooler(pooling=pooling, head=head) - - @staticmethod - def for_classify( - pooler_config: PoolerConfig, - classifier: ClassifierFn | None, - act_fn: PoolerActivation | str | None = None, - ): - resolved_config = ResolvedPoolingConfig.from_config( - task="classify", - pooler_config=pooler_config, - ) - - pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type) - - return ClassifierPooler( - pooling=pooling, - classifier=classifier, - act_fn=act_fn, - ) - - @abstractmethod - def get_supported_tasks(self) -> Set[PoolingTask]: - """Determine which pooling tasks are supported.""" - raise NotImplementedError - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - """ - Construct the updated pooling parameters to use for a supported task. - """ - return PoolingParamsUpdate() - - @abstractmethod - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - raise NotImplementedError - - -class DummyPooler(Pooler): - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"plugin", "score"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - return hidden_states - - -class TokenPoolerHead(nn.Module, ABC): - """Applicable to pooling strategies that output one token.""" - - @abstractmethod - def forward( - self, - pooled_data: TokenPoolingMethodOutput, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerHeadOutput: - raise NotImplementedError - - -class EmbeddingPoolerHead(TokenPoolerHead): - def __init__(self) -> None: - super().__init__() - - # Load ST projector if available - vllm_config = get_current_vllm_config() - self.projector = ( - _load_st_projector(vllm_config.model_config) if vllm_config else None - ) - self.head_dtype = vllm_config.model_config.head_dtype - - self.activation = PoolerNormalize() - - def forward( - self, - pooled_data: TokenPoolingMethodOutput, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerHeadOutput: - if isinstance(pooled_data, list): - pooled_data = torch.stack(pooled_data) - # pooled_data shape: [batchsize, hidden_dimension] - - pooled_data = pooled_data.to(self.head_dtype) - - # Apply ST projector - if self.projector is not None: - pooled_data = self.projector(pooled_data) - # pooled_data shape: [batchsize, embedding_dimension] - - pooling_params = pooling_metadata.pooling_params - - # for matryoshka representation - dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] - if any(d is not None for d in dimensions_list): - # change the output dimension - assert len(pooled_data) == len(dimensions_list) - if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list): - # if all dimensions are the same - d = dimensions_list[0] - pooled_data = pooled_data[..., :d] - else: - pooled_data = [ - vecs if d is None else vecs[..., :d] - for vecs, d in zip(pooled_data, dimensions_list) - ] - - # for normalize - flags = [p.normalize for p in pooling_params] - if len(set(flags)) == 1: - if flags[0]: - pooled_data = self.activation(pooled_data) - else: - pooled_data = [ - self.activation(vecs) if f else vecs - for vecs, f in zip(pooled_data, flags) - ] - - # pooled_data shape: [batchsize, embedding_dimension] - return pooled_data - - -class SimplePooler(Pooler): - """A layer that pools specific information from hidden states. - - This layer does the following: - 1. Extracts specific tokens or aggregates data based on pooling method. - 2. Normalizes output if specified. - 3. Returns structured results as `PoolerOutput`. - """ - - def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None: - super().__init__() - - self.pooling = pooling - self.head = head - - def get_supported_tasks(self) -> Set[PoolingTask]: - return self.pooling.get_supported_tasks() - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return self.pooling.get_pooling_updates(task) - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerHeadOutput: - pooled_data = self.pooling(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - return pooled_data - - -class ClassifierPooler(Pooler): - """A pooling layer for classification tasks. - - This layer does the following: - 1. Applies a classification layer to the hidden states. - 2. Optionally applies a pooler layer. - 3. Applies an activation function to the output. - """ - - @staticmethod - def act_fn_for_seq_cls(model_config: ModelConfig): - return get_classification_activation_function(model_config.hf_config) - - @staticmethod - def act_fn_for_cross_encoder(model_config: ModelConfig): - return get_cross_encoder_activation_function(model_config.hf_config) - - @staticmethod - def resolve_act_fn( - model_config: ModelConfig, - static_num_labels: bool = True, - act_fn: PoolerActivation | str | None = None, - ): - if isinstance(act_fn, str): - if act_fn == "classify": - return ClassifierPooler.act_fn_for_seq_cls(model_config) - elif act_fn == "score": - return ClassifierPooler.act_fn_for_cross_encoder(model_config) - else: - raise ValueError(f"act_fn [{act_fn=}] not supported.") - elif act_fn is None: - return PoolerClassify(static_num_labels=static_num_labels) - else: - assert callable(act_fn) - return act_fn - - def __init__( - self, - pooling: PoolingFn, - classifier: ClassifierFn | None, - act_fn: PoolerActivation | str | None = None, - ) -> None: - super().__init__() - - vllm_config = get_current_vllm_config() - self.pooling = pooling - self.classifier = classifier - self.act_fn = self.resolve_act_fn( - vllm_config.model_config, static_num_labels=True, act_fn=act_fn - ) - self.logit_bias: float | None = ( - vllm_config.model_config.pooler_config.logit_bias - ) - self.head_dtype = vllm_config.model_config.head_dtype - - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"classify", "score"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerOutput: - pooled_data = self.pooling(hidden_states, pooling_metadata) - if isinstance(pooled_data, list): - pooled_data = torch.stack(pooled_data) - # pooled_data shape: [batchsize, hidden_size] - - pooled_data = pooled_data.to(self.head_dtype) - - if self.classifier is not None: - pooled_data = self.classifier(pooled_data) - # pooled_data shape: [batchsize, num_labels] - - if self.logit_bias is not None: - pooled_data -= self.logit_bias - - pooling_params = pooling_metadata.pooling_params - flags = [p.use_activation for p in pooling_params] - - if len(set(flags)) == 1: - scores = self.act_fn(pooled_data) if flags[0] else pooled_data - else: - scores = [ - self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) - ] - - # scores shape: [batchsize, num_labels] - return scores - - -class TokenwisePoolerHead(nn.Module, ABC): - """Applicable to pooling strategies that output multiple tokens.""" - - @abstractmethod - def forward( - self, - pooled_data: TokenwisePoolingMethodOutputItem, - pooling_param: PoolingParams, - ) -> TokenwisePoolerHeadOutput: - raise NotImplementedError - - -class TokenEmbeddingPoolerHead(TokenwisePoolerHead): - def __init__(self) -> None: - super().__init__() - - # Load ST projector if available - vllm_config = get_current_vllm_config() - self.projector = ( - _load_st_projector(vllm_config.model_config) if vllm_config else None - ) - self.head_dtype = vllm_config.model_config.head_dtype - - self.activation = PoolerNormalize() - - def forward( - self, - pooled_data: TokenwisePoolingMethodOutputItem, - pooling_param: PoolingParams, - ) -> TokenwisePoolerHeadOutput: - # for unfinished chunked prefill - if pooled_data is None: - return None - - pooled_data = pooled_data.to(self.head_dtype) - # pooled_data shape: [n_tokens, hidden_dimension] - - # Apply ST projector - if self.projector is not None: - pooled_data = self.projector(pooled_data) - # pooled_data shape: [n_tokens, embedding_dimension] - - # for matryoshka representation - pooled_data = pooled_data[..., : pooling_param.dimensions] - - # for normalize - if pooling_param.normalize: - pooled_data = self.activation(pooled_data) - - # pooled_data shape: [n_tokens, embedding_dimension] - return pooled_data - - -class TokenClassifierPoolerHead(TokenwisePoolerHead): - def __init__( - self, - classifier: ClassifierFn | None, - act_fn: PoolerActivation | str | None = None, - ) -> None: - super().__init__() - - vllm_config = get_current_vllm_config() - - self.classifier = classifier - self.logit_bias: float | None = ( - vllm_config.model_config.pooler_config.logit_bias - ) - self.head_dtype = vllm_config.model_config.head_dtype - - self.activation = ClassifierPooler.resolve_act_fn( - vllm_config.model_config, static_num_labels=False, act_fn=act_fn - ) - - def forward( - self, - pooled_data: TokenwisePoolingMethodOutputItem, - pooling_param: PoolingParams, - ) -> TokenwisePoolerHeadOutput: - # for unfinished chunked prefill - if pooled_data is None: - return None - - pooled_data = pooled_data.to(self.head_dtype) - # hidden_states shape: [n_token, hidden_size] - - if self.classifier is not None: - scores = self.classifier(pooled_data) - else: - scores = pooled_data - # scores shape: [n_token, num_labels] - - if self.logit_bias is not None: - scores -= self.logit_bias - - if pooling_param.use_activation: - scores = self.activation(scores) - - # scores shape: [n_token, num_labels] - return scores - - -class AllPooler(Pooler): - def __init__(self, head: TokenwisePoolerHead) -> None: - super().__init__() - - self.pooling = AllPool() - self.head = head - - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_embed", "token_classify"} - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenwisePoolerOutput: - pooled_data = self.pooling(hidden_states, pooling_metadata) - pooling_params = pooling_metadata.pooling_params - assert len(pooled_data) == len(pooling_params) - - return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] - - -class StepPooler(Pooler): - def __init__(self, head: TokenwisePoolerHead) -> None: - super().__init__() - - self.pooling = AllPool() - self.head = head - - def extract_states( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor | None]: - pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = pooling_metadata.get_prompt_token_ids() - pooling_params = pooling_metadata.pooling_params - - pooled_data = list[torch.Tensor | None]() - for data, token_id, pooling_param in zip( - pooled_data_lst, prompt_token_ids, pooling_params - ): - # for unfinished chunked prefill - if data is None: - pooled_data.append(data) - continue - - step_tag_id = pooling_param.step_tag_id - returned_token_ids = pooling_param.returned_token_ids - - if returned_token_ids is not None and len(returned_token_ids) > 0: - data = data[:, returned_token_ids] - - if step_tag_id is not None: - data = data[token_id == step_tag_id] - pooled_data.append(data) - - return pooled_data - - def get_supported_tasks(self) -> Set[PoolingTask]: - return {"token_embed", "token_classify"} - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return PoolingParamsUpdate(requires_token_ids=True) - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenwisePoolerOutput: - pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooling_params = pooling_metadata.pooling_params - assert len(pooled_data) == len(pooling_params) - - return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] - - -class DispatchPooler(Pooler): - """Dispatches calls to a sub-pooler based on the pooling task.""" - - def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None: - super().__init__() - - for task, pooler in poolers_by_task.items(): - if task not in pooler.get_supported_tasks(): - raise ValueError( - f"{pooler=} does not support {task=}. " - f"Supported tasks: {pooler.get_supported_tasks()}" - ) - - self.poolers_by_task = poolers_by_task - - def get_supported_tasks(self) -> Set[PoolingTask]: - return set(self.poolers_by_task) - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return self.poolers_by_task[task].get_pooling_updates(task) - - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - poolers_by_task = self.poolers_by_task - - outputs = list[torch.Tensor | None]() - offset = 0 - for task, group in groupby(pooling_metadata.tasks): - if not (pooler := poolers_by_task.get(task)): - raise ValueError( - f"Unsupported task: {task} " - f"Supported tasks: {self.get_supported_tasks()}" - ) - - num_items = len(list(group)) - group_output: PoolerOutput = pooler( - hidden_states, - pooling_metadata[offset : offset + num_items], - ) - - outputs.extend(group_output) - offset += num_items - - return outputs - - def extra_repr(self) -> str: - s = f"supported_task={self.get_supported_tasks()}" - return s diff --git a/vllm/model_executor/layers/pooler/__init__.py b/vllm/model_executor/layers/pooler/__init__.py new file mode 100644 index 000000000..2be361338 --- /dev/null +++ b/vllm/model_executor/layers/pooler/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .abstract import * +from .common import * +from .special import * diff --git a/vllm/model_executor/layers/pooler/abstract.py b/vllm/model_executor/layers/pooler/abstract.py new file mode 100644 index 000000000..82abef4f6 --- /dev/null +++ b/vllm/model_executor/layers/pooler/abstract.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Set + +import torch +import torch.nn as nn + +from vllm.tasks import PoolingTask +from vllm.v1.outputs import PoolerOutput +from vllm.v1.pool.metadata import PoolingMetadata + +from .common import PoolingParamsUpdate + + +class Pooler(nn.Module, ABC): + """The interface required for all poolers used in pooling models in vLLM.""" + + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + """Determine which pooling tasks are supported.""" + raise NotImplementedError + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + """ + Construct the updated pooling parameters to use for a supported task. + """ + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + +__all__ = ["Pooler"] diff --git a/vllm/model_executor/layers/pooler/activations.py b/vllm/model_executor/layers/pooler/activations.py new file mode 100644 index 000000000..b57e6ba68 --- /dev/null +++ b/vllm/model_executor/layers/pooler/activations.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import TypeVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import ModelConfig, get_current_vllm_config +from vllm.logger import init_logger +from vllm.utils.import_utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def get_classification_act_fn( + config: PretrainedConfig, +) -> "PoolerActivation": + # Implement alignment with transformers ForSequenceClassificationLoss + # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 + problem_type = getattr(config, "problem_type", "") + if problem_type == "regression": + return PoolerIdentity() + if problem_type == "single_label_classification": + return PoolerClassify() + if problem_type == "multi_label_classification": + return PoolerMultiLabelClassify() + + return PoolerClassify() + + +def get_cross_encoder_act_fn( + config: PretrainedConfig, +) -> "PoolerActivation": + function_name: str | None = None + if ( + hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers + ): + function_name = config.sentence_transformers["activation_fn"] + elif ( + hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None + ): + function_name = config.sbert_ce_default_activation_function + + if function_name is not None: + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons" + ) + fn = resolve_obj_by_qualname(function_name)() + return PoolerActivation.wraps(fn) + + return PoolerClassify() + + +def resolve_classifier_act_fn( + model_config: ModelConfig, + static_num_labels: bool = True, + act_fn: "PoolerActivation | str | None" = None, +): + if isinstance(act_fn, str): + if act_fn == "classify": + return get_classification_act_fn(model_config.hf_config) + if act_fn == "score": + return get_cross_encoder_act_fn(model_config.hf_config) + + raise ValueError(f"act_fn [{act_fn=}] not supported.") + + if act_fn is None: + return PoolerClassify(static_num_labels=static_num_labels) + + assert callable(act_fn) + return act_fn + + +_T = TypeVar("_T", torch.Tensor, list[torch.Tensor]) + + +class PoolerActivation(nn.Module, ABC): + @staticmethod + def wraps(module: nn.Module): + if isinstance(module, nn.Identity): + return PoolerIdentity() + if isinstance(module, (nn.Sigmoid, nn.Softmax)): + return PoolerClassify() + + return LambdaPoolerActivation(module) + + @abstractmethod + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def forward(self, pooled_data: _T) -> _T: + # shape: + # classify (& score) -> (batch_size, num_classes) + # embed -> (batch_size, embedding_dim) or list(embedding_dim) + # (batch_size, dimensions) or list(dimensions) if using MRL + if isinstance(pooled_data, list): + return [self.forward_chunk(data) for data in pooled_data] + + return self.forward_chunk(pooled_data) + + +class PoolerIdentity(PoolerActivation): + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return pooled_data + + +class PoolerNormalize(PoolerActivation): + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return F.normalize(pooled_data, p=2, dim=-1) + + +class PoolerMultiLabelClassify(PoolerActivation): + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return F.sigmoid(pooled_data) + + +class PoolerClassify(PoolerActivation): + def __init__(self, *, static_num_labels: bool = True) -> None: + super().__init__() + + if static_num_labels: + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + num_labels = getattr(model_config.hf_config, "num_labels", 0) + else: + num_labels = None + + if num_labels == 0: + logger.warning( + "num_labels should be > 0 for classification " + "models, falling back to softmax. " + "Please check if the configuration is correct." + ) + + self.num_labels = num_labels + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + num_labels = self.num_labels + if num_labels is None: + num_labels = pooled_data.shape[-1] + + if num_labels < 2: + return F.sigmoid(pooled_data) + + return F.softmax(pooled_data, dim=-1) + + +class LambdaPoolerActivation(PoolerActivation): + def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): + super().__init__() + + self.fn = fn + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return self.fn(pooled_data) diff --git a/vllm/model_executor/layers/pooler/common.py b/vllm/model_executor/layers/pooler/common.py new file mode 100644 index 000000000..7dc77cf79 --- /dev/null +++ b/vllm/model_executor/layers/pooler/common.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from dataclasses import dataclass + +import torch + +from vllm.pooling_params import PoolingParams + +ClassifierFn = Callable[[torch.Tensor], torch.Tensor] + + +@dataclass(frozen=True) +class PoolingParamsUpdate: + requires_token_ids: bool = False + """Set this flag to enable `get_prompt_token_ids` for your pooler.""" + + def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate": + return PoolingParamsUpdate( + requires_token_ids=self.requires_token_ids or other.requires_token_ids, + ) + + def apply(self, params: PoolingParams) -> None: + params.requires_token_ids = self.requires_token_ids + + +__all__ = ["ClassifierFn", "PoolingParamsUpdate"] diff --git a/vllm/model_executor/layers/pooler/seqwise/__init__.py b/vllm/model_executor/layers/pooler/seqwise/__init__.py new file mode 100644 index 000000000..e1b0476a5 --- /dev/null +++ b/vllm/model_executor/layers/pooler/seqwise/__init__.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Poolers that produce an output aggregating all tokens in the sequence.""" + +from .heads import ( + ClassifierPoolerHead, + EmbeddingPoolerHead, + SequencePoolerHead, + SequencePoolerHeadOutput, +) +from .methods import ( + CLSPool, + LastPool, + MeanPool, + SequencePoolingMethod, + SequencePoolingMethodOutput, + get_seq_pooling_method, +) +from .poolers import ( + SequencePooler, + SequencePoolerOutput, + SequencePoolingFn, + SequencePoolingHeadFn, + pooler_for_classify, + pooler_for_embed, +) + +__all__ = [ + "SequencePoolerHead", + "SequencePoolerHeadOutput", + "ClassifierPoolerHead", + "EmbeddingPoolerHead", + "SequencePoolingMethod", + "SequencePoolingMethodOutput", + "CLSPool", + "LastPool", + "MeanPool", + "get_seq_pooling_method", + "SequencePooler", + "SequencePoolingFn", + "SequencePoolingHeadFn", + "SequencePoolerOutput", + "pooler_for_classify", + "pooler_for_embed", +] diff --git a/vllm/model_executor/layers/pooler/seqwise/heads.py b/vllm/model_executor/layers/pooler/seqwise/heads.py new file mode 100644 index 000000000..24aed94fd --- /dev/null +++ b/vllm/model_executor/layers/pooler/seqwise/heads.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Set +from typing import TypeAlias + +import torch +import torch.nn as nn + +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.pooler import ClassifierFn +from vllm.model_executor.layers.pooler.activations import ( + PoolerActivation, + PoolerNormalize, + resolve_classifier_act_fn, +) +from vllm.model_executor.models.adapters import _load_st_projector +from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +from .methods import SequencePoolingMethodOutput + +SequencePoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor] + + +class SequencePoolerHead(nn.Module, ABC): + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + raise NotImplementedError + + @abstractmethod + def forward( + self, + pooled_data: SequencePoolingMethodOutput, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolerHeadOutput: + raise NotImplementedError + + +class EmbeddingPoolerHead(SequencePoolerHead): + def __init__(self) -> None: + super().__init__() + + # Load ST projector if available + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + + self.projector = _load_st_projector(model_config) + self.head_dtype = model_config.head_dtype + + self.activation = PoolerNormalize() + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"embed"} + + def forward( + self, + pooled_data: SequencePoolingMethodOutput, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolerHeadOutput: + pooling_params = pooling_metadata.pooling_params + assert len(pooled_data) == len(pooling_params) + + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_dimension] + + pooled_data = pooled_data.to(self.head_dtype) + + # Apply ST projector + if self.projector is not None: + pooled_data = self.projector(pooled_data) + # pooled_data shape: [batchsize, embedding_dimension] + + # for matryoshka representation + dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] + if any(d is not None for d in dimensions_list): + # change the output dimension + assert len(pooled_data) == len(dimensions_list) + if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list): + # if all dimensions are the same + d = dimensions_list[0] + pooled_data = pooled_data[..., :d] + else: + pooled_data = [ + vecs if d is None else vecs[..., :d] + for vecs, d in zip(pooled_data, dimensions_list) + ] + + # for normalize + flags = [p.normalize for p in pooling_params] + if len(set(flags)) == 1: + if flags[0]: + pooled_data = self.activation(pooled_data) + else: + pooled_data = [ + self.activation(vecs) if f else vecs + for vecs, f in zip(pooled_data, flags) + ] + + # pooled_data shape: [batchsize, embedding_dimension] + return pooled_data + + +class ClassifierPoolerHead(SequencePoolerHead): + def __init__( + self, + classifier: ClassifierFn | None = None, + act_fn: PoolerActivation | str | None = None, + ) -> None: + super().__init__() + + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + + self.classifier = classifier + self.logit_bias: float | None = model_config.pooler_config.logit_bias + self.head_dtype = model_config.head_dtype + + self.act_fn = resolve_classifier_act_fn( + model_config, static_num_labels=True, act_fn=act_fn + ) + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"classify", "score"} + + def forward( + self, + pooled_data: SequencePoolingMethodOutput, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolerHeadOutput: + pooling_params = pooling_metadata.pooling_params + assert len(pooled_data) == len(pooling_params) + + if isinstance(pooled_data, list): + pooled_data = torch.stack(pooled_data) + # pooled_data shape: [batchsize, hidden_size] + + pooled_data = pooled_data.to(self.head_dtype) + + if self.classifier is not None: + pooled_data = self.classifier(pooled_data) + # pooled_data shape: [batchsize, num_labels] + + if self.logit_bias is not None: + pooled_data -= self.logit_bias + + flags = [p.use_activation for p in pooling_params] + if len(set(flags)) == 1: + scores = self.act_fn(pooled_data) if flags[0] else pooled_data + else: + scores = [ + self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags) + ] + + # scores shape: [batchsize, num_labels] + return scores diff --git a/vllm/model_executor/layers/pooler/seqwise/methods.py b/vllm/model_executor/layers/pooler/seqwise/methods.py new file mode 100644 index 000000000..e71a9de3f --- /dev/null +++ b/vllm/model_executor/layers/pooler/seqwise/methods.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Set +from typing import TypeAlias + +import torch +import torch.nn as nn + +from vllm.config.pooler import PoolingTypeStr +from vllm.model_executor.layers.pooler import PoolingParamsUpdate +from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor] + + +class SequencePoolingMethod(nn.Module, ABC): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed", "token_classify", "embed", "classify", "score"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolingMethodOutput: + raise NotImplementedError + + +class CLSPool(SequencePoolingMethod): + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() + assert not pooling_cursor.is_partial_prefill(), ( + "partial prefill not supported with CLS pooling" + ) + + return hidden_states[pooling_cursor.first_token_indices_gpu] + + +class LastPool(SequencePoolingMethod): + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() + return hidden_states[pooling_cursor.last_token_indices_gpu] + + +class MeanPool(SequencePoolingMethod): + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolingMethodOutput: + pooling_cursor = pooling_metadata.get_pooling_cursor() + assert not pooling_cursor.is_partial_prefill(), ( + "partial prefill not supported with MEAN pooling" + ) + + prompt_lens = pooling_cursor.prompt_lens_cpu.to( + hidden_states.device, non_blocking=True + ) + + # Use float32 for torch.cumsum in MeanPool, + # otherwise precision will be lost significantly. + cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) + + start_indices = pooling_cursor.first_token_indices_gpu + end_indices = pooling_cursor.last_token_indices_gpu + + return ( + cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices] + ) / prompt_lens.unsqueeze(1) + + +def get_seq_pooling_method(pooling_type: PoolingTypeStr | str): + if pooling_type == "LAST": + return LastPool() + if pooling_type == "CLS": + return CLSPool() + if pooling_type == "MEAN": + return MeanPool() + + raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}") diff --git a/vllm/model_executor/layers/pooler/seqwise/poolers.py b/vllm/model_executor/layers/pooler/seqwise/poolers.py new file mode 100644 index 000000000..586dcfb99 --- /dev/null +++ b/vllm/model_executor/layers/pooler/seqwise/poolers.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable, Set +from typing import TypeAlias + +import torch + +from vllm.config import PoolerConfig +from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate +from vllm.model_executor.layers.pooler.abstract import Pooler +from vllm.model_executor.layers.pooler.activations import PoolerActivation +from vllm.tasks import POOLING_TASKS, PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +from .heads import ( + ClassifierPoolerHead, + EmbeddingPoolerHead, + SequencePoolerHead, + SequencePoolerHeadOutput, +) +from .methods import ( + SequencePoolingMethod, + SequencePoolingMethodOutput, + get_seq_pooling_method, +) + +SequencePoolingFn: TypeAlias = Callable[ + [torch.Tensor, PoolingMetadata], + SequencePoolingMethodOutput, +] +SequencePoolingHeadFn: TypeAlias = Callable[ + [SequencePoolingMethodOutput, PoolingMetadata], + SequencePoolerHeadOutput, +] + +SequencePoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] + + +class SequencePooler(Pooler): + """ + A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Postprocesses the output based on pooling head. + 3. Returns structured results as `PoolerOutput`. + """ + + def __init__( + self, + pooling: SequencePoolingMethod | SequencePoolingFn, + head: SequencePoolerHead | SequencePoolingHeadFn, + ) -> None: + super().__init__() + + self.pooling = pooling + self.head = head + + def get_supported_tasks(self) -> Set[PoolingTask]: + tasks = set(POOLING_TASKS) + + if isinstance(self.pooling, SequencePoolingMethod): + tasks &= self.pooling.get_supported_tasks() + if isinstance(self.head, SequencePoolerHead): + tasks &= self.head.get_supported_tasks() + + return tasks + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + updates = PoolingParamsUpdate() + + if isinstance(self.pooling, SequencePoolingMethod): + updates |= self.pooling.get_pooling_updates(task) + + return updates + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> SequencePoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return pooled_data + + +def pooler_for_embed(pooler_config: PoolerConfig): + pooling = get_seq_pooling_method(pooler_config.get_pooling_type()) + head = EmbeddingPoolerHead() + + return SequencePooler(pooling=pooling, head=head) + + +def pooler_for_classify( + pooler_config: PoolerConfig, + *, + pooling: SequencePoolingMethod | SequencePoolingFn | None = None, + classifier: ClassifierFn | None = None, + act_fn: PoolerActivation | str | None = None, +): + if pooling is None: + pooling = get_seq_pooling_method(pooler_config.get_pooling_type()) + + head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn) + + return SequencePooler(pooling=pooling, head=head) diff --git a/vllm/model_executor/layers/pooler/special.py b/vllm/model_executor/layers/pooler/special.py new file mode 100644 index 000000000..425f61a98 --- /dev/null +++ b/vllm/model_executor/layers/pooler/special.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping, Set +from itertools import groupby + +import torch + +from vllm.config import PoolerConfig +from vllm.model_executor.layers.pooler import PoolingParamsUpdate +from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +from .abstract import Pooler, PoolerOutput +from .common import ClassifierFn +from .seqwise import ( + SequencePoolingFn, + SequencePoolingMethod, + pooler_for_classify, + pooler_for_embed, +) +from .tokwise import AllPool, pooler_for_token_classify, pooler_for_token_embed + + +class DispatchPooler(Pooler): + """Dispatches calls to a sub-pooler based on the pooling task.""" + + @classmethod + def for_embedding(cls, pooler_config: PoolerConfig): + return cls( + { + "token_embed": pooler_for_token_embed(pooler_config), + "embed": pooler_for_embed(pooler_config), + }, + ) + + @classmethod + def for_seq_cls( + cls, + pooler_config: PoolerConfig, + *, + pooling: SequencePoolingMethod | SequencePoolingFn | None = None, + classifier: ClassifierFn | None = None, + ): + return cls( + { + "token_classify": pooler_for_token_classify( + pooler_config, + pooling=AllPool(), + classifier=classifier, + ), + "classify": pooler_for_classify( + pooler_config, + pooling=pooling, + classifier=classifier, + act_fn="classify", + ), + "score": pooler_for_classify( + pooler_config, + pooling=pooling, + classifier=classifier, + act_fn="score", + ), + } + ) + + def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None: + super().__init__() + + for task, pooler in poolers_by_task.items(): + if task not in pooler.get_supported_tasks(): + raise ValueError( + f"{pooler=} does not support {task=}. " + f"Supported tasks: {pooler.get_supported_tasks()}" + ) + + self.poolers_by_task = poolers_by_task + + def get_supported_tasks(self) -> Set[PoolingTask]: + return set(self.poolers_by_task) + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.poolers_by_task[task].get_pooling_updates(task) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + poolers_by_task = self.poolers_by_task + + outputs = list[torch.Tensor | None]() + offset = 0 + for task, group in groupby(pooling_metadata.tasks): + if not (pooler := poolers_by_task.get(task)): + raise ValueError( + f"Unsupported task: {task!r} " + f"Supported tasks: {self.get_supported_tasks()}" + ) + + num_items = len(list(group)) + group_output: PoolerOutput = pooler( + hidden_states, + pooling_metadata[offset : offset + num_items], + ) + + outputs.extend(group_output) + offset += num_items + + return outputs + + def extra_repr(self) -> str: + s = f"supported_task={self.get_supported_tasks()}" + return s + + +class IdentityPooler(Pooler): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"plugin", "score"} + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return hidden_states + + +__all__ = ["DispatchPooler", "IdentityPooler"] diff --git a/vllm/model_executor/layers/pooler/tokwise/__init__.py b/vllm/model_executor/layers/pooler/tokwise/__init__.py new file mode 100644 index 000000000..fbc610c85 --- /dev/null +++ b/vllm/model_executor/layers/pooler/tokwise/__init__.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Poolers that produce an output for each token in the sequence.""" + +from .heads import ( + TokenClassifierPoolerHead, + TokenEmbeddingPoolerHead, + TokenPoolerHead, + TokenPoolerHeadOutputItem, +) +from .methods import ( + AllPool, + StepPool, + TokenPoolingMethod, + TokenPoolingMethodOutputItem, + get_tok_pooling_method, +) +from .poolers import ( + TokenPooler, + TokenPoolerOutput, + pooler_for_token_classify, + pooler_for_token_embed, +) + +__all__ = [ + "TokenPoolerHead", + "TokenPoolerHeadOutputItem", + "TokenClassifierPoolerHead", + "TokenEmbeddingPoolerHead", + "TokenPoolingMethod", + "TokenPoolingMethodOutputItem", + "AllPool", + "StepPool", + "get_tok_pooling_method", + "TokenPooler", + "TokenPoolerOutput", + "pooler_for_token_classify", + "pooler_for_token_embed", +] diff --git a/vllm/model_executor/layers/pooler/tokwise/heads.py b/vllm/model_executor/layers/pooler/tokwise/heads.py new file mode 100644 index 000000000..7421ff5c2 --- /dev/null +++ b/vllm/model_executor/layers/pooler/tokwise/heads.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Set +from typing import TypeAlias + +import torch +import torch.nn as nn + +from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.pooler import ClassifierFn +from vllm.model_executor.layers.pooler.activations import ( + PoolerActivation, + PoolerNormalize, + resolve_classifier_act_fn, +) +from vllm.model_executor.models.adapters import _load_st_projector +from vllm.pooling_params import PoolingParams +from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +from .methods import TokenPoolingMethodOutputItem + +TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None + + +class TokenPoolerHead(nn.Module, ABC): + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + raise NotImplementedError + + @abstractmethod + def forward_chunk( + self, + pooled_data: TokenPoolingMethodOutputItem, + pooling_param: PoolingParams, + ) -> TokenPoolerHeadOutputItem: + raise NotImplementedError + + def forward( + self, + pooled_data: list[TokenPoolingMethodOutputItem], + pooling_metadata: PoolingMetadata, + ) -> list[TokenPoolerHeadOutputItem]: + pooling_params = pooling_metadata.pooling_params + assert len(pooled_data) == len(pooling_params) + + return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)] + + +class TokenEmbeddingPoolerHead(TokenPoolerHead): + def __init__(self) -> None: + super().__init__() + + # Load ST projector if available + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + + self.projector = _load_st_projector(model_config) + self.head_dtype = model_config.head_dtype + + self.activation = PoolerNormalize() + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed"} + + def forward_chunk( + self, + pooled_data: TokenPoolingMethodOutputItem, + pooling_param: PoolingParams, + ) -> TokenPoolerHeadOutputItem: + # for unfinished chunked prefill + if pooled_data is None: + return None + + pooled_data = pooled_data.to(self.head_dtype) + # pooled_data shape: [n_tokens, hidden_dimension] + + # Apply ST projector + if self.projector is not None: + pooled_data = self.projector(pooled_data) + # pooled_data shape: [n_tokens, embedding_dimension] + + # for matryoshka representation + pooled_data = pooled_data[..., : pooling_param.dimensions] + + # for normalize + if pooling_param.normalize: + pooled_data = self.activation(pooled_data) + + # pooled_data shape: [n_tokens, embedding_dimension] + return pooled_data + + +class TokenClassifierPoolerHead(TokenPoolerHead): + def __init__( + self, + classifier: ClassifierFn | None = None, + act_fn: PoolerActivation | str | None = None, + ) -> None: + super().__init__() + + vllm_config = get_current_vllm_config() + model_config = vllm_config.model_config + + self.classifier = classifier + self.logit_bias: float | None = model_config.pooler_config.logit_bias + self.head_dtype = model_config.head_dtype + + self.act_fn = resolve_classifier_act_fn( + model_config, static_num_labels=False, act_fn=act_fn + ) + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_classify"} + + def forward_chunk( + self, + pooled_data: TokenPoolingMethodOutputItem, + pooling_param: PoolingParams, + ) -> TokenPoolerHeadOutputItem: + # for unfinished chunked prefill + if pooled_data is None: + return None + + pooled_data = pooled_data.to(self.head_dtype) + # hidden_states shape: [n_token, hidden_size] + + if self.classifier is not None: + scores = self.classifier(pooled_data) + else: + scores = pooled_data + # scores shape: [n_token, num_labels] + + if self.logit_bias is not None: + scores -= self.logit_bias + + if pooling_param.use_activation: + scores = self.act_fn(scores) + + # scores shape: [n_token, num_labels] + return scores diff --git a/vllm/model_executor/layers/pooler/tokwise/methods.py b/vllm/model_executor/layers/pooler/tokwise/methods.py new file mode 100644 index 000000000..4e84f57d7 --- /dev/null +++ b/vllm/model_executor/layers/pooler/tokwise/methods.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Set +from typing import TypeAlias + +import torch +import torch.nn as nn + +from vllm.config import get_current_vllm_config +from vllm.config.pooler import PoolingTypeStr +from vllm.model_executor.layers.pooler import PoolingParamsUpdate +from vllm.tasks import PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +TokenPoolingMethodOutputItem: TypeAlias = torch.Tensor | None + + +class TokenPoolingMethod(nn.Module, ABC): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"token_embed", "token_classify"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> list[TokenPoolingMethodOutputItem]: + raise NotImplementedError + + +class AllPool(TokenPoolingMethod): + def __init__(self): + super().__init__() + + vllm_config = get_current_vllm_config() + scheduler_config = vllm_config.scheduler_config + + self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> list[TokenPoolingMethodOutputItem]: + pooling_cursor = pooling_metadata.get_pooling_cursor() + hidden_states_all = hidden_states.split( + pooling_cursor.num_scheduled_tokens_cpu.tolist() + ) + hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index] + + if not self.enable_chunked_prefill: + return hidden_states_lst + + pooling_states = pooling_metadata.pooling_states + + # If chunked_prefill is enabled + # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache + for p, hs_chunk in zip(pooling_states, hidden_states_lst): + p.hidden_states_cache.append(hs_chunk) + + # 2. Once prefill is finished, send hidden_states_cache to PoolerHead + output_list = list[TokenPoolingMethodOutputItem]() + for p, finished in zip(pooling_states, pooling_cursor.is_finished()): + if finished: + hidden_states_cache = p.hidden_states_cache + if len(hidden_states_cache) == 1: + output_list.append(hidden_states_cache[0]) + else: + output_list.append(torch.concat(hidden_states_cache, dim=0)) + p.clean() + else: + output_list.append(None) + + return output_list + + +class StepPool(AllPool): + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> list[TokenPoolingMethodOutputItem]: + pooled_data_lst = super().forward(hidden_states, pooling_metadata) + prompt_token_ids = pooling_metadata.get_prompt_token_ids() + pooling_params = pooling_metadata.pooling_params + + pooled_data = list[torch.Tensor | None]() + for data, token_id, pooling_param in zip( + pooled_data_lst, prompt_token_ids, pooling_params + ): + # for unfinished chunked prefill + if data is None: + pass + else: + step_tag_id = pooling_param.step_tag_id + returned_token_ids = pooling_param.returned_token_ids + + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + + pooled_data.append(data) + + return pooled_data + + +def get_tok_pooling_method(pooling_type: PoolingTypeStr | str): + if pooling_type == "ALL": + return AllPool() + if pooling_type == "STEP": + return StepPool() + + # TODO: Separate seq and tok pooling types so we don't need this fallback + return AllPool() + raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}") diff --git a/vllm/model_executor/layers/pooler/tokwise/poolers.py b/vllm/model_executor/layers/pooler/tokwise/poolers.py new file mode 100644 index 000000000..ff68359bb --- /dev/null +++ b/vllm/model_executor/layers/pooler/tokwise/poolers.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable, Set +from typing import TypeAlias + +import torch + +from vllm.config import PoolerConfig +from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate +from vllm.model_executor.layers.pooler.abstract import Pooler +from vllm.model_executor.layers.pooler.activations import PoolerActivation +from vllm.tasks import POOLING_TASKS, PoolingTask +from vllm.v1.pool.metadata import PoolingMetadata + +from .heads import ( + TokenClassifierPoolerHead, + TokenEmbeddingPoolerHead, + TokenPoolerHead, + TokenPoolerHeadOutputItem, +) +from .methods import ( + TokenPoolingMethod, + TokenPoolingMethodOutputItem, + get_tok_pooling_method, +) + +TokenPoolingFn: TypeAlias = Callable[ + [torch.Tensor, PoolingMetadata], + list[TokenPoolingMethodOutputItem], +] +TokenPoolingHeadFn: TypeAlias = Callable[ + [list[TokenPoolingMethodOutputItem], PoolingMetadata], + list[TokenPoolerHeadOutputItem], +] + +TokenPoolerOutput: TypeAlias = list[torch.Tensor | None] + + +class TokenPooler(Pooler): + """ + A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Postprocesses the output based on pooling head. + 3. Returns structured results as `PoolerOutput`. + """ + + def __init__( + self, + pooling: TokenPoolingMethod | TokenPoolingFn, + head: TokenPoolerHead | TokenPoolingHeadFn, + ) -> None: + super().__init__() + + self.pooling = pooling + self.head = head + + def get_supported_tasks(self) -> Set[PoolingTask]: + tasks = set(POOLING_TASKS) + + if isinstance(self.pooling, TokenPoolingMethod): + tasks &= self.pooling.get_supported_tasks() + if isinstance(self.head, TokenPoolerHead): + tasks &= self.head.get_supported_tasks() + + return tasks + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + updates = PoolingParamsUpdate() + + if isinstance(self.pooling, TokenPoolingMethod): + updates |= self.pooling.get_pooling_updates(task) + + return updates + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> TokenPoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return pooled_data + + +def pooler_for_token_embed(pooler_config: PoolerConfig): + pooling = get_tok_pooling_method(pooler_config.get_pooling_type()) + head = TokenEmbeddingPoolerHead() + + return TokenPooler(pooling=pooling, head=head) + + +def pooler_for_token_classify( + pooler_config: PoolerConfig, + *, + pooling: TokenPoolingMethod | TokenPoolingFn | None = None, + classifier: ClassifierFn | None = None, + act_fn: PoolerActivation | str | None = None, +): + if pooling is None: + pooling = get_tok_pooling_method(pooler_config.get_pooling_type()) + + head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn) + + return TokenPooler(pooling=pooling, head=head) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 07fa72561..a2c9554c7 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.pooler import DispatchPooler, Pooler + from vllm.model_executor.layers.pooler import DispatchPooler class ModelForEmbedding(_create_pooling_model_cls(cls)): def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - }, - ) + self.pooler = DispatchPooler.for_embedding(pooler_config) ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding") @@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import ReplicatedLinear - from vllm.model_executor.layers.pooler import ( - DispatchPooler, - Pooler, - ) + from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.models.interfaces import SupportsCrossEncoding from .utils import maybe_prefix @@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T: pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.score - ), - "classify": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="classify" - ), - "score": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="score" - ), - } + self.pooler = DispatchPooler.for_seq_cls( + pooler_config, classifier=self.score ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index fd2f02641..b52f6d2bf 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.pooler import ( - ClassifierPooler, DispatchPooler, Pooler, - PoolingMethod, PoolingParamsUpdate, - TokenPoolerHeadOutput, - TokenPoolingMethodOutput, +) +from vllm.model_executor.layers.pooler.seqwise import ( + CLSPool, + SequencePooler, + SequencePoolerHeadOutput, + SequencePoolerOutput, + SequencePoolingMethodOutput, +) +from vllm.model_executor.layers.pooler.tokwise import ( + pooler_for_token_classify, + pooler_for_token_embed, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from vllm.v1.outputs import TokenPoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding, SupportsQuant @@ -85,25 +91,21 @@ class BertEmbedding(nn.Module): return embeddings -class BertPooler(Pooler): +class BertPooler(SequencePooler): def __init__(self, config: BertConfig): - super().__init__() + super().__init__( + pooling=CLSPool(), + head=self.head, + ) - self.pooling = PoolingMethod.from_pooling_type("CLS") self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def get_supported_tasks(self) -> Set[PoolingTask]: - return self.pooling.get_supported_tasks() - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return self.pooling.get_pooling_updates(task) - def head( self, - pooled_data: TokenPoolingMethodOutput, + pooled_data: SequencePoolingMethodOutput, pooling_metadata: PoolingMetadata, - ) -> TokenPoolerHeadOutput: + ) -> SequencePoolerHeadOutput: if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) @@ -111,15 +113,6 @@ class BertPooler(Pooler): pooled_data = self.activation(pooled_data) return pooled_data - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerOutput: - pooled_data = self.pooling(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - return pooled_data - class BertEncoder(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): @@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant): ) def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: - return DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) + return DispatchPooler.for_embedding(pooler_config) # Here we encode the token type ids together with the input ids. @@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler): remove_cls_sep: bool = True, ): super().__init__() + assert pooling in ("max", "sum") self.mlm_head = mlm_head self.cls_token_id = cls_token_id @@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler): self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> torch.Tensor: - assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2 - - lens_tensor: torch.Tensor = pooling_metadata.prompt_lens + ) -> SequencePoolerOutput: + lens_tensor = pooling_metadata.prompt_lens lens: list[int] = lens_tensor.tolist() B: int = len(lens) @@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel): return DispatchPooler( { - "token_embed": Pooler.for_token_embed(pooler_config), + "token_embed": pooler_for_token_embed(pooler_config), "embed": SPLADESparsePooler( mlm_head=self.mlm_head, cls_token_id=cls_id, @@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.classifier - ), - "classify": ClassifierPooler( - pooling=self.bert.pooler, - classifier=self.classifier, - act_fn="classify", - ), - "score": ClassifierPooler( - pooling=self.bert.pooler, classifier=self.classifier, act_fn="score" - ), - } + self.pooler = DispatchPooler.for_seq_cls( + pooler_config, + pooling=self.bert.pooler, + classifier=self.classifier, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config=pooler_config - ), - } - ) + self.pooler = pooler_for_token_classify(pooler_config) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.bert.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 131cb6891..14794fd6a 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import ( ReplicatedLinear, RowParallelLinear, ) +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler from .bert import BertPooler from .interfaces import SupportsCrossEncoding, SupportsQuant from .interfaces_base import default_pooling_type @@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.classifier - ), - "classify": ClassifierPooler( - pooling=self.new.pooler, - classifier=self.classifier, - act_fn="classify", - ), - "score": ClassifierPooler( - pooling=self.new.pooler, classifier=self.classifier, act_fn="score" - ), - } + self.pooler = DispatchPooler.for_seq_cls( + pooler_config, + pooling=self.new.pooler, + classifier=self.classifier, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 6ec700a1c..1eae71f3a 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): assert pooler_config is not None self.pooler_config = pooler_config - self.pooler = DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) + self.pooler = DispatchPooler.for_embedding(pooler_config) # Assumes that self.forward is called after self.embed_input_ids self._is_text_input = True diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index da5d48a94..bacf30d12 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from ..layers.pooler import DispatchPooler, Pooler from .interfaces import SupportsCrossEncoding, SupportsPP from .utils import ( AutoWeightsLoader, @@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.score - ), - "classify": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="classify" - ), - "score": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="score" - ), - } - ) + self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.transformer.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 5bd731e6e..34dbd8050 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.pooler import ( DispatchPooler, - Pooler, - PoolerNormalize, - PoolingMethod, PoolingParamsUpdate, - TokenPoolerHeadOutput, - TokenPoolingMethodOutput, ) +from vllm.model_executor.layers.pooler.activations import PoolerNormalize +from vllm.model_executor.layers.pooler.seqwise import ( + SequencePooler, + SequencePoolerHeadOutput, + SequencePoolingMethod, + SequencePoolingMethodOutput, +) +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.tasks import PoolingTask from vllm.tokenizers import cached_tokenizer_from_config -from vllm.v1.outputs import TokenPoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces_base import default_pooling_type @@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type logger = init_logger(__name__) -class GritLMMeanPool(PoolingMethod): +class GritLMMeanPool(SequencePoolingMethod): """As `MeanPool`, but only includes non-instruction tokens.""" def __init__(self, model_config: ModelConfig): @@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> TokenPoolingMethodOutput: + ) -> SequencePoolingMethodOutput: prompt_lens = pooling_metadata.prompt_lens instr_lens = torch.tensor( [ @@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod): return pooled_data -class GritLMPooler(Pooler): +class GritLMPooler(SequencePooler): def __init__(self, model_config: ModelConfig): - super().__init__() + super().__init__( + pooling=GritLMMeanPool(model_config), + head=self.head, + ) - self.pooling = GritLMMeanPool(model_config) self.activation = PoolerNormalize() - def get_supported_tasks(self) -> Set[PoolingTask]: - return self.pooling.get_supported_tasks() - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return self.pooling.get_pooling_updates(task) - def head( self, - pooled_data: TokenPoolingMethodOutput, + pooled_data: SequencePoolingMethodOutput, pooling_metadata: PoolingMetadata, - ) -> TokenPoolerHeadOutput: + ) -> SequencePoolerHeadOutput: return self.activation(pooled_data) - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerOutput: - pooled_data = self.pooling(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - return pooled_data - @default_pooling_type("MEAN") class GritLM(LlamaForCausalLM): @@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM): if pooler_config is not None: self.pooler = DispatchPooler( { - "token_embed": Pooler.for_token_embed(pooler_config), + "token_embed": pooler_for_token_embed(pooler_config), "embed": GritLMPooler(vllm_config.model_config), } ) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 3ca886461..37309cd09 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"token_classify": Pooler.for_token_classify(pooler_config)} - ) + self.pooler = pooler_for_token_classify(pooler_config) def forward( self, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 39e1226b5..91b58a83e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator, ) -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.score - ), - "classify": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="classify" - ), - "score": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="score" - ), - } - ) + self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py index 7be3d4778..c03fa211a 100644 --- a/vllm/model_executor/models/jina_vl.py +++ b/vllm/model_executor/models/jina_vl.py @@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -105,19 +105,7 @@ class JinaVLForSequenceClassification( self.score = JinaVLScorer( vllm_config.model_config, prefix=maybe_prefix(prefix, "score") ) - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.score - ), - "classify": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="classify" - ), - "score": Pooler.for_classify( - pooler_config, classifier=self.score, act_fn="score" - ), - } - ) + self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 580dbb830..d72b4800c 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Set +from collections.abc import Iterable import torch from torch import nn @@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear -from vllm.model_executor.layers.pooler import ( - ClassifierPooler, - DispatchPooler, - Pooler, - PoolingMethod, - PoolingParamsUpdate, - TokenPoolerHeadOutput, - TokenPoolingMethodOutput, +from vllm.model_executor.layers.pooler import DispatchPooler +from vllm.model_executor.layers.pooler.seqwise import ( + SequencePooler, + SequencePoolerHeadOutput, + SequencePoolingMethodOutput, + get_seq_pooling_method, ) +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.sequence import IntermediateTensors -from vllm.tasks import PoolingTask -from vllm.v1.outputs import TokenPoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces import SupportsCrossEncoding @@ -282,12 +279,13 @@ class ModernBertModel(nn.Module): return norm_outputs -class ModernBertPooler(Pooler): +class ModernBertPooler(SequencePooler): def __init__(self, config: ModernBertConfig): - super().__init__() + super().__init__( + pooling=get_seq_pooling_method(config.classifier_pooling.upper()), + head=self.head, + ) - pooling_type = config.classifier_pooling.upper() - self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.dense = nn.Linear( config.hidden_size, config.hidden_size, config.classifier_bias ) @@ -296,32 +294,17 @@ class ModernBertPooler(Pooler): config.hidden_size, eps=config.norm_eps, bias=config.norm_bias ) - def get_supported_tasks(self) -> Set[PoolingTask]: - return self.pooling.get_supported_tasks() - - def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: - return self.pooling.get_pooling_updates(task) - def head( self, - pooled_data: TokenPoolingMethodOutput, + pooled_data: SequencePoolingMethodOutput, pooling_metadata: PoolingMetadata, - ) -> TokenPoolerHeadOutput: + ) -> SequencePoolerHeadOutput: if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) pooled_data = pooled_data.to(self.dense.weight.dtype) return self.norm(self.act(self.dense(pooled_data))) - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> TokenPoolerOutput: - pooled_data = self.pooling(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - return pooled_data - @default_pooling_type("CLS") class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): @@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.classifier - ), - "classify": ClassifierPooler( - pooling=self.pooling, classifier=self.classifier, act_fn="classify" - ), - "score": ClassifierPooler( - pooling=self.pooling, classifier=self.classifier, act_fn="score" - ), - } + self.pooler = DispatchPooler.for_seq_cls( + pooler_config, + pooling=self.pooling, + classifier=self.classifier, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: @@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config=pooler_config - ), - } - ) + self.pooler = pooler_for_token_classify(pooler_config) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index eac46e0f8..963edcb75 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -14,7 +14,8 @@ from torch import nn from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import Pooler +from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"token_classify": Pooler.for_token_classify(pooler_config)} - ) + self.pooler = pooler_for_token_classify(pooler_config) @default_pooling_type("STEP") @@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - {"token_classify": Pooler.for_token_classify(pooler_config)} - ) + self.pooler = pooler_for_token_classify(pooler_config) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 45b6e9330..647fb70ef 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -8,12 +8,8 @@ from torch import nn from transformers import RobertaConfig from vllm.config import ModelConfig, VllmConfig -from vllm.model_executor.layers.pooler import ( - ClassifierPooler, - CLSPool, - DispatchPooler, - Pooler, -) +from vllm.model_executor.layers.pooler import DispatchPooler +from vllm.model_executor.layers.pooler.seqwise import CLSPool from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.bert import ( TOKEN_TYPE_SHIFT, @@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config=pooler_config, classifier=self.classifier - ), - "classify": ClassifierPooler( - pooling=CLSPool(), classifier=self.classifier, act_fn="classify" - ), - "score": ClassifierPooler( - pooling=CLSPool(), classifier=self.classifier, act_fn="score" - ), - } + self.pooler = DispatchPooler.for_seq_cls( + pooler_config, + pooling=CLSPool(), + classifier=self.classifier, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 85772c11a..e39ae4340 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( @@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant): assert pooler_config is not None self.pooler_config = pooler_config - self.pooler = DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) + self.pooler = DispatchPooler.for_embedding(pooler_config) self._is_text_input = True diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 402081a70..c97af0db5 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -34,7 +34,7 @@ from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler +from vllm.model_executor.layers.pooler import IdentityPooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler({"plugin": DummyPooler()}) + self.pooler = IdentityPooler() def embed_input_ids( self, diff --git a/vllm/model_executor/models/transformers/pooling.py b/vllm/model_executor/models/transformers/pooling.py index 4c2a74bcc..470ca48ee 100644 --- a/vllm/model_executor/models/transformers/pooling.py +++ b/vllm/model_executor/models/transformers/pooling.py @@ -22,12 +22,8 @@ import torch from transformers import AutoModelForSequenceClassification from vllm.config.utils import getattr_iter -from vllm.model_executor.layers.pooler import ( - ClassifierPooler, - CLSPool, - DispatchPooler, - Pooler, -) +from vllm.model_executor.layers.pooler import DispatchPooler +from vllm.model_executor.layers.pooler.seqwise import CLSPool from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.interfaces_base import VllmModelForPooling @@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling): pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None - self.pooler = DispatchPooler( - { - "token_embed": Pooler.for_token_embed(pooler_config), - "embed": Pooler.for_embed(pooler_config), - } - ) + self.pooler = DispatchPooler.for_embedding(pooler_config) class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): @@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling): self.classifier.__class__ = ClassifierWithReshape - self.pooler = DispatchPooler( - { - "token_classify": Pooler.for_token_classify( - pooler_config, classifier=self.classifier - ), - "classify": ClassifierPooler( - pooling=CLSPool(), classifier=self.classifier, act_fn="classify" - ), - "score": ClassifierPooler( - pooling=CLSPool(), classifier=self.classifier, act_fn="score" - ), - } + self.pooler = DispatchPooler.for_seq_cls( + pooler_config, + pooling=CLSPool(), + classifier=self.classifier, ) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 60643044c..b95ba3ad7 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple): # [num_reqs, ] # The shape of each element depends on the pooler used -TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] -TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None] -PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput +PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None] @dataclass