[Model] Consolidate pooler implementations (#20927)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-16 21:39:13 +08:00
committed by GitHub
parent 260127ea54
commit 1c3198b6c4
9 changed files with 558 additions and 372 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
from typing import Optional, Union
import torch
from torch import nn
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingType)
PoolingMethod, PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
@@ -84,14 +84,18 @@ class BertPooler(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[0, :]
pooled_output = self.dense(first_token_tensor)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
pooled_output = self.dense(pooled_output)
pooled_output = self.activation(pooled_output)
return pooled_output
@@ -472,8 +476,11 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier, self.bert.pooler)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)