[Model] Standardize pooling heads (#32148)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -8,7 +8,7 @@ from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, PoolerConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
|
||||
Pooler,
|
||||
PoolingParamsUpdate,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
|
||||
from vllm.model_executor.layers.pooler.seqwise import (
|
||||
EmbeddingPoolerHead,
|
||||
SequencePooler,
|
||||
SequencePoolerHeadOutput,
|
||||
SequencePoolerOutput,
|
||||
SequencePoolingMethodOutput,
|
||||
get_seq_pooling_method,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler.tokwise import (
|
||||
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
|
||||
|
||||
|
||||
class BertPooler(SequencePooler):
|
||||
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
pooler_config = model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
config: BertConfig = model_config.hf_config
|
||||
|
||||
super().__init__(
|
||||
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
|
||||
head=self.head,
|
||||
# We set this dummy to avoid adding parameters to nn.Module too early
|
||||
head=nn.Identity(),
|
||||
)
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=head_dtype,
|
||||
)
|
||||
self.act_fn = nn.Tanh()
|
||||
|
||||
def head(
|
||||
self,
|
||||
pooled_data: SequencePoolingMethodOutput,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolerHeadOutput:
|
||||
if isinstance(pooled_data, list):
|
||||
pooled_data = torch.stack(pooled_data)
|
||||
|
||||
pooled_data = self.dense(pooled_data)
|
||||
pooled_data = self.activation(pooled_data)
|
||||
return pooled_data
|
||||
# Use lambdas so that weights are not registered under `self.head`
|
||||
self.head = EmbeddingPoolerHead(
|
||||
projector=lambda x: self.dense(x),
|
||||
head_dtype=head_dtype,
|
||||
activation=LambdaPoolerActivation(self.act_fn),
|
||||
)
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
|
||||
embedding_class=embedding_class,
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = BertPooler(config, pooler_config)
|
||||
self.pooler = BertPooler(vllm_config.model_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
other_weights, loaded_stacked_params = self._load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user