[Model][2/N] Automatic conversion of CrossEncoding model (#19978)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -145,9 +145,9 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
return ModelForEmbedding # type: ignore
|
||||
|
||||
|
||||
def as_classification_model(cls: _T) -> _T:
|
||||
def as_seq_cls_model(cls: _T) -> _T:
|
||||
"""
|
||||
Subclass an existing vLLM model to support classification.
|
||||
Subclass an existing vLLM model to support classify and score tasks.
|
||||
|
||||
By default, the class probabilities are extracted from the softmaxed
|
||||
hidden state corresponding to the last token.
|
||||
@@ -164,7 +164,9 @@ def as_classification_model(cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType
|
||||
from vllm.model_executor.models.interfaces import SupportsCrossEncoding
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
@@ -176,7 +178,8 @@ def as_classification_model(cls: _T) -> _T:
|
||||
default_softmax=True,
|
||||
)
|
||||
|
||||
class ModelForClassification(ModelForPooling):
|
||||
class ModelForSequenceClassification(ModelForPooling,
|
||||
SupportsCrossEncoding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -190,6 +193,10 @@ def as_classification_model(cls: _T) -> _T:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.task = vllm_config.model_config.task
|
||||
self.pooling_type = (
|
||||
vllm_config.model_config.pooler_config.pooling_type)
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config,
|
||||
@@ -205,17 +212,41 @@ def as_classification_model(cls: _T) -> _T:
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = super().forward(input_ids, positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
return super().forward(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
|
||||
def get_logits(hidden_states):
|
||||
if isinstance(hidden_states, list):
|
||||
logits = [self.score(state)[0] for state in hidden_states]
|
||||
else:
|
||||
logits, _ = self.score(hidden_states)
|
||||
return logits
|
||||
|
||||
if self.pooling_type == PoolingType.ALL:
|
||||
logits = get_logits(hidden_states)
|
||||
return self._pooler(logits, pooling_metadata)
|
||||
else:
|
||||
hidden_states = self._pooler.extract_states(
|
||||
hidden_states, pooling_metadata)
|
||||
logits = get_logits(hidden_states)
|
||||
pooled_data = self._pooler.head(logits, pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
self._pooler.build_output(data) for data in pooled_data
|
||||
]
|
||||
return PoolerOutput(outputs=pooled_outputs)
|
||||
|
||||
|
||||
ModelForClassification.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForClassification")
|
||||
ModelForSequenceClassification.__name__ = \
|
||||
_get_pooling_model_name(cls.__name__, "ForSequenceClassification")
|
||||
|
||||
return ModelForClassification # type: ignore
|
||||
return ModelForSequenceClassification # type: ignore
|
||||
|
||||
|
||||
def as_reward_model(cls: _T) -> _T:
|
||||
|
||||
Reference in New Issue
Block a user