[Model] Update pooling model interface (#21058)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -18,12 +18,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
PoolingMethod, PoolingType)
|
||||
PoolingMethod, PoolingTask,
|
||||
PoolingType)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
@@ -80,7 +82,7 @@ class BertEmbedding(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
class BertPooler(Pooler):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
@@ -89,6 +91,9 @@ class BertPooler(nn.Module):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]:
|
||||
return self.pooling.get_pooling_params(task)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
@@ -319,6 +324,9 @@ class BertOutput(nn.Module):
|
||||
|
||||
|
||||
class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
|
||||
def __init__(self,
|
||||
@@ -403,12 +411,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = self._build_pooler(pooler_config)
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -422,13 +433,6 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
|
||||
@@ -466,6 +470,8 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -476,7 +482,7 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self._pooler = ClassifierPooler(
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
@@ -487,13 +493,6 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
loaded_params = loader.load_weights(weights)
|
||||
return loaded_params
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
|
||||
Reference in New Issue
Block a user