[Model] Reorganize pooling layers (#31973)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
|
||||
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_embed": Pooler.for_token_embed(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
},
|
||||
)
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
@@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
DispatchPooler,
|
||||
Pooler,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
|
||||
from .utils import maybe_prefix
|
||||
@@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
pooler_config, classifier=self.score
|
||||
),
|
||||
"classify": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="classify"
|
||||
),
|
||||
"score": Pooler.for_classify(
|
||||
pooler_config, classifier=self.score, act_fn="score"
|
||||
),
|
||||
}
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config, classifier=self.score
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
Reference in New Issue
Block a user