[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
|
||||
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}
|
||||
)
|
||||
@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
|
||||
|
||||
return DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": SPLADESparsePooler(
|
||||
mlm_head=self.mlm_head,
|
||||
cls_token_id=cls_id,
|
||||
@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.classifier
|
||||
),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
act_fn="classify",
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=self.bert.pooler,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
pooling=self.bert.pooler, classifier=self.classifier, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user