[Model][1/N] Support multiple poolers at model level (#21227)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-21 17:22:21 +08:00
committed by GitHub
parent 378d33c392
commit 042af0c8d3
22 changed files with 549 additions and 413 deletions

View File

@@ -13,7 +13,6 @@ from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module])
@@ -34,16 +33,8 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
return model_name + pooling_suffix
def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
def _create_pooling_model_cls(orig_cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.pooler import Pooler
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForPooling(orig_cls, VllmModelForPooling):
@@ -71,15 +62,7 @@ def _create_pooling_model_cls(
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
raise NotImplementedError
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
# TODO: Support uninitialized params tracking
@@ -132,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T:
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
class ModelForEmbedding(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{
"encode": Pooler.for_encode(pooler_config),
"embed": Pooler.for_embed(pooler_config),
}, )
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
@@ -165,20 +154,14 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolingType, SimplePooler)
DispatchPooler, Pooler,
PoolingMethod, PoolingType)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForSequenceClassification(ModelForPooling,
class ModelForSequenceClassification(_create_pooling_model_cls(cls),
SupportsCrossEncoding):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
@@ -198,19 +181,28 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)
pooling_type_str = pooler_config.pooling_type
pooling_type = (PoolingType.LAST if pooling_type_str is None else
PoolingType[pooling_type_str])
self.pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=PoolingMethod.from_pooling_type(pooling_type),
classifier=self._classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
@@ -259,14 +251,16 @@ def as_reward_model(cls: _T) -> _T:
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)
class ModelForReward(_create_pooling_model_cls(cls)):
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, )
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")