[Model] Consolidate pooler implementations (#20927)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-16 21:39:13 +08:00
committed by GitHub
parent 260127ea54
commit 1c3198b6c4
9 changed files with 558 additions and 372 deletions

View File

@@ -9,7 +9,7 @@ from torch import nn
from transformers import RobertaConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
@@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CLSPool has already been applied in `pooling`
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
@@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)
self._pooler = ClassifierPooler(vllm_config.model_config,
self.classifier)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=CLSPool(),
classifier=self.classifier,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)