[Model] Consolidate pooler implementations (#20927)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-16 21:39:13 +08:00
committed by GitHub
parent 260127ea54
commit 1c3198b6c4
9 changed files with 558 additions and 372 deletions

View File

@@ -58,22 +58,27 @@ def _create_pooling_model_cls(
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
self.vllm_config = vllm_config
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._init_pooler(vllm_config, prefix=prefix)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
# If the model already defines a pooler instance, don't overwrite it
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def pooler(
self,
@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
from vllm.model_executor.layers.pooler import (ClassifierPooler,
PoolerOutput, PoolingType,
SimplePooler)
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
class ModelForSequenceClassification(ModelForPooling,
SupportsCrossEncoding):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.vllm_config = vllm_config
self.task = vllm_config.model_config.task
self.pooling_type = (
vllm_config.model_config.pooler_config.pooling_type)
self.score = RowParallelLinear(
config.hidden_size,
config.num_labels,
input_is_parallel=False,
bias=False,
params_dtype=torch.float32,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "score"),
)
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
pooler = SimplePooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True,
)
self._pooler = ClassifierPooler(
vllm_config.model_config,
pooling=pooler.pooling,
classifier=self._classifier,
act_fn=pooler.head.activation,
)
def _classifier(self, x: torch.Tensor):
x, _ = self.score(x.float())
return x
def forward(
self,
@@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
def get_logits(hidden_states):
if isinstance(hidden_states, list):
logits = [self.score(state)[0] for state in hidden_states]
else:
logits, _ = self.score(hidden_states)
return logits
if self.pooling_type == PoolingType.ALL:
logits = get_logits(hidden_states)
return self._pooler(logits, pooling_metadata)
else:
hidden_states = self._pooler.extract_states(
hidden_states, pooling_metadata)
logits = get_logits(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)