[Model] Avoid hardcoding pooling type (#32119)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-12 13:28:12 +08:00
committed by GitHub
parent 025a32f9ed
commit 9101dc756c
6 changed files with 47 additions and 22 deletions

View File

@@ -25,11 +25,11 @@ from vllm.model_executor.layers.pooler import (
PoolingParamsUpdate,
)
from vllm.model_executor.layers.pooler.seqwise import (
CLSPool,
SequencePooler,
SequencePoolerHeadOutput,
SequencePoolerOutput,
SequencePoolingMethodOutput,
get_seq_pooling_method,
)
from vllm.model_executor.layers.pooler.tokwise import (
pooler_for_token_classify,
@@ -94,9 +94,9 @@ class BertEmbedding(nn.Module):
class BertPooler(SequencePooler):
def __init__(self, config: BertConfig):
def __init__(self, config: BertConfig, pooler_config: PoolerConfig):
super().__init__(
pooling=CLSPool(),
pooling=get_seq_pooling_method(pooler_config.seq_pooling_type),
head=self.head,
)
@@ -450,7 +450,11 @@ class BertPoolingModel(BertModel):
)
config = vllm_config.model_config.hf_config
self.pooler = BertPooler(config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = BertPooler(config, pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
other_weights, loaded_stacked_params = self._load_weights(weights)
@@ -711,6 +715,8 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
layer_norm_eps=getattr(cfg, "layer_norm_eps", 1e-12),
)
# None of vLLM's built-in sequence pooling types are
# applicable so it is overwritten by SPLADESparsePooler
pooling_mode = getattr(self, "_splade_pooling", "max")
cls_id = getattr(cfg, "cls_token_id", None)