[Model][1/N] Support multiple poolers at model level (#21227)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
@@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def _create_pooling_model_cls(
|
||||
orig_cls: _T,
|
||||
*,
|
||||
default_pooling_type: "PoolingType",
|
||||
default_normalize: bool,
|
||||
default_softmax: bool,
|
||||
) -> _T:
|
||||
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
@@ -71,15 +62,7 @@ def _create_pooling_model_cls(
|
||||
self._init_pooler(vllm_config, prefix=prefix)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
# TODO: Support uninitialized params tracking
|
||||
@@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
|
||||
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"encode": Pooler.for_encode(pooler_config),
|
||||
"embed": Pooler.for_embed(pooler_config),
|
||||
}, )
|
||||
|
||||
ModelForEmbedding = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=True,
|
||||
default_softmax=False,
|
||||
)
|
||||
ModelForEmbedding.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
@@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
PoolingType, SimplePooler)
|
||||
DispatchPooler, Pooler,
|
||||
PoolingMethod, PoolingType)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
ModelForPooling = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.LAST,
|
||||
default_normalize=False,
|
||||
default_softmax=True,
|
||||
)
|
||||
|
||||
class ModelForSequenceClassification(ModelForPooling,
|
||||
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
@@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
)
|
||||
pooling_type_str = pooler_config.pooling_type
|
||||
pooling_type = (PoolingType.LAST if pooling_type_str is None else
|
||||
PoolingType[pooling_type_str])
|
||||
|
||||
self.pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self._classifier,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
self.pooler = DispatchPooler({
|
||||
"encode":
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=PoolingMethod.from_pooling_type(pooling_type),
|
||||
classifier=self._classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
})
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
@@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
|
||||
ModelForReward = _create_pooling_model_cls(
|
||||
cls,
|
||||
default_pooling_type=PoolingType.ALL,
|
||||
default_normalize=False,
|
||||
default_softmax=False,
|
||||
)
|
||||
class ModelForReward(_create_pooling_model_cls(cls)):
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||
|
||||
ModelForReward.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForReward")
|
||||
|
||||
Reference in New Issue
Block a user