[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
},
|
||||
)
|
||||
@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
ClassifierPooler,
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
PoolingMethod,
|
||||
PoolingType,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
model_config.hidden_size,
|
||||
config.num_labels,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
params_dtype=vllm_config.model_config.head_dtype,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooling_type_str = pooler_config.pooling_type
|
||||
assert pooling_type_str is not None
|
||||
pooling_type = PoolingType[pooling_type_str]
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"classify": ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config
|
||||
),
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"score": ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)},
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config=pooler_config
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
ModelForReward.__name__ = _get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
||||
Reference in New Issue
Block a user