[Model] Automatic conversion of classification and reward models (#11469)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -20,11 +20,10 @@ import torch.nn as nn
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .adapters import as_embedding_model
|
||||
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
|
||||
supports_cross_encoding, supports_multimodal,
|
||||
supports_pp)
|
||||
from .interfaces_base import is_pooling_model, is_text_generation_model
|
||||
from .interfaces_base import is_text_generation_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -125,12 +124,13 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@@ -226,19 +226,10 @@ class _ModelInfo:
|
||||
|
||||
@staticmethod
|
||||
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
|
||||
is_pooling_model_ = is_pooling_model(model)
|
||||
if not is_pooling_model_:
|
||||
try:
|
||||
as_embedding_model(model)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
is_pooling_model_ = True
|
||||
|
||||
return _ModelInfo(
|
||||
architecture=model.__name__,
|
||||
is_text_generation_model=is_text_generation_model(model),
|
||||
is_pooling_model=is_pooling_model_,
|
||||
is_pooling_model=True, # Can convert any model into a pooling model
|
||||
supports_cross_encoding=supports_cross_encoding(model),
|
||||
supports_multimodal=supports_multimodal(model),
|
||||
supports_pp=supports_pp(model),
|
||||
|
||||
Reference in New Issue
Block a user