[Model] Consolidate pooler implementations (#20927)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -58,22 +58,27 @@ def _create_pooling_model_cls(
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# These are not used in pooling models
|
||||
for attr in ("lm_head", "logits_processor"):
|
||||
if hasattr(self, attr):
|
||||
delattr(self, attr)
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "_pooler", None):
|
||||
self._init_pooler(vllm_config, prefix=prefix)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "_pooler", None):
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=default_pooling_type,
|
||||
normalize=default_normalize,
|
||||
softmax=default_softmax,
|
||||
)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@@ -165,7 +170,9 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
|
||||
# Lazy import
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
|
||||
from vllm.model_executor.layers.pooler import (ClassifierPooler,
|
||||
PoolerOutput, PoolingType,
|
||||
SimplePooler)
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -182,30 +189,40 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
class ModelForSequenceClassification(ModelForPooling,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.task = vllm_config.model_config.task
|
||||
self.pooling_type = (
|
||||
vllm_config.model_config.pooler_config.pooling_type)
|
||||
self.score = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "score"),
|
||||
)
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
input_is_parallel=False,
|
||||
bias=False,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "score"))
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
pooler = SimplePooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
)
|
||||
|
||||
self._pooler = ClassifierPooler(
|
||||
vllm_config.model_config,
|
||||
pooling=pooler.pooling,
|
||||
classifier=self._classifier,
|
||||
act_fn=pooler.head.activation,
|
||||
)
|
||||
|
||||
def _classifier(self, x: torch.Tensor):
|
||||
x, _ = self.score(x.float())
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -222,27 +239,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
|
||||
def get_logits(hidden_states):
|
||||
if isinstance(hidden_states, list):
|
||||
logits = [self.score(state)[0] for state in hidden_states]
|
||||
else:
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
if self.pooling_type == PoolingType.ALL:
|
||||
logits = get_logits(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
else:
|
||||
hidden_states = self._pooler.extract_states(
|
||||
hidden_states, pooling_metadata)
|
||||
logits = get_logits(hidden_states)
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
tokens = getattr(self.config, "classifier_from_token", None)
|
||||
|
||||
Reference in New Issue
Block a user