Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
49628fe13e
commit
214efc2c3c
@@ -9,6 +9,7 @@ from huggingface_hub import (file_exists, hf_hub_download,
|
||||
from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError)
|
||||
from torch import nn
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
from transformers.models.auto.image_processing_auto import (
|
||||
get_image_processor_config)
|
||||
@@ -31,6 +32,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
@@ -577,3 +579,16 @@ def try_get_generation_config(
|
||||
return GenerationConfig.from_model_config(config)
|
||||
except OSError: # Not found
|
||||
return None
|
||||
|
||||
|
||||
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
if (hasattr(config, "sbert_ce_default_activation_function")
|
||||
and config.sbert_ce_default_activation_function is not None):
|
||||
|
||||
function_name = config.sbert_ce_default_activation_function
|
||||
assert function_name.startswith("torch.nn.modules."), \
|
||||
"Loading of activation functions is restricted to " \
|
||||
"torch.nn.modules for security reasons"
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
else:
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||
|
||||
Reference in New Issue
Block a user