[Model] Automatic conversion of classification and reward models (#11469)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-25 02:22:22 +08:00
committed by GitHub
parent 409475a827
commit 3f3e92e1f2
9 changed files with 206 additions and 161 deletions

View File

@@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model
output = llm.generate("Hello, my name is") output = llm.generate("Hello, my name is")
print(output) print(output)
# For pooling models (task={embed,classify,reward}) only # For pooling models (task={embed,classify,reward,score}) only
llm = LLM(model=..., task="embed") # Name or path of your model llm = LLM(model=..., task="embed") # Name or path of your model
output = llm.encode("Hello, my name is") output = llm.encode("Hello, my name is")
print(output) print(output)
@@ -59,7 +59,7 @@ llm = LLM(model=..., revision=..., task=..., trust_remote_code=True)
output = llm.generate("Hello, my name is") output = llm.generate("Hello, my name is")
print(output) print(output)
# For pooling models (task={embed,classify,reward}) only # For pooling models (task={embed,classify,reward,score}) only
output = llm.encode("Hello, my name is") output = llm.encode("Hello, my name is")
print(output) print(output)
``` ```
@@ -369,14 +369,6 @@ you should explicitly specify the task type to ensure that the model is used in
#### Text Embedding (`--task embed`) #### Text Embedding (`--task embed`)
Any text generation model can be converted into an embedding model by passing {code}`--task embed`.
```{note}
To get the best results, you should use pooling models that are specifically trained as such.
```
The following table lists those that are tested in vLLM.
```{eval-rst} ```{eval-rst}
.. list-table:: .. list-table::
:widths: 25 25 50 5 5 :widths: 25 25 50 5 5
@@ -437,6 +429,10 @@ On the other hand, its 1.5B variant ({code}`Alibaba-NLP/gte-Qwen2-1.5B-instruct`
despite being described otherwise on its model card. despite being described otherwise on its model card.
``` ```
If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
#### Reward Modeling (`--task reward`) #### Reward Modeling (`--task reward`)
```{eval-rst} ```{eval-rst}
@@ -461,6 +457,9 @@ despite being described otherwise on its model card.
- ✅︎ - ✅︎
``` ```
If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_reward_model`. By default, we return the hidden states of each token directly.
```{important} ```{important}
For process-supervised reward models such as {code}`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, For process-supervised reward models such as {code}`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
@@ -490,6 +489,9 @@ e.g.: {code}`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 1
- ✅︎ - ✅︎
``` ```
If your model is not in the above list, we will try to automatically convert the model using
:func:`vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
#### Sentence Pair Scoring (`--task score`) #### Sentence Pair Scoring (`--task score`)
```{eval-rst} ```{eval-rst}

View File

@@ -1,7 +1,4 @@
"""Compare the outputs of HF and vLLM when using greedy sampling. """Compare the classification outputs of HF and vLLM models.
This test only tests small models. Big models such as 7B should be tested from
test_big_models.py because it could use a larger instance to run tests.
Run `pytest tests/models/test_cls_models.py`. Run `pytest tests/models/test_cls_models.py`.
""" """

View File

@@ -1,6 +1,6 @@
"""Compare the embedding outputs of HF and vLLM models. """Compare the scoring outputs of HF and vLLM models.
Run `pytest tests/models/embedding/language/test_embedding.py`. Run `pytest tests/models/embedding/language/test_scoring.py`.
""" """
import math import math

View File

@@ -6,7 +6,9 @@ import torch.cuda
from vllm.model_executor.models import (is_pooling_model, from vllm.model_executor.models import (is_pooling_model,
is_text_generation_model, is_text_generation_model,
supports_multimodal) supports_multimodal)
from vllm.model_executor.models.adapters import as_embedding_model from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS, _SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS, _TEXT_GENERATION_MODELS,
@@ -29,9 +31,10 @@ def test_registry_imports(model_arch):
or model_arch in _MULTIMODAL_MODELS): or model_arch in _MULTIMODAL_MODELS):
assert is_text_generation_model(model_cls) assert is_text_generation_model(model_cls)
# All vLLM models should be convertible to an embedding model # All vLLM models should be convertible to a pooling model
embed_model = as_embedding_model(model_cls) assert is_pooling_model(as_classification_model(model_cls))
assert is_pooling_model(embed_model) assert is_pooling_model(as_embedding_model(model_cls))
assert is_pooling_model(as_reward_model(model_cls))
if model_arch in _MULTIMODAL_MODELS: if model_arch in _MULTIMODAL_MODELS:
assert supports_multimodal(model_cls) assert supports_multimodal(model_cls)

View File

@@ -7,7 +7,9 @@ from torch import nn
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import as_embedding_model from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
@contextlib.contextmanager @contextlib.contextmanager
@@ -35,8 +37,12 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
model_cls, arch = ModelRegistry.resolve_model_cls(architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.runner_type == "pooling": if model_config.task == "embed":
model_cls = as_embedding_model(model_cls) model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
model_cls = as_classification_model(model_cls)
elif model_config.task == "reward":
model_cls = as_reward_model(model_cls)
return model_cls, arch return model_cls, arch

View File

@@ -1,29 +1,48 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, TypeVar from typing import TYPE_CHECKING, Any, Optional, TypeVar
import torch import torch
import torch.nn as nn import torch.nn as nn
from .interfaces_base import VllmModelForPooling, is_pooling_model from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.model_executor.layers.pooler import PoolingType
_T = TypeVar("_T", bound=type[nn.Module]) _T = TypeVar("_T", bound=type[nn.Module])
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
"ChatModel",
"LMHeadModel",
]
def as_embedding_model(cls: _T) -> _T:
"""Subclass an existing vLLM model to support embeddings."""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name
for generate_suffix in _GENERATE_SUFFIXES:
model_name = model_name.removesuffix(generate_suffix)
return model_name + pooling_suffix
def _create_pooling_model_cls(
orig_cls: _T,
*,
default_pooling_type: "PoolingType",
default_normalize: bool,
default_softmax: bool,
) -> _T:
# Lazy import # Lazy import
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput, from vllm.model_executor.layers.pooler import Pooler, PoolerOutput
PoolingType)
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from .utils import AutoWeightsLoader, WeightsMapper from .utils import AutoWeightsLoader, WeightsMapper
class ModelForEmbedding(cls, VllmModelForPooling): class ModelForPooling(orig_cls, VllmModelForPooling):
def __init__( def __init__(
self, self,
@@ -34,7 +53,7 @@ def as_embedding_model(cls: _T) -> _T:
) -> None: ) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
# These are not used in embedding models # These are not used in pooling models
for attr in ("lm_head", "logits_processor"): for attr in ("lm_head", "logits_processor"):
if hasattr(self, attr): if hasattr(self, attr):
delattr(self, attr) delattr(self, attr)
@@ -46,9 +65,9 @@ def as_embedding_model(cls: _T) -> _T:
if not getattr(self, "_pooler", None): if not getattr(self, "_pooler", None):
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.LAST, pooling_type=default_pooling_type,
normalize=True, normalize=default_normalize,
softmax=False, softmax=default_softmax,
) )
def pooler( def pooler(
@@ -82,17 +101,148 @@ def as_embedding_model(cls: _T) -> _T:
return return
# For most other models # For most other models
if hasattr(cls, "load_weights"): if hasattr(orig_cls, "load_weights"):
cls.load_weights(self, weights) # type: ignore orig_cls.load_weights(self, weights) # type: ignore
# Fallback # Fallback
else: else:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
loader.load_weights(weights) loader.load_weights(weights)
ModelForEmbedding.__name__ = cls.__name__ \ return ModelForPooling # type: ignore
.removesuffix("ForCausalLM") \
.removesuffix("ForConditionalGeneration") \
.removesuffix("ChatModel") \ def as_embedding_model(cls: _T) -> _T:
.removesuffix("LMHeadModel") + "ForEmbedding" """
Subclass an existing vLLM model to support embeddings.
By default, the embeddings of the whole prompt are extracted from the
normalized hidden state corresponding to the last token.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing embedding models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForEmbedding = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=True,
default_softmax=False,
)
ModelForEmbedding.__name__ = \
_get_pooling_model_name(cls.__name__, "ForEmbedding")
return ModelForEmbedding # type: ignore return ModelForEmbedding # type: ignore
def as_classification_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support classification.
By default, the class probabilities are extracted from the softmaxed
hidden state corresponding to the last token.
Note:
We assume that the classification head is a single linear layer
stored as the attribute `score` of the top-level model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing classification models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import PoolingType
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
ModelForPooling = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.LAST,
default_normalize=False,
default_softmax=True,
)
class ModelForClassification(ModelForPooling):
def __init__(
self,
*,
vllm_config: "VllmConfig",
prefix: str = "",
**kwargs: Any,
) -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
prefix, "score"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = super().forward(input_ids, positions, kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
ModelForClassification.__name__ = \
_get_pooling_model_name(cls.__name__, "ForClassification")
return ModelForClassification # type: ignore
def as_reward_model(cls: _T) -> _T:
"""
Subclass an existing vLLM model to support reward modeling.
By default, we return the hidden states of each token directly.
Note:
We assume that no extra layers are added to the original model;
please implement your own model if this is not the case.
"""
# Avoid modifying existing reward models
if is_pooling_model(cls):
return cls
# Lazy import
from vllm.model_executor.layers.pooler import PoolingType
ModelForReward = _create_pooling_model_cls(
cls,
default_pooling_type=PoolingType.ALL,
default_normalize=False,
default_softmax=False,
)
ModelForReward.__name__ = \
_get_pooling_model_name(cls.__name__, "ForReward")
return ModelForReward # type: ignore

View File

@@ -545,8 +545,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
self.model = Qwen2Model(vllm_config=vllm_config, self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
# TODO: Replace this model class with for_embedding(Qwen2ForCausalLM), # TODO: Replace this model class with as_embedding_model(
# after changing the default pooling method # Qwen2ForCausalLM) after changing the default pooling method
if pooler_config.pooling_type is None: if pooler_config.pooling_type is None:
logger.warning( logger.warning(
"This embedding model will default to last-token pooling in " "This embedding model will default to last-token pooling in "

View File

@@ -1,104 +0,0 @@
# Adapted from
# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py
# Copyright 2024 Kakao Corp. (Kanana-X Team)
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-Classification model compatible with HF weights."""
from typing import Iterable, List, Optional, Set, Tuple
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
# hidden_states from Qwen2Model has been reduced,
# the input of score layer is not parallelized.
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."])
return loader.load_weights(weights)

View File

@@ -20,11 +20,10 @@ import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, is_hybrid, from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal, supports_cross_encoding, supports_multimodal,
supports_pp) supports_pp)
from .interfaces_base import is_pooling_model, is_text_generation_model from .interfaces_base import is_text_generation_model
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -125,12 +124,13 @@ _EMBEDDING_MODELS = {
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
} }
_CROSS_ENCODER_MODELS = { _CROSS_ENCODER_MODELS = {
@@ -226,19 +226,10 @@ class _ModelInfo:
@staticmethod @staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_pooling_model_ = is_pooling_model(model)
if not is_pooling_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_pooling_model_ = True
return _ModelInfo( return _ModelInfo(
architecture=model.__name__, architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model), is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model_, is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model), supports_pp=supports_pp(model),