[Model] Consolidate pooler implementations (#20927)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -9,7 +9,7 @@ from torch import nn
|
||||
from transformers import RobertaConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler
|
||||
from vllm.model_executor.layers.pooler import ClassifierPooler, CLSPool
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
||||
@@ -106,8 +106,8 @@ class RobertaClassificationHead(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[0, :] # take <s> token (equiv. to [CLS])
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# CLSPool has already been applied in `pooling`
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.out_proj(x)
|
||||
@@ -188,8 +188,11 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
|
||||
add_pooling_layer=False)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
self._pooler = ClassifierPooler(vllm_config.model_config,
|
||||
self.classifier)
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=CLSPool(),
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
Reference in New Issue
Block a user