[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -40,9 +40,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from ..layers.pooler import Pooler, PoolingType
from .interfaces import SupportsPP
@@ -332,6 +331,8 @@ class GPT2ForSequenceClassification(nn.Module):
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -339,7 +340,7 @@ class GPT2ForSequenceClassification(nn.Module):
prefix=maybe_prefix(prefix, "gpt2"))
self.score = nn.Linear(config.n_embd, config.num_labels, bias=False)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
@@ -349,13 +350,6 @@ class GPT2ForSequenceClassification(nn.Module):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: torch.Tensor,