[Model] Automatic conversion of classification and reward models (#11469)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-25 02:22:22 +08:00
committed by GitHub
parent 409475a827
commit 3f3e92e1f2
9 changed files with 206 additions and 161 deletions

View File

@@ -20,11 +20,10 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import current_platform
from .adapters import as_embedding_model
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
supports_pp)
from .interfaces_base import is_pooling_model, is_text_generation_model
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@@ -125,12 +124,13 @@ _EMBEDDING_MODELS = {
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# [Auto-converted (see adapters.py)]
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
}
_CROSS_ENCODER_MODELS = {
@@ -226,19 +226,10 @@ class _ModelInfo:
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
is_pooling_model_ = is_pooling_model(model)
if not is_pooling_model_:
try:
as_embedding_model(model)
except Exception:
pass
else:
is_pooling_model_ = True
return _ModelInfo(
architecture=model.__name__,
is_text_generation_model=is_text_generation_model(model),
is_pooling_model=is_pooling_model_,
is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_pp=supports_pp(model),