[Frontend] Support using chat template as custom score template for reranking models (#30550)
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -88,6 +88,26 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
from vllm.config.pooler import PoolingTypeStr
|
||||
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config.is_causal = False
|
||||
|
||||
pooling_type_map: dict[str, PoolingTypeStr] = {
|
||||
"avg": "MEAN",
|
||||
"cls": "CLS",
|
||||
"last": "LAST",
|
||||
}
|
||||
|
||||
pooling_type = pooling_type_map.get(hf_config.pooling, None)
|
||||
if pooling_type is None:
|
||||
raise ValueError(f"pool_type {hf_config.pooling} not supported")
|
||||
vllm_config.model_config.pooler_config.pooling_type = pooling_type
|
||||
|
||||
|
||||
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@@ -509,6 +529,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
|
||||
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
||||
@@ -57,7 +57,14 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .adapters import as_embedding_model, as_seq_cls_model
|
||||
from .interfaces import (
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
)
|
||||
from .interfaces_base import attn_type
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -698,3 +705,17 @@ class LlamaForCausalLM(
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
class LlamaBidirectionalModel(as_embedding_model(LlamaForCausalLM)):
|
||||
# This class sets the correct attention type and pooling type
|
||||
# through LlamaBidirectionalConfig.
|
||||
pass
|
||||
|
||||
@@ -203,6 +203,7 @@ _EMBEDDING_MODELS = {
|
||||
"GteNewModel": ("bert_with_rope", "GteNewModel"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
@@ -246,6 +247,11 @@ _CROSS_ENCODER_MODELS = {
|
||||
"bert_with_rope",
|
||||
"GteNewForSequenceClassification",
|
||||
),
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"),
|
||||
"LlamaBidirectionalForSequenceClassification": (
|
||||
"llama",
|
||||
"LlamaBidirectionalForSequenceClassification",
|
||||
),
|
||||
"ModernBertForSequenceClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForSequenceClassification",
|
||||
@@ -259,8 +265,6 @@ _CROSS_ENCODER_MODELS = {
|
||||
"roberta",
|
||||
"RobertaForSequenceClassification",
|
||||
),
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
||||
Reference in New Issue
Block a user