[Model] Automatic conversion of classification and reward models (#11469)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-25 02:22:22 +08:00
committed by GitHub
parent 409475a827
commit 3f3e92e1f2
9 changed files with 206 additions and 161 deletions

View File

@@ -1,29 +1,48 @@
from collections.abc import Iterable
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, Optional, TypeVar
import torch
import torch.nn as nn
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module])
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name
for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)
return model_name + pooling_suffix
def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
# Lazy import
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput,
PoolingType)
from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForPooling):
class ModelForPooling(orig_cls, VllmModelForPooling):
def __init__(
self,
@@ -34,7 +53,7 @@ def as_embedding_model(cls: _T) -> _T:
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in embedding models
# These are not used in pooling models
for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr):
delattr(self, attr)
@@ -46,9 +65,9 @@ def as_embedding_model(cls: _T) -> _T:
if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False,
pooling_type=default_pooling_type,
normalize=default_normalize,
softmax=default_softmax,
)
def pooler(
@@ -82,17 +101,148 @@ def as_embedding_model(cls: _T) -> _T:
return
# For most other models
if hasattr(cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore
if hasattr(orig_cls, "load_weights"):
orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
ModelForEmbedding.__name__ = cls.__name__ \
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \
.removesuffix("LMHeadModel") + "ForEmbedding"
return ModelForPooling # type: ignore
def as_embedding_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
return ModelForEmbedding # type: ignore
def as_classification_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classification.
By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.
Note:
We assume that the classification head is a single linear layer
stored as the attribute `score` of the top-level model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing classification models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForClassification(ModelForPooling):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
ModelForClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForClassification")
return ModelForClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore