[Model] Reorganize pooling layers (#31973)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -153,7 +153,7 @@ mkdocs.yaml @hmellor
|
||||
/vllm/entrypoints/pooling @noooop
|
||||
/vllm/config/pooler.py @noooop
|
||||
/vllm/pooling_params.py @noooop
|
||||
/vllm/model_executor/layers/pooler.py @noooop
|
||||
/vllm/model_executor/layers/pooler @noooop
|
||||
|
||||
# Security guide and policies
|
||||
/docs/usage/security.md @russellb
|
||||
|
||||
@@ -5,7 +5,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import CLSPool, MeanPool
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -28,12 +28,7 @@ class MyGemma2Embedding(nn.Module):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -88,6 +88,10 @@ class PoolerConfig:
|
||||
# raise deprecated warning for softmax and activation
|
||||
self.use_activation = get_use_activation(self)
|
||||
|
||||
def get_pooling_type(self) -> PoolingTypeStr:
|
||||
assert self.pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.pooling_type
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@@ -1,845 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Mapping, Set
|
||||
from dataclasses import dataclass
|
||||
from itertools import groupby
|
||||
from typing import TypeAlias, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, get_current_vllm_config
|
||||
from vllm.config.pooler import PoolerConfig, PoolingTypeStr
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.v1.outputs import PoolerOutput, TokenPoolerOutput, TokenwisePoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PoolingFn = Callable[
|
||||
[torch.Tensor | list[torch.Tensor], PoolingMetadata],
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
]
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
TokenPoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokenwisePoolingMethodOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
TokenwisePoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
||||
PoolingMethodOutput: TypeAlias = TokenPoolingMethodOutput | TokenwisePoolingMethodOutput
|
||||
|
||||
TokenPoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokenwisePoolerHeadOutput: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedPoolingConfig:
|
||||
pooling_type: PoolingTypeStr
|
||||
task: PoolingTask
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
task: PoolingTask,
|
||||
pooler_config: PoolerConfig,
|
||||
) -> "ResolvedPoolingConfig":
|
||||
assert pooler_config.pooling_type is not None
|
||||
return cls(task=task, pooling_type=pooler_config.pooling_type)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PoolingParamsUpdate:
|
||||
requires_token_ids: bool = False
|
||||
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
|
||||
|
||||
def apply(self, params: PoolingParams) -> None:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
def get_classification_activation_function(config: PretrainedConfig):
|
||||
# Implement alignment with transformers ForSequenceClassificationLoss
|
||||
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
||||
problem_type = getattr(config, "problem_type", "")
|
||||
if problem_type == "regression":
|
||||
return PoolerIdentity()
|
||||
if problem_type == "single_label_classification":
|
||||
return PoolerClassify()
|
||||
if problem_type == "multi_label_classification":
|
||||
return PoolerMultiLabelClassify()
|
||||
return PoolerClassify()
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
function_name: str | None = None
|
||||
if (
|
||||
hasattr(config, "sentence_transformers")
|
||||
and "activation_fn" in config.sentence_transformers
|
||||
):
|
||||
function_name = config.sentence_transformers["activation_fn"]
|
||||
elif (
|
||||
hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None
|
||||
):
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
|
||||
if function_name is not None:
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons"
|
||||
)
|
||||
fn = resolve_obj_by_qualname(function_name)()
|
||||
return PoolerActivation.wraps(fn)
|
||||
|
||||
return PoolerClassify()
|
||||
|
||||
|
||||
class PoolingMethod(nn.Module, ABC):
|
||||
@staticmethod
|
||||
def from_pooling_type(pooling_type: PoolingTypeStr) -> "PoolingMethod":
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
if pooling_type == "ALL":
|
||||
return AllPool()
|
||||
if pooling_type == "CLS":
|
||||
return CLSPool()
|
||||
if pooling_type == "MEAN":
|
||||
return MeanPool()
|
||||
if pooling_type == "STEP":
|
||||
raise ValueError(
|
||||
"'STEP' pooling is handled by StepPooler "
|
||||
"and is not a standalone PoolingMethod."
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unsupported method: {pooling_type!r}")
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolingMethodOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
|
||||
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
||||
|
||||
|
||||
class LastPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class AllPool(PoolingMethod):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.enable_chunked_prefill = (
|
||||
vllm_config.scheduler_config.enable_chunked_prefill
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenwisePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
hidden_states_all = hidden_states.split(
|
||||
pooling_cursor.num_scheduled_tokens_cpu.tolist()
|
||||
)
|
||||
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
|
||||
|
||||
if not self.enable_chunked_prefill:
|
||||
return hidden_states_lst
|
||||
|
||||
pooling_states = pooling_metadata.pooling_states
|
||||
|
||||
# If chunked_prefill is enabled
|
||||
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
|
||||
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
|
||||
p.hidden_states_cache.append(hs_chunk)
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list = list[torch.Tensor | None]()
|
||||
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
if len(hidden_states_cache) == 1:
|
||||
output_list.append(hidden_states_cache[0])
|
||||
else:
|
||||
output_list.append(torch.concat(hidden_states_cache, dim=0))
|
||||
p.clean()
|
||||
else:
|
||||
output_list.append(None)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
class MeanPool(PoolingMethod):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
|
||||
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
|
||||
start_indices = pooling_cursor.first_token_indices_gpu
|
||||
end_indices = pooling_cursor.last_token_indices_gpu
|
||||
return (
|
||||
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
|
||||
) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
|
||||
|
||||
|
||||
class BasePoolerActivation(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(self, pooled_data: _T) -> _T:
|
||||
# shape:
|
||||
# classify (& score) -> (batch_size, num_classes)
|
||||
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
|
||||
# (batch_size, dimensions) or list(dimensions) if using MRL
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PoolerActivation(BasePoolerActivation):
|
||||
@staticmethod
|
||||
def wraps(module: nn.Module):
|
||||
if isinstance(module, nn.Identity):
|
||||
return PoolerIdentity()
|
||||
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
|
||||
return PoolerClassify()
|
||||
|
||||
return LambdaPoolerActivation(module)
|
||||
|
||||
@abstractmethod
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pooled_data: _T) -> _T:
|
||||
if isinstance(pooled_data, list):
|
||||
return [self.forward_chunk(data) for data in pooled_data]
|
||||
|
||||
return self.forward_chunk(pooled_data)
|
||||
|
||||
|
||||
class PoolerIdentity(PoolerActivation):
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return pooled_data
|
||||
|
||||
|
||||
class PoolerNormalize(PoolerActivation):
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return F.normalize(pooled_data, p=2, dim=-1)
|
||||
|
||||
|
||||
class PoolerMultiLabelClassify(PoolerActivation):
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return F.sigmoid(pooled_data)
|
||||
|
||||
|
||||
class PoolerClassify(PoolerActivation):
|
||||
def __init__(self, *, static_num_labels: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
if static_num_labels:
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.num_labels = getattr(
|
||||
vllm_config.model_config.hf_config, "num_labels", 0
|
||||
)
|
||||
if self.num_labels == 0:
|
||||
logger.warning(
|
||||
"num_labels should be > 0 for classification"
|
||||
"models, falling back to softmax. "
|
||||
"Please check if the configuration is correct."
|
||||
)
|
||||
else:
|
||||
self.num_labels = None
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
num_labels = (
|
||||
self.num_labels if self.num_labels is not None else pooled_data.shape[-1]
|
||||
)
|
||||
|
||||
if num_labels < 2:
|
||||
return F.sigmoid(pooled_data)
|
||||
|
||||
return F.softmax(pooled_data, dim=-1)
|
||||
|
||||
|
||||
class LambdaPoolerActivation(PoolerActivation):
|
||||
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
self.fn = fn
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return self.fn(pooled_data)
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@staticmethod
|
||||
def for_token_embed(pooler_config: PoolerConfig):
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler(head=head)
|
||||
|
||||
return AllPooler(head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_token_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
if pooler_config.pooling_type == "STEP":
|
||||
return StepPooler(head=head)
|
||||
|
||||
return AllPooler(head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_embed(pooler_config: PoolerConfig):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="embed",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
return SimplePooler(pooling=pooling, head=head)
|
||||
|
||||
@staticmethod
|
||||
def for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
resolved_config = ResolvedPoolingConfig.from_config(
|
||||
task="classify",
|
||||
pooler_config=pooler_config,
|
||||
)
|
||||
|
||||
pooling = PoolingMethod.from_pooling_type(resolved_config.pooling_type)
|
||||
|
||||
return ClassifierPooler(
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyPooler(Pooler):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"plugin", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TokenPoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output one token."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
# for matryoshka representation
|
||||
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
|
||||
# if all dimensions are the same
|
||||
d = dimensions_list[0]
|
||||
pooled_data = pooled_data[..., :d]
|
||||
else:
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
# for normalize
|
||||
flags = [p.normalize for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class SimplePooler(Pooler):
|
||||
"""A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Normalizes output if specified.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(self, pooling: PoolingMethod, head: TokenPoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPooler(Pooler):
|
||||
"""A pooling layer for classification tasks.
|
||||
|
||||
This layer does the following:
|
||||
1. Applies a classification layer to the hidden states.
|
||||
2. Optionally applies a pooler layer.
|
||||
3. Applies an activation function to the output.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_seq_cls(model_config: ModelConfig):
|
||||
return get_classification_activation_function(model_config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def act_fn_for_cross_encoder(model_config: ModelConfig):
|
||||
return get_cross_encoder_activation_function(model_config.hf_config)
|
||||
|
||||
@staticmethod
|
||||
def resolve_act_fn(
|
||||
model_config: ModelConfig,
|
||||
static_num_labels: bool = True,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "classify":
|
||||
return ClassifierPooler.act_fn_for_seq_cls(model_config)
|
||||
elif act_fn == "score":
|
||||
return ClassifierPooler.act_fn_for_cross_encoder(model_config)
|
||||
else:
|
||||
raise ValueError(f"act_fn [{act_fn=}] not supported.")
|
||||
elif act_fn is None:
|
||||
return PoolerClassify(static_num_labels=static_num_labels)
|
||||
else:
|
||||
assert callable(act_fn)
|
||||
return act_fn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling: PoolingFn,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.pooling = pooling
|
||||
self.classifier = classifier
|
||||
self.act_fn = self.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=True, act_fn=act_fn
|
||||
)
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
|
||||
if len(set(flags)) == 1:
|
||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
scores = [
|
||||
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class TokenwisePoolerHead(nn.Module, ABC):
|
||||
"""Applicable to pooling strategies that output multiple tokens."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokenwisePoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenwisePoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokenwisePoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.projector = (
|
||||
_load_st_projector(vllm_config.model_config) if vllm_config else None
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokenwisePoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenwisePoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if pooling_param.normalize:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(TokenwisePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias: float | None = (
|
||||
vllm_config.model_config.pooler_config.logit_bias
|
||||
)
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
self.activation = ClassifierPooler.resolve_act_fn(
|
||||
vllm_config.model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: TokenwisePoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenwisePoolerHeadOutput:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(pooled_data)
|
||||
else:
|
||||
scores = pooled_data
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.use_activation:
|
||||
scores = self.activation(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
|
||||
|
||||
class AllPooler(Pooler):
|
||||
def __init__(self, head: TokenwisePoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenwisePoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class StepPooler(Pooler):
|
||||
def __init__(self, head: TokenwisePoolerHead) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = AllPool()
|
||||
self.head = head
|
||||
|
||||
def extract_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[torch.Tensor | None]:
|
||||
pooled_data_lst = self.pooling(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
pooled_data = list[torch.Tensor | None]()
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
# for unfinished chunked prefill
|
||||
if data is None:
|
||||
pooled_data.append(data)
|
||||
continue
|
||||
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenwisePoolerOutput:
|
||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
return [self.head(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||
|
||||
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
|
||||
super().__init__()
|
||||
|
||||
for task, pooler in poolers_by_task.items():
|
||||
if task not in pooler.get_supported_tasks():
|
||||
raise ValueError(
|
||||
f"{pooler=} does not support {task=}. "
|
||||
f"Supported tasks: {pooler.get_supported_tasks()}"
|
||||
)
|
||||
|
||||
self.poolers_by_task = poolers_by_task
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return set(self.poolers_by_task)
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.poolers_by_task[task].get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
outputs = list[torch.Tensor | None]()
|
||||
offset = 0
|
||||
for task, group in groupby(pooling_metadata.tasks):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
raise ValueError(
|
||||
f"Unsupported task: {task} "
|
||||
f"Supported tasks: {self.get_supported_tasks()}"
|
||||
)
|
||||
|
||||
num_items = len(list(group))
|
||||
group_output: PoolerOutput = pooler(
|
||||
hidden_states,
|
||||
pooling_metadata[offset : offset + num_items],
|
||||
)
|
||||
|
||||
outputs.extend(group_output)
|
||||
offset += num_items
|
||||
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"supported_task={self.get_supported_tasks()}"
|
||||
return s
|
||||
5
vllm/model_executor/layers/pooler/__init__.py
Normal file
5
vllm/model_executor/layers/pooler/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .abstract import *
|
||||
from .common import *
|
||||
from .special import *
|
||||
39
vllm/model_executor/layers/pooler/abstract.py
Normal file
39
vllm/model_executor/layers/pooler/abstract.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import PoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .common import PoolingParamsUpdate
|
||||
|
||||
|
||||
class Pooler(nn.Module, ABC):
|
||||
"""The interface required for all poolers used in pooling models in vLLM."""
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
"""Determine which pooling tasks are supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
"""
|
||||
Construct the updated pooling parameters to use for a supported task.
|
||||
"""
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
__all__ = ["Pooler"]
|
||||
162
vllm/model_executor/layers/pooler/activations.py
Normal file
162
vllm/model_executor/layers/pooler/activations.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_classification_act_fn(
|
||||
config: PretrainedConfig,
|
||||
) -> "PoolerActivation":
|
||||
# Implement alignment with transformers ForSequenceClassificationLoss
|
||||
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
|
||||
problem_type = getattr(config, "problem_type", "")
|
||||
if problem_type == "regression":
|
||||
return PoolerIdentity()
|
||||
if problem_type == "single_label_classification":
|
||||
return PoolerClassify()
|
||||
if problem_type == "multi_label_classification":
|
||||
return PoolerMultiLabelClassify()
|
||||
|
||||
return PoolerClassify()
|
||||
|
||||
|
||||
def get_cross_encoder_act_fn(
|
||||
config: PretrainedConfig,
|
||||
) -> "PoolerActivation":
|
||||
function_name: str | None = None
|
||||
if (
|
||||
hasattr(config, "sentence_transformers")
|
||||
and "activation_fn" in config.sentence_transformers
|
||||
):
|
||||
function_name = config.sentence_transformers["activation_fn"]
|
||||
elif (
|
||||
hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None
|
||||
):
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
|
||||
if function_name is not None:
|
||||
assert function_name.startswith("torch.nn.modules."), (
|
||||
"Loading of activation functions is restricted to "
|
||||
"torch.nn.modules for security reasons"
|
||||
)
|
||||
fn = resolve_obj_by_qualname(function_name)()
|
||||
return PoolerActivation.wraps(fn)
|
||||
|
||||
return PoolerClassify()
|
||||
|
||||
|
||||
def resolve_classifier_act_fn(
|
||||
model_config: ModelConfig,
|
||||
static_num_labels: bool = True,
|
||||
act_fn: "PoolerActivation | str | None" = None,
|
||||
):
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "classify":
|
||||
return get_classification_act_fn(model_config.hf_config)
|
||||
if act_fn == "score":
|
||||
return get_cross_encoder_act_fn(model_config.hf_config)
|
||||
|
||||
raise ValueError(f"act_fn [{act_fn=}] not supported.")
|
||||
|
||||
if act_fn is None:
|
||||
return PoolerClassify(static_num_labels=static_num_labels)
|
||||
|
||||
assert callable(act_fn)
|
||||
return act_fn
|
||||
|
||||
|
||||
_T = TypeVar("_T", torch.Tensor, list[torch.Tensor])
|
||||
|
||||
|
||||
class PoolerActivation(nn.Module, ABC):
|
||||
@staticmethod
|
||||
def wraps(module: nn.Module):
|
||||
if isinstance(module, nn.Identity):
|
||||
return PoolerIdentity()
|
||||
if isinstance(module, (nn.Sigmoid, nn.Softmax)):
|
||||
return PoolerClassify()
|
||||
|
||||
return LambdaPoolerActivation(module)
|
||||
|
||||
@abstractmethod
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pooled_data: _T) -> _T:
|
||||
# shape:
|
||||
# classify (& score) -> (batch_size, num_classes)
|
||||
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
|
||||
# (batch_size, dimensions) or list(dimensions) if using MRL
|
||||
if isinstance(pooled_data, list):
|
||||
return [self.forward_chunk(data) for data in pooled_data]
|
||||
|
||||
return self.forward_chunk(pooled_data)
|
||||
|
||||
|
||||
class PoolerIdentity(PoolerActivation):
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return pooled_data
|
||||
|
||||
|
||||
class PoolerNormalize(PoolerActivation):
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return F.normalize(pooled_data, p=2, dim=-1)
|
||||
|
||||
|
||||
class PoolerMultiLabelClassify(PoolerActivation):
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return F.sigmoid(pooled_data)
|
||||
|
||||
|
||||
class PoolerClassify(PoolerActivation):
|
||||
def __init__(self, *, static_num_labels: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
if static_num_labels:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
num_labels = getattr(model_config.hf_config, "num_labels", 0)
|
||||
else:
|
||||
num_labels = None
|
||||
|
||||
if num_labels == 0:
|
||||
logger.warning(
|
||||
"num_labels should be > 0 for classification "
|
||||
"models, falling back to softmax. "
|
||||
"Please check if the configuration is correct."
|
||||
)
|
||||
|
||||
self.num_labels = num_labels
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
num_labels = self.num_labels
|
||||
if num_labels is None:
|
||||
num_labels = pooled_data.shape[-1]
|
||||
|
||||
if num_labels < 2:
|
||||
return F.sigmoid(pooled_data)
|
||||
|
||||
return F.softmax(pooled_data, dim=-1)
|
||||
|
||||
|
||||
class LambdaPoolerActivation(PoolerActivation):
|
||||
def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]):
|
||||
super().__init__()
|
||||
|
||||
self.fn = fn
|
||||
|
||||
def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
|
||||
return self.fn(pooled_data)
|
||||
27
vllm/model_executor/layers/pooler/common.py
Normal file
27
vllm/model_executor/layers/pooler/common.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PoolingParamsUpdate:
|
||||
requires_token_ids: bool = False
|
||||
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
|
||||
|
||||
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
|
||||
return PoolingParamsUpdate(
|
||||
requires_token_ids=self.requires_token_ids or other.requires_token_ids,
|
||||
)
|
||||
|
||||
def apply(self, params: PoolingParams) -> None:
|
||||
params.requires_token_ids = self.requires_token_ids
|
||||
|
||||
|
||||
__all__ = ["ClassifierFn", "PoolingParamsUpdate"]
|
||||
45
vllm/model_executor/layers/pooler/seqwise/__init__.py
Normal file
45
vllm/model_executor/layers/pooler/seqwise/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Poolers that produce an output aggregating all tokens in the sequence."""
|
||||
|
||||
from .heads import (
|
||||
ClassifierPoolerHead,
|
||||
EmbeddingPoolerHead,
|
||||
SequencePoolerHead,
|
||||
SequencePoolerHeadOutput,
|
||||
)
|
||||
from .methods import (
|
||||
CLSPool,
|
||||
LastPool,
|
||||
MeanPool,
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from .poolers import (
|
||||
SequencePooler,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingFn,
|
||||
SequencePoolingHeadFn,
|
||||
pooler_for_classify,
|
||||
pooler_for_embed,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SequencePoolerHead",
|
||||
"SequencePoolerHeadOutput",
|
||||
"ClassifierPoolerHead",
|
||||
"EmbeddingPoolerHead",
|
||||
"SequencePoolingMethod",
|
||||
"SequencePoolingMethodOutput",
|
||||
"CLSPool",
|
||||
"LastPool",
|
||||
"MeanPool",
|
||||
"get_seq_pooling_method",
|
||||
"SequencePooler",
|
||||
"SequencePoolingFn",
|
||||
"SequencePoolingHeadFn",
|
||||
"SequencePoolerOutput",
|
||||
"pooler_for_classify",
|
||||
"pooler_for_embed",
|
||||
]
|
||||
157
vllm/model_executor/layers/pooler/seqwise/heads.py
Normal file
157
vllm/model_executor/layers/pooler/seqwise/heads.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .methods import SequencePoolingMethodOutput
|
||||
|
||||
SequencePoolerHeadOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePoolerHead(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EmbeddingPoolerHead(SequencePoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.projector = _load_st_projector(model_config)
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"embed"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_dimension]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params]
|
||||
if any(d is not None for d in dimensions_list):
|
||||
# change the output dimension
|
||||
assert len(pooled_data) == len(dimensions_list)
|
||||
if len(set(dimensions_list)) == 1 and not isinstance(pooled_data, list):
|
||||
# if all dimensions are the same
|
||||
d = dimensions_list[0]
|
||||
pooled_data = pooled_data[..., :d]
|
||||
else:
|
||||
pooled_data = [
|
||||
vecs if d is None else vecs[..., :d]
|
||||
for vecs, d in zip(pooled_data, dimensions_list)
|
||||
]
|
||||
|
||||
# for normalize
|
||||
flags = [p.normalize for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
if flags[0]:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
else:
|
||||
pooled_data = [
|
||||
self.activation(vecs) if f else vecs
|
||||
for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# pooled_data shape: [batchsize, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class ClassifierPoolerHead(SequencePoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias: float | None = model_config.pooler_config.logit_bias
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.act_fn = resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=True, act_fn=act_fn
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"classify", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
# pooled_data shape: [batchsize, hidden_size]
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
|
||||
if self.classifier is not None:
|
||||
pooled_data = self.classifier(pooled_data)
|
||||
# pooled_data shape: [batchsize, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
pooled_data -= self.logit_bias
|
||||
|
||||
flags = [p.use_activation for p in pooling_params]
|
||||
if len(set(flags)) == 1:
|
||||
scores = self.act_fn(pooled_data) if flags[0] else pooled_data
|
||||
else:
|
||||
scores = [
|
||||
self.act_fn(vecs) if f else vecs for vecs, f in zip(pooled_data, flags)
|
||||
]
|
||||
|
||||
# scores shape: [batchsize, num_labels]
|
||||
return scores
|
||||
93
vllm/model_executor/layers/pooler/seqwise/methods.py
Normal file
93
vllm/model_executor/layers/pooler/seqwise/methods.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePoolingMethod(nn.Module, ABC):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify", "embed", "classify", "score"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CLSPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with CLS pooling"
|
||||
)
|
||||
|
||||
return hidden_states[pooling_cursor.first_token_indices_gpu]
|
||||
|
||||
|
||||
class LastPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
return hidden_states[pooling_cursor.last_token_indices_gpu]
|
||||
|
||||
|
||||
class MeanPool(SequencePoolingMethod):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
assert not pooling_cursor.is_partial_prefill(), (
|
||||
"partial prefill not supported with MEAN pooling"
|
||||
)
|
||||
|
||||
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
)
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
|
||||
start_indices = pooling_cursor.first_token_indices_gpu
|
||||
end_indices = pooling_cursor.last_token_indices_gpu
|
||||
|
||||
return (
|
||||
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
|
||||
) / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
def get_seq_pooling_method(pooling_type: PoolingTypeStr | str):
|
||||
if pooling_type == "LAST":
|
||||
return LastPool()
|
||||
if pooling_type == "CLS":
|
||||
return CLSPool()
|
||||
if pooling_type == "MEAN":
|
||||
return MeanPool()
|
||||
|
||||
raise NotImplementedError(f"Unknown sequence pooling type: {pooling_type!r}")
|
||||
106
vllm/model_executor/layers/pooler/seqwise/poolers.py
Normal file
106
vllm/model_executor/layers/pooler/seqwise/poolers.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerActivation
|
||||
from vllm.tasks import POOLING_TASKS, PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .heads import (
|
||||
ClassifierPoolerHead,
|
||||
EmbeddingPoolerHead,
|
||||
SequencePoolerHead,
|
||||
SequencePoolerHeadOutput,
|
||||
)
|
||||
from .methods import (
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
|
||||
SequencePoolingFn: TypeAlias = Callable[
|
||||
[torch.Tensor, PoolingMetadata],
|
||||
SequencePoolingMethodOutput,
|
||||
]
|
||||
SequencePoolingHeadFn: TypeAlias = Callable[
|
||||
[SequencePoolingMethodOutput, PoolingMetadata],
|
||||
SequencePoolerHeadOutput,
|
||||
]
|
||||
|
||||
SequencePoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
|
||||
class SequencePooler(Pooler):
|
||||
"""
|
||||
A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Postprocesses the output based on pooling head.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling: SequencePoolingMethod | SequencePoolingFn,
|
||||
head: SequencePoolerHead | SequencePoolingHeadFn,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
tasks = set(POOLING_TASKS)
|
||||
|
||||
if isinstance(self.pooling, SequencePoolingMethod):
|
||||
tasks &= self.pooling.get_supported_tasks()
|
||||
if isinstance(self.head, SequencePoolerHead):
|
||||
tasks &= self.head.get_supported_tasks()
|
||||
|
||||
return tasks
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
updates = PoolingParamsUpdate()
|
||||
|
||||
if isinstance(self.pooling, SequencePoolingMethod):
|
||||
updates |= self.pooling.get_pooling_updates(task)
|
||||
|
||||
return updates
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
|
||||
head = EmbeddingPoolerHead()
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
|
||||
|
||||
def pooler_for_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
*,
|
||||
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if pooling is None:
|
||||
pooling = get_seq_pooling_method(pooler_config.get_pooling_type())
|
||||
|
||||
head = ClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
return SequencePooler(pooling=pooling, head=head)
|
||||
128
vllm/model_executor/layers/pooler/special.py
Normal file
128
vllm/model_executor/layers/pooler/special.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping, Set
|
||||
from itertools import groupby
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .abstract import Pooler, PoolerOutput
|
||||
from .common import ClassifierFn
|
||||
from .seqwise import (
|
||||
SequencePoolingFn,
|
||||
SequencePoolingMethod,
|
||||
pooler_for_classify,
|
||||
pooler_for_embed,
|
||||
)
|
||||
from .tokwise import AllPool, pooler_for_token_classify, pooler_for_token_embed
|
||||
|
||||
|
||||
class DispatchPooler(Pooler):
|
||||
"""Dispatches calls to a sub-pooler based on the pooling task."""
|
||||
|
||||
@classmethod
|
||||
def for_embedding(cls, pooler_config: PoolerConfig):
|
||||
return cls(
|
||||
{
|
||||
"token_embed": pooler_for_token_embed(pooler_config),
|
||||
"embed": pooler_for_embed(pooler_config),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_seq_cls(
|
||||
cls,
|
||||
pooler_config: PoolerConfig,
|
||||
*,
|
||||
pooling: SequencePoolingMethod | SequencePoolingFn | None = None,
|
||||
classifier: ClassifierFn | None = None,
|
||||
):
|
||||
return cls(
|
||||
{
|
||||
"token_classify": pooler_for_token_classify(
|
||||
pooler_config,
|
||||
pooling=AllPool(),
|
||||
classifier=classifier,
|
||||
),
|
||||
"classify": pooler_for_classify(
|
||||
pooler_config,
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": pooler_for_classify(
|
||||
pooler_config,
|
||||
pooling=pooling,
|
||||
classifier=classifier,
|
||||
act_fn="score",
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None:
|
||||
super().__init__()
|
||||
|
||||
for task, pooler in poolers_by_task.items():
|
||||
if task not in pooler.get_supported_tasks():
|
||||
raise ValueError(
|
||||
f"{pooler=} does not support {task=}. "
|
||||
f"Supported tasks: {pooler.get_supported_tasks()}"
|
||||
)
|
||||
|
||||
self.poolers_by_task = poolers_by_task
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return set(self.poolers_by_task)
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.poolers_by_task[task].get_pooling_updates(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
poolers_by_task = self.poolers_by_task
|
||||
|
||||
outputs = list[torch.Tensor | None]()
|
||||
offset = 0
|
||||
for task, group in groupby(pooling_metadata.tasks):
|
||||
if not (pooler := poolers_by_task.get(task)):
|
||||
raise ValueError(
|
||||
f"Unsupported task: {task!r} "
|
||||
f"Supported tasks: {self.get_supported_tasks()}"
|
||||
)
|
||||
|
||||
num_items = len(list(group))
|
||||
group_output: PoolerOutput = pooler(
|
||||
hidden_states,
|
||||
pooling_metadata[offset : offset + num_items],
|
||||
)
|
||||
|
||||
outputs.extend(group_output)
|
||||
offset += num_items
|
||||
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"supported_task={self.get_supported_tasks()}"
|
||||
return s
|
||||
|
||||
|
||||
class IdentityPooler(Pooler):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"plugin", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return hidden_states
|
||||
|
||||
|
||||
__all__ = ["DispatchPooler", "IdentityPooler"]
|
||||
39
vllm/model_executor/layers/pooler/tokwise/__init__.py
Normal file
39
vllm/model_executor/layers/pooler/tokwise/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Poolers that produce an output for each token in the sequence."""
|
||||
|
||||
from .heads import (
|
||||
TokenClassifierPoolerHead,
|
||||
TokenEmbeddingPoolerHead,
|
||||
TokenPoolerHead,
|
||||
TokenPoolerHeadOutputItem,
|
||||
)
|
||||
from .methods import (
|
||||
AllPool,
|
||||
StepPool,
|
||||
TokenPoolingMethod,
|
||||
TokenPoolingMethodOutputItem,
|
||||
get_tok_pooling_method,
|
||||
)
|
||||
from .poolers import (
|
||||
TokenPooler,
|
||||
TokenPoolerOutput,
|
||||
pooler_for_token_classify,
|
||||
pooler_for_token_embed,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TokenPoolerHead",
|
||||
"TokenPoolerHeadOutputItem",
|
||||
"TokenClassifierPoolerHead",
|
||||
"TokenEmbeddingPoolerHead",
|
||||
"TokenPoolingMethod",
|
||||
"TokenPoolingMethodOutputItem",
|
||||
"AllPool",
|
||||
"StepPool",
|
||||
"get_tok_pooling_method",
|
||||
"TokenPooler",
|
||||
"TokenPoolerOutput",
|
||||
"pooler_for_token_classify",
|
||||
"pooler_for_token_embed",
|
||||
]
|
||||
142
vllm/model_executor/layers/pooler/tokwise/heads.py
Normal file
142
vllm/model_executor/layers/pooler/tokwise/heads.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn
|
||||
from vllm.model_executor.layers.pooler.activations import (
|
||||
PoolerActivation,
|
||||
PoolerNormalize,
|
||||
resolve_classifier_act_fn,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import _load_st_projector
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .methods import TokenPoolingMethodOutputItem
|
||||
|
||||
TokenPoolerHeadOutputItem: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
class TokenPoolerHead(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward_chunk(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenPoolerHeadOutputItem:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pooled_data: list[TokenPoolingMethodOutputItem],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[TokenPoolerHeadOutputItem]:
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
assert len(pooled_data) == len(pooling_params)
|
||||
|
||||
return [self.forward_chunk(d, p) for d, p in zip(pooled_data, pooling_params)]
|
||||
|
||||
|
||||
class TokenEmbeddingPoolerHead(TokenPoolerHead):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Load ST projector if available
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.projector = _load_st_projector(model_config)
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed"}
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenPoolerHeadOutputItem:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# pooled_data shape: [n_tokens, hidden_dimension]
|
||||
|
||||
# Apply ST projector
|
||||
if self.projector is not None:
|
||||
pooled_data = self.projector(pooled_data)
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
|
||||
# for matryoshka representation
|
||||
pooled_data = pooled_data[..., : pooling_param.dimensions]
|
||||
|
||||
# for normalize
|
||||
if pooling_param.normalize:
|
||||
pooled_data = self.activation(pooled_data)
|
||||
|
||||
# pooled_data shape: [n_tokens, embedding_dimension]
|
||||
return pooled_data
|
||||
|
||||
|
||||
class TokenClassifierPoolerHead(TokenPoolerHead):
|
||||
def __init__(
|
||||
self,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
self.classifier = classifier
|
||||
self.logit_bias: float | None = model_config.pooler_config.logit_bias
|
||||
self.head_dtype = model_config.head_dtype
|
||||
|
||||
self.act_fn = resolve_classifier_act_fn(
|
||||
model_config, static_num_labels=False, act_fn=act_fn
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_classify"}
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutputItem,
|
||||
pooling_param: PoolingParams,
|
||||
) -> TokenPoolerHeadOutputItem:
|
||||
# for unfinished chunked prefill
|
||||
if pooled_data is None:
|
||||
return None
|
||||
|
||||
pooled_data = pooled_data.to(self.head_dtype)
|
||||
# hidden_states shape: [n_token, hidden_size]
|
||||
|
||||
if self.classifier is not None:
|
||||
scores = self.classifier(pooled_data)
|
||||
else:
|
||||
scores = pooled_data
|
||||
# scores shape: [n_token, num_labels]
|
||||
|
||||
if self.logit_bias is not None:
|
||||
scores -= self.logit_bias
|
||||
|
||||
if pooling_param.use_activation:
|
||||
scores = self.act_fn(scores)
|
||||
|
||||
# scores shape: [n_token, num_labels]
|
||||
return scores
|
||||
124
vllm/model_executor/layers/pooler/tokwise/methods.py
Normal file
124
vllm/model_executor/layers/pooler/tokwise/methods.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
from vllm.model_executor.layers.pooler import PoolingParamsUpdate
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
TokenPoolingMethodOutputItem: TypeAlias = torch.Tensor | None
|
||||
|
||||
|
||||
class TokenPoolingMethod(nn.Module, ABC):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"token_embed", "token_classify"}
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate()
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[TokenPoolingMethodOutputItem]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AllPool(TokenPoolingMethod):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
self.enable_chunked_prefill = scheduler_config.enable_chunked_prefill
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[TokenPoolingMethodOutputItem]:
|
||||
pooling_cursor = pooling_metadata.get_pooling_cursor()
|
||||
hidden_states_all = hidden_states.split(
|
||||
pooling_cursor.num_scheduled_tokens_cpu.tolist()
|
||||
)
|
||||
hidden_states_lst = [hidden_states_all[i] for i in pooling_cursor.index]
|
||||
|
||||
if not self.enable_chunked_prefill:
|
||||
return hidden_states_lst
|
||||
|
||||
pooling_states = pooling_metadata.pooling_states
|
||||
|
||||
# If chunked_prefill is enabled
|
||||
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
|
||||
for p, hs_chunk in zip(pooling_states, hidden_states_lst):
|
||||
p.hidden_states_cache.append(hs_chunk)
|
||||
|
||||
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
|
||||
output_list = list[TokenPoolingMethodOutputItem]()
|
||||
for p, finished in zip(pooling_states, pooling_cursor.is_finished()):
|
||||
if finished:
|
||||
hidden_states_cache = p.hidden_states_cache
|
||||
if len(hidden_states_cache) == 1:
|
||||
output_list.append(hidden_states_cache[0])
|
||||
else:
|
||||
output_list.append(torch.concat(hidden_states_cache, dim=0))
|
||||
p.clean()
|
||||
else:
|
||||
output_list.append(None)
|
||||
|
||||
return output_list
|
||||
|
||||
|
||||
class StepPool(AllPool):
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return PoolingParamsUpdate(requires_token_ids=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> list[TokenPoolingMethodOutputItem]:
|
||||
pooled_data_lst = super().forward(hidden_states, pooling_metadata)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
|
||||
pooling_params = pooling_metadata.pooling_params
|
||||
|
||||
pooled_data = list[torch.Tensor | None]()
|
||||
for data, token_id, pooling_param in zip(
|
||||
pooled_data_lst, prompt_token_ids, pooling_params
|
||||
):
|
||||
# for unfinished chunked prefill
|
||||
if data is None:
|
||||
pass
|
||||
else:
|
||||
step_tag_id = pooling_param.step_tag_id
|
||||
returned_token_ids = pooling_param.returned_token_ids
|
||||
|
||||
if returned_token_ids is not None and len(returned_token_ids) > 0:
|
||||
data = data[:, returned_token_ids]
|
||||
|
||||
if step_tag_id is not None:
|
||||
data = data[token_id == step_tag_id]
|
||||
|
||||
pooled_data.append(data)
|
||||
|
||||
return pooled_data
|
||||
|
||||
|
||||
def get_tok_pooling_method(pooling_type: PoolingTypeStr | str):
|
||||
if pooling_type == "ALL":
|
||||
return AllPool()
|
||||
if pooling_type == "STEP":
|
||||
return StepPool()
|
||||
|
||||
# TODO: Separate seq and tok pooling types so we don't need this fallback
|
||||
return AllPool()
|
||||
raise NotImplementedError(f"Unknown tokenwise pooling type: {pooling_type!r}")
|
||||
106
vllm/model_executor/layers/pooler/tokwise/poolers.py
Normal file
106
vllm/model_executor/layers/pooler/tokwise/poolers.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Set
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate
|
||||
from vllm.model_executor.layers.pooler.abstract import Pooler
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerActivation
|
||||
from vllm.tasks import POOLING_TASKS, PoolingTask
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .heads import (
|
||||
TokenClassifierPoolerHead,
|
||||
TokenEmbeddingPoolerHead,
|
||||
TokenPoolerHead,
|
||||
TokenPoolerHeadOutputItem,
|
||||
)
|
||||
from .methods import (
|
||||
TokenPoolingMethod,
|
||||
TokenPoolingMethodOutputItem,
|
||||
get_tok_pooling_method,
|
||||
)
|
||||
|
||||
TokenPoolingFn: TypeAlias = Callable[
|
||||
[torch.Tensor, PoolingMetadata],
|
||||
list[TokenPoolingMethodOutputItem],
|
||||
]
|
||||
TokenPoolingHeadFn: TypeAlias = Callable[
|
||||
[list[TokenPoolingMethodOutputItem], PoolingMetadata],
|
||||
list[TokenPoolerHeadOutputItem],
|
||||
]
|
||||
|
||||
TokenPoolerOutput: TypeAlias = list[torch.Tensor | None]
|
||||
|
||||
|
||||
class TokenPooler(Pooler):
|
||||
"""
|
||||
A layer that pools specific information from hidden states.
|
||||
|
||||
This layer does the following:
|
||||
1. Extracts specific tokens or aggregates data based on pooling method.
|
||||
2. Postprocesses the output based on pooling head.
|
||||
3. Returns structured results as `PoolerOutput`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pooling: TokenPoolingMethod | TokenPoolingFn,
|
||||
head: TokenPoolerHead | TokenPoolingHeadFn,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pooling = pooling
|
||||
self.head = head
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
tasks = set(POOLING_TASKS)
|
||||
|
||||
if isinstance(self.pooling, TokenPoolingMethod):
|
||||
tasks &= self.pooling.get_supported_tasks()
|
||||
if isinstance(self.head, TokenPoolerHead):
|
||||
tasks &= self.head.get_supported_tasks()
|
||||
|
||||
return tasks
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
updates = PoolingParamsUpdate()
|
||||
|
||||
if isinstance(self.pooling, TokenPoolingMethod):
|
||||
updates |= self.pooling.get_pooling_updates(task)
|
||||
|
||||
return updates
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
def pooler_for_token_embed(pooler_config: PoolerConfig):
|
||||
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
|
||||
head = TokenEmbeddingPoolerHead()
|
||||
|
||||
return TokenPooler(pooling=pooling, head=head)
|
||||
|
||||
|
||||
def pooler_for_token_classify(
|
||||
pooler_config: PoolerConfig,
|
||||
*,
|
||||
pooling: TokenPoolingMethod | TokenPoolingFn | None = None,
|
||||
classifier: ClassifierFn | None = None,
|
||||
act_fn: PoolerActivation | str | None = None,
|
||||
):
|
||||
if pooling is None:
|
||||
pooling = get_tok_pooling_method(pooler_config.get_pooling_type())
|
||||
|
||||
head = TokenClassifierPoolerHead(classifier=classifier, act_fn=act_fn)
|
||||
|
||||
return TokenPooler(pooling=pooling, head=head)
|
||||
@@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
|
||||
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
},
|
||||
)
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
@@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
|
||||
from .utils import maybe_prefix
|
||||
@@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config, classifier=self.score
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
CLSPool,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
pooler_for_token_classify,
|
||||
pooler_for_token_embed,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
@@ -85,25 +91,21 @@ class BertEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertPooler(Pooler):
|
||||
class BertPooler(SequencePooler):
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
pooling=CLSPool(),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
self.pooling = PoolingMethod.from_pooling_type("CLS")
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
) -> SequencePoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
@@ -111,15 +113,6 @@ class BertPooler(Pooler):
|
||||
pooled_data = self.activation(pooled_data)
|
||||
return pooled_data
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
return DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
|
||||
# Here we encode the token type ids together with the input ids.
|
||||
@@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler):
|
||||
remove_cls_sep: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert pooling in ("max", "sum")
|
||||
self.mlm_head = mlm_head
|
||||
self.cls_token_id = cls_token_id
|
||||
@@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(hidden_states, torch.Tensor) and hidden_states.dim() == 2
|
||||
|
||||
lens_tensor: torch.Tensor = pooling_metadata.prompt_lens
|
||||
) -> SequencePoolerOutput:
|
||||
lens_tensor = pooling_metadata.prompt_lens
|
||||
lens: list[int] = lens_tensor.tolist()
|
||||
B: int = len(lens)
|
||||
|
||||
@@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"token_embed": pooler_for_token_embed(pooler_config),
|
||||
"embed": SPLADESparsePooler(
|
||||
mlm_head=self.mlm_head,
|
||||
cls_token_id=cls_id,
|
||||
@@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.bert.embed_input_ids(input_ids)
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
@@ -37,7 +38,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ..layers.pooler import ClassifierPooler, DispatchPooler, Pooler
|
||||
from .bert import BertPooler
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces_base import default_pooling_type
|
||||
@@ -693,20 +693,10 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.new.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=self.new.pooler,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -26,7 +26,7 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -880,12 +880,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
# Assumes that self.forward is called after self.embed_input_ids
|
||||
self._is_text_input = True
|
||||
|
||||
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
@@ -49,7 +50,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ..layers.pooler import DispatchPooler, Pooler
|
||||
from .interfaces import SupportsCrossEncoding, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
@@ -351,19 +351,7 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.transformer.embed_input_ids(input_ids)
|
||||
|
||||
@@ -9,17 +9,19 @@ from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolerNormalize,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.activations import PoolerNormalize
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolingMethod,
|
||||
SequencePoolingMethodOutput,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces_base import default_pooling_type
|
||||
@@ -27,7 +29,7 @@ from .interfaces_base import default_pooling_type
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GritLMMeanPool(PoolingMethod):
|
||||
class GritLMMeanPool(SequencePoolingMethod):
|
||||
"""As `MeanPool`, but only includes non-instruction tokens."""
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
@@ -151,7 +153,7 @@ class GritLMMeanPool(PoolingMethod):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolingMethodOutput:
|
||||
) -> SequencePoolingMethodOutput:
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
@@ -174,35 +176,22 @@ class GritLMMeanPool(PoolingMethod):
|
||||
return pooled_data
|
||||
|
||||
|
||||
class GritLMPooler(Pooler):
|
||||
class GritLMPooler(SequencePooler):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
pooling=GritLMMeanPool(model_config),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
self.pooling = GritLMMeanPool(model_config)
|
||||
self.activation = PoolerNormalize()
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
) -> SequencePoolerHeadOutput:
|
||||
return self.activation(pooled_data)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
@default_pooling_type("MEAN")
|
||||
class GritLM(LlamaForCausalLM):
|
||||
@@ -245,7 +234,7 @@ class GritLM(LlamaForCausalLM):
|
||||
if pooler_config is not None:
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"token_embed": pooler_for_token_embed(pooler_config),
|
||||
"embed": GritLMPooler(vllm_config.model_config),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -434,9 +434,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
@@ -596,16 +596,4 @@ class JambaForSequenceClassification(JambaForCausalLM):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
|
||||
|
||||
@@ -10,7 +10,7 @@ from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -105,19 +105,7 @@ class JinaVLForSequenceClassification(
|
||||
self.score = JinaVLScorer(
|
||||
vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
|
||||
)
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Set
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -12,21 +12,18 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
TokenPoolerHeadOutput,
|
||||
TokenPoolingMethodOutput,
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
from vllm.v1.outputs import TokenPoolerOutput
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
from .interfaces import SupportsCrossEncoding
|
||||
@@ -282,12 +279,13 @@ class ModernBertModel(nn.Module):
|
||||
return norm_outputs
|
||||
|
||||
|
||||
class ModernBertPooler(Pooler):
|
||||
class ModernBertPooler(SequencePooler):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
pooling=get_seq_pooling_method(config.classifier_pooling.upper()),
|
||||
head=self.head,
|
||||
)
|
||||
|
||||
pooling_type = config.classifier_pooling.upper()
|
||||
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size, config.hidden_size, config.classifier_bias
|
||||
)
|
||||
@@ -296,32 +294,17 @@ class ModernBertPooler(Pooler):
|
||||
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
||||
)
|
||||
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: TokenPoolingMethodOutput,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerHeadOutput:
|
||||
) -> SequencePoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = pooled_data.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_data)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> TokenPoolerOutput:
|
||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||
return pooled_data
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
@@ -344,18 +327,10 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.pooling, classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.pooling, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
@@ -438,13 +413,7 @@ class ModernBertForTokenClassification(nn.Module):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
@@ -14,7 +14,8 @@ from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@@ -104,9 +105,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
|
||||
@default_pooling_type("STEP")
|
||||
@@ -118,6 +117,4 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
@@ -8,12 +8,8 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
CLSPool,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import CLSPool
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.bert import (
|
||||
TOKEN_TYPE_SHIFT,
|
||||
@@ -196,18 +192,10 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -1050,12 +1050,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
self._is_text_input = True
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from transformers import BatchFeature
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
|
||||
from vllm.model_executor.layers.pooler import IdentityPooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -248,7 +248,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({"plugin": DummyPooler()})
|
||||
self.pooler = IdentityPooler()
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
|
||||
@@ -22,12 +22,8 @@ import torch
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
CLSPool,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.seqwise import CLSPool
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.models.interfaces_base import VllmModelForPooling
|
||||
|
||||
@@ -47,12 +43,7 @@ class EmbeddingMixin(VllmModelForPooling):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
|
||||
class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
@@ -104,16 +95,8 @@ class SequenceClassificationMixin(SupportsCrossEncoding, VllmModelForPooling):
|
||||
|
||||
self.classifier.__class__ = ClassifierWithReshape
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="classify"
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=CLSPool(), classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
@@ -91,9 +91,7 @@ class LogprobsTensors(NamedTuple):
|
||||
|
||||
# [num_reqs, <dynamic>]
|
||||
# The shape of each element depends on the pooler used
|
||||
TokenPoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
TokenwisePoolerOutput: TypeAlias = list[torch.Tensor] | list[torch.Tensor | None]
|
||||
PoolerOutput: TypeAlias = TokenPoolerOutput | TokenwisePoolerOutput
|
||||
PoolerOutput: TypeAlias = torch.Tensor | list[torch.Tensor] | list[torch.Tensor | None]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user