[Model][1/N] Support multiple poolers at model level (#21227)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Set
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -17,7 +17,8 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
DispatchPooler, Pooler,
|
||||
PoolingMethod,
|
||||
PoolingParamsUpdate,
|
||||
PoolingType)
|
||||
@@ -92,20 +93,29 @@ class BertPooler(Pooler):
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def get_pooling_updates(
|
||||
self,
|
||||
task: PoolingTask,
|
||||
) -> Optional[PoolingParamsUpdate]:
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return self.pooling.get_supported_tasks()
|
||||
|
||||
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
pooled_output = self.pooling(hidden_states, pooling_metadata)
|
||||
pooled_output = self.dense(pooled_output)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
|
||||
if isinstance(pooled_output, list):
|
||||
pooled_output = [self._head(output) for output in pooled_output]
|
||||
else:
|
||||
pooled_output = self._head(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
@@ -333,18 +343,19 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
|
||||
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type = BertEmbedding,
|
||||
add_pooling_layer: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type[nn.Module] = BertEmbedding,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.embeddings = embedding_class(config)
|
||||
self.encoder = BertEncoder(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -366,8 +377,7 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "query", "q"),
|
||||
@@ -395,10 +405,43 @@ class BertModel(nn.Module, SupportsQuant):
|
||||
if name in params_dict:
|
||||
other_weights.append((name, loaded_weight))
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["pooler."] if self.pooler is None else []),
|
||||
return other_weights, loaded_stacked_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["pooler."])
|
||||
loaded_params = loader.load_weights(other_weights)
|
||||
loaded_params.update(loaded_stacked_params)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertPoolingModel(BertModel):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
embedding_class: type[nn.Module] = BertEmbedding,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=embedding_class,
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(other_weights)
|
||||
loaded_params.update(loaded_stacked_params)
|
||||
return loaded_params
|
||||
@@ -421,6 +464,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
super().__init__()
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.model = self._build_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.pooler = self._build_pooler(pooler_config)
|
||||
@@ -456,10 +501,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
embedding_class=BertEmbedding)
|
||||
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return Pooler.from_config_with_defaults(pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
normalize=True,
|
||||
softmax=False)
|
||||
return DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"embed":
|
||||
Pooler.for_embed(
|
||||
pooler_config,
|
||||
default_pooling_type=PoolingType.CLS,
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
@@ -481,16 +531,32 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.bert = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding,
|
||||
add_pooling_layer=True)
|
||||
self.bert = BertPoolingModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "bert"),
|
||||
embedding_class=BertEmbedding)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
Reference in New Issue
Block a user