[Model] Standardize pooling heads (#32148)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-13 01:01:49 +08:00
committed by GitHub
parent 3f72639d36
commit 8863c2b25c
9 changed files with 182 additions and 149 deletions

View File

@@ -8,7 +8,7 @@ from torch import nn
from transformers import BertConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.encoder_only_attention import (
@@ -24,11 +24,11 @@ from vllm.model_executor.layers.pooler import (
Pooler,
PoolingParamsUpdate,
)
from vllm.model_executor.layers.pooler.activations import LambdaPoolerActivation
from vllm.model_executor.layers.pooler.seqwise import (
EmbeddingPoolerHead,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import (
@@ -94,26 +94,32 @@ class BertEmbedding(nn.Module):
class BertPooler(SequencePooler):
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
def __init__(self, model_config: ModelConfig):
pooler_config = model_config.pooler_config
assert pooler_config is not None
config: BertConfig = model_config.hf_config
super().__init__(
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
head=self.head,
# We set this dummy to avoid adding parameters to nn.Module too early
head=nn.Identity(),
)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
head_dtype = model_config.head_dtype
self.dense = nn.Linear(
config.hidden_size,
config.hidden_size,
dtype=head_dtype,
)
self.act_fn = nn.Tanh()
def head(
self,
pooled_data: SequencePoolingMethodOutput,
pooling_metadata: PoolingMetadata,
) -> SequencePoolerHeadOutput:
if isinstance(pooled_data, list):
pooled_data = torch.stack(pooled_data)
pooled_data = self.dense(pooled_data)
pooled_data = self.activation(pooled_data)
return pooled_data
# Use lambdas so that weights are not registered under `self.head`
self.head = EmbeddingPoolerHead(
projector=lambda x: self.dense(x),
head_dtype=head_dtype,
activation=LambdaPoolerActivation(self.act_fn),
)
class BertEncoder(nn.Module):
@@ -449,12 +455,7 @@ class BertPoolingModel(BertModel):
embedding_class=embedding_class,
)
config = vllm_config.model_config.hf_config
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(config, pooler_config)
self.pooler = BertPooler(vllm_config.model_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)