[Model] Update pooling model interface (#21058)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-18 00:05:40 +08:00
committed by GitHub
parent 9fb2d22032
commit 90bd2ab6e3
17 changed files with 247 additions and 345 deletions

View File

@@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
PoolingMethod, PoolingType)
PoolingMethod, PoolingTask,
PoolingType)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.pooling_params import PoolingParams
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
return embeddings
class BertPooler(nn.Module):
class BertPooler(Pooler):
def __init__(self, config: BertConfig):
super().__init__()
@@ -89,6 +91,9 @@ class BertPooler(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
return self.pooling.get_pooling_params(task)
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
@@ -319,6 +324,9 @@ class BertOutput(nn.Module):
class BertModel(nn.Module, SupportsQuant):
is_pooling_model = True
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
@@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
pooler_config = vllm_config.model_config.pooler_config
self.model = self._build_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = self._build_pooler(pooler_config)
self.pooler = self._build_pooler(pooler_config)
def forward(
self,
@@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
@@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
_pooler: An instance of Pooler used for pooling operations.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
embedding_class=BertEmbedding,
add_pooling_layer=True)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self._pooler = ClassifierPooler(
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=self.bert.pooler,
classifier=self.classifier,
@@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
loaded_params = loader.load_weights(weights)
return loaded_params
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def forward(
self,
input_ids: Optional[torch.Tensor],