[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-10-15 19:14:41 +08:00
committed by GitHub
parent d4d1a6024f
commit f54f85129e
41 changed files with 786 additions and 399 deletions

View File

@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}
)
@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_embed": Pooler.for_token_embed(pooler_config),
"embed": SPLADESparsePooler(
mlm_head=self.mlm_head,
cls_token_id=cls_id,
@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config, classifier=self.classifier
),
"classify": ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config
),
act_fn="classify",
),
"score": ClassifierPooler(
pooling=self.bert.pooler,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config
),
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
),
}
)
@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"token_classify": Pooler.for_token_classify(
pooler_config=pooler_config
),
}
)